import contextlib
import torch
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from transformers import AutoModelForCausalLM, AutoTokenizer
# from torch_scatter import scatter
BOS = '<s>[INST]'
EOS_USER = '[/INST]'
EOS = '[/s]'
IGNORE_INDEX = -100
[docs]class LLAGA(torch.nn.Module):
r"""`"LLaGA: Large Language and Graph Assistant"
<https://arxiv.org/abs/2402.08170>`_ paper.
Parameters
----------
model: class:`ggfm.models`
The used model.
tokenizer: `transformers.AutoTokenizer`
The tokenizer.
device: int
Device
"""
def __init__(
self,
args,
**kwargs
):
super().__init__()
self.max_txt_len = args.max_txt_len
self.max_new_tokens = args.max_new_tokens
print('Loading LLAMA')
kwargs = {
"max_memory": {1: '60GiB'},
"device_map": "auto",
"revision": "main",
}
self.tokenizer = AutoTokenizer.from_pretrained(args.llm_model_path, use_fast=False, revision=kwargs["revision"])
self.tokenizer.pad_token_id = 0
self.tokenizer.padding_side = 'left'
model = AutoModelForCausalLM.from_pretrained(
args.llm_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
**kwargs
)
if args.llm_frozen == True:
print("Freezing LLAMA!")
for name, param in model.named_parameters():
param.requires_grad = False
else:
raise ValueError
self.model = model
print('Finish loading LLAMA!')
self.projector = nn.Sequential(
nn.Linear(args.mm_hidden_size, 2048),
nn.Linear(2048, 4096),
).to(self.model.device)
self.word_embedding = self.model.model.get_input_embeddings()
@property
def device(self):
return list(self.parameters())[0].device
[docs] def maybe_autocast(self, dtype=torch.bfloat16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
[docs] def encode_graphs(self, samples):
graphs = samples['graph']
graphs = graphs.to(self.model.device)
n_embeds, _ = self.graph_encoder(graphs.x, graphs.edge_index.long(), graphs.edge_attr)
# mean pooling
g_embeds = scatter(n_embeds, graphs.batch, dim=0, reduce='mean')
return g_embeds
[docs] def forward(self, samples):
# encode description, questions and labels
questions = self.tokenizer(samples["question"], add_special_tokens=False)
label = self.tokenizer(samples["label"], add_special_tokens=False)
graph_id = samples["graph_id"]
graph = samples["label"]
graph_emb = samples["graph_emb"]
# encode special tokens
eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)
bos_embeds = self.word_embedding(self.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(self.device))
pad_embeds = self.word_embedding(torch.tensor(self.tokenizer.pad_token_id).to(self.device)).unsqueeze(0)
graph_emb = torch.nn.utils.rnn.pad_sequence(graph_emb, batch_first=True, padding_value=0)
graph_embeds = self.projector(graph_emb.to(self.device))
batch_size = len(samples['id'])
batch_inputs_embeds = []
batch_attention_mask = []
batch_label_input_ids = []
for i in range(batch_size):
# Add bos & eos token
label_input_ids = label.input_ids[i][:self.max_new_tokens] + eos_tokens.input_ids
input_ids = questions.input_ids[i] + eos_user_tokens.input_ids + label_input_ids
inputs_embeds = self.word_embedding(torch.tensor(input_ids).to(self.model.device))
inputs_embeds = torch.cat([bos_embeds, graph_embeds[i], inputs_embeds], dim=0)
batch_inputs_embeds.append(inputs_embeds)
batch_attention_mask.append([1] * inputs_embeds.shape[0])
label_input_ids = [IGNORE_INDEX] * (inputs_embeds.shape[0]-len(label_input_ids))+label_input_ids
batch_label_input_ids.append(label_input_ids)
# pad inputs_embeds
max_length = max([x.shape[0] for x in batch_inputs_embeds])
for i in range(batch_size):
pad_length = max_length-batch_inputs_embeds[i].shape[0]
batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
batch_attention_mask[i] = [0]*pad_length+batch_attention_mask[i]
batch_label_input_ids[i] = [IGNORE_INDEX] * pad_length+batch_label_input_ids[i]
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device)
attention_mask = torch.tensor(batch_attention_mask).to(self.model.device)
label_input_ids = torch.tensor(batch_label_input_ids).to(self.model.device)
with self.maybe_autocast():
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=label_input_ids,
)
return outputs.loss
[docs] def inference(self, samples):
# encode description, questions and labels
questions = self.tokenizer(samples["question"], add_special_tokens=False)
graph_id = samples["graph_id"]
graph = samples["label"]
graph_emb = samples["graph_emb"]
# encode special tokens
eos_tokens = self.tokenizer(EOS, add_special_tokens=False)
eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)
bos_embeds = self.word_embedding(
self.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(self.device))
pad_embeds = self.word_embedding(torch.tensor(self.tokenizer.pad_token_id).to(self.device)).unsqueeze(0)
graph_emb = torch.nn.utils.rnn.pad_sequence(graph_emb, batch_first=True, padding_value=0)
graph_embeds = self.projector(graph_emb.to(self.device))
batch_size = len(samples['id'])
batch_inputs_embeds = []
batch_attention_mask = []
for i in range(batch_size):
# Add bos & eos token
input_ids = questions.input_ids[i] + eos_user_tokens.input_ids
inputs_embeds = self.word_embedding(torch.tensor(input_ids).to(self.model.device))
inputs_embeds = torch.cat([bos_embeds, graph_embeds[i], inputs_embeds], dim=0)
batch_inputs_embeds.append(inputs_embeds)
batch_attention_mask.append([1] * inputs_embeds.shape[0])
# pad inputs_embeds
max_length = max([x.shape[0] for x in batch_inputs_embeds])
for i in range(batch_size):
pad_length = max_length - batch_inputs_embeds[i].shape[0]
batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]])
batch_attention_mask[i] = [0] * pad_length + batch_attention_mask[i]
inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device)
attention_mask = torch.tensor(batch_attention_mask).to(self.model.device)
with self.maybe_autocast():
outputs = self.model.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=self.max_new_tokens,
attention_mask=attention_mask,
use_cache=True
)
pred = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return {'id': samples['id'],
'pred': pred,
'label': samples['label'],
'question': samples['question'],
}
[docs] def print_trainable_params(self):
trainable_params = 0
all_param = 0
for _, param in self.named_parameters():
num_params = param.numel()
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params, all_param