Source code for ggfm.models.sgformer

import torch
import torch.nn as nn
from ggfm.conv.sgformer_conv import *

[docs]class SGFormer(nn.Module): r""" SGFormer model from the `"SGFormer: Spatial Graph Transformer for Molecular Property Prediction" Args: in_channels (int): Number of input features. hidden_channels (int): Number of hidden features. out_channels (int): Number of output """ def __init__(self, in_channels, hidden_channels, out_channels, trans_num_layers=1, trans_num_heads=1, trans_dropout=0.5, trans_use_bn=True, trans_use_residual=True, trans_use_weight=True, trans_use_act=True, gnn_num_layers=1, gnn_dropout=0.5, gnn_use_weight=True, gnn_use_init=False, gnn_use_bn=True, gnn_use_residual=True, gnn_use_act=True, use_graph=True, graph_weight=0.8, aggregate='add'): super().__init__() self.trans_conv = TransConv(in_channels, hidden_channels, trans_num_layers, trans_num_heads, trans_dropout, trans_use_bn, trans_use_residual, trans_use_weight, trans_use_act) self.graph_conv = GraphConv(in_channels, hidden_channels, gnn_num_layers, gnn_dropout, gnn_use_bn, gnn_use_residual, gnn_use_weight, gnn_use_init, gnn_use_act) self.use_graph = use_graph self.graph_weight = graph_weight self.aggregate = aggregate if aggregate == 'add': self.fc = nn.Linear(hidden_channels, out_channels) elif aggregate == 'cat': self.fc = nn.Linear(2 * hidden_channels, out_channels) else: raise ValueError(f'Invalid aggregate type:{aggregate}') self.params1 = list(self.trans_conv.parameters()) self.params2 = list(self.graph_conv.parameters()) if self.graph_conv is not None else [] self.params2.extend(list(self.fc.parameters()))
[docs] def forward(self, x, edge_index): x1 = self.trans_conv(x) if self.use_graph: x2 = self.graph_conv(x, edge_index) if self.aggregate == 'add': x = self.graph_weight * x2 + (1 - self.graph_weight) * x1 else: x = torch.cat((x1, x2), dim=1) else: x = x1 x = self.fc(x) return x
[docs] def get_attentions(self, x): attns = self.trans_conv.get_attentions(x) # [layer num, N, N] return attns
[docs] def reset_parameters(self): self.trans_conv.reset_parameters() if self.use_graph: self.graph_conv.reset_parameters()