import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree
[docs]class GraphConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, use_weight=True, use_init=False):
super(GraphConvLayer, self).__init__()
self.use_init = use_init
self.use_weight = use_weight
if self.use_init:
in_channels_ = 2 * in_channels
else:
in_channels_ = in_channels
self.W = nn.Linear(in_channels_, out_channels)
[docs] def reset_parameters(self):
self.W.reset_parameters()
[docs] def forward(self, x, edge_index, x0):
N = x.shape[0]
row, col = edge_index
d = degree(col, N).float()
d_norm_in = (1. / d[col]).sqrt()
d_norm_out = (1. / d[row]).sqrt()
value = torch.ones_like(row) * d_norm_in * d_norm_out
value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N))
x = matmul(adj, x) # [N, D]
if self.use_init:
x = torch.cat([x, x0], 1)
x = self.W(x)
elif self.use_weight:
x = self.W(x)
return x
[docs]class GraphConv(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers=2, dropout=0.5, use_bn=True, use_residual=True,
use_weight=True, use_init=False, use_act=True):
super(GraphConv, self).__init__()
self.convs = nn.ModuleList()
self.fcs = nn.ModuleList()
self.fcs.append(nn.Linear(in_channels, hidden_channels))
self.bns = nn.ModuleList()
self.bns.append(nn.BatchNorm1d(hidden_channels))
for _ in range(num_layers):
self.convs.append(
GraphConvLayer(hidden_channels, hidden_channels, use_weight, use_init))
self.bns.append(nn.BatchNorm1d(hidden_channels))
self.dropout = dropout
self.activation = F.relu
self.use_bn = use_bn
self.use_residual = use_residual
self.use_act = use_act
[docs] def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
for fc in self.fcs:
fc.reset_parameters()
[docs] def forward(self, x, edge_index):
layer_ = []
x = self.fcs[0](x)
if self.use_bn:
x = self.bns[0](x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
layer_.append(x)
for i, conv in enumerate(self.convs):
x = conv(x, edge_index, layer_[0])
if self.use_bn:
x = self.bns[i + 1](x)
if self.use_act:
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
if self.use_residual:
x = x + layer_[-1]
return x
[docs]class TransConvLayer(nn.Module):
'''
transformer with fast attention
'''
def __init__(self, in_channels,
out_channels,
num_heads,
use_weight=True):
super().__init__()
self.Wk = nn.Linear(in_channels, out_channels * num_heads)
self.Wq = nn.Linear(in_channels, out_channels * num_heads)
if use_weight:
self.Wv = nn.Linear(in_channels, out_channels * num_heads)
self.out_channels = out_channels
self.num_heads = num_heads
self.use_weight = use_weight
[docs] def reset_parameters(self):
self.Wk.reset_parameters()
self.Wq.reset_parameters()
if self.use_weight:
self.Wv.reset_parameters()
[docs] def forward(self, query_input, source_input, output_attn=False):
# feature transformation
qs = self.Wq(query_input).reshape(-1, self.num_heads, self.out_channels)
ks = self.Wk(source_input).reshape(-1, self.num_heads, self.out_channels)
if self.use_weight:
vs = self.Wv(source_input).reshape(-1, self.num_heads, self.out_channels)
else:
vs = source_input.reshape(-1, 1, self.out_channels)
# normalize input
qs = qs / torch.norm(qs, p=2) # [N, H, M]
ks = ks / torch.norm(ks, p=2) # [L, H, M]
N = qs.shape[0]
# numerator
kvs = torch.einsum("lhm,lhd->hmd", ks, vs)
attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D]
attention_num += N * vs
# denominator
all_ones = torch.ones([ks.shape[0]]).to(ks.device)
ks_sum = torch.einsum("lhm,l->hm", ks, all_ones)
attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum) # [N, H]
# attentive aggregated results
attention_normalizer = torch.unsqueeze(
attention_normalizer, len(attention_normalizer.shape)) # [N, H, 1]
attention_normalizer += torch.ones_like(attention_normalizer) * N
attn_output = attention_num / attention_normalizer # [N, H, D]
# compute attention for visualization if needed
if output_attn:
attention = torch.einsum("nhm,lhm->nlh", qs, ks).mean(dim=-1) # [N, N]
normalizer = attention_normalizer.squeeze(dim=-1).mean(dim=-1, keepdims=True) # [N,1]
attention = attention / normalizer
final_output = attn_output.mean(dim=1)
if output_attn:
return final_output, attention
else:
return final_output
[docs]class TransConv(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers=2, num_heads=1,
dropout=0.5, use_bn=True, use_residual=True, use_weight=True, use_act=True):
super().__init__()
self.convs = nn.ModuleList()
self.fcs = nn.ModuleList()
self.fcs.append(nn.Linear(in_channels, hidden_channels))
self.bns = nn.ModuleList()
self.bns.append(nn.LayerNorm(hidden_channels))
for i in range(num_layers):
self.convs.append(
TransConvLayer(hidden_channels, hidden_channels, num_heads=num_heads, use_weight=use_weight))
self.bns.append(nn.LayerNorm(hidden_channels))
self.dropout = dropout
self.activation = F.relu
self.use_bn = use_bn
self.use_residual = use_residual
self.use_act = use_act
[docs] def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
for fc in self.fcs:
fc.reset_parameters()
[docs] def forward(self, x):
layer_ = []
# input MLP layer
x = self.fcs[0](x)
if self.use_bn:
x = self.bns[0](x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# store as residual link
layer_.append(x)
for i, conv in enumerate(self.convs):
# graph convolution with full attention aggregation
x = conv(x, x)
if self.use_residual:
x = (x + layer_[i]) / 2.
if self.use_bn:
x = self.bns[i + 1](x)
if self.use_act:
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
layer_.append(x)
return x
[docs] def get_attentions(self, x):
layer_, attentions = [], []
x = self.fcs[0](x)
if self.use_bn:
x = self.bns[0](x)
x = self.activation(x)
layer_.append(x)
for i, conv in enumerate(self.convs):
x, attn = conv(x, x, output_attn=True)
attentions.append(attn)
if self.use_residual:
x = (x + layer_[i]) / 2.
if self.use_bn:
x = self.bns[i + 1](x)
if self.use_act:
x = self.activation(x)
layer_.append(x)
return torch.stack(attentions, dim=0) # [layer num, N, N]