|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from i3_modules import i3Model |
|
|
|
|
|
class i3Config(PretrainedConfig): |
|
|
model_type = "i3" |
|
|
|
|
|
def __init__(self, vocab_size=34, d_model=256, n_layers=6, n_heads=8, |
|
|
max_seq_len=128, rank=8, d_state=16, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self.n_heads = n_heads |
|
|
self.max_seq_len = max_seq_len |
|
|
self.rank = rank |
|
|
self.d_state = d_state |
|
|
|
|
|
class i3(PreTrainedModel): |
|
|
config_class = i3Config |
|
|
base_model_prefix = "i3" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = i3Model( |
|
|
vocab_size=config.vocab_size, |
|
|
d_model=config.d_model, |
|
|
n_layers=config.n_layers, |
|
|
n_heads=config.n_heads, |
|
|
max_seq_len=config.max_seq_len, |
|
|
rank=config.rank, |
|
|
d_state=config.d_state |
|
|
) |
|
|
|
|
|
def forward(self, input_ids, labels=None): |
|
|
return self.model(input_ids, labels) |
|
|
|