Source code for ggfm.conv.hgt_conv

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_geometric.nn.inits import glorot
from torch_geometric.nn.conv import MessagePassing


[docs]class HGTConv(MessagePassing): r"""The Heterogeneous Graph Transformer (HGT) operator from the `"Heterogeneous Graph Transformer" <https://arxiv.org/abs/2003.01332>`_ paper. Parameters ---------- in_dim: int Size of each input sample of every node type, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_dim: int Size of each output sample. num_type: int Number of node types. num_relations: int Number of relations. heads: int, optional Number of multi-head-attentions. (default: :obj:`1`) dropout: float Dropout rate. (default: :obj:`0.2`) use_norm: bool, optional If use norm. (default: :obj:`True`) **kwargs: optional Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads=1, dropout=0.2, use_norm=True, **kwargs): super(HGTConv, self).__init__(node_dim=0, aggr='add', **kwargs) self.in_dim = in_dim self.out_dim = out_dim self.num_types = num_types self.num_relations = num_relations self.total_rel = num_types * num_relations * num_types self.n_heads = n_heads self.d_k = out_dim // n_heads self.sqrt_dk = math.sqrt(self.d_k) self.use_norm = use_norm self.att = None self.k_linears = nn.ModuleList() self.q_linears = nn.ModuleList() self.v_linears = nn.ModuleList() self.a_linears = nn.ModuleList() self.norms = nn.ModuleList() for t in range(num_types): self.k_linears.append(nn.Linear(in_dim, out_dim)) self.q_linears.append(nn.Linear(in_dim, out_dim)) self.v_linears.append(nn.Linear(in_dim, out_dim)) self.a_linears.append(nn.Linear(out_dim, out_dim)) if use_norm: self.norms.append(nn.LayerNorm(out_dim)) ''' TODO: make relation_pri smaller, as not all <st, rt, tt> pair exist in meta relation list. ''' self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads)) self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) self.skip = nn.Parameter(torch.ones(num_types)) self.drop = nn.Dropout(dropout) self.emb = RelTemporalEncoding(in_dim) glorot(self.relation_att) glorot(self.relation_msg)
[docs] def forward(self, node_inp, node_type, edge_index, edge_type, edge_time): return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \ edge_type=edge_type, edge_time = edge_time)
[docs] def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time): # j: source, i: target; <j, i> data_size = edge_index_i.size(0) # Create Attention and Message tensor beforehand. res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device) res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device) for source_type in range(self.num_types): sb = (node_type_j == int(source_type)) k_linear = self.k_linears[source_type] v_linear = self.v_linears[source_type] for target_type in range(self.num_types): tb = (node_type_i == int(target_type)) & sb q_linear = self.q_linears[target_type] for relation_type in range(self.num_relations): # idx is all the edges with meta relation <source_type, relation_type, target_type> idx = (edge_type == int(relation_type)) & tb if idx.sum() == 0: continue ''' Get the corresponding input node representations by idx. Add tempotal encoding to source representation (j) ''' target_node_vec = node_inp_i[idx] source_node_vec = self.emb(node_inp_j[idx], edge_time[idx]) ''' Step 1: Heterogeneous Mutual Attention ''' q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k) k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k) k_mat = torch.bmm(k_mat.transpose(1,0), self.relation_att[relation_type]).transpose(1,0) res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk ''' Step 2: Heterogeneous Message Passing ''' v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k) res_msg[idx] = torch.bmm(v_mat.transpose(1,0), self.relation_msg[relation_type]).transpose(1,0) # Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization. self.att = softmax(res_att, edge_index_i) res = res_msg * self.att.view(-1, self.n_heads, 1) del res_att, res_msg return res.view(-1, self.out_dim)
[docs] def update(self, aggr_out, node_inp, node_type): aggr_out = F.gelu(aggr_out) res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device) for target_type in range(self.num_types): idx = (node_type == int(target_type)) if idx.sum() == 0: continue trans_out = self.a_linears[target_type](aggr_out[idx]) ''' Add skip connection with learnable weight self.skip[t_id] ''' alpha = torch.sigmoid(self.skip[target_type]) if self.use_norm: res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha)) else: res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha) return self.drop(res)
def __repr__(self): return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format( self.__class__.__name__, self.in_dim, self.out_dim, self.num_types, self.num_relations)
class RelTemporalEncoding(nn.Module): ''' Implement the Temporal Encoding (Sinusoid) function. ''' def __init__(self, n_hid, max_len = 240, dropout = 0.2): super(RelTemporalEncoding, self).__init__() self.drop = nn.Dropout(dropout) position = torch.arange(0., max_len).unsqueeze(1) div_term = 1 / (10000 ** (torch.arange(0., n_hid * 2, 2.)) / n_hid / 2) self.emb = nn.Embedding(max_len, n_hid * 2) self.emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt(n_hid) self.emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt(n_hid) self.emb.requires_grad = False self.lin = nn.Linear(n_hid * 2, n_hid) def forward(self, x, t): return x + self.lin(self.drop(self.emb(t)))