File size: 1,165 Bytes
24d2ad5
 
 
7fe5689
24d2ad5
 
 
 
9a47333
 
24d2ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from i3_modules import i3Model  # import the original i3Model class

class i3Config(PretrainedConfig):
    model_type = "i3"

    def __init__(self, vocab_size=34, d_model=256, n_layers=6, n_heads=8,
                 max_seq_len=128, rank=8, d_state=16, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.max_seq_len = max_seq_len
        self.rank = rank
        self.d_state = d_state

class i3(PreTrainedModel):
    config_class = i3Config
    base_model_prefix = "i3"

    def __init__(self, config):
        super().__init__(config)
        self.model = i3Model(
            vocab_size=config.vocab_size,
            d_model=config.d_model,
            n_layers=config.n_layers,
            n_heads=config.n_heads,
            max_seq_len=config.max_seq_len,
            rank=config.rank,
            d_state=config.d_state
        )

    def forward(self, input_ids, labels=None):
        return self.model(input_ids, labels)