FlameF0X commited on
Commit
24d2ad5
·
verified ·
1 Parent(s): bf66d4f

Create modeling_i3.py

Browse files
Files changed (1) hide show
  1. modeling_i3.py +37 -0
modeling_i3.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from i3_modules import i3Model # import your original i3Model
5
+
6
+ class i3Config(PretrainedConfig):
7
+ model_type = "i3"
8
+
9
+ def __init__(self, vocab_size=65, d_model=256, n_layers=6, n_heads=8,
10
+ max_seq_len=256, rank=8, d_state=16, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.vocab_size = vocab_size
13
+ self.d_model = d_model
14
+ self.n_layers = n_layers
15
+ self.n_heads = n_heads
16
+ self.max_seq_len = max_seq_len
17
+ self.rank = rank
18
+ self.d_state = d_state
19
+
20
+ class i3(PreTrainedModel):
21
+ config_class = i3Config
22
+ base_model_prefix = "i3"
23
+
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ self.model = i3Model(
27
+ vocab_size=config.vocab_size,
28
+ d_model=config.d_model,
29
+ n_layers=config.n_layers,
30
+ n_heads=config.n_heads,
31
+ max_seq_len=config.max_seq_len,
32
+ rank=config.rank,
33
+ d_state=config.d_state
34
+ )
35
+
36
+ def forward(self, input_ids, labels=None):
37
+ return self.model(input_ids, labels)