Source code for ggfm.conv.sgformer_conv

import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree


[docs]class GraphConvLayer(nn.Module): def __init__(self, in_channels, out_channels, use_weight=True, use_init=False): super(GraphConvLayer, self).__init__() self.use_init = use_init self.use_weight = use_weight if self.use_init: in_channels_ = 2 * in_channels else: in_channels_ = in_channels self.W = nn.Linear(in_channels_, out_channels)
[docs] def reset_parameters(self): self.W.reset_parameters()
[docs] def forward(self, x, edge_index, x0): N = x.shape[0] row, col = edge_index d = degree(col, N).float() d_norm_in = (1. / d[col]).sqrt() d_norm_out = (1. / d[row]).sqrt() value = torch.ones_like(row) * d_norm_in * d_norm_out value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0) adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N)) x = matmul(adj, x) # [N, D] if self.use_init: x = torch.cat([x, x0], 1) x = self.W(x) elif self.use_weight: x = self.W(x) return x
[docs]class GraphConv(nn.Module): def __init__(self, in_channels, hidden_channels, num_layers=2, dropout=0.5, use_bn=True, use_residual=True, use_weight=True, use_init=False, use_act=True): super(GraphConv, self).__init__() self.convs = nn.ModuleList() self.fcs = nn.ModuleList() self.fcs.append(nn.Linear(in_channels, hidden_channels)) self.bns = nn.ModuleList() self.bns.append(nn.BatchNorm1d(hidden_channels)) for _ in range(num_layers): self.convs.append( GraphConvLayer(hidden_channels, hidden_channels, use_weight, use_init)) self.bns.append(nn.BatchNorm1d(hidden_channels)) self.dropout = dropout self.activation = F.relu self.use_bn = use_bn self.use_residual = use_residual self.use_act = use_act
[docs] def reset_parameters(self): for conv in self.convs: conv.reset_parameters() for bn in self.bns: bn.reset_parameters() for fc in self.fcs: fc.reset_parameters()
[docs] def forward(self, x, edge_index): layer_ = [] x = self.fcs[0](x) if self.use_bn: x = self.bns[0](x) x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) layer_.append(x) for i, conv in enumerate(self.convs): x = conv(x, edge_index, layer_[0]) if self.use_bn: x = self.bns[i + 1](x) if self.use_act: x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) if self.use_residual: x = x + layer_[-1] return x
[docs]class TransConvLayer(nn.Module): ''' transformer with fast attention ''' def __init__(self, in_channels, out_channels, num_heads, use_weight=True): super().__init__() self.Wk = nn.Linear(in_channels, out_channels * num_heads) self.Wq = nn.Linear(in_channels, out_channels * num_heads) if use_weight: self.Wv = nn.Linear(in_channels, out_channels * num_heads) self.out_channels = out_channels self.num_heads = num_heads self.use_weight = use_weight
[docs] def reset_parameters(self): self.Wk.reset_parameters() self.Wq.reset_parameters() if self.use_weight: self.Wv.reset_parameters()
[docs] def forward(self, query_input, source_input, output_attn=False): # feature transformation qs = self.Wq(query_input).reshape(-1, self.num_heads, self.out_channels) ks = self.Wk(source_input).reshape(-1, self.num_heads, self.out_channels) if self.use_weight: vs = self.Wv(source_input).reshape(-1, self.num_heads, self.out_channels) else: vs = source_input.reshape(-1, 1, self.out_channels) # normalize input qs = qs / torch.norm(qs, p=2) # [N, H, M] ks = ks / torch.norm(ks, p=2) # [L, H, M] N = qs.shape[0] # numerator kvs = torch.einsum("lhm,lhd->hmd", ks, vs) attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D] attention_num += N * vs # denominator all_ones = torch.ones([ks.shape[0]]).to(ks.device) ks_sum = torch.einsum("lhm,l->hm", ks, all_ones) attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum) # [N, H] # attentive aggregated results attention_normalizer = torch.unsqueeze( attention_normalizer, len(attention_normalizer.shape)) # [N, H, 1] attention_normalizer += torch.ones_like(attention_normalizer) * N attn_output = attention_num / attention_normalizer # [N, H, D] # compute attention for visualization if needed if output_attn: attention = torch.einsum("nhm,lhm->nlh", qs, ks).mean(dim=-1) # [N, N] normalizer = attention_normalizer.squeeze(dim=-1).mean(dim=-1, keepdims=True) # [N,1] attention = attention / normalizer final_output = attn_output.mean(dim=1) if output_attn: return final_output, attention else: return final_output
[docs]class TransConv(nn.Module): def __init__(self, in_channels, hidden_channels, num_layers=2, num_heads=1, dropout=0.5, use_bn=True, use_residual=True, use_weight=True, use_act=True): super().__init__() self.convs = nn.ModuleList() self.fcs = nn.ModuleList() self.fcs.append(nn.Linear(in_channels, hidden_channels)) self.bns = nn.ModuleList() self.bns.append(nn.LayerNorm(hidden_channels)) for i in range(num_layers): self.convs.append( TransConvLayer(hidden_channels, hidden_channels, num_heads=num_heads, use_weight=use_weight)) self.bns.append(nn.LayerNorm(hidden_channels)) self.dropout = dropout self.activation = F.relu self.use_bn = use_bn self.use_residual = use_residual self.use_act = use_act
[docs] def reset_parameters(self): for conv in self.convs: conv.reset_parameters() for bn in self.bns: bn.reset_parameters() for fc in self.fcs: fc.reset_parameters()
[docs] def forward(self, x): layer_ = [] # input MLP layer x = self.fcs[0](x) if self.use_bn: x = self.bns[0](x) x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) # store as residual link layer_.append(x) for i, conv in enumerate(self.convs): # graph convolution with full attention aggregation x = conv(x, x) if self.use_residual: x = (x + layer_[i]) / 2. if self.use_bn: x = self.bns[i + 1](x) if self.use_act: x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) layer_.append(x) return x
[docs] def get_attentions(self, x): layer_, attentions = [], [] x = self.fcs[0](x) if self.use_bn: x = self.bns[0](x) x = self.activation(x) layer_.append(x) for i, conv in enumerate(self.convs): x, attn = conv(x, x, output_attn=True) attentions.append(attn) if self.use_residual: x = (x + layer_[i]) / 2. if self.use_bn: x = self.bns[i + 1](x) if self.use_act: x = self.activation(x) layer_.append(x) return torch.stack(attentions, dim=0) # [layer num, N, N]