Source code for ggfm.models.sgformer
import torch
import torch.nn as nn
from ggfm.conv.sgformer_conv import *
[docs]class SGFormer(nn.Module):
r"""
SGFormer model from the `"SGFormer: Spatial Graph Transformer for Molecular Property Prediction"
Args:
in_channels (int): Number of input features.
hidden_channels (int): Number of hidden features.
out_channels (int): Number of output
"""
def __init__(self, in_channels, hidden_channels, out_channels,
trans_num_layers=1, trans_num_heads=1, trans_dropout=0.5, trans_use_bn=True, trans_use_residual=True,
trans_use_weight=True, trans_use_act=True,
gnn_num_layers=1, gnn_dropout=0.5, gnn_use_weight=True, gnn_use_init=False, gnn_use_bn=True,
gnn_use_residual=True, gnn_use_act=True,
use_graph=True, graph_weight=0.8, aggregate='add'):
super().__init__()
self.trans_conv = TransConv(in_channels, hidden_channels, trans_num_layers, trans_num_heads, trans_dropout,
trans_use_bn, trans_use_residual, trans_use_weight, trans_use_act)
self.graph_conv = GraphConv(in_channels, hidden_channels, gnn_num_layers, gnn_dropout, gnn_use_bn,
gnn_use_residual, gnn_use_weight, gnn_use_init, gnn_use_act)
self.use_graph = use_graph
self.graph_weight = graph_weight
self.aggregate = aggregate
if aggregate == 'add':
self.fc = nn.Linear(hidden_channels, out_channels)
elif aggregate == 'cat':
self.fc = nn.Linear(2 * hidden_channels, out_channels)
else:
raise ValueError(f'Invalid aggregate type:{aggregate}')
self.params1 = list(self.trans_conv.parameters())
self.params2 = list(self.graph_conv.parameters()) if self.graph_conv is not None else []
self.params2.extend(list(self.fc.parameters()))
[docs] def forward(self, x, edge_index):
x1 = self.trans_conv(x)
if self.use_graph:
x2 = self.graph_conv(x, edge_index)
if self.aggregate == 'add':
x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
else:
x = torch.cat((x1, x2), dim=1)
else:
x = x1
x = self.fc(x)
return x
[docs] def get_attentions(self, x):
attns = self.trans_conv.get_attentions(x) # [layer num, N, N]
return attns
[docs] def reset_parameters(self):
self.trans_conv.reset_parameters()
if self.use_graph:
self.graph_conv.reset_parameters()