Source code for ggfm.models.higpt

import os
import json
import logging
import glob
import math
import re
import html
import gzip
import numpy as np
from collections import OrderedDict
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, List, Optional, Sequence, Tuple, Union, Callable
from omegaconf import OmegaConf
from urllib.parse import urlparse
from torch import Tensor
from torch_geometric.typing import Adj, EdgeType, NodeType

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.nn import CrossEntropyLoss, Parameter
from torch.utils.data import Dataset

import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaTokenizer,
    LlamaForCausalLM,
    LlamaModel,
    LlamaConfig,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.configuration_utils import PretrainedConfig

import torch_geometric
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data
from torch_geometric.utils import (
    remove_self_loops,
    add_self_loops,
    degree,
    softmax,
    add_remaining_self_loops,
)
from torch_geometric.utils.hetero import construct_bipartite_edge_index
# from torch_scatter import scatter_add

import ftfy


"""Special tokens used in the model"""
DEFAULT_GRAPH_TOKEN = "<graph>"
DEFAULT_GRAPH_PATCH_TOKEN = "<g_patch>"
DEFAULT_G_START_TOKEN = "<g_start>"
DEFAULT_G_END_TOKEN = "<g_end>"
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

def is_url(url_or_filename):
    """
    Check if a string is a URL.
    
    Args:
        url_or_filename (str): String to check
        
    Returns:
        bool: True if string is a URL, False otherwise
    """
    parsed = urlparse(url_or_filename)
    return parsed.scheme in ("http", "https")

def get_abs_path(path):
    """
    Get absolute path from a potentially relative path.
    
    Args:
        path (str): Input path
        
    Returns:
        str: Absolute path
    """
    return os.path.abspath(os.path.expanduser(path))

@lru_cache()
def default_bpe():
    """
    Get default path to BPE vocabulary file.
    
    Returns:
        str: Path to BPE vocabulary file
    """
    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")

@lru_cache()
def bytes_to_unicode():
    """
    Convert bytes to unicode characters.
    
    Returns:
        dict: Mapping from bytes to unicode characters
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    """
    Get all adjacent pairs of characters from a word.
    
    Args:
        word (tuple): Word as tuple of characters
        
    Returns:
        set: Set of character pairs
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

def basic_clean(text):
    """
    Basic text cleaning using ftfy.
    
    Args:
        text (str): Input text
        
    Returns:
        str: Cleaned text
    """
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()

def whitespace_clean(text):
    """
    Clean whitespace in text.
    
    Args:
        text (str): Input text
        
    Returns:
        str: Text with normalized whitespace
    """
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text

class BaseModel(nn.Module):
    """
    Base class for all models in HiGPT.
    
    Provides common functionality for model loading, optimization and evaluation.
    """

    def __init__(self):
        """Initialize the base model."""
        super().__init__()

    @property
    def device(self):
        """Get the device where model parameters are stored."""
        return list(self.parameters())[0].device

    def load_checkpoint(self, url_or_filename):
        """
        Load model weights from a checkpoint file.
        
        Args:
            url_or_filename (str): Path or URL to checkpoint
            
        Returns:
            LoaderOutput: Results of loading the checkpoint
        """
        if is_url(url_or_filename):
            cached_file = download_cached_file(
                url_or_filename, check_hash=False, progress=True
            )
            checkpoint = torch.load(cached_file, map_location="cpu")
        elif os.path.isfile(url_or_filename):
            checkpoint = torch.load(url_or_filename, map_location="cpu")
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        if "model" in checkpoint.keys():
            state_dict = checkpoint["model"]
        else:
            state_dict = checkpoint

        msg = self.load_state_dict(state_dict, strict=False)
        logging.info("Missing keys {}".format(msg.missing_keys))
        logging.info("load checkpoint from %s" % url_or_filename)
        return msg

    @classmethod
    def from_pretrained(cls, model_type):
        """
        Create a model instance from pretrained weights.
        
        Args:
            model_type (str): Type/name of the pretrained model
            
        Returns:
            BaseModel: Model instance initialized with pretrained weights
        """
        model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
        model = cls.from_config(model_cfg)
        return model

    @classmethod
    def default_config_path(cls, model_type):
        """
        Get the default configuration file path for a model type.
        
        Args:
            model_type (str): Type/name of the model
            
        Returns:
            str: Path to the configuration file
            
        Raises:
            AssertionError: If model_type is not recognized
        """
        assert model_type in cls.PRETRAINED_MODEL_CONFIG_DICT, "Unknown model type {}".format(model_type)
        return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])

    def load_checkpoint_from_config(self, cfg, **kwargs):
        """
        Load checkpoint based on configuration.
        
        Args:
            cfg (Config): Configuration object containing checkpoint paths
            **kwargs: Additional arguments for loading pretrained weights
            
        Raises:
            AssertionError: If required paths are missing in config
        """
        load_finetuned = cfg.get("load_finetuned", True)
        if load_finetuned:
            finetune_path = cfg.get("finetuned", None)
            assert finetune_path is not None, "Found load_finetuned is True, but finetune_path is None."
            self.load_checkpoint(url_or_filename=finetune_path)
        else:
            load_pretrained = cfg.get("load_pretrained", True)
            if load_pretrained:
                pretrain_path = cfg.get("pretrained", None)
                assert "Found load_finetuned is False, but pretrain_path is None."
                self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)

    def before_training(self, **kwargs):
        pass

    def get_optimizer_params(self, weight_decay, lr_scale=1):
        """
        Get parameters for optimizer with proper weight decay settings.
        
        Args:
            weight_decay (float): Weight decay factor
            lr_scale (float, optional): Learning rate scaling factor. Defaults to 1
            
        Returns:
            list: List of parameter groups with optimization settings
        """
        p_wd, p_non_wd = [], []
        for n, p in self.named_parameters():
            if not p.requires_grad:
                continue
            if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
                p_non_wd.append(p)
            else:
                p_wd.append(p)
        optim_params = [
            {"params": p_wd, "weight_decay": weight_decay, "lr_scale": lr_scale},
            {"params": p_non_wd, "weight_decay": 0, "lr_scale": lr_scale},
        ]
        return optim_params

    def before_evaluation(self, **kwargs):
        pass

    def show_n_params(self, return_str=True):
        """
        Calculate and format the total number of parameters.
        
        Args:
            return_str (bool, optional): Whether to return formatted string. Defaults to True
            
        Returns:
            Union[str, int]: Number of parameters as string (with M/K suffix) or integer
        """
        tot = 0
        for p in self.parameters():
            w = 1
            for x in p.shape:
                w *= x
            tot += w
        if return_str:
            if tot >= 1e6:
                return "{:.1f}M".format(tot / 1e6)
            else:
                return "{:.1f}K".format(tot / 1e3)
        else:
            return tot

class LayerNorm(nn.LayerNorm):
    """
    Layer normalization module with fp16 support.
    
    Extends PyTorch's LayerNorm to properly handle float16 precision.
    """

    def forward(self, x: torch.Tensor):
        """
        Apply layer normalization.
        
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            torch.Tensor: Normalized tensor in original dtype
        """
        orig_type = x.dtype
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        return x.to(orig_type)

class QuickGELU(nn.Module):
    """
    Fast approximation of GELU activation function.
    
    Uses sigmoid multiplication instead of error function for efficiency.
    """
    
    def forward(self, x: torch.Tensor):
        """
        Apply Quick GELU activation.
        
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            torch.Tensor: Activated tensor
        """
        return x * torch.sigmoid(1.702 * x)
    
class graph_transformer(nn.Module):
    """
    Graph Transformer model for processing graph structured data.
    
    Combines transformer architecture with graph neural network components
    to process graph node features and edge structure.
    """
    
    def __init__(self, args):
        """
        Initialize the graph transformer.
        
        Args:
            args: Configuration object containing model parameters including:
                - gnn_width: Width of GNN layers
                - gnn_layers: Number of GNN layers
                - gnn_heads: Number of attention heads
        """
        super().__init__()
        self.config = PretrainedConfig()
        self.gnn = Transformer(
            width=args.gnn_width,
            layers=args.gnn_layers,
            heads=args.gnn_heads
        )
        self.ln_post = LayerNorm(args.gnn_width)
        self.proj = nn.Parameter(torch.randn(args.gnn_width, args.gnn_output) / args.gnn_width ** 0.5)

    def forward(self, g):
        """
        Process input graph through the transformer.
        
        Args:
            g: Graph object containing:
                - graph_node: Node feature tensor
                - edge_index: Edge connectivity tensor
                
        Returns:
            torch.Tensor: Processed node features
        """
        x = g.graph_node
        edge_index = g.edge_index
        
        x = self.gnn(x)
        x = self.ln_post(x)
        if self.proj is not None:
            x = x @ self.proj
            
        return x

