i3-tiny / modeling_i3.py
FlameF0X's picture
Update modeling_i3.py
7fe5689 verified
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from i3_modules import i3Model # import the original i3Model class
class i3Config(PretrainedConfig):
model_type = "i3"
def __init__(self, vocab_size=34, d_model=256, n_layers=6, n_heads=8,
max_seq_len=128, rank=8, d_state=16, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.max_seq_len = max_seq_len
self.rank = rank
self.d_state = d_state
class i3(PreTrainedModel):
config_class = i3Config
base_model_prefix = "i3"
def __init__(self, config):
super().__init__(config)
self.model = i3Model(
vocab_size=config.vocab_size,
d_model=config.d_model,
n_layers=config.n_layers,
n_heads=config.n_heads,
max_seq_len=config.max_seq_len,
rank=config.rank,
d_state=config.d_state
)
def forward(self, input_ids, labels=None):
return self.model(input_ids, labels)