import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax
from torch_geometric.nn.inits import glorot
from torch_geometric.nn.conv import MessagePassing
[docs]class HGTConv(MessagePassing):
r"""The Heterogeneous Graph Transformer (HGT) operator from the
`"Heterogeneous Graph Transformer"
<https://arxiv.org/abs/2003.01332>`_ paper.
Parameters
----------
in_dim: int
Size of each input sample of every
node type, or :obj:`-1` to derive the size from the first input(s)
to the forward method.
out_dim: int
Size of each output sample.
num_type: int
Number of node types.
num_relations: int
Number of relations.
heads: int, optional
Number of multi-head-attentions.
(default: :obj:`1`)
dropout: float
Dropout rate.
(default: :obj:`0.2`)
use_norm: bool, optional
If use norm.
(default: :obj:`True`)
**kwargs: optional
Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads=1, dropout=0.2, use_norm=True, **kwargs):
super(HGTConv, self).__init__(node_dim=0, aggr='add', **kwargs)
self.in_dim = in_dim
self.out_dim = out_dim
self.num_types = num_types
self.num_relations = num_relations
self.total_rel = num_types * num_relations * num_types
self.n_heads = n_heads
self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k)
self.use_norm = use_norm
self.att = None
self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList()
for t in range(num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim))
self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm:
self.norms.append(nn.LayerNorm(out_dim))
'''
TODO: make relation_pri smaller, as not all <st, rt, tt> pair exist in meta relation list.
'''
self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(num_types))
self.drop = nn.Dropout(dropout)
self.emb = RelTemporalEncoding(in_dim)
glorot(self.relation_att)
glorot(self.relation_msg)
[docs] def forward(self, node_inp, node_type, edge_index, edge_type, edge_time):
return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \
edge_type=edge_type, edge_time = edge_time)
[docs] def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time):
# j: source, i: target; <j, i>
data_size = edge_index_i.size(0)
# Create Attention and Message tensor beforehand.
res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device)
res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device)
for source_type in range(self.num_types):
sb = (node_type_j == int(source_type))
k_linear = self.k_linears[source_type]
v_linear = self.v_linears[source_type]
for target_type in range(self.num_types):
tb = (node_type_i == int(target_type)) & sb
q_linear = self.q_linears[target_type]
for relation_type in range(self.num_relations):
# idx is all the edges with meta relation <source_type, relation_type, target_type>
idx = (edge_type == int(relation_type)) & tb
if idx.sum() == 0:
continue
'''
Get the corresponding input node representations by idx.
Add tempotal encoding to source representation (j)
'''
target_node_vec = node_inp_i[idx]
source_node_vec = self.emb(node_inp_j[idx], edge_time[idx])
'''
Step 1: Heterogeneous Mutual Attention
'''
q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k)
k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
k_mat = torch.bmm(k_mat.transpose(1,0), self.relation_att[relation_type]).transpose(1,0)
res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk
'''
Step 2: Heterogeneous Message Passing
'''
v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k)
res_msg[idx] = torch.bmm(v_mat.transpose(1,0), self.relation_msg[relation_type]).transpose(1,0)
# Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization.
self.att = softmax(res_att, edge_index_i)
res = res_msg * self.att.view(-1, self.n_heads, 1)
del res_att, res_msg
return res.view(-1, self.out_dim)
[docs] def update(self, aggr_out, node_inp, node_type):
aggr_out = F.gelu(aggr_out)
res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device)
for target_type in range(self.num_types):
idx = (node_type == int(target_type))
if idx.sum() == 0:
continue
trans_out = self.a_linears[target_type](aggr_out[idx])
'''
Add skip connection with learnable weight self.skip[t_id]
'''
alpha = torch.sigmoid(self.skip[target_type])
if self.use_norm:
res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha))
else:
res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha)
return self.drop(res)
def __repr__(self):
return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
self.__class__.__name__, self.in_dim, self.out_dim,
self.num_types, self.num_relations)
class RelTemporalEncoding(nn.Module):
'''
Implement the Temporal Encoding (Sinusoid) function.
'''
def __init__(self, n_hid, max_len = 240, dropout = 0.2):
super(RelTemporalEncoding, self).__init__()
self.drop = nn.Dropout(dropout)
position = torch.arange(0., max_len).unsqueeze(1)
div_term = 1 / (10000 ** (torch.arange(0., n_hid * 2, 2.)) / n_hid / 2)
self.emb = nn.Embedding(max_len, n_hid * 2)
self.emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt(n_hid)
self.emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt(n_hid)
self.emb.requires_grad = False
self.lin = nn.Linear(n_hid * 2, n_hid)
def forward(self, x, t):
return x + self.lin(self.drop(self.emb(t)))