ggfm.conv.HGTConv

class ggfm.conv.HGTConv(in_dim, out_dim, num_types, num_relations, n_heads=1, dropout=0.2, use_norm=True, **kwargs)[source]

The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.

Parameters:
  • in_dim (int) – Size of each input sample of every node type, or -1 to derive the size from the first input(s) to the forward method.

  • out_dim (int) – Size of each output sample.

  • num_type (int) – Number of node types.

  • num_relations (int) – Number of relations.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • dropout (float) – Dropout rate. (default: 0.2)

  • use_norm (bool, optional) – If use norm. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(node_inp, node_type, edge_index, edge_type, edge_time)[source]

Runs the forward pass of the module.

message(edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

update(aggr_out, node_inp, node_type)[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().