def load_model(
    model_path: str,
    device: str,
    num_gpus: int,
    max_gpu_memory: Optional[str] = None,
    load_8bit: bool = False,
    cpu_offloading: bool = False,
    debug: bool = False,
):
    """
    Load a pretrained model from Hugging Face.
    
    Args:
        model_path (str): Path or name of model on HuggingFace
        device (str): Device to load model on ('cpu', 'cuda', 'mps')
        num_gpus (int): Number of GPUs to use
        max_gpu_memory (str, optional): Maximum GPU memory per device
        load_8bit (bool, optional): Whether to load in 8-bit precision. Defaults to False
        cpu_offloading (bool, optional): Whether to offload weights to CPU. Defaults to False
        debug (bool, optional): Whether to print debug info. Defaults to False
        
    Returns:
        AutoModelForCausalLM: Loaded model instance
        
    Raises:
        ValueError: If device is invalid
    """
    if device == "cpu":
        kwargs = {"torch_dtype": torch.float32}
    elif device == "cuda":
        kwargs = {"torch_dtype": torch.float16}
        if num_gpus != 1:
            kwargs["device_map"] = "auto"
            if max_gpu_memory is None:
                kwargs["device_map"] = "sequential"
                available_gpu_memory = get_gpu_memory(num_gpus)
                kwargs["max_memory"] = {
                    i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
                    for i in range(num_gpus)
                }
            else:
                kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
    else:
        raise ValueError(f"Invalid device: {device}")

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        low_cpu_mem_usage=True,
        **kwargs
    )

    if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device == "mps":
        model.to(device)

    if debug:
        print(model)

    return model


def add_model_args(parser):
    """Add model arguments to the parser."""
    group = parser.add_argument_group('model')
    group.add_argument(
        "--model-path",
        type=str,
        default="lmsys/vicuna-7b-v1.3",
        help="Path to the model weights or model name on Hugging Face."
    )
    group.add_argument(
        "--device",
        type=str,
        choices=["cpu", "cuda", "mps"],
        default="cuda",
        help="The device to run the model on."
    )
    group.add_argument(
        "--num-gpus",
        type=int,
        default=1,
        help="Number of GPUs to use."
    )
    group.add_argument(
        "--max-gpu-memory",
        type=str,
        help="Maximum GPU memory to use per GPU."
    )
    group.add_argument(
        "--load-8bit",
        action="store_true",
        help="Load the model in 8-bit precision."
    )
    group.add_argument(
        "--cpu-offloading", 
        action="store_true",
        help="Offload model weights to CPU to save GPU memory."
    )
    return group

def get_gpu_memory(num_gpus):
    """Get available memory for each GPU."""
    import torch.cuda
    gpu_memory = []
    for i in range(num_gpus):
        with torch.cuda.device(i):
            device = torch.cuda.current_device()
            gpu_properties = torch.cuda.get_device_properties(device)
            total_memory = gpu_properties.total_memory / (1024**3)  
            gpu_memory.append(total_memory)
    return gpu_memory

def maybe_zero_3(param, ignore_status=False, name=None):
    """Handle DeepSpeed ZeRO-3 params."""
    # from deepspeed import zero
    # from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
    if hasattr(param, "ds_id"):
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
            if not ignore_status:
                logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param

class ResidualAttentionBlock(nn.Module):
    """
    Transformer block with residual attention and MLP.
    
    Implements a standard transformer block with self-attention followed by MLP,
    with layer normalization and residual connections.
    """
    
    def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
        """
        Initialize the residual attention block.
        
        Args:
            d_model (int): Hidden dimension size
            n_head (int): Number of attention heads
            act_layer (Callable, optional): Activation function. Defaults to GELU
        """
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(
            OrderedDict(
                [
                    ("c_fc", nn.Linear(d_model, d_model * 4)),
                    ("gelu", act_layer()),
                    ("c_proj", nn.Linear(d_model * 4, d_model)),
                ]
            )
        )
        self.ln_2 = LayerNorm(d_model)

    def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
        """
        Compute self-attention.
        
        Args:
            x (torch.Tensor): Input tensor
            attn_mask (torch.Tensor, optional): Attention mask
            
        Returns:
            torch.Tensor: Self-attention output
        """
        return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
        """
        Forward pass through the block.
        
        Args:
            x (torch.Tensor): Input tensor
            attn_mask (torch.Tensor, optional): Attention mask
            
        Returns:
            torch.Tensor: Processed tensor
        """
        x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x

class Transformer(nn.Module):
    """
    Full transformer model with multiple attention blocks.
    
    Stacks multiple ResidualAttentionBlocks to form a complete transformer.
    """
    
    def __init__(self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU):
        """
        Initialize the transformer.
        
        Args:
            width (int): Hidden dimension size
            layers (int): Number of transformer layers
            heads (int): Number of attention heads per layer
            act_layer (Callable, optional): Activation function. Defaults to GELU
        """
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.ModuleList(
            [ResidualAttentionBlock(width, heads, act_layer=act_layer) for _ in range(layers)]
        )

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
        """
        Forward pass through the transformer.
        
        Args:
            x (torch.Tensor): Input tensor
            attn_mask (torch.Tensor, optional): Attention mask
            
        Returns:
            torch.Tensor: Processed tensor
        """
        for r in self.resblocks:
            x = r(x, attn_mask=attn_mask)
        return x

class SimpleTokenizer(object):
    """
    Basic tokenizer implementation with BPE encoding.
    
    Implements byte-pair encoding (BPE) tokenization with support for special tokens
    and caching.
    """
    
    def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
        """
        Initialize the tokenizer.
        
        Args:
            bpe_path (str, optional): Path to BPE vocabulary file
            special_tokens (list, optional): Additional special tokens to add
        """
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        if not special_tokens:
            special_tokens = ["<start_of_text>", "<end_of_text>"]
        else:
            special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
        vocab.extend(special_tokens)
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {t: t for t in special_tokens}
        special = "|".join(special_tokens)
        self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
        self.vocab_size = len(self.encoder)
        self.all_special_ids = [self.encoder[t] for t in special_tokens]

    def bpe(self, token):
        """
        Apply byte-pair encoding to a token.
        
        Args:
            token (str): Input token
            
        Returns:
            str: BPE-encoded token
        """
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + (token[-1] + "</w>",)
        pairs = get_pairs(word)

        if not pairs:
            return token + "</w>"

        while True:
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = " ".join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        """
        Encode text into token IDs.
        
        Args:
            text (str): Input text
            
        Returns:
            list: List of token IDs
        """
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
        return bpe_tokens

    def decode(self, tokens):
        """
        Decode token IDs back to text.
        
        Args:
            tokens (list): List of token IDs
            
        Returns:
            str: Decoded text
        """
        text = "".join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
        return text

