Source code for ggfm.models.gpt_gnn


import math
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from ggfm.conv.hgt_conv import HGTConv
from gensim.parsing.preprocessing import *


[docs]class GPT_GNN(nn.Module): r"""`"GPT-GNN: Generative Pre-Training of Graph Neural Networks" <https://arxiv.org/abs/2006.15437>`_ paper. Parameters ---------- gnn: class:`ggfm.models` The used GNN model. rem_edge_list: dict The remaining edge list after sampling. attr_decoder: `ggfm.models` Attribute decoder. neg_samp_num: int Maximum number of negative sample for each target node. (default: :obj:`1`) device: int Device neg_queue_size: int, optional Max size of negetive adaptive embedding queue. (default: :obj:`0`) """ def __init__(self, gnn, rem_edge_list, attr_decoder, neg_samp_num, device, neg_queue_size=0): super(GPT_GNN, self).__init__() if gnn is None: return self.gnn = gnn self.params = nn.ModuleList() self.neg_queue_size = neg_queue_size self.link_dec_dict = {} self.neg_queue = {} for source_type in rem_edge_list: self.link_dec_dict[source_type] = {} self.neg_queue[source_type] = {} for relation_type in rem_edge_list[source_type]: print(source_type, relation_type) matcher = Matcher(gnn.n_hid, gnn.n_hid) self.neg_queue[source_type][relation_type] = torch.FloatTensor([]).to(device) self.link_dec_dict[source_type][relation_type] = matcher self.params.append(matcher) self.attr_decoder = attr_decoder self.init_emb = nn.Parameter(torch.randn(gnn.in_dim)) self.ce = nn.CrossEntropyLoss(reduction = 'none') self.neg_samp_num = neg_samp_num
[docs] def neg_sample(self, souce_node_list, pos_node_list): np.random.shuffle(souce_node_list) neg_nodes = [] keys = {key : True for key in pos_node_list} tot = 0 for node_id in souce_node_list: if node_id not in keys: neg_nodes += [node_id] tot += 1 if tot == self.neg_samp_num: break return neg_nodes
[docs] def forward(self, node_feature, node_type, edge_time, edge_index, edge_type): return self.gnn(node_feature, node_type, edge_time, edge_index, edge_type)
[docs] def text_loss(self, reps, texts, w2v_model, device): def parse_text(texts, w2v_model, device): idxs = [] pad = w2v_model.wv.key_to_index['eos'] # text to tokens for text in texts: idx = [] for word in ['bos'] + preprocess_string(text) + ['eos']: if word in w2v_model.wv.key_to_index: idx += [w2v_model.wv.key_to_index[word]] idxs += [idx] mxl = np.max([len(s) for s in idxs]) + 1 inp_idxs = [] out_idxs = [] masks = [] for i, idx in enumerate(idxs): inp_idxs += [idx + [pad for _ in range(mxl - len(idx) - 1)]] out_idxs += [idx[1:] + [pad for _ in range(mxl - len(idx))]] masks += [[1 for _ in range(len(idx))] + [0 for _ in range(mxl - len(idx) - 1)]] return torch.LongTensor(inp_idxs).transpose(0, 1).to(device), \ torch.LongTensor(out_idxs).transpose(0, 1).to(device), torch.BoolTensor(masks).transpose(0, 1).to(device) inp_idxs, out_idxs, masks = parse_text(texts, w2v_model, device) pred_prob = self.attr_decoder(inp_idxs, reps.repeat(inp_idxs.shape[0], 1, 1)) return self.ce(pred_prob[masks], out_idxs[masks]).mean()
[docs] def feat_loss(self, reps, out): return -self.attr_decoder(reps, out).mean()
[docs]class Classifier(nn.Module): r"""Classifier for graph node classification task. Parameters ---------- n_hid: int Input size. n_out: int Output size. """ def __init__(self, n_hid, n_out): super(Classifier, self).__init__() self.n_hid = n_hid self.n_out = n_out self.linear = nn.Linear(n_hid, n_out)
[docs] def forward(self, x): tx = self.linear(x) return torch.log_softmax(tx.squeeze(), dim=-1)
def __repr__(self): return '{}(n_hid={}, n_out={})'.format( self.__class__.__name__, self.n_hid, self.n_out)
[docs]class Matcher(nn.Module): r"""Matching between a pair of nodes to conduct link prediction. Use multi-head attention as matching model. Parameters ---------- n_hid: int Input size. n_out: int Output size. temperature: float, optional Temperature. (default: :obj:`0.1`) """ def __init__(self, n_hid, n_out, temperature=0.1): super(Matcher, self).__init__() self.n_hid = n_hid self.linear = nn.Linear(n_hid, n_out) self.sqrt_hd = math.sqrt(n_out) self.drop = nn.Dropout(0.2) self.cosine = nn.CosineSimilarity(dim=1) self.cache = None self.temperature = temperature
[docs] def forward(self, x, ty, use_norm = True): tx = self.drop(self.linear(x)) if use_norm: return self.cosine(tx, ty) / self.temperature else: return (tx * ty).sum(dim=-1) / self.sqrt_hd
def __repr__(self): return '{}(n_hid={})'.format( self.__class__.__name__, self.n_hid)
[docs]class HGT(nn.Module): r"""The Heterogeneous Graph Transformer (HGT) operator from the `"Heterogeneous Graph Transformer" <https://arxiv.org/abs/2003.01332>`_ paper. """ def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.2, prev_norm = False, last_norm = False): super(HGT, self).__init__() self.gcs = nn.ModuleList() self.num_types = num_types self.in_dim = in_dim self.n_hid = n_hid self.adapt_ws = nn.ModuleList() self.drop = nn.Dropout(dropout) for t in range(num_types): self.adapt_ws.append(nn.Linear(in_dim, n_hid)) for l in range(n_layers - 1): self.gcs.append(HGTConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = prev_norm)) self.gcs.append(HGTConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = last_norm))
[docs] def forward(self, node_feature, node_type, edge_time, edge_index, edge_type): res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device) for t_id in range(self.num_types): idx = (node_type == int(t_id)) if idx.sum() == 0: continue res[idx] = torch.tanh(self.adapt_ws[t_id](node_feature[idx])) meta_xs = self.drop(res) del res for gc in self.gcs: meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time) return meta_xs
[docs]class RNNModel(nn.Module): r"""Container module with an encoder, a recurrent module, and a decoder. Parameters ---------- n_word: int Number of tokens. ninp: int Input size. nhid: int Hidden size. nlayers: int Layer number of LSTM. dropout: float, optional Dropout rate. (default: :obj:`0.2`) """ def __init__(self, n_word, ninp, nhid, nlayers, dropout=0.2): super(RNNModel, self).__init__() self.drop = nn.Dropout(dropout) self.rnn = nn.LSTM(nhid, nhid, nlayers) self.encoder = nn.Embedding(n_word, nhid) self.decoder = nn.Linear(nhid, n_word) self.adp = nn.Linear(ninp + nhid, nhid)
[docs] def forward(self, inp, hidden = None): emb = self.encoder(inp) if hidden is not None: emb = torch.cat((emb, hidden), dim=-1) emb = F.gelu(self.adp(emb)) output, _ = self.rnn(emb) decoded = self.decoder(self.drop(output)) return decoded
[docs] def from_w2v(self, w2v): initrange = 0.1 self.encoder.weight.data = w2v self.decoder.weight = self.encoder.weight self.encoder.weight.requires_grad = False self.decoder.weight.requires_grad = False