ggfm.conv.HGTConv¶
- class ggfm.conv.HGTConv(in_dim, out_dim, num_types, num_relations, n_heads=1, dropout=0.2, use_norm=True, **kwargs)[source]¶
The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.
- Parameters:
in_dim (int) – Size of each input sample of every node type, or
-1to derive the size from the first input(s) to the forward method.out_dim (int) – Size of each output sample.
num_type (int) – Number of node types.
num_relations (int) – Number of relations.
heads (int, optional) – Number of multi-head-attentions. (default:
1)dropout (float) – Dropout rate. (default:
0.2)use_norm (bool, optional) – If use norm. (default:
True)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
- forward(node_inp, node_type, edge_index, edge_type, edge_time)[source]¶
Runs the forward pass of the module.
- message(edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time)[source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.