def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
    """
    Tokenize text(s) with special tokens and padding.
    
    Args:
        texts (Union[str, List[str]]): Input text or list of texts
        context_length (int, optional): Maximum sequence length. Defaults to 77
        
    Returns:
        torch.LongTensor: Tensor of token IDs with shape (batch_size, context_length)
    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<start_of_text>"]
    eot_token = _tokenizer.encoder["<end_of_text>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            tokens = tokens[:context_length]
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

# _tokenizer = SimpleTokenizer()



def gcn_conv(h, edge_index):
    """
    Basic Graph Convolutional Network convolution operation.
    
    Implements the standard GCN propagation rule with normalized adjacency matrix.
    
    Args:
        h (torch.Tensor): Node feature matrix
        edge_index (torch.Tensor): Graph connectivity in COO format
        
    Returns:
        torch.Tensor: Updated node features after convolution
    """
    N, node_feas = h.shape
    edge_index, _ = remove_self_loops(edge_index)
    edge_index, _ = add_self_loops(edge_index, num_nodes=N)
    
    src, dst = edge_index
    deg = degree(dst, num_nodes=N)

    deg_src = deg[src].pow(-0.5) 
    deg_src.masked_fill_(deg_src == float('inf'), 0)
    deg_dst = deg[dst].pow(-0.5)
    deg_dst.masked_fill_(deg_dst == float('inf'), 0)
    edge_weight = deg_src * deg_dst

    a = torch.sparse_coo_tensor(edge_index, edge_weight, torch.Size([N, N])).t()
    rows, cols = edge_index
    edge_msg = h[rows, :] * torch.unsqueeze(edge_weight, dim=-1)
    col_embeds = h[cols, :]
    tem = torch.zeros([N, node_feas]).to(edge_msg.device)
    rows = rows.to(edge_msg.device)
    h_prime = tem.index_add_(0, rows, edge_msg)
    return h_prime

class MPNN(nn.Module):
    """
    Message Passing Neural Network implementation.
    
    A general framework for graph neural networks that updates node representations
    via message passing between neighbors.
    
    Args:
        in_channels (int): Input feature dimension
        hidden_channels (int): Hidden layer dimension
        out_channels (int): Output feature dimension
        **kwargs: Additional arguments including:
            - dropout (float): Dropout rate
            - num_layers (int): Number of message passing layers
            - if_param (bool): Whether to use learnable parameters
    """
    
    def __init__(self, in_channels, hidden_channels, out_channels, **kwargs):
        super(MPNN, self).__init__()
        self.config = PretrainedConfig()
        self.dropout = kwargs.get('dropout')
        self.num_layers = kwargs.get('num_layers')
        self.ff_bias = True

        self.bns = nn.BatchNorm1d(hidden_channels, affine=False, track_running_stats=False)
        self.activation = F.relu
        self.if_param = kwargs.get('if_param')

        if self.if_param:
            self.fcs = nn.ModuleList([])
            self.fcs.append(nn.Linear(in_channels, hidden_channels, bias=self.ff_bias))
            for _ in range(self.num_layers - 2):
                self.fcs.append(nn.Linear(hidden_channels, hidden_channels, bias=self.ff_bias))
            self.fcs.append(nn.Linear(hidden_channels, out_channels, bias=self.ff_bias))
            self.reset_parameters()

    def reset_parameters(self):
        """Initialize model parameters using Xavier initialization."""
        for mlp in self.fcs:
            nn.init.xavier_uniform_(mlp.weight, gain=1.414)
            nn.init.zeros_(mlp.bias)

    def forward(self, g, use_conv=True):
        """
        Forward pass through the MPNN.
        
        Args:
            g: Graph object containing node features and connectivity
            use_conv (bool, optional): Whether to use convolution. Defaults to True
            
        Returns:
            torch.Tensor: Updated node features
        """
        x = g.graph_node
        edge_index = g.edge_index
        try:
            device = self.parameters().__next__().device
        except:
            device = x.device
        x = x.to(device)
        edge_index = edge_index.to(device)

        for i in range(self.num_layers - 1):
            if self.if_param:
                x = x @ self.fcs[i].weight.t()
            if use_conv:
                x = gcn_conv(x, edge_index)
            if self.ff_bias and self.if_param:
                x = x + self.fcs[i].bias
            try:
                x = self.activation(self.bns(x))
            except:
                x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        if self.if_param:
            x = x @ self.fcs[-1].weight.t()
        if use_conv:
            x = gcn_conv(x, edge_index)
        if self.ff_bias and self.if_param:
            x = x + self.fcs[-1].bias
        return x

@dataclass
class MetaHGTConvCfg:
    """
    Configuration class for Meta Heterogeneous Graph Transformer Convolution.
    
    Attributes:
        in_channels (int): Input feature dimension
        out_channels (int): Output feature dimension
        heads (int): Number of attention heads
        dynamic (bool): Whether to use dynamic weight generation
    """
    in_channels: int
    out_channels: int
    heads: int
    dynamic: bool = True

class MetaHGTConv(MessagePassing):
    """
    Meta Heterogeneous Graph Transformer Convolution layer.
    
    Implements attention-based message passing for heterogeneous graphs with
    meta-learning capabilities for handling different types of nodes and edges.
    """
    
    def __init__(self, in_channels, out_channels, heads=1, dynamic=False, text_cfg=None, **kwargs):
        """
        Initialize the MetaHGTConv layer.
        
        Args:
            in_channels (int): Input feature dimension
            out_channels (int): Output feature dimension
            heads (int, optional): Number of attention heads. Defaults to 1
            dynamic (bool, optional): Whether to use dynamic weights. Defaults to False
            text_cfg: Text processing configuration
            **kwargs: Additional arguments
        """
        super().__init__(aggr='add', node_dim=0, **kwargs)

        self.config = PretrainedConfig()

        if out_channels % heads != 0:
            raise ValueError(f"'out_channels' (got {out_channels}) must be divisible by the number of heads (got {heads})")

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads

        self.kqv_lin = MetaHeteroDictLinear(text_cfg.width, self.in_channels,
                                        self.out_channels * 3, dynamic)

        self.out_lin = MetaHeteroDictLinear(text_cfg.width, self.out_channels, self.out_channels, dynamic)
        self.context_length = text_cfg.context_length

        dim = out_channels // heads

        self.k_rel = MetaHeteroLinear(text_cfg.width, dim, dim, dynamic)
        self.v_rel = MetaHeteroLinear(text_cfg.width, dim, dim, dynamic)

        self.skipTrans = nn.Linear(text_cfg.width, 1)
        self.p_relTrans = nn.Linear(text_cfg.width, heads)
        self.norm = nn.LayerNorm(self.out_channels, eps=1e-6)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()

    def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]:
        """
        Concatenate features from different node types.
        
        Args:
            x_dict (Dict[str, Tensor]): Dictionary of node features by type
            
        Returns:
            Tuple[Tensor, Dict[str, int]]: Concatenated features and offset mapping
        """
        cumsum = 0
        outs: List[Tensor] = []
        offset: Dict[str, int] = {}
        for key, x in x_dict.items():
            outs.append(x)
            offset[key] = cumsum
            cumsum += x.size(0)
        return torch.cat(outs, dim=0), offset

    def _construct_src_node_feat(
        self, k_dict: Dict[str, Tensor], v_dict: Dict[str, Tensor],
        edge_index_dict: Dict[EdgeType, Adj], 
        edge_type_feas_dict: Dict[EdgeType, Tensor], 
    ) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]:
        """
        Construct source node representations for attention.
        
        Args:
            k_dict (Dict[str, Tensor]): Key vectors by node type
            v_dict (Dict[str, Tensor]): Value vectors by node type
            edge_index_dict (Dict[EdgeType, Adj]): Edge indices by type
            edge_type_feas_dict (Dict[EdgeType, Tensor]): Edge type features
            
        Returns:
            Tuple[Tensor, Tensor, Dict[EdgeType, int]]: Processed key and value vectors with offsets
        """
        cumsum = 0
        num_edge_types = len(edge_index_dict.keys())
        H, D = self.heads, self.out_channels // self.heads

        ks: List[Tensor] = []
        vs: List[Tensor] = []
        type_list: List[Tensor] = []
        offset: Dict[EdgeType] = {}

        edge_types_map = {
            edge_type: i
            for i, edge_type in enumerate(edge_index_dict.keys())
        }
        for edge_type in edge_index_dict.keys():
            src = edge_type[0]
            N = k_dict[src].size(0)
            offset[edge_type] = cumsum
            cumsum += N

            edge_type_offset = edge_types_map[edge_type]
            type_vec = torch.arange(H, dtype=torch.long).view(-1, 1).repeat(
                1, N) * num_edge_types + edge_type_offset

            type_list.append(type_vec)
            ks.append(k_dict[src])
            vs.append(v_dict[src])

        ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D)
        vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D)
        type_vec = torch.cat(type_list, dim=1).flatten()

        edge_feas_dict = {edge_types_map[k]: v for k, v in edge_type_feas_dict.items()}

        k = self.k_rel(ks, type_vec, edge_feas_dict).view(H, -1, D).transpose(0, 1)
        v = self.v_rel(vs, type_vec, edge_feas_dict).view(H, -1, D).transpose(0, 1)

        return k, v, offset

    def _construct_p_rel(self, edge_type_feas_dict: Dict[EdgeType, Tensor]):
        """
        Construct relation-specific attention weights.
        
        Args:
            edge_type_feas_dict (Dict[EdgeType, Tensor]): Edge type features
            
        Returns:
            Dict[EdgeType, Tensor]: Processed attention weights for each edge type
        """
        p_rel = {k: self.p_relTrans(v).unsqueeze(0) for k, v in edge_type_feas_dict.items()}
        return p_rel

    def _construct_skip(self, node_type_feas_dict: Dict[EdgeType, Tensor]):
        """
        Construct skip connection weights.
        
        Args:
            node_type_feas_dict (Dict[EdgeType, Tensor]): Node type features
            
        Returns:
            Dict[EdgeType, Tensor]: Skip connection weights for each node type
        """
        skip = {k: self.skipTrans(v) for k, v in node_type_feas_dict.items()}
        return skip

    def forward(
        self,
        x_dict: Dict[NodeType, Tensor],
        edge_index_dict: Dict[EdgeType, Adj],
        data_type: str = 'dblp',
        node_type_feas_dict: Dict[NodeType, Tensor] = None,
        edge_type_feas_dict: Dict[EdgeType, Tensor] = None,
    ) -> Dict[NodeType, Optional[Tensor]]:
        F = self.out_channels
        H = self.heads
        D = F // H

        k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {}

        kqv_dict = self.kqv_lin(x_dict, node_type_feas_dict)

        for key, val in kqv_dict.items():
            k, q, v = torch.tensor_split(val, 3, dim=1)
            k_dict[key] = k.view(-1, H, D)
            q_dict[key] = q.view(-1, H, D)
            v_dict[key] = v.view(-1, H, D)

        q, dst_offset = self._cat(q_dict)
        k, v, src_offset = self._construct_src_node_feat(
            k_dict, v_dict, edge_index_dict, edge_type_feas_dict)
        p_rel = self._construct_p_rel(edge_type_feas_dict)
        edge_index, edge_attr = construct_bipartite_edge_index(
            edge_index_dict, src_offset, dst_offset, edge_attr_dict=p_rel)

        out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr,
                             size=None)

        dst_node_types = set([key[-1] for key in edge_index_dict.keys()])

        for node_type, start_offset in dst_offset.items():
            end_offset = start_offset + q_dict[node_type].size(0)
            if node_type in dst_node_types:
                out_dict[node_type] = out[start_offset:end_offset]

        a_dict = self.out_lin({
            k: v if v is not None else v
            for k, v in out_dict.items()
        }, node_type_feas_dict)

        skip = self._construct_skip(node_type_feas_dict)

        for node_type, out in out_dict.items():
            out = a_dict[node_type]

            if out.size(-1) == x_dict[node_type].size(-1):
                alpha = skip[node_type].sigmoid()
                out = alpha * out + (1 - alpha) * x_dict[node_type]
            out = self.norm(out)
            out_dict[node_type] = out
            
        return out_dict

    def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor,
                index: Tensor, ptr: Optional[Tensor],
                size_i: Optional[int]) -> Tensor:
        """
        Compute messages in the message passing framework.
        
        Implements the attention-based message computation between nodes.
        
        Args:
            k_j (Tensor): Key vectors of source nodes
            q_i (Tensor): Query vectors of target nodes
            v_j (Tensor): Value vectors of source nodes
            edge_attr (Tensor): Edge attributes
            index (Tensor): Target node indices
            ptr (Optional[Tensor]): Compressed sparse format pointer
            size_i (Optional[int]): Number of target nodes
            
        Returns:
            Tensor: Computed messages
        """
        alpha = (q_i * k_j).sum(dim=-1) * edge_attr
        alpha = alpha / math.sqrt(q_i.size(-1))
        alpha = softmax(alpha, index, ptr, size_i)
        out = v_j * alpha.view(-1, self.heads, 1)
        return out.view(-1, self.out_channels)

    def __repr__(self) -> str:
        """
        Get string representation of the layer.
        
        Returns:
            str: Layer description with output channels and number of heads
        """
        return (f'{self.__class__.__name__}(-1, {self.out_channels}, '
                f'heads={self.heads})')

class GNN(MessagePassing):
    """
    Graph Neural Network implementation.
    
    A basic GNN that uses message passing to update node representations.
    Includes learnable weight matrices and bias terms.
    
    Args:
        args: Configuration object containing:
            - gnn_hid (int): Hidden dimension size
            - gnn_input (int): Input feature dimension
            - gnn_output (int): Output feature dimension
    """
    
    def __init__(self, args, **kwargs):
        super(GNN, self).__init__(aggr='add', **kwargs)
        self.config = PretrainedConfig()
        self.vars = nn.ParameterList()

        w = nn.Parameter(torch.ones([args.gnn_hid, args.gnn_input]))
        torch.nn.init.xavier_uniform_(w)
        self.vars.append(w)
        self.vars.append(nn.Parameter(torch.zeros(args.gnn_hid)))

        w = nn.Parameter(torch.ones([args.gnn_output, args.gnn_hid]))
        torch.nn.init.xavier_uniform_(w)
        self.vars.append(w)
        self.vars.append(nn.Parameter(torch.zeros(args.gnn_output)))

    @staticmethod
    def norm(edge_index, num_nodes, improved=False, dtype=None):
        """
        Compute normalized edge weights.
        
        Args:
            edge_index (Tensor): Edge indices
            num_nodes (int): Number of nodes in graph
            improved (bool, optional): Whether to use improved normalization. Defaults to False
            dtype (torch.dtype, optional): Data type of weights
            
        Returns:
            Tuple[Tensor, Tensor]: Normalized edge indices and weights
        """
        edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
                                 device=edge_index.device)

        fill_value = 1.0 if not improved else 2.0
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, g, vars=None):
        device = self.parameters()[0].device
        g = g.to(device)
        
        edge_index = g.edge_index
        x = g.graph_node
        if vars is None:
            vars = self.vars
        improved = False

        w, b = vars[0], vars[1]
        edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype)
        x = self.propagate(edge_index, x=x, norm=norm)
        w = w.to(x.device)
        b = b.to(x.device)
        x = F.linear(x, w, b)
        x = F.leaky_relu(x)

        w, b = vars[2], vars[3]
        edge_index, norm = self.norm(edge_index, x.size(self.node_dim), improved, x.dtype)
        x = self.propagate(edge_index, x=x, norm=norm)
        w = w.to(x.device)
        b = b.to(x.device)
        x = F.linear(x, w, b)

        return x

    def parameters(self):
        return self.vars


@dataclass
class CLIPTextCfg:
    """
    Configuration class for CLIP text encoder.
    
    Attributes:
        context_length (int): Maximum sequence length for text
        vocab_size (int): Size of vocabulary
        width (int): Hidden dimension size
        heads (int): Number of attention heads
        layers (int): Number of transformer layers
    """
    context_length: int
    vocab_size: int
    width: int
    heads: int
    layers: int

@dataclass
class ClipOutputFeatures:
    """
    Data class for storing features extracted by CLIP model.
    
    Attributes:
        image_embeds (torch.FloatTensor, optional): Raw image embeddings
        image_embeds_proj (torch.FloatTensor, optional): Projected image embeddings
        text_embeds (torch.FloatTensor, optional): Raw text embeddings
        text_embeds_proj (torch.FloatTensor, optional): Projected text embeddings
    """
    image_embeds: Optional[torch.FloatTensor] = None
    image_embeds_proj: Optional[torch.FloatTensor] = None
    text_embeds: Optional[torch.FloatTensor] = None
    text_embeds_proj: Optional[torch.FloatTensor] = None

@dataclass
class ClipOutput:
    """
    Output class for CLIP model.
    
    Attributes:
        intermediate_output (ClipOutputFeatures, optional): Intermediate feature outputs
        logit_scale_exp (torch.FloatTensor, optional): Exponential of learnable temperature parameter
        loss (torch.FloatTensor, optional): Contrastive loss value
    """
    intermediate_output: Optional[ClipOutputFeatures] = None
    logit_scale_exp: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None

@dataclass
class HeteClipOutputFeatures:
    """
    Data class for storing features from heterogeneous CLIP model.
    
    Similar to ClipOutputFeatures but replaces image embeddings with graph embeddings
    for handling heterogeneous graph data.
    
    Attributes:
        graph_embeds (torch.FloatTensor, optional): Raw graph embeddings
        graph_embeds_proj (torch.FloatTensor, optional): Projected graph embeddings
        text_embeds (torch.FloatTensor, optional): Raw text embeddings
        text_embeds_proj (torch.FloatTensor, optional): Projected text embeddings
    """
    graph_embeds: Optional[torch.FloatTensor] = None
    graph_embeds_proj: Optional[torch.FloatTensor] = None
    text_embeds: Optional[torch.FloatTensor] = None
    text_embeds_proj: Optional[torch.FloatTensor] = None

class CLIP(BaseModel):
    """
    CLIP model adapted for graph-text contrastive learning.
    
    Implements a CLIP-style architecture that learns joint embeddings of 
    heterogeneous graphs and text descriptions.
    
    Args:
        embed_dim (int): Joint embedding dimension
        graph_cfg (MetaHGTConvCfg): Configuration for graph encoder
        text_cfg (CLIPTextCfg): Configuration for text encoder
        quick_gelu (bool, optional): Whether to use quick GELU activation. Defaults to False
    """
    def __init__(
        self,
        embed_dim: int,
        graph_cfg: MetaHGTConvCfg,
        text_cfg: CLIPTextCfg,
        quick_gelu: bool = False,
    ):
        super().__init__()

        self.tokenizer = tokenize
        self._loss = None

        if isinstance(graph_cfg, dict):
            graph_cfg = MetaHGTConvCfg(**graph_cfg)
        if isinstance(text_cfg, dict):
            text_cfg = CLIPTextCfg(**text_cfg)

        self.context_length = text_cfg.context_length
        act_layer = QuickGELU if quick_gelu else nn.GELU

        self.graph_encoder = MetaHGTConv(
            in_channels = graph_cfg.in_channels,
            out_channels = graph_cfg.out_channels,
            heads = graph_cfg.heads,
            dynamic = graph_cfg.dynamic,
            text_cfg = text_cfg
        )

        self.transformer = Transformer(
            width=text_cfg.width,
            layers=text_cfg.layers,
            heads=text_cfg.heads,
            act_layer=act_layer,
        )

        self.vocab_size = text_cfg.vocab_size
        self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, text_cfg.width))
        self.ln_final = LayerNorm(text_cfg.width)

        self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)

        self.prompt_templates = openai_imagenet_template
        self.classifier = None

        self.init_parameters()

    @property
    def loss(self):
        """Get the contrastive loss function."""
        if self._loss is None:
            self._loss = HeteClipLoss()
        return self._loss

    def init_parameters(self):
        """Initialize model parameters with proper scaling."""
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)
        nn.init.constant_(self.logit_scale, np.log(1 / 0.07))

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

    def build_attention_mask(self):
        """
        Build causal attention mask for transformer.
        
        Returns:
            torch.Tensor: Attention mask with upper triangular set to -inf
        """
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)
        return mask

    def encode_graph(self, graph: List[Dict[str, torch.Tensor]], des_order: List[List[str]]):
        """
        Encode heterogeneous graph inputs.
        
        Args:
            graph (List[Dict[str, torch.Tensor]]): List of graph dictionaries
            des_order (List[List[str]]): Node type ordering for each graph
            
        Returns:
            torch.Tensor: Graph embeddings
        """
        graph_list = []
        for graph_dict in graph:
            graph_list.append(self.graph_encoder(graph_dict.x_dict, graph_dict.edge_index_dict))
        graph_embeds = []
        assert len(graph_list) == len(des_order)
        for idx, order in enumerate(des_order):
            graph_embeds.extend([graph_list[idx][o] for o in order])
        graph_embeds = torch.cat(graph_embeds, dim=0)
        return graph_embeds

    def encode_text(self, text):
        """
        Encode text inputs through transformer.
        
        Args:
            text: Input text tokens
            
        Returns:
            torch.Tensor: Text embeddings
        """
        x = self.token_embedding(text)
        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return x

    def forward(self, samples):
        """
        Forward pass computing contrastive loss between graph and text.
        
        Args:
            samples: Dictionary containing:
                - graph (List[Dict]): Graph inputs
                - text_input (List[str]): Text inputs
                - des_order (List[List[str]]): Node type ordering
                
        Returns:
            ClipOutput: Model outputs including features and loss
        """
        graph: List[Dict] = samples.get("graph")
        text: List[str] = samples.get("text_input")
        des_order: List[List[str]] = samples.get("des_order")

        if text is not None:
            text = self.tokenizer(text, self.context_length).to(self.token_embedding.weight.device)

        if graph is None:
            return self.encode_text(text)
        elif text is None:
            return self.encode_graph(graph, des_order)

        graph_embeds = self.encode_graph(graph, des_order)
        graph_features = F.normalize(graph_embeds, dim=-1)

        text_embeds = self.encode_text(text)
        text_features = F.normalize(text_embeds, dim=-1)
        assert graph_features.shape == text_features.shape

        loss = self.loss(graph_features, text_features, self.logit_scale.exp())

        return ClipOutput(
            intermediate_output=HeteClipOutputFeatures(
                graph_embeds=graph_embeds,
                graph_embeds_proj=graph_features,
                text_embeds=text_embeds,
                text_embeds_proj=text_features,
            ),
            loss=loss,
            logit_scale_exp=self.logit_scale.exp(),
        )

    def extract_features(self, samples):
        """
        Extract features without computing loss.
        
        Similar to forward() but only returns embeddings without loss computation.
        
        Args:
            samples: Dictionary containing graph and text inputs
            
        Returns:
            HeteClipOutputFeatures: Extracted features
        """
        graph: List[Dict] = samples.get("graph")
        text: List[str] = samples.get("text_input")
        des_order: List[List[str]] = samples.get("des_order")

        if text is not None:
            text = self.tokenizer(text)

        if graph is None:
            return self.encode_text(text)
        elif text is None:
            return self.encode_graph(graph, des_order)

        graph_embeds = self.encode_graph(graph, des_order)
        graph_features = F.normalize(graph_embeds, dim=-1)

        text_embeds = self.encode_text(text)
        text_features = F.normalize(text_embeds, dim=-1)
        assert graph_features.shape == text_features.shape

        return HeteClipOutputFeatures(
            graph_embeds=graph_embeds,
            graph_embeds_proj=graph_features,
            text_embeds=text_embeds,
            text_embeds_proj=text_features,
        )



class GraphLlamaConfig(LlamaConfig):
    """
    Configuration class for GraphLLaMA model.
    
    Extends LlamaConfig to include graph-specific configuration options.
    """
    model_type = "GraphLlama"

class GraphPretrainConfig:
    """
    Configuration class for graph pre-training.
    
    A simple wrapper that converts dictionary config to object attributes.
    
    Args:
        dictionary (dict): Configuration dictionary
    """
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            setattr(self, key, value)

def load_model_pretrained(model_name, pretrain_model_path):
    """
    Load a pretrained model from checkpoint.
    
    Args:
        model_name: Model class to instantiate
        pretrain_model_path (str): Path to pretrained model checkpoint
        
    Returns:
        Tuple[nn.Module, GraphPretrainConfig]: Loaded model and its configuration
        
    Raises:
        AssertionError: If config.json is missing
    """
    assert os.path.exists(os.path.join(pretrain_model_path, 'config.json')), 'config.json missing'
    with open(os.path.join(pretrain_model_path, 'config.json'), 'r') as f:
        config_dict = json.load(f)
    args = GraphPretrainConfig(config_dict)
    model = model_name(args)
    pkl_files = glob.glob(os.path.join(pretrain_model_path, '*.pkl'))
    state_dict = torch.load(pkl_files[0])
    if 'logit_scale' in state_dict.keys():
        state_dict.pop('logit_scale')
    print('loading graph pre train model')
    model.load_state_dict(state_dict)
    return model, args

def load_metahgt_pretrained(model_name, pretrain_model_path):
    """
    Load a pretrained MetaHGT model from checkpoint.
    
    Args:
        model_name: Should be MetaHGTConv class
        pretrain_model_path (str): Path to pretrained model checkpoint
        
    Returns:
        MetaHGTConv: Loaded model instance
        
    Raises:
        AssertionError: If config files are missing or model_name is incorrect
    """
    assert os.path.exists(os.path.join(pretrain_model_path, 'graph_config.json')), 'graph_config.json missing'
    with open(os.path.join(pretrain_model_path, 'graph_config.json'), 'r') as f:
        graph_config_dict = json.load(f)
    graph_cfg = MetaHGTConvCfg(**graph_config_dict)

    assert os.path.exists(os.path.join(pretrain_model_path, 'text_config.json')), 'text_config.json missing'
    with open(os.path.join(pretrain_model_path, 'text_config.json'), 'r') as f:
        text_config_dict = json.load(f)
    text_cfg = CLIPTextCfg(**text_config_dict)
    
    assert model_name == MetaHGTConv
    model = model_name(
        in_channels=graph_cfg.in_channels,
        out_channels=graph_cfg.out_channels,
        heads=graph_cfg.heads,
        dynamic=graph_cfg.dynamic,
        text_cfg=text_cfg,
    )

    pkl_files = glob.glob(os.path.join(pretrain_model_path, '*.ckpt'))
    state_dict = torch.load(pkl_files[0], map_location='cpu')['state_dict']
    print('loading graph pre train model ...')
    gnn_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('model.graph_encoder'):
            new_key = key.split('model.graph_encoder.')[1]
            gnn_state_dict[new_key] = value
    model.load_state_dict(gnn_state_dict, strict=False)
    return model

def transfer_param_tograph(clip_graph, gnn):
    """
    Transfer parameters from CLIP graph encoder to GNN.
    
    Args:
        clip_graph: Source CLIP model containing graph encoder
        gnn: Target GNN model
        
    Returns:
        nn.Module: GNN with transferred parameters
    """
    gnn_state_dict = clip_graph.gnn.state_dict()
    gnn.load_state_dict(gnn_state_dict)
    return gnn

class GraphLlamaModel(LlamaModel):
    """
    GraphLLaMA model that combines LLaMA with graph processing capabilities.
    
    Extends LlamaModel to handle graph inputs through various graph neural network
    architectures including MPNN, GCN, and graph transformers.
    
    Args:
        config (LlamaConfig): Model configuration
    """
    config_class = GraphLlamaConfig

    def __init__(self, config: LlamaConfig):
        super(GraphLlamaModel, self).__init__(config)

        if hasattr(config, "graph_tower"):
            if config.graph_tower == 'MPNN':
                self.graph_tower = MPNN(
                    in_channels=config.graph_hidden_size,
                    hidden_channels=config.graph_hidden_size * 2,
                    out_channels=config.graph_hidden_size,
                    dropout=0.1,
                    num_layers=2,
                    if_param=False
                )
            elif config.graph_tower == "clip_gcn_arxiv":
                clip_graph, args = load_model_pretrained(CLIP, config.pretrain_graph_model_path)
                self.graph_tower = GNN(args)
                self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower)
            elif config.graph_tower == "clip_gt":
                clip_graph, args = load_model_pretrained(CLIP, config.pretrain_graph_model_path)
                self.graph_tower = graph_transformer(args)
                self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower)
            elif config.graph_tower == "clip_gt_arxiv":
                clip_graph, args = load_model_pretrained(CLIP, config.pretrain_graph_model_path)
                self.graph_tower = graph_transformer(args)
                self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower)
            elif config.graph_tower == "clip_gt_arxiv_pub":
                clip_graph, args = load_model_pretrained(CLIP, config.pretrain_graph_model_path)
                self.graph_tower = graph_transformer(args)
                self.graph_tower = transfer_param_tograph(clip_graph, self.graph_tower)

        if hasattr(config, "use_graph_proj"):
            self.graph_projector = nn.Linear(config.graph_hidden_size, config.hidden_size)

    def get_graph_tower(self):
        """
        Get the graph processing component.
        
        Returns:
            nn.Module: Graph neural network module
        """
        graph_tower = getattr(self, 'graph_tower', None)
        if type(graph_tower) is list:
            graph_tower = graph_tower[0]
        return graph_tower

    def initialize_graph_modules(self, graph_tower, graph_select_layer,
                               pretrain_graph_mlp_adapter=None, fsdp=None):
        """
        Initialize graph processing modules.
        
        Args:
            graph_tower (str): Type of graph neural network to use
            graph_select_layer (int): Which layer to select features from
            pretrain_graph_mlp_adapter (str, optional): Path to pretrained adapter weights
            fsdp (list, optional): FSDP configuration
        """
        self.config.graph_tower = graph_tower

        if not hasattr(self, 'graph_tower'):
            if self.config.graph_tower == 'MPNN':
                graph_tower = MPNN(
                    in_channels=self.config.graph_hidden_size,
                    hidden_channels=self.config.graph_hidden_size * 2,
                    out_channels=self.config.graph_hidden_size,
                    dropout=0.1,
                    num_layers=2,
                    if_param=False
                )
            elif self.config.graph_tower == "clip_gcn_arxiv":
                clip_graph, args = load_model_pretrained(CLIP, self.config.pretrain_graph_model_path)
                graph_tower = GNN(args)
                graph_tower = transfer_param_tograph(clip_graph, graph_tower)
            elif self.config.graph_tower == "clip_gt":
                clip_graph, args = load_model_pretrained(CLIP, self.config.pretrain_graph_model_path)
                graph_tower = graph_transformer(args)
                graph_tower = transfer_param_tograph(clip_graph, graph_tower)
            elif self.config.graph_tower == "clip_gt_arxiv":
                clip_graph, args = load_model_pretrained(CLIP, self.config.pretrain_graph_model_path)
                graph_tower = graph_transformer(args)
                graph_tower = transfer_param_tograph(clip_graph, graph_tower)
            elif self.config.graph_tower == "clip_gt_arxiv_pub":
                clip_graph, args = load_model_pretrained(CLIP, self.config.pretrain_graph_model_path)
                graph_tower = graph_transformer(args)
                graph_tower = transfer_param_tograph(clip_graph, graph_tower)
        else:
            graph_tower = self.graph_tower

        graph_tower.requires_grad_(False)

        if fsdp is not None and len(fsdp) > 0:
            self.graph_tower = [graph_tower]
        else:
            self.graph_tower = graph_tower

        self.config.use_graph_proj = True
        self.config.graph_select_layer = graph_select_layer

        if not hasattr(self, 'graph_projector'):
            self.graph_projector = nn.Linear(self.config.graph_hidden_size, self.config.hidden_size)

        if pretrain_graph_mlp_adapter is not None:
            graph_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu')
            self.graph_projector.load_state_dict({k.split('.')[-1]: v for k, v in graph_projector_weights.items()})

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        graph_data: Optional[Data] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        """
        Forward pass of the GraphLLaMA model.
        
        Processes both text and graph inputs through the model, combining them
        via graph-augmented attention mechanisms.
        
        Args:
            input_ids (torch.LongTensor, optional): Input token IDs
            attention_mask (torch.Tensor, optional): Attention mask
            past_key_values (List[torch.FloatTensor], optional): Cached key/values for faster inference
            inputs_embeds (torch.FloatTensor, optional): Pre-computed input embeddings
            use_cache (bool, optional): Whether to use past key/values
            output_attentions (bool, optional): Whether to output attention weights
            output_hidden_states (bool, optional): Whether to output all hidden states
            graph_data (Data, optional): Input graph data
            return_dict (bool, optional): Whether to return a dictionary output
            
        Returns:
            Union[Tuple, BaseModelOutputWithPast]: Model outputs including hidden states,
                attention weights and past key/values if requested
        """
        orig_embeds_params = getattr(self, 'orig_embeds_params', None)

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        graph_tower = self.get_graph_tower()
        if graph_tower is not None and (input_ids.shape[1] != 1 or self.training) and graph_data is not None:
            with torch.no_grad():
                if type(graph_data) is list:
                    graph_node_features = []
                    if type(graph_data[0]) is Data:
                        for g in graph_data:
                            node_forward_out = graph_tower(g)
                            graph_node_features.append(node_forward_out)
                    elif type(graph_data[0]) is dict:
                        for g_dict in graph_data:
                            node_forward_out_1 = graph_tower(g_dict['graph_1'])
                            node_forward_out_2 = graph_tower(g_dict['graph_2'])
                            graph_node_features.append(node_forward_out_1)
                            graph_node_features.append(node_forward_out_2)
                else:
                    raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}')

            if type(graph_data) is list:
                graph_node_features = [self.graph_projector(node_feature) for node_feature in graph_node_features]
            else:
                raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}')

            dummy_graph_features = torch.zeros(256, 128, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
            dummy_graph_features = self.graph_projector(dummy_graph_features)

            new_input_embeds = []
            cur_graph_idx = 0
            for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
                if (cur_input_ids == graph_tower.config.graph_patch_token).sum() == 0:
                    cur_input_embeds = cur_input_embeds + (0. * dummy_graph_features).sum()
                    new_input_embeds.append(cur_input_embeds)
                    cur_graph_idx += 1
                    continue
                if graph_tower.config.use_graph_start_end:
                    cur_graph_features = graph_node_features[cur_graph_idx]
                    num_patches = cur_graph_features.shape[0]
                    if (cur_input_ids == graph_tower.config.graph_start_token).sum() != (cur_input_ids == graph_tower.config.graph_end_token).sum():
                        raise ValueError("The number of graph start tokens and graph end tokens should be the same.")
                    graph_start_tokens = torch.where(cur_input_ids == graph_tower.config.graph_start_token)[0]
                    for graph_start_token_pos in graph_start_tokens:
                        cur_graph_features = graph_node_features[cur_graph_idx].to(device=cur_input_embeds.device)
                        num_patches = cur_graph_features.shape[0]
                        if cur_input_ids[graph_start_token_pos + num_patches + 1] != graph_tower.config.graph_end_token:
                            raise ValueError("The graph end token should follow the graph start token.")
                        if orig_embeds_params is not None:
                            cur_new_input_embeds = torch.cat((cur_input_embeds[:graph_start_token_pos].detach(), cur_input_embeds[graph_start_token_pos:graph_start_token_pos+1], cur_graph_features, cur_input_embeds[graph_start_token_pos + num_patches + 1:graph_start_token_pos + num_patches + 2], cur_input_embeds[graph_start_token_pos + num_patches + 2:].detach()), dim=0)
                        else:
                            cur_new_input_embeds = torch.cat((cur_input_embeds[:graph_start_token_pos+1], cur_graph_features, cur_input_embeds[graph_start_token_pos + num_patches + 1:]), dim=0)
                        cur_graph_idx += 1
                    new_input_embeds.append(cur_new_input_embeds)
                else:
                    cur_graph_features = graph_node_features[cur_graph_idx]
                    num_patches = cur_graph_features.shape[0]
                    if (cur_input_ids == graph_tower.config.graph_patch_token).sum() != num_patches:
                        raise ValueError("The number of graph patch tokens should be the same as the number of graph patches.")
                    masked_indices = torch.where(cur_input_ids == graph_tower.config.graph_patch_token)[0]
                    mask_index_start = masked_indices[0]
                    if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
                        raise ValueError("The graph patch tokens should be consecutive.")
                    if orig_embeds_params is not None:
                        cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_graph_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
                    else:
                        cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_graph_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
                    new_input_embeds.append(cur_new_input_embeds)
                    cur_graph_idx += 1

            assert cur_graph_idx == len(graph_node_features)
            inputs_embeds = torch.stack(new_input_embeds, dim=0)

        return super(GraphLlamaModel, self).forward(
            input_ids=None,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

class GraphLlamaForCausalLM(LlamaForCausalLM):
    """
    GraphLLaMA model for causal language modeling.
    
    Extends LlamaForCausalLM to support graph-augmented language modeling by
    incorporating graph structure information into the generation process.
    
    Args:
        config (GraphLlamaConfig): Model configuration
    """
    config_class = GraphLlamaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = GraphLlamaModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()

    def get_model(self):
        """Get the underlying GraphLlamaModel."""
        return self.model

    def get_graph_tower(self):
        """Get the graph processing component."""
        return self.get_model().get_graph_tower()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        graph_data: Optional[Data] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            graph_data=graph_data
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        """
        Prepare inputs for text generation.
        
        Args:
            input_ids: Input token IDs
            past_key_values: Cached key/values from previous forward passes
            attention_mask: Attention mask
            inputs_embeds: Pre-computed input embeddings
            **kwargs: Additional arguments
            
        Returns:
            dict: Dictionary of model inputs
        """
        if past_key_values:
            input_ids = input_ids[:, -1:]

        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "graph_data": [kwargs.get("graph_data", None)],
            }
        )
        return model_inputs

    def initialize_graph_tokenizer(self, use_graph_start_end, tokenizer, device,
                                 tune_graph_mlp_adapter=False, pretrain_graph_mlp_adapter=None):
        """
        Initialize tokenizer for graph inputs.
        
        Args:
            use_graph_start_end (bool): Whether to use special graph tokens
            tokenizer: Base tokenizer to extend
            device: Device to place new tokens on
            tune_graph_mlp_adapter (bool): Whether to tune graph MLP adapter
            pretrain_graph_mlp_adapter (str, optional): Path to pretrained adapter
        """
        vision_config = self.get_graph_tower().config
        vision_config.use_graph_start_end = use_graph_start_end
        tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True)
        self.resize_token_embeddings(len(tokenizer))

        if use_graph_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))
            vision_config.graph_start_token, vision_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN])

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if tune_graph_mlp_adapter:
                self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if pretrain_graph_mlp_adapter:
                mm_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu')
                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")

        vision_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0]


AutoConfig.register("GraphLlama", GraphLlamaConfig)
AutoModelForCausalLM.register(GraphLlamaConfig, GraphLlamaForCausalLM)


class HeteroLlamaConfig(LlamaConfig):
    """
    Configuration class for HeteroLLaMA model.
    
    Extends LlamaConfig to include configuration options for heterogeneous graph processing.
    """
    model_type = "HeteroLlama"

class HeteroLlamaModel(LlamaModel):
    """
    HeteroLLaMA model that handles heterogeneous graph inputs.
    
    Extends LlamaModel to process heterogeneous graphs with different node and edge types
    through specialized graph neural networks.
    
    Args:
        config (LlamaConfig): Model configuration
    """
    
    config_class = HeteroLlamaConfig

    def __init__(self, config: LlamaConfig):
        super(HeteroLlamaModel, self).__init__(config)

        if hasattr(config, "graph_tower"):
            if config.graph_tower == 'MPNN':
                self.graph_tower = MPNN(
                    in_channels=config.graph_hidden_size,
                    hidden_channels=config.graph_hidden_size * 2,
                    out_channels=config.graph_hidden_size,
                    dropout=0.1,
                    num_layers=2,
                    if_param=False
                )
            elif config.graph_tower == "meta_hgt":
                self.graph_tower = load_metahgt_pretrained(MetaHGTConv, config.pretrain_graph_model_path)

        if hasattr(config, "use_graph_proj"):
            self.graph_projector = nn.Linear(config.graph_hidden_size, config.hidden_size)

    def get_graph_tower(self):
        graph_tower = getattr(self, 'graph_tower', None)
        if type(graph_tower) is list:
            graph_tower = graph_tower[0]
        return graph_tower

    def initialize_graph_modules(self, graph_tower, graph_select_layer,
                               pretrain_graph_mlp_adapter=None, fsdp=None):
        self.config.graph_tower = graph_tower

        if not hasattr(self, 'graph_tower'):
            if self.config.graph_tower == 'MPNN':
                graph_tower = MPNN(
                    in_channels=self.config.graph_hidden_size,
                    hidden_channels=self.config.graph_hidden_size * 2,
                    out_channels=self.config.graph_hidden_size,
                    dropout=0.1,
                    num_layers=2,
                    if_param=False
                )
            elif self.config.graph_tower == "meta_hgt":
                graph_tower = load_metahgt_pretrained(MetaHGTConv, self.config.pretrain_graph_model_path)
        else:
            graph_tower = self.graph_tower

        graph_tower.requires_grad_(False)

        if fsdp is not None and len(fsdp) > 0:
            self.graph_tower = [graph_tower]
        else:
            self.graph_tower = graph_tower

        self.config.use_graph_proj = True
        self.config.graph_select_layer = graph_select_layer

        if not hasattr(self, 'graph_projector'):
            self.graph_projector = nn.Linear(self.config.graph_hidden_size, self.config.hidden_size)

        if pretrain_graph_mlp_adapter is not None:
            graph_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu')
            self.graph_projector.load_state_dict({k.split('.')[-1]: v for k, v in graph_projector_weights.items()})

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        graph_data: Optional[Data] = None,
        return_dict: Optional[bool] = None,
        hetero_key_order: Optional[List[List[str]]] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        """
        Forward pass of the HeteroLLaMA model.
        
        Processes heterogeneous graph inputs along with text through the model.
        
        Args:
            input_ids (torch.LongTensor, optional): Input token IDs
            attention_mask (torch.Tensor, optional): Attention mask
            past_key_values (List[torch.FloatTensor], optional): Cached key/values
            inputs_embeds (torch.FloatTensor, optional): Pre-computed input embeddings
            use_cache (bool, optional): Whether to use past key/values
            output_attentions (bool, optional): Whether to output attention weights
            output_hidden_states (bool, optional): Whether to output all hidden states
            graph_data (Data, optional): Input heterogeneous graph data
            return_dict (bool, optional): Whether to return a dictionary output
            hetero_key_order (List[List[str]], optional): Node type ordering for heterogeneous graphs
            
        Returns:
            Union[Tuple, BaseModelOutputWithPast]: Model outputs including hidden states,
                attention weights and past key/values if requested
        """
        orig_embeds_params = getattr(self, 'orig_embeds_params', None)

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        graph_tower = self.get_graph_tower()
        if graph_tower is not None and (input_ids.shape[1] != 1 or self.training) and graph_data is not None:
            with torch.no_grad():
                if type(graph_data) is list:
                    graph_node_features = []
                    if type(graph_data[0]) is Data:
                        for g in graph_data:
                            node_forward_out = graph_tower(g.x_dict, g.edge_index_dict)
                            graph_node_features.append(node_forward_out)
                    elif type(graph_data[0]) is dict:
                        for g_dict in graph_data:
                            node_forward_out_1 = graph_tower(g_dict['graph_1'])
                            node_forward_out_2 = graph_tower(g_dict['graph_2'])
                            graph_node_features.append(node_forward_out_1)
                            graph_node_features.append(node_forward_out_2)
                else:
                    raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}')

            if type(graph_data) is list:
                graph_node_features_list = []
                for idx, order in enumerate(hetero_key_order):
                    graph_node_features_list.extend([graph_node_features[idx][o] for o in order])
                graph_node_features = [self.graph_projector(node_feature) for node_feature in graph_node_features_list]
            else:
                raise ValueError(f'graph_node_reps is expected to be a list but got {type(graph_data)}')

            dummy_graph_features = torch.zeros(256, 128, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
            dummy_graph_features = self.graph_projector(dummy_graph_features)

            new_input_embeds = []
            cur_graph_idx = 0
            for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
                if (cur_input_ids == graph_tower.config.graph_patch_token).sum() == 0:
                    cur_input_embeds = cur_input_embeds + (0. * dummy_graph_features).sum()
                    new_input_embeds.append(cur_input_embeds)
                    cur_graph_idx += 1
                    continue
                if graph_tower.config.use_graph_start_end:
                    cur_graph_features = graph_node_features[cur_graph_idx]
                    num_patches = cur_graph_features.shape[0]
                    if (cur_input_ids == graph_tower.config.graph_start_token).sum() != (cur_input_ids == graph_tower.config.graph_end_token).sum():
                        raise ValueError("The number of graph start tokens and graph end tokens should be the same.")
                    graph_start_tokens = torch.where(cur_input_ids == graph_tower.config.graph_start_token)[0]
                    for graph_start_token_pos in graph_start_tokens:
                        cur_graph_features = graph_node_features[cur_graph_idx].to(device=cur_input_embeds.device)
                        num_patches = cur_graph_features.shape[0]
                        if cur_input_ids[graph_start_token_pos + num_patches + 1] != graph_tower.config.graph_end_token:
                            raise ValueError("The graph end token should follow the graph start token.")
                        if orig_embeds_params is not None:
                            cur_new_input_embeds = torch.cat((cur_input_embeds[:graph_start_token_pos].detach(), cur_input_embeds[graph_start_token_pos:graph_start_token_pos+1], cur_graph_features, cur_input_embeds[graph_start_token_pos + num_patches + 1:graph_start_token_pos + num_patches + 2], cur_input_embeds[graph_start_token_pos + num_patches + 2:].detach()), dim=0)
                        else:
                            cur_new_input_embeds = torch.cat((cur_input_embeds[:graph_start_token_pos+1], cur_graph_features, cur_input_embeds[graph_start_token_pos + num_patches + 1:]), dim=0)
                        cur_graph_idx += 1
                    new_input_embeds.append(cur_new_input_embeds)
                else:
                    cur_graph_features = graph_node_features[cur_graph_idx]
                    num_patches = cur_graph_features.shape[0]
                    if (cur_input_ids == graph_tower.config.graph_patch_token).sum() != num_patches:
                        raise ValueError("The number of graph patch tokens should be the same as the number of graph patches.")
                    masked_indices = torch.where(cur_input_ids == graph_tower.config.graph_patch_token)[0]
                    mask_index_start = masked_indices[0]
                    if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
                        raise ValueError("The graph patch tokens should be consecutive.")
                    if orig_embeds_params is not None:
                        cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_graph_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
                    else:
                        cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_graph_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
                    new_input_embeds.append(cur_new_input_embeds)
                    cur_graph_idx += 1

            assert cur_graph_idx == len(graph_node_features)
            inputs_embeds = torch.stack(new_input_embeds, dim=0)

        return super(HeteroLlamaModel, self).forward(
            input_ids=None,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

[docs]class HeteroLlamaForCausalLM(LlamaForCausalLM): """ HeteroLLaMA model for causal language modeling with heterogeneous graphs. Extends LlamaForCausalLM to support language modeling conditioned on heterogeneous graph structures. Args: config (HeteroLlamaConfig): Model configuration """ config_class = HeteroLlamaConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = HeteroLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init()
[docs] def get_model(self): """Get the underlying HeteroLlamaModel.""" return self.model
[docs] def get_graph_tower(self): """Get the heterogeneous graph processing component.""" return self.get_model().get_graph_tower()
[docs] def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, graph_data: Optional[Data] = None, return_dict: Optional[bool] = None, hetero_key_order: Optional[List[List[str]]] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, graph_data=graph_data, hetero_key_order=hetero_key_order ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[docs] def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} if kwargs.get("graph_data") is None: model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "graph_data": None, "hetero_key_order": [kwargs.get("hetero_key_order", None)] } ) else: model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "graph_data": [kwargs.get("graph_data", None)], "hetero_key_order": [kwargs.get("hetero_key_order", None)] } ) return model_inputs
[docs] def initialize_graph_tokenizer(self, use_graph_start_end, tokenizer, device, tune_graph_mlp_adapter=False, pretrain_graph_mlp_adapter=None): vision_config = self.get_graph_tower().config vision_config.use_graph_start_end = use_graph_start_end tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if use_graph_start_end: num_new_tokens = tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) vision_config.graph_start_token, vision_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN]) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if tune_graph_mlp_adapter: self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False if pretrain_graph_mlp_adapter: mm_projector_weights = torch.load(pretrain_graph_mlp_adapter, map_location='cpu') embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") vision_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0]
AutoConfig.register("HeteroLlama", HeteroLlamaConfig) AutoModelForCausalLM.register(HeteroLlamaConfig, HeteroLlamaForCausalLM) def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """ Tokenize a list of strings. Args: strings (Sequence[str]): List of input strings to tokenize tokenizer (PreTrainedTokenizer): Tokenizer to use Returns: Dict: Dictionary containing: - input_ids: Token IDs - labels: Labels for language modeling - input_ids_lens: Lengths of input sequences - labels_lens: Lengths of label sequences """ tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): """ Mask target tokens for dialogue modeling. Args: target: Target token IDs to mask tokenized_lens: List of token sequence lengths speakers: List of speaker identifiers """ cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): """ Resize tokenizer and embedding layers to accommodate new special tokens. Args: special_tokens_dict (Dict): Dictionary of special tokens to add tokenizer (PreTrainedTokenizer): Tokenizer to modify model (PreTrainedModel): Model whose embeddings need resizing """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def unwrap_model(model: nn.Module) -> nn.Module: """ Recursively unwrap a model from potential containers. Args: model (nn.Module): Model to unwrap Returns: nn.Module: Unwrapped model """ if hasattr(model, "module"): return unwrap_model(model.module) else: return model def find_all_linear_names(model): """ Find all linear layer names in the model. Args: model: Model to analyze Returns: list: List of linear layer names, excluding lm_head """ cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: lora_module_names.remove('lm_head') return list(lora_module_names) def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """ Collects the state dict and saves model to disk safely. Handles DeepSpeed and regular model saving with proper synchronization. Args: trainer (transformers.Trainer): HuggingFace trainer instance output_dir (str): Directory to save model """ if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) def get_peft_state_maybe_zero_3(named_params, bias): """ Get PEFT state dict handling DeepSpeed ZeRO-3. Args: named_params: Named parameters from model bias (str): Bias handling mode ('none', 'all', or 'lora_only') Returns: dict: State dict with gathered parameters Raises: NotImplementedError: If bias mode is not recognized """ if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): """ Get non-LoRA PEFT state dict handling DeepSpeed ZeRO-3. Args: named_params: Named parameters from model require_grad_only (bool, optional): Whether to only include parameters requiring gradients. Defaults to True Returns: dict: State dict with gathered parameters """ to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def maybe_zero_3(param, ignore_status=False, name=None): """ Handle DeepSpeed ZeRO-3 parameters. Args: param: Model parameter ignore_status (bool, optional): Whether to ignore parameter status. Defaults to False name (str, optional): Parameter name for logging. Defaults to None Returns: torch.Tensor: Gathered parameter data """ # from deepspeed import zero # from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def download_cached_file(url: str, check_hash: bool = True, progress: bool = True) -> str: """ Download a file from URL and cache it locally. Args: url (str): URL to download from check_hash (bool, optional): Whether to verify file hash. Defaults to True progress (bool, optional): Whether to show progress bar. Defaults to True Returns: str: Path to cached file """ from torch.hub import download_url_to_file, get_dir from urllib.parse import urlparse import os parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(get_dir(), filename) if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}') download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) return cached_file class MetaHeteroLinear(nn.Module): """ Meta-learning based linear layer for heterogeneous inputs. Implements a dynamic or static linear transformation that can handle different types of inputs through meta-learning. Args: text_width (int): Width of text embeddings for generating weights in_features (int): Input feature dimension out_features (int): Output feature dimension dynamic (bool, optional): Whether to use dynamic weight generation. Defaults to True """ def __init__(self, text_width: int, in_features: int, out_features: int, dynamic: bool = True): super().__init__() self.text_width = text_width self.in_features = in_features self.out_features = out_features self.dynamic = dynamic if dynamic: self.weight_proj = nn.Linear(text_width, in_features * out_features) self.bias_proj = nn.Linear(text_width, out_features) else: self.weight = nn.Parameter(torch.randn(in_features, out_features) / in_features ** 0.5) self.bias = nn.Parameter(torch.zeros(out_features)) def forward(self, x: Tensor, type_id: Tensor, type_embed_dict: Dict[int, Tensor]) -> Tensor: """ Forward pass with type-specific transformations. Args: x (Tensor): Input features type_id (Tensor): Type IDs for each input type_embed_dict (Dict[int, Tensor]): Type embeddings dictionary Returns: Tensor: Transformed features """ if self.dynamic: weight = [] bias = [] for i in range(len(type_embed_dict)): type_embed = type_embed_dict[i] cur_weight = self.weight_proj(type_embed).view(self.in_features, self.out_features) cur_bias = self.bias_proj(type_embed) weight.append(cur_weight) bias.append(cur_bias) weight = torch.stack(weight) bias = torch.stack(bias) weight = weight[type_id] bias = bias[type_id] else: weight = self.weight bias = self.bias return F.linear(x, weight, bias) class MetaHeteroDictLinear(nn.Module): """ Meta-learning based linear layer for dictionary of heterogeneous inputs. Similar to MetaHeteroLinear but handles dictionary inputs where each key represents a different node/edge type. Args: text_width (int): Width of text embeddings for generating weights in_features (int): Input feature dimension out_features (int): Output feature dimension dynamic (bool, optional): Whether to use dynamic weight generation. Defaults to True """ def __init__(self, text_width: int, in_features: int, out_features: int, dynamic: bool = True): super().__init__() self.text_width = text_width self.in_features = in_features self.out_features = out_features self.dynamic = dynamic if dynamic: self.weight_proj = nn.Linear(text_width, in_features * out_features) self.bias_proj = nn.Linear(text_width, out_features) else: self.weight = nn.Parameter(torch.randn(in_features, out_features) / in_features ** 0.5) self.bias = nn.Parameter(torch.zeros(out_features)) def forward(self, x_dict: Dict[str, Tensor], type_embed_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: """ Forward pass with type-specific transformations on dictionary inputs. Args: x_dict (Dict[str, Tensor]): Dictionary of input features by type type_embed_dict (Dict[str, Tensor]): Dictionary of type embeddings Returns: Dict[str, Tensor]: Dictionary of transformed features by type """ out_dict = {} for node_type, x in x_dict.items(): if self.dynamic: type_embed = type_embed_dict[node_type] weight = self.weight_proj(type_embed).view(self.in_features, self.out_features) bias = self.bias_proj(type_embed) else: weight = self.weight bias = self.bias out_dict[node_type] = F.linear(x, weight, bias) return out_dict class HeteClipLoss(nn.Module): """ Contrastive loss for heterogeneous CLIP model. Implements InfoNCE-style contrastive loss between graph and text embeddings with temperature scaling. """ def __init__(self): super().__init__() self.labels = None self.last_local_batch_size = None def forward(self, graph_features, text_features, logit_scale): """ Compute contrastive loss between graph and text features. Args: graph_features (Tensor): Graph embeddings text_features (Tensor): Text embeddings logit_scale (Tensor): Temperature scaling factor Returns: Tensor: Contrastive loss value """ device = graph_features.device local_batch_size = graph_features.shape[0] if local_batch_size != self.last_local_batch_size: self.labels = local_batch_size * [1] self.labels = torch.LongTensor(self.labels).to(device) self.last_local_batch_size = local_batch_size graph_features = F.normalize(graph_features, dim=-1) text_features = F.normalize(text_features, dim=-1) logits_per_graph = logit_scale * graph_features @ text_features.t() logits_per_text = logits_per_graph.t() loss = ( F.cross_entropy(logits_per_graph, self.labels) + F.cross_entropy(logits_per_text, self.labels) ) / 2 return loss openai_imagenet_template = [ lambda c: f"a photo of a {c}.", lambda c: f"a photograph of a {c}.", lambda c: f"an image of a {c}.", lambda c: f"a picture of a {c}.", ]