import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
from typing import Optional, List, Tuple, Union
import logging
import torch
from lightning.pytorch import LightningModule
from transformers import (get_cosine_schedule_with_warmup,
AutoConfig, AutoModelForCausalLM,
LlamaConfig, LlamaModel, LlamaForCausalLM)
# BaseModelOutputWithPast, CausalLMOutputWithPast)
from transformers.configuration_utils import PretrainedConfig
from torch.optim import AdamW
import torch.nn as nn
import transformers
import torch.nn.functional as F
from torch_geometric.data import Data
# from torch_geometric.data import Data, remove_self_loops, add_self_loops, degree
from torch_geometric.utils import add_remaining_self_loops
# from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
import json
import os.path as osp
import glob
from collections import OrderedDict
DEFAULT_GRAPH_TOKEN = "<graph>"
DEFAULT_GRAPH_PATCH_TOKEN = "<g_patch>"
DEFAULT_G_START_TOKEN = "<g_start>"
DEFAULT_G_END_TOKEN = "<g_end>"
############# basic conv layers #############
init = nn.init.xavier_uniform_
uniformInit = nn.init.uniform
import torch as t
import math
def gcn_conv(h, edge_index):
"""
Graph Convolutional Network layer
Parameters
----------
h
edge_index
Returns
-------
"""
# print(edge_index)
N, node_feas = h.shape
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=N)
src, dst = edge_index
deg = degree(dst, num_nodes=N)
deg_src = deg[src].pow(-0.5)
deg_src.masked_fill_(deg_src == float('inf'), 0)
deg_dst = deg[dst].pow(-0.5)
deg_dst.masked_fill_(deg_dst == float('inf'), 0)
edge_weight = deg_src * deg_dst
a = torch.sparse_coo_tensor(edge_index, edge_weight, torch.Size([N, N])).t()
rows, cols = edge_index
edge_msg = h[rows, :] * torch.unsqueeze(edge_weight, dim=-1)
col_embeds = h[cols, :]
tem = torch.zeros([N, node_feas]).to(edge_msg.device)
rows = rows.to(edge_msg.device)
h_prime = tem.index_add_(0, rows, edge_msg) # nd
# h = h.float()
# h_prime = a @ h
# h_prime = h_prime.bfloat16()
return h_prime
# Implementation of MPNN, which can become MLP or GCN depending on whether using message passing
[docs]class MPNN(nn.Module):
r"""
Message Passing Neural Network (MPNN) layer
Parameters
----------
in_channels : int
Number of input features
hidden_channels : int
Number of hidden features
out_channels : int
Number of output features
"""
def __init__(self, in_channels, hidden_channels, out_channels, **kwargs):
super(MPNN, self).__init__()
self.config = PretrainedConfig()
self.dropout = kwargs.get('dropout') # args.dropout
self.num_layers = kwargs.get('num_layers') # args.num_layers
self.ff_bias = True # Use bias for FF layers in default
self.bns = nn.BatchNorm1d(hidden_channels, affine=False, track_running_stats=False)
self.activation = F.relu
self.if_param = kwargs.get('if_param')
if self.if_param:
self.fcs = nn.ModuleList([])
self.fcs.append(nn.Linear(in_channels, hidden_channels, bias=self.ff_bias))
for _ in range(self.num_layers - 2): self.fcs.append(
nn.Linear(hidden_channels, hidden_channels, bias=self.ff_bias)) # 1s
self.fcs.append(nn.Linear(hidden_channels, out_channels, bias=self.ff_bias)) # 1
self.reset_parameters()
[docs] def reset_parameters(self):
for mlp in self.fcs:
nn.init.xavier_uniform_(mlp.weight, gain=1.414)
nn.init.zeros_(mlp.bias)
[docs] def forward(self, g, use_conv=True):
r"""
Forward pass of MPNN layer
Parameters
----------
g
use_conv
Returns
-------
"""
x = g.graph_node
edge_index = g.edge_index
try:
device = self.parameters().__next__().device
except:
device = x.device
x = x.to(device)
edge_index = edge_index.to(device)
for i in range(self.num_layers - 1):
if self.if_param: x = x @ self.fcs[i].weight.t()
if use_conv: x = gcn_conv(x,
edge_index) # Optionally replace 'gcn_conv' with other conv functions in conv.py
if self.ff_bias and self.if_param: x = x + self.fcs[i].bias
try:
x = self.activation(self.bns(x))
except:
x = self.activation((x))
x = F.dropout(x, p=self.dropout, training=self.training)
if self.if_param: x = x @ self.fcs[-1].weight.t()
if use_conv: x = gcn_conv(x, edge_index)
if self.ff_bias and self.if_param: x = x + self.fcs[-1].bias
return x
def PositionalEncoding(q_len, d_model, normalize=True):
r"""
Positional encoding for the transformer
Parameters
----------
q_len
d_model
normalize
Returns
-------
"""
pe = t.zeros(q_len, d_model)
position = t.arange(0, q_len).unsqueeze(1)
div_term = t.exp(t.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = t.sin(position * div_term)
pe[:, 1::2] = t.cos(position * div_term)
if normalize:
pe = pe - pe.mean()
pe = pe / (pe.std() * 10)
return pe
def pos_encoding(pe, learn_pe, nvar, d_model):
r"""
Positional encoding
Parameters
----------
pe
learn_pe
nvar
d_model
Returns
-------
"""
if pe == None:
W_pos = t.empty((nvar, d_model))
nn.init.uniform_(W_pos, -0.02, 0.02)
learn_pe = False
elif pe == 'zero':
W_pos = t.empty((nvar, 1))
nn.init.uniform_(W_pos, -0.02, 0.02)
elif pe == 'zeros':
W_pos = t.empty((nvar, d_model))
nn.init.uniform_(W_pos, -0.02, 0.02)
elif pe == 'normal' or pe == 'gauss':
W_pos = t.zeros((nvar, 1))
t.nn.init.normal_(W_pos, mean=0.0, std=0.1)
elif pe == 'uniform':
W_pos = t.zeros((nvar, 1))
nn.init.uniform_(W_pos, a=0.0, b=0.1)
elif pe == 'sincos':
W_pos = PositionalEncoding(nvar, d_model, normalize=True)
else:
raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
'zeros', 'zero', uniform', 'sincos', None.)")
return nn.Parameter(W_pos, requires_grad=learn_pe)
class graph_transformer(nn.Module):
r"""
Graph Transformer layer
"""
def __init__(self, args):
super(graph_transformer, self).__init__()
self.config = PretrainedConfig()
self.gtLayers = nn.Sequential(*[GTLayer(args) for i in range(args.gt_layers)])
self.W_pos = pos_encoding('zeros', True, 1, args.att_d_model)
self.W_P = nn.Linear(args.gnn_input, args.att_d_model)
self.dropout = nn.Dropout(0.1)
self.inverW_P = nn.Linear(args.att_d_model, args.gnn_output)
self.args = args
def forward(self, g):
r"""
Forward pass of the graph transformer layer
Parameters
----------
g
Returns
-------
"""
# Adj: sp adj
# x: bs * n * d_model * num_patch
# print(edge_index)
device = self.parameters().__next__().device
g = g.to(device)
x = g.graph_node
# x, W_P_weight, W_P_bias= Mv2Samedevice([x, self.W_P.weight, self.W_P.bias])
# self.W_P.weight = nn.Parameter(W_P_weight.to(x.dtype))
# self.W_P.bias = nn.Parameter(W_P_bias.to(x.dtype))
# print(self.W_P.dtype, x.dtype)
z = self.W_P(x)
if self.args.if_pos:
embeds = self.dropout(z + self.W_pos)
else:
embeds = self.dropout(z)
for gt in self.gtLayers:
embeds = gt(g, embeds) # bs * num_patch * n * d_model
# embeds, inverW_P_weight, inverW_P_bias = Mv2Samedevice([embeds, self.inverW_P.weight, self.inverW_P.bias])
# self.inverW_P.weight = nn.Parameter(inverW_P_weight.to(embeds.dtype))
# self.inverW_P.bias = nn.Parameter(inverW_P_bias.to(embeds.dtype))
ret = self.inverW_P(embeds)
return ret
def Mv2Samedevice(vars):
return [var.to(vars[0].device) for var in vars]
class GTLayer(nn.Module):
r"""
Graph Transformer layer
"""
def __init__(self, args):
super(GTLayer, self).__init__()
self.qTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model)))
self.kTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model)))
self.vTrans = nn.Parameter(init(t.empty(args.att_d_model, args.att_d_model)))
if args.att_norm:
self.norm = nn.LayerNorm(args.att_d_model, eps=1e-6)
self.args = args
def forward(self, g, embeds):
# Adj: adj
# x: n * d_model
rows, cols = g.edge_index
nvar, _ = embeds.shape
# print(rows)
# print(cols)
rowEmbeds = embeds[rows, :]
colEmbeds = embeds[cols, :]
evar, _ = rowEmbeds.shape
# rowEmbeds, qTrans, kTrans, vTrans = Mv2Samedevice([rowEmbeds, self.qTrans, self.kTrans, self.vTrans])
# self.qTrans = nn.Parameter(qTrans.to(rowEmbeds.dtype))
# self.kTrans = nn.Parameter(kTrans.to(rowEmbeds.dtype))
# self.vTrans = nn.Parameter(vTrans.to(rowEmbeds.dtype))
qEmbeds = (rowEmbeds @ self.qTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head])
kEmbeds = (colEmbeds @ self.kTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head])
vEmbeds = (colEmbeds @ self.vTrans).view([evar, self.args.head, self.args.att_d_model // self.args.head])
att = t.einsum('ehd, ehd -> eh', qEmbeds, kEmbeds)
att = t.clamp(att, -10.0, 10.0)
expAtt = t.exp(att)
tem = t.zeros([nvar, self.args.head]).to(expAtt.device, dtype=expAtt.dtype)
# print(tem.device, expAtt.device, rows.device)
rows = rows.to(expAtt.device)
attNorm = (tem.index_add_(0, rows, expAtt))[rows, :]
att = expAtt / (attNorm + 1e-8) # bleh
resEmbeds = t.einsum('eh, ehd -> ehd', att, vEmbeds).view([evar, self.args.att_d_model])
tem = t.zeros([nvar, self.args.att_d_model]).to(resEmbeds.device, dtype=resEmbeds.dtype)
rows = rows.to(resEmbeds.device)
tem = tem.to(resEmbeds.dtype)
resEmbeds = tem.index_add_(0, rows, resEmbeds) # nd
resEmbeds = resEmbeds + embeds
if self.args.att_norm:
# resEmbeds, norm_weight, norm_bias = Mv2Samedevice([resEmbeds, self.norm.weight, self.norm.bias])
# self.norm.weight = nn.Parameter(norm_weight.to(resEmbeds.dtype))
# self.norm.bias = nn.Parameter(norm_bias.to(resEmbeds.dtype))
resEmbeds = self.norm(resEmbeds)
return resEmbeds
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
r"""
Basic cleaning of text
Parameters
----------
text
Returns
-------
"""
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
r"""
Simple tokenizer
"""
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
# merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
# _tokenizer = SimpleTokenizer()
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
"""
Quick GELU activation function
"""
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
r"""
Residual Attention Block
Parameters
----------
d_model : int
Dimension of the model
n_head : int
Number of heads
attn_mask : torch.Tensor
Attention mask
"""
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
r"""
Transformer layer
Parameters
----------
width : int
Width of the model
layers : int
Number of layers
heads : int
Number of heads
attn_mask : torch.Tensor
Attention mask
"""
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class GNN(MessagePassing):
r"""
Graph Neural Network layer
"""
def __init__(self, args, **kwargs):
super(GNN, self).__init__(aggr='add', **kwargs)
self.config = PretrainedConfig()
self.vars = nn.ParameterList()
w = nn.Parameter(torch.ones([args.gnn_hid, args.gnn_input]))
torch.nn.init.xavier_uniform_(w)
self.vars.append(w)
self.vars.append(nn.Parameter(torch.zeros(args.gnn_hid)))
w = nn.Parameter(torch.ones([args.gnn_output, args.gnn_hid]))
torch.nn.init.xavier_uniform_(w)
self.vars.append(w)
self.vars.append(nn.Parameter(torch.zeros(args.gnn_output)))
@staticmethod
def norm(edge_index, num_nodes, improved=False, dtype=None):
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
device=edge_index.device)
fill_value = 1.0 if not improved else 2.0
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, g, vars=None):
device = self.parameters()[0].device
g = g.to(device)
edge_index = g.edge_index
x = g.graph_node
if vars is None:
vars = self.vars
improved = False
w, b = vars[0], vars[1]
edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype)
x = self.propagate(edge_index, x=x, norm=norm)
w = w.to(x.device)
b = b.to(x.device)
x = F.linear(x, w, b)
x = F.leaky_relu(x)
w, b = vars[2], vars[3]
edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype)
x = self.propagate(edge_index, x=x, norm=norm)
w = w.to(x.device)
b = b.to(x.device)
x = F.linear(x, w, b)
return x
def parameters(self):
return self.vars
def Mv2SameDevice(var_list):
r"""
Move all variables in the list to the same device
Parameters
----------
var_list
Returns
-------
"""
for vid in range(1, len(var_list)):
var_list[vid] = var_list[vid].to(var_list[0].device)
return var_list
[docs]class CLIP(nn.Module):
r"""
CLIP model class
"""
def __init__(self,
args
):
super().__init__()
self.context_length = args.context_length
self.args = args
self.edge_coef = args.edge_coef
if args.gnn_type == 'gcn':
self.gnn = GNN(args)
elif args.gnn_type == 'gt':
self.gnn = graph_transformer(args)
self.transformer = Transformer(
width=args.transformer_width,
layers=args.transformer_layers,
heads=args.transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = args.vocab_size
self.token_embedding = nn.Embedding(args.vocab_size,
args.transformer_width) # the embedding for all possible tokens
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, args.transformer_width))
self.ln_final = LayerNorm(args.transformer_width)
self.text_projection = nn.Parameter(torch.empty(args.transformer_width, args.embed_dim))
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
if args.gnn_type == 'gcn':
self.dtype = self.gnn.vars[0].dtype
elif args.gnn_type == 'gt':
self.dtype = self.gnn.W_pos.dtype
self.optim = nn.optim.Adam([{'params': self.token_embedding.weight},
{'params': self.positional_embedding},
{'params': self.transformer.parameters()},
{'params': self.text_projection},
{'params': self.gnn.parameters()}
], lr=args.lr)
self.initialize_parameters()
[docs] def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
[docs] def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
[docs] def encode_image(self, idx_train, g):
embs = self.gnn(g)
idx_train = idx_train.to(embs.device)
idx_train = idx_train
train_embs = embs[idx_train]
return train_embs
[docs] def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0,
2) # NLD -> LND, batch_size * context_length *emb_dim -> context_length * batch_size *emb_dim
x = self.transformer(x)
x = x.permute(1, 0,
2) # LND -> NLD, context_length * batch_size *emb_dim -> batch_size * context_length *emb_dim
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot (end of token) embedding (eot_token is the highest number in each sequence)
# so there is node need to shorten the context length
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] #
x = x @ self.text_projection
return x
[docs] def forward(self, g, s_n, t_n, s_n_text, t_n_text, training=True):
s_image_features = self.encode_image(s_n, g)
s_text_features = self.encode_text(s_n_text)
t_text_features = self.encode_text(t_n_text)
t_text_features = t_text_features.reshape(s_image_features.shape[0], self.args.neigh_num, self.args.gnn_output)
t_text_features = torch.mean(t_text_features, dim=1, keepdim=False)
# normalized features
s_image_features = s_image_features / s_image_features.norm(dim=-1, keepdim=True)
s_text_features = s_text_features / s_text_features.norm(dim=-1, keepdim=True)
t_text_features = t_text_features / t_text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
labels = torch.arange(s_image_features.shape[0]).cuda()
# logit_scale = self.logit_scale.exp() # the temporature hyperparameter
# logit_scale, s_image_features, s_text_features = Mv2SameDevice([logit_scale, s_image_features, s_text_features])
# logits = logit_scale * s_image_features @ s_text_features.t()
# loss_i = F.cross_entropy(logits, labels)
# loss_t = F.cross_entropy(logits.T, labels)
# node_loss = (loss_i + loss_t) / 2
# logit_scale, s_image_features, t_text_features = Mv2SameDevice([logit_scale, s_image_features, t_text_features])
# logits = logit_scale * s_image_features @ t_text_features.t()
# loss_i = F.cross_entropy(logits, labels)
# loss_t = F.cross_entropy(logits.T, labels)
# gt_loss = (loss_i + loss_t)/2
# logit_scale, s_text_features, t_text_features = Mv2SameDevice([logit_scale, s_text_features, t_text_features])
# logits = logit_scale * s_text_features @ t_text_features.t()
# loss_i = F.cross_entropy(logits, labels)
# loss_t = F.cross_entropy(logits.T, labels)
# tt_loss = (loss_i + loss_t)/2
# shape = [global_batch_size, global_batch_size]
# return all_loss
return s_image_features, s_text_features, t_text_features, labels
def tokenize(texts: Union[str, List[str]], context_length: int = 128, truncate: bool = True) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
###### GraphGPT conv layers ######
# class GraphLlamaConfig(LlamaConfig):
# model_type = "GraphLlama"
class GraphPretrainConfig:
def __init__(self, dictionary):
for key, value in dictionary.items():
setattr(self, key, value)
def find_all_linear_names(model):
r"""
Find all linear layer names in the model
Parameters
----------
model
Returns
-------
"""
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
def load_model_pretrained(model_name, pretrain_model_path):
r"""
Load pretrain model
Parameters
----------
model_name
pretrain_model_path
Returns
-------
"""
# load conig json
print("load args from pretrain model:" + pretrain_model_path)
assert osp.exists(osp.join(pretrain_model_path, 'config.json')), 'config.json missing'
with open(osp.join(pretrain_model_path, 'config.json'), 'r') as f:
config_dict = json.load(f)
args = GraphPretrainConfig(config_dict)
model = model_name(args)
pkl_files = glob.glob(osp.join(pretrain_model_path, '*.pkl'))
state_dict = torch.load(pkl_files[0])
# print(state_dict.keys())
if 'logit_scale' in state_dict.keys():
state_dict.pop('logit_scale')
print('loading graph pre train model')
model.load_state_dict(state_dict)
def transfer_param_tograph(clip_graph, gnn):
r"""
Transfer parameters from clip graph to GNN
Parameters
----------
clip_graph
gnn
Returns
-------
"""
print(clip_graph)
gnn_state_dict = clip_graph.gnn.state_dict()
gnn.load_state_dict(gnn_state_dict)
return gnn
[docs]class GraphLlamaModel(LlamaModel):
r"""
Graph Llama model
"""
# config_class = GraphLlamaConfig
def __init__(self, config: LlamaConfig):
super(GraphLlamaModel, self).__init__(config)
if hasattr(config, "graph_tower"):
# HACK: for FSDP
# self.vision_tower = [CLIPVisionModel.from_pretrained(config.graph_tower)]
# self.arxiv_projector = nn.Linear(config.graph_hidden_size, config.hidden_size)
print("config.graph_tower:" + config.graph_tower)
if config.graph_tower == 'MPNN':
self.graph_tower = MPNN(in_channels=config.graph_hidden_size,
hidden_channels=config.graph_hidden_size * 2,
out_channels=config.graph_hidden_size, dropout=0.1, num_layers=2,
if_param=False)
elif config.graph_tower == "clip_gcn_arxiv":
clip_graph, args = load_model_pretrained(CLIP, config.pretrain_graph_model_path)
self.graph_tower = GNN(args)
self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower)
elif config.graph_tower == "clip_gt":
clip_graph, args = load_model_pretrained(CLIP, config.pretrain_graph_model_path)
self.graph_tower = graph_transformer(args)
self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower)
elif config.graph_tower == "clip_gt_arxiv":
clip_graph, args = load_model_pretrained(CLIP, config.pretrain_graph_model_path)
self.graph_tower = graph_transformer(args)
self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower)
elif config.graph_tower == "clip_gt_arxiv_pub":
clip_graph, args = load_model_pretrained(CLIP, config.pretrain_graph_model_path)
self.graph_tower = graph_transformer(args)
self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower)
# self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
if hasattr(config, "use_graph_proj"):
self.graph_projector = nn.Linear(config.graph_hidden_size, config.hidden_size)
[docs] def get_graph_tower(self):
graph_tower = getattr(self, 'graph_tower', None)
if type(graph_tower) is list:
graph_tower = graph_tower[0]
return graph_tower
[docs] def initialize_graph_modules(self, graph_tower, graph_select_layer,
pretrain_graph_mlp_adapter=None, fsdp=None): # TODO: modify this function
self.config.graph_tower = graph_tower
print("graph_tower:", graph_tower)
if not hasattr(self, 'graph_tower'):
print("graph_tower:", graph_tower)
if self.config.graph_tower == 'MPNN':
graph_tower = MPNN(in_channels=self.config.graph_hidden_size,
hidden_channels=self.config.graph_hidden_size * 2,
out_channels=self.config.graph_hidden_size, dropout=0.1, num_layers=2,
if_param=False)
elif self.config.graph_tower == "clip_gcn_arxiv":
clip_graph, args = load_model_pretrained(CLIP, self.config.pretrain_graph_model_path)
graph_tower = GNN(args)
graph_tower = transfer_param_tograph(clip_graph, graph_tower)
elif self.config.graph_tower == "clip_gt":
clip_graph, args = load_model_pretrained(CLIP, self.config.pretrain_graph_model_path)
graph_tower = graph_transformer(args)
graph_tower = transfer_param_tograph(clip_graph, graph_tower)
# graph_tower = MPNN(in_channels = self.config.graph_hidden_size, hidden_channels = self.config.graph_hidden_size * 2, out_channels = self.config.graph_hidden_size, dropout = 0.1, num_layers = 2)
elif self.config.graph_tower == "clip_gt_arxiv":
clip_graph, args = load_model_pretrained(CLIP, self.config.pretrain_graph_model_path)
graph_tower = graph_transformer(args)
graph_tower = transfer_param_tograph(clip_graph, graph_tower)
elif self.config.graph_tower == "clip_gt_arxiv_pub":
clip_graph, args = load_model_pretrained(CLIP, self.config.pretrain_graph_model_path)
graph_tower = graph_transformer(args)
graph_tower = transfer_param_tograph(clip_graph, graph_tower)
else:
print("self graph_tower:", self.graph_tower)
graph_tower = self.graph_tower
graph_tower.requires_grad_(False)
if fsdp is not None and len(fsdp) > 0:
self.graph_tower = [graph_tower]
else:
self.graph_tower = graph_tower
self.config.use_graph_proj = True
self.config.graph_select_layer = graph_select_layer
if not hasattr(self, 'graph_projector'):
self.graph_projector = nn.Linear(self.config.graph_hidden_size, self.config.hidden_size)
if pretrain_graph_mlp_adapter is not None:
graph_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu')
self.graph_projector.load_state_dict({k.split('.')[-1]: v for k, v in graph_projector_weights.items()})
[docs] def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
# graph_node_reps: Optional[torch.FloatTensor] = None,
# edge_index_reps: Optional[torch.FloatTensor] = None,
graph_data: Optional[Data] = None,
return_dict: Optional[bool] = None,
# ) -> Union[Tuple, BaseModelOutputWithPast]:
):
# HACK: replace back original embeddings for LLaVA pretraining
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
# if orig_embeds_params is not None:
# orig_embeds_params = orig_embeds_params[0]
# with torch.no_grad():
# self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
graph_tower = self.get_graph_tower()
if graph_tower is not None and (input_ids.shape[1] != 1 or self.training) and graph_data is not None:
# TODO: this is a modified multimodal LLM -- Haotian Liu
with torch.no_grad():
if type(graph_data) is list:
# variable length images
graph_node_features = []
if type(graph_data[0]) is Data:
for g in graph_data:
# print(g)
node_forward_out = graph_tower(g)
graph_node_features.append(node_forward_out)
elif type(graph_data[0]) is dict:
for g_dict in graph_data:
node_forward_out_1 = graph_tower(g_dict['graph_1'])
node_forward_out_2 = graph_tower(g_dict['graph_2'])
graph_node_features.append(node_forward_out_1)
graph_node_features.append(node_forward_out_2)
else:
raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}')
if type(graph_data) is list:
# if type(graph_node_features[0]) is not dict:
graph_node_features = [self.graph_projector(node_feature) for node_feature in graph_node_features]
# else:
# graph_node_features = [{'graph_1': self.graph_projector(node_feature['graph_1']), 'graph_2': self.graph_projector(node_feature['graph_2'])} for node_feature in graph_node_features]
else:
raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}')
dummy_graph_features = torch.zeros(256, 128, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_graph_features = self.graph_projector(dummy_graph_features)
new_input_embeds = []
cur_graph_idx = 0
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
if (cur_input_ids == graph_tower.config.graph_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_graph_features).sum()
new_input_embeds.append(cur_input_embeds)
cur_graph_idx += 1
continue
if graph_tower.config.use_graph_start_end:
cur_graph_features = graph_node_features[cur_graph_idx]
num_patches = cur_graph_features.shape[0]
if (cur_input_ids == graph_tower.config.graph_start_token).sum() != (
cur_input_ids == graph_tower.config.graph_end_token).sum():
raise ValueError("The number of graph start tokens and graph end tokens should be the same.")
graph_start_tokens = torch.where(cur_input_ids == graph_tower.config.graph_start_token)[0]
# print(graph_start_tokens)
for graph_start_token_pos in graph_start_tokens:
cur_graph_features = graph_node_features[cur_graph_idx].to(device=cur_input_embeds.device)
num_patches = cur_graph_features.shape[0]
if cur_input_ids[graph_start_token_pos + num_patches + 1] != graph_tower.config.graph_end_token:
raise ValueError("The graph end token should follow the graph start token.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat((cur_input_embeds[:graph_start_token_pos].detach(),
cur_input_embeds[
graph_start_token_pos:graph_start_token_pos + 1],
cur_graph_features, cur_input_embeds[
graph_start_token_pos + num_patches + 1:graph_start_token_pos + num_patches + 2],
cur_input_embeds[
graph_start_token_pos + num_patches + 2:].detach()),
dim=0)
else:
cur_new_input_embeds = torch.cat((cur_input_embeds[:graph_start_token_pos + 1],
cur_graph_features, cur_input_embeds[
graph_start_token_pos + num_patches + 1:]),
dim=0)
cur_graph_idx += 1
new_input_embeds.append(cur_new_input_embeds)
else:
cur_graph_features = graph_node_features[cur_graph_idx]
num_patches = cur_graph_features.shape[0]
if (cur_input_ids == graph_tower.config.graph_patch_token).sum() != num_patches:
raise ValueError(
"The number of graph patch tokens should be the same as the number of graph patches.")
masked_indices = torch.where(cur_input_ids == graph_tower.config.graph_patch_token)[0]
mask_index_start = masked_indices[0]
if (masked_indices != torch.arange(mask_index_start, mask_index_start + num_patches,
device=masked_indices.device, dtype=masked_indices.dtype)).any():
raise ValueError("The graph patch tokens should be consecutive.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(),
cur_graph_features,
cur_input_embeds[mask_index_start + num_patches:].detach()),
dim=0)
else:
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_graph_features,
cur_input_embeds[mask_index_start + num_patches:]), dim=0)
new_input_embeds.append(cur_new_input_embeds)
cur_graph_idx += 1
# print(cur_graph_idx)
# print(len(graph_node_features))
assert cur_graph_idx == len(graph_node_features)
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(GraphLlamaModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
[docs]class GraphLlamaForCausalLM(LlamaForCausalLM):
r"""
Graph Llama model for causal language modeling
"""
# config_class = GraphLlamaConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = GraphLlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
[docs] def get_model(self):
return self.model
[docs] def get_graph_tower(self):
return self.get_model().get_graph_tower()
[docs] def get_vision_tower(self):
model = self.get_model()
graph_tower = model.graph_tower
if type(graph_tower) is list:
graph_tower = graph_tower[0]
return graph_tower
[docs] def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
# graph_node_reps: Optional[torch.FloatTensor] = None,
# edge_index_reps: Optional[torch.FloatTensor] = None,
graph_data: Optional[Data] = None,
return_dict: Optional[bool] = None,
# ) -> Union[Tuple, CausalLMOutputWithPast]:
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
# graph_node_reps=graph_node_reps,
# edge_index_reps=edge_index_reps
graph_data = graph_data
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
[docs] def initialize_graph_tokenizer(self, use_graph_start_end, tokenizer, device,
tune_graph_mlp_adapter=False, pretrain_graph_mlp_adapter=None):
vision_config = self.get_graph_tower().config
vision_config.use_graph_start_end = use_graph_start_end
tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if use_graph_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
vision_config.graph_start_token, vision_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN])
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if tune_graph_mlp_adapter:
self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if pretrain_graph_mlp_adapter:
mm_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
vision_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0]
# AutoConfig.register("GraphLlama", GraphLlamaConfig)
# AutoModelForCausalLM.register(GraphLlamaConfig, GraphLlamaForCausalLM)
[docs]class GraphGPT_pl(LightningModule):
r"""
Graph GPT model
"""
def __init__(self,
training_args, model_args, data_args, tokenizer,
**kwargs,
):
super().__init__()
self.training_args = training_args
self.model_args = model_args
self.data_args = data_args
compute_dtype = (
torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
bnb_model_from_pretrained_args = {}
## load 4 8 bit
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_int8_training
bnb_model_from_pretrained_args.update(dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
))
if model_args.graph_tower is not None:
self.model = GraphLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args
) ## TODO: add real Graph Llama model
else:
self.model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args
)
self.model.config.pretrain_graph_model_path = self.model.config.pretrain_graph_model_path + model_args.graph_tower
self.model.config.use_cache = False
if model_args.freeze_backbone:
self.model.model.requires_grad_(False)
if training_args.bits in [4, 8]:
self.model.config.torch_dtype = (
torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
self.model = prepare_model_for_int8_training(self.model,
use_gradient_checkpointing=training_args.gradient_checkpointing)
if training_args.gradient_checkpointing and model_args.graph_tower is None:
if hasattr(self.model, "enable_input_require_grads"):
self.model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
self.model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(self.model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
self.model.to(torch.bfloat16)
if training_args.fp16:
self.model.to(torch.float16)
logging.warning("Adding LoRA adapters...")
model = get_peft_model(self.model, lora_config)
if model_args.graph_tower is not None:
model_graph_dict = self.model.get_model().initialize_graph_modules(
graph_tower=model_args.graph_tower,
graph_select_layer=model_args.graph_select_layer,
pretrain_graph_mlp_adapter=model_args.pretrain_graph_mlp_adapter,
fsdp=None
)
self.model.get_graph_tower().to(dtype=compute_dtype)
# graph_config = model_graph_dict['graph_config']
# data_args.graph_token_len = model_graph_dict['graph_token_len']
# data_args.graph_processor = model_graph_dict['graph_processor']
data_args.is_graph = True
self.model.config.tune_graph_mlp_adapter = training_args.tune_graph_mlp_adapter = model_args.tune_graph_mlp_adapter
if model_args.tune_graph_mlp_adapter:
self.model.requires_grad_(False)
for p in self.model.get_model().graph_projector.parameters():
p.requires_grad = True
self.model.config.freeze_graph_mlp_adapter = training_args.freeze_graph_mlp_adapter
if training_args.freeze_graph_mlp_adapter:
for p in self.model.get_model().graph_projector.parameters():
p.requires_grad = False
if training_args.bits in [4, 8]:
self.model.get_model().graph_projector.to(dtype=compute_dtype, device=training_args.device)
self.model.config.use_graph_start_end = data_args.use_graph_start_end = model_args.use_graph_start_end
# graph_config.use_graph_start_end = training_args.use_graph_start_end = model_args.use_graph_start_end
training_args.use_graph_start_end = model_args.use_graph_start_end
self.model.config.sep_graph_conv_front = data_args.sep_graph_conv_front
self.model.initialize_graph_tokenizer(use_graph_start_end=model_args.use_graph_start_end,
tokenizer=tokenizer, device='cuda',
tune_graph_mlp_adapter=model_args.tune_graph_mlp_adapter,
pretrain_graph_mlp_adapter=model_args.pretrain_graph_mlp_adapter)
params_no_grad = [n for n, p in self.model.named_parameters() if not p.requires_grad]
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in self.model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
print('************************** parameters: #',
sum(p.numel() for p in self.model.parameters() if p.requires_grad))
tuned_params = []
for name, param in self.model.named_parameters():
if param.requires_grad:
tuned_params.append(name)
print(tuned_params)
[docs] def training_step(self, batch, batch_idx):
bs = len(batch["input_ids"])
loss_dict = self.model(**batch)
loss = loss_dict['loss']
log_dict = {f'train_loss': loss.item()}
self.log_dict(log_dict, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=bs)
return loss