Source code for ggfm.models.llaga

import contextlib
import torch
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from transformers import AutoModelForCausalLM, AutoTokenizer
# from torch_scatter import scatter

BOS = '<s>[INST]'
EOS_USER = '[/INST]'
EOS = '[/s]'

IGNORE_INDEX = -100


[docs]class LLAGA(torch.nn.Module): r"""`"LLaGA: Large Language and Graph Assistant" <https://arxiv.org/abs/2402.08170>`_ paper. Parameters ---------- model: class:`ggfm.models` The used model. tokenizer: `transformers.AutoTokenizer` The tokenizer. device: int Device """ def __init__( self, args, **kwargs ): super().__init__() self.max_txt_len = args.max_txt_len self.max_new_tokens = args.max_new_tokens print('Loading LLAMA') kwargs = { "max_memory": {1: '60GiB'}, "device_map": "auto", "revision": "main", } self.tokenizer = AutoTokenizer.from_pretrained(args.llm_model_path, use_fast=False, revision=kwargs["revision"]) self.tokenizer.pad_token_id = 0 self.tokenizer.padding_side = 'left' model = AutoModelForCausalLM.from_pretrained( args.llm_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs ) if args.llm_frozen == True: print("Freezing LLAMA!") for name, param in model.named_parameters(): param.requires_grad = False else: raise ValueError self.model = model print('Finish loading LLAMA!') self.projector = nn.Sequential( nn.Linear(args.mm_hidden_size, 2048), nn.Linear(2048, 4096), ).to(self.model.device) self.word_embedding = self.model.model.get_input_embeddings() @property def device(self): return list(self.parameters())[0].device
[docs] def maybe_autocast(self, dtype=torch.bfloat16): # if on cpu, don't use autocast # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 enable_autocast = self.device != torch.device("cpu") if enable_autocast: return torch.cuda.amp.autocast(dtype=dtype) else: return contextlib.nullcontext()
[docs] def encode_graphs(self, samples): graphs = samples['graph'] graphs = graphs.to(self.model.device) n_embeds, _ = self.graph_encoder(graphs.x, graphs.edge_index.long(), graphs.edge_attr) # mean pooling g_embeds = scatter(n_embeds, graphs.batch, dim=0, reduce='mean') return g_embeds
[docs] def forward(self, samples): # encode description, questions and labels questions = self.tokenizer(samples["question"], add_special_tokens=False) label = self.tokenizer(samples["label"], add_special_tokens=False) graph_id = samples["graph_id"] graph = samples["label"] graph_emb = samples["graph_emb"] # encode special tokens eos_tokens = self.tokenizer(EOS, add_special_tokens=False) eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False) bos_embeds = self.word_embedding(self.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(self.device)) pad_embeds = self.word_embedding(torch.tensor(self.tokenizer.pad_token_id).to(self.device)).unsqueeze(0) graph_emb = torch.nn.utils.rnn.pad_sequence(graph_emb, batch_first=True, padding_value=0) graph_embeds = self.projector(graph_emb.to(self.device)) batch_size = len(samples['id']) batch_inputs_embeds = [] batch_attention_mask = [] batch_label_input_ids = [] for i in range(batch_size): # Add bos & eos token label_input_ids = label.input_ids[i][:self.max_new_tokens] + eos_tokens.input_ids input_ids = questions.input_ids[i] + eos_user_tokens.input_ids + label_input_ids inputs_embeds = self.word_embedding(torch.tensor(input_ids).to(self.model.device)) inputs_embeds = torch.cat([bos_embeds, graph_embeds[i], inputs_embeds], dim=0) batch_inputs_embeds.append(inputs_embeds) batch_attention_mask.append([1] * inputs_embeds.shape[0]) label_input_ids = [IGNORE_INDEX] * (inputs_embeds.shape[0]-len(label_input_ids))+label_input_ids batch_label_input_ids.append(label_input_ids) # pad inputs_embeds max_length = max([x.shape[0] for x in batch_inputs_embeds]) for i in range(batch_size): pad_length = max_length-batch_inputs_embeds[i].shape[0] batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]]) batch_attention_mask[i] = [0]*pad_length+batch_attention_mask[i] batch_label_input_ids[i] = [IGNORE_INDEX] * pad_length+batch_label_input_ids[i] inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device) attention_mask = torch.tensor(batch_attention_mask).to(self.model.device) label_input_ids = torch.tensor(batch_label_input_ids).to(self.model.device) with self.maybe_autocast(): outputs = self.model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=label_input_ids, ) return outputs.loss
[docs] def inference(self, samples): # encode description, questions and labels questions = self.tokenizer(samples["question"], add_special_tokens=False) graph_id = samples["graph_id"] graph = samples["label"] graph_emb = samples["graph_emb"] # encode special tokens eos_tokens = self.tokenizer(EOS, add_special_tokens=False) eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False) bos_embeds = self.word_embedding( self.tokenizer(BOS, add_special_tokens=False, return_tensors='pt').input_ids[0].to(self.device)) pad_embeds = self.word_embedding(torch.tensor(self.tokenizer.pad_token_id).to(self.device)).unsqueeze(0) graph_emb = torch.nn.utils.rnn.pad_sequence(graph_emb, batch_first=True, padding_value=0) graph_embeds = self.projector(graph_emb.to(self.device)) batch_size = len(samples['id']) batch_inputs_embeds = [] batch_attention_mask = [] for i in range(batch_size): # Add bos & eos token input_ids = questions.input_ids[i] + eos_user_tokens.input_ids inputs_embeds = self.word_embedding(torch.tensor(input_ids).to(self.model.device)) inputs_embeds = torch.cat([bos_embeds, graph_embeds[i], inputs_embeds], dim=0) batch_inputs_embeds.append(inputs_embeds) batch_attention_mask.append([1] * inputs_embeds.shape[0]) # pad inputs_embeds max_length = max([x.shape[0] for x in batch_inputs_embeds]) for i in range(batch_size): pad_length = max_length - batch_inputs_embeds[i].shape[0] batch_inputs_embeds[i] = torch.cat([pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]]) batch_attention_mask[i] = [0] * pad_length + batch_attention_mask[i] inputs_embeds = torch.stack(batch_inputs_embeds, dim=0).to(self.model.device) attention_mask = torch.tensor(batch_attention_mask).to(self.model.device) with self.maybe_autocast(): outputs = self.model.generate( inputs_embeds=inputs_embeds, max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, use_cache=True ) pred = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return {'id': samples['id'], 'pred': pred, 'label': samples['label'], 'question': samples['question'], }
[docs] def print_trainable_params(self): trainable_params = 0 all_param = 0 for _, param in self.named_parameters(): num_params = param.numel() all_param += num_params if param.requires_grad: trainable_params += num_params return trainable_params, all_param