import torch from torch import nn from transformers import PreTrainedModel, PretrainedConfig from i3_modules import i3Model # import the original i3Model class class i3Config(PretrainedConfig): model_type = "i3" def __init__(self, vocab_size=34, d_model=256, n_layers=6, n_heads=8, max_seq_len=128, rank=8, d_state=16, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.max_seq_len = max_seq_len self.rank = rank self.d_state = d_state class i3(PreTrainedModel): config_class = i3Config base_model_prefix = "i3" def __init__(self, config): super().__init__(config) self.model = i3Model( vocab_size=config.vocab_size, d_model=config.d_model, n_layers=config.n_layers, n_heads=config.n_heads, max_seq_len=config.max_seq_len, rank=config.rank, d_state=config.d_state ) def forward(self, input_ids, labels=None): return self.model(input_ids, labels)