Source code for ggfm.models.utils

import torch
import torch.nn as nn
import torch.nn.functional as F

def get_optimizer(parameters, name, optimizer_args):
        
    if name == "adam":
        optimizer = torch.optim.Adam(parameters, **optimizer_args)
    elif name == "adamw":
        optimizer = torch.optim.AdamW(parameters, **optimizer_args)
    elif name == "adadelta":
        optimizer = torch.optim.Adadelta(parameters, **optimizer_args)
    elif name == "radam":
        optimizer = torch.optim.RAdam(parameters, **optimizer_args)
    else:
        return NotImplementedError
    
    return optimizer

[docs]class LinkPredictor(nn.Module): r"""LinkPredictor for graph link prediction task. Parameters ---------- n_hid: int Input size. n_out: int Output size. """ def __init__(self, n_hid, n_out): super(LinkPredictor, self).__init__() self.fc1 = nn.Linear(n_hid * 2, n_hid) self.fc2 = nn.Linear(n_hid, n_out)
[docs] def forward(self, src, dst): x = torch.cat([src, dst], 1) y = self.fc2(F.relu(self.fc1(x))) return y