File size: 1,165 Bytes
24d2ad5 7fe5689 24d2ad5 9a47333 24d2ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from i3_modules import i3Model # import the original i3Model class
class i3Config(PretrainedConfig):
model_type = "i3"
def __init__(self, vocab_size=34, d_model=256, n_layers=6, n_heads=8,
max_seq_len=128, rank=8, d_state=16, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.max_seq_len = max_seq_len
self.rank = rank
self.d_state = d_state
class i3(PreTrainedModel):
config_class = i3Config
base_model_prefix = "i3"
def __init__(self, config):
super().__init__(config)
self.model = i3Model(
vocab_size=config.vocab_size,
d_model=config.d_model,
n_layers=config.n_layers,
n_heads=config.n_heads,
max_seq_len=config.max_seq_len,
rank=config.rank,
d_state=config.d_state
)
def forward(self, input_ids, labels=None):
return self.model(input_ids, labels)
|