msr2000 commited on
Commit
fba8af1
·
verified ·
1 Parent(s): ba658fb

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. inference/README.md +14 -0
  2. inference/config_671B_v3.2.json +26 -0
  3. inference/convert.py +100 -0
  4. inference/generate.py +186 -0
  5. inference/kernel.py +274 -0
  6. inference/model.py +923 -0
  7. inference/requirements.txt +5 -0
  8. model-00001-of-000163.safetensors +3 -0
  9. model-00002-of-000163.safetensors +3 -0
  10. model-00003-of-000163.safetensors +3 -0
  11. model-00004-of-000163.safetensors +3 -0
  12. model-00005-of-000163.safetensors +3 -0
  13. model-00006-of-000163.safetensors +3 -0
  14. model-00007-of-000163.safetensors +3 -0
  15. model-00008-of-000163.safetensors +3 -0
  16. model-00009-of-000163.safetensors +3 -0
  17. model-00010-of-000163.safetensors +3 -0
  18. model-00011-of-000163.safetensors +3 -0
  19. model-00012-of-000163.safetensors +3 -0
  20. model-00013-of-000163.safetensors +3 -0
  21. model-00014-of-000163.safetensors +3 -0
  22. model-00015-of-000163.safetensors +3 -0
  23. model-00016-of-000163.safetensors +3 -0
  24. model-00017-of-000163.safetensors +3 -0
  25. model-00018-of-000163.safetensors +3 -0
  26. model-00019-of-000163.safetensors +3 -0
  27. model-00020-of-000163.safetensors +3 -0
  28. model-00021-of-000163.safetensors +3 -0
  29. model-00022-of-000163.safetensors +3 -0
  30. model-00023-of-000163.safetensors +3 -0
  31. model-00024-of-000163.safetensors +3 -0
  32. model-00025-of-000163.safetensors +3 -0
  33. model-00149-of-000163.safetensors +3 -0
  34. model-00150-of-000163.safetensors +3 -0
  35. model-00151-of-000163.safetensors +3 -0
  36. model-00152-of-000163.safetensors +3 -0
  37. model-00153-of-000163.safetensors +3 -0
  38. model-00154-of-000163.safetensors +3 -0
  39. model-00155-of-000163.safetensors +3 -0
  40. model-00156-of-000163.safetensors +3 -0
  41. model-00157-of-000163.safetensors +3 -0
  42. model-00158-of-000163.safetensors +3 -0
  43. model-00159-of-000163.safetensors +3 -0
  44. model-00160-of-000163.safetensors +3 -0
  45. model-00161-of-000163.safetensors +3 -0
  46. model-00162-of-000163.safetensors +3 -0
  47. model-00163-of-000163.safetensors +3 -0
  48. model.safetensors.index.json +0 -0
  49. tokenizer.json +0 -0
  50. tokenizer_config.json +34 -0
inference/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek V3.2
2
+
3
+ First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count:
4
+ ```bash
5
+ cd inference
6
+ export EXPERTS=256
7
+ python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
8
+ ```
9
+
10
+ Launch the interactive chat interface and start exploring DeepSeek's capabilities:
11
+ ```bash
12
+ export CONFIG=config_671B_v3.2.json
13
+ torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
14
+ ```
inference/config_671B_v3.2.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 129280,
3
+ "dim": 7168,
4
+ "inter_dim": 18432,
5
+ "moe_inter_dim": 2048,
6
+ "n_layers": 61,
7
+ "n_dense_layers": 3,
8
+ "n_heads": 128,
9
+ "n_routed_experts": 256,
10
+ "n_shared_experts": 1,
11
+ "n_activated_experts": 8,
12
+ "n_expert_groups": 8,
13
+ "n_limited_groups": 4,
14
+ "route_scale": 2.5,
15
+ "score_func": "sigmoid",
16
+ "q_lora_rank": 1536,
17
+ "kv_lora_rank": 512,
18
+ "qk_nope_head_dim": 128,
19
+ "qk_rope_head_dim": 64,
20
+ "v_head_dim": 128,
21
+ "dtype": "fp8",
22
+ "scale_fmt": "ue8m0",
23
+ "index_n_heads": 64,
24
+ "index_head_dim": 128,
25
+ "index_topk": 2048
26
+ }
inference/convert.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from argparse import ArgumentParser
4
+ from glob import glob
5
+ from tqdm import tqdm, trange
6
+
7
+ import torch
8
+ from safetensors.torch import safe_open, save_file
9
+
10
+
11
+ mapping = {
12
+ "embed_tokens": ("embed", 0),
13
+ "input_layernorm": ("attn_norm", None),
14
+ "post_attention_layernorm": ("ffn_norm", None),
15
+ "q_proj": ("wq", 0),
16
+ "q_a_proj": ("wq_a", None),
17
+ "q_a_layernorm": ("q_norm", None),
18
+ "q_b_proj": ("wq_b", 0),
19
+ "kv_a_proj_with_mqa": ("wkv_a", None),
20
+ "kv_a_layernorm": ("kv_norm", None),
21
+ "kv_b_proj": ("wkv_b", 0),
22
+ "o_proj": ("wo", 1),
23
+ "gate": ("gate", None),
24
+ "gate_proj": ("w1", 0),
25
+ "down_proj": ("w2", 1),
26
+ "up_proj": ("w3", 0),
27
+ "norm": ("norm", None),
28
+ "lm_head": ("head", 0),
29
+ "scale": ("scale", None),
30
+ "wq_b": ("wq_b", None),
31
+ "wk": ("wk", None),
32
+ "k_norm": ("k_norm", None),
33
+ "weights_proj": ("weights_proj", None),
34
+ }
35
+
36
+
37
+ def main(hf_ckpt_path, save_path, n_experts, mp):
38
+ """
39
+ Converts and saves model checkpoint files into a specified format.
40
+
41
+ Args:
42
+ hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
43
+ save_path (str): Path to the directory where the converted checkpoint files will be saved.
44
+ n_experts (int): Total number of experts in the model.
45
+ mp (int): Model parallelism factor.
46
+
47
+ Returns:
48
+ None
49
+ """
50
+ torch.set_num_threads(8)
51
+ n_local_experts = n_experts // mp
52
+ state_dicts = [{} for _ in range(mp)]
53
+
54
+ for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
55
+ with safe_open(file_path, framework="pt", device="cpu") as f:
56
+ for name in f.keys():
57
+ if "model.layers.61" in name:
58
+ continue
59
+ param: torch.Tensor = f.get_tensor(name)
60
+ if name.startswith("model."):
61
+ name = name[len("model."):]
62
+ name = name.replace("self_attn", "attn")
63
+ name = name.replace("mlp", "ffn")
64
+ name = name.replace("weight_scale_inv", "scale")
65
+ name = name.replace("e_score_correction_bias", "bias")
66
+ key = name.split(".")[-2]
67
+ assert key in mapping, f"Key {key} not found in mapping"
68
+ new_key, dim = mapping[key]
69
+ name = name.replace(key, new_key)
70
+ for i in range(mp):
71
+ new_param = param
72
+ if "experts" in name and "shared_experts" not in name:
73
+ idx = int(name.split(".")[-3])
74
+ if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
75
+ continue
76
+ elif dim is not None:
77
+ assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
78
+ shard_size = param.size(dim) // mp
79
+ new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
80
+ state_dicts[i][name] = new_param
81
+
82
+ os.makedirs(save_path, exist_ok=True)
83
+
84
+ for i in trange(mp):
85
+ save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
86
+
87
+ for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
88
+ new_file_path = os.path.join(save_path, os.path.basename(file_path))
89
+ shutil.copyfile(file_path, new_file_path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = ArgumentParser()
94
+ parser.add_argument("--hf-ckpt-path", type=str, required=True)
95
+ parser.add_argument("--save-path", type=str, required=True)
96
+ parser.add_argument("--n-experts", type=int, required=True)
97
+ parser.add_argument("--model-parallel", type=int, required=True)
98
+ args = parser.parse_args()
99
+ assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
100
+ main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
inference/generate.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from argparse import ArgumentParser
4
+ from typing import List
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from transformers import AutoTokenizer
9
+ from safetensors.torch import load_model
10
+
11
+ from model import Transformer, ModelArgs
12
+
13
+
14
+ def sample(logits, temperature: float = 1.0):
15
+ """
16
+ Samples a token from the logits using temperature scaling.
17
+
18
+ Args:
19
+ logits (torch.Tensor): The logits tensor for token predictions.
20
+ temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
21
+
22
+ Returns:
23
+ torch.Tensor: The sampled token.
24
+ """
25
+ logits = logits / max(temperature, 1e-5)
26
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
27
+ return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
28
+
29
+
30
+ @torch.inference_mode()
31
+ def generate(
32
+ model: Transformer,
33
+ prompt_tokens: List[List[int]],
34
+ max_new_tokens: int,
35
+ eos_id: int,
36
+ temperature: float = 1.0
37
+ ) -> List[List[int]]:
38
+ """
39
+ Generates new tokens based on the given prompt tokens using the specified model.
40
+
41
+ Args:
42
+ model (Transformer): The transformer model used for token generation.
43
+ prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
44
+ max_new_tokens (int): The maximum number of new tokens to generate.
45
+ eos_id (int): The end-of-sequence token ID.
46
+ temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
47
+
48
+ Returns:
49
+ List[List[int]]: A list of lists containing the generated tokens for each sequence.
50
+ """
51
+ prompt_lens = [len(t) for t in prompt_tokens]
52
+ assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
53
+ total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
54
+ tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
55
+ for i, t in enumerate(prompt_tokens):
56
+ tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
57
+ prev_pos = 0
58
+ finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
59
+ prompt_mask = tokens != -1
60
+ for cur_pos in range(min(prompt_lens), total_len):
61
+ logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
62
+ if temperature > 0:
63
+ next_token = sample(logits, temperature)
64
+ else:
65
+ next_token = logits.argmax(dim=-1)
66
+ next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
67
+ tokens[:, cur_pos] = next_token
68
+ finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
69
+ prev_pos = cur_pos
70
+ if finished.all():
71
+ break
72
+ completion_tokens = []
73
+ for i, toks in enumerate(tokens.tolist()):
74
+ toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
75
+ if eos_id in toks:
76
+ toks = toks[:toks.index(eos_id)]
77
+ completion_tokens.append(toks)
78
+ return completion_tokens
79
+
80
+
81
+ def main(
82
+ ckpt_path: str,
83
+ config: str,
84
+ input_file: str = "",
85
+ interactive: bool = True,
86
+ max_new_tokens: int = 100,
87
+ temperature: float = 1.0,
88
+ ) -> None:
89
+ """
90
+ Main function to load the model and perform interactive or batch text generation.
91
+
92
+ Args:
93
+ ckpt_path (str): Path to the model checkpoint directory.
94
+ config (str): Path to the model configuration file.
95
+ input_file (str, optional): Path to a file containing input prompts. Defaults to "".
96
+ interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
97
+ max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
98
+ temperature (float, optional): Temperature for sampling. Defaults to 1.0.
99
+ """
100
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
101
+ rank = int(os.getenv("RANK", "0"))
102
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
103
+ if world_size > 1:
104
+ dist.init_process_group("nccl")
105
+ global print
106
+ if rank != 0:
107
+ print = lambda *_, **__: None
108
+ torch.cuda.set_device(local_rank)
109
+ torch.set_default_dtype(torch.bfloat16)
110
+ torch.set_num_threads(8)
111
+ torch.manual_seed(33377335)
112
+ with open(config) as f:
113
+ args = ModelArgs(**json.load(f))
114
+ print(args)
115
+ with torch.device("cuda"):
116
+ model = Transformer(args)
117
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
118
+ print("load model")
119
+ load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
120
+ print("I'm DeepSeek 👋")
121
+
122
+ if interactive:
123
+ messages = []
124
+ while True:
125
+ if world_size == 1:
126
+ prompt = input(">>> ")
127
+ elif rank == 0:
128
+ prompt = input(">>> ")
129
+ objects = [prompt]
130
+ dist.broadcast_object_list(objects, 0)
131
+ else:
132
+ objects = [None]
133
+ dist.broadcast_object_list(objects, 0)
134
+ prompt = objects[0]
135
+ if prompt == "/exit":
136
+ break
137
+ elif prompt == "/clear":
138
+ messages.clear()
139
+ continue
140
+ messages.append({"role": "user", "content": prompt})
141
+ prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
142
+ completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
143
+ completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
144
+ print(completion)
145
+ messages.append({"role": "assistant", "content": completion})
146
+ else:
147
+ with open(input_file) as f:
148
+ prompts = f.read().split("\n\n")
149
+ assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
150
+ prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
151
+ completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
152
+ completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
153
+ for prompt, completion in zip(prompts, completions):
154
+ print("Prompt:", prompt)
155
+ print("Completion:", completion)
156
+ print()
157
+
158
+ if world_size > 1:
159
+ dist.destroy_process_group()
160
+
161
+
162
+ if __name__ == "__main__":
163
+ """
164
+ Command-line interface for distributed text generation.
165
+
166
+ Arguments:
167
+ --ckpt-path (str): Path to the model checkpoint directory.
168
+ --config (str): Path to the model configuration file.
169
+ --input-file (str, optional): File containing prompts for batch processing.
170
+ --interactive (bool, optional): Enable interactive mode for generating text.
171
+ --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
172
+ --temperature (float, optional): Temperature for sampling. Defaults to 0.2.
173
+
174
+ Raises:
175
+ AssertionError: If neither input-file nor interactive mode is specified.
176
+ """
177
+ parser = ArgumentParser()
178
+ parser.add_argument("--ckpt-path", type=str, required=True)
179
+ parser.add_argument("--config", type=str, required=True)
180
+ parser.add_argument("--input-file", type=str, default="")
181
+ parser.add_argument("--interactive", action="store_true")
182
+ parser.add_argument("--max-new-tokens", type=int, default=200)
183
+ parser.add_argument("--temperature", type=float, default=0.6)
184
+ args = parser.parse_args()
185
+ assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
186
+ main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
inference/kernel.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tilelang
3
+ import tilelang.language as T
4
+ from typing import Tuple, Optional
5
+
6
+
7
+ tilelang.set_log_level("WARNING")
8
+
9
+ pass_configs = {
10
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
11
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
12
+ tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
13
+ }
14
+
15
+ FP8 = "float8_e4m3"
16
+ BF16 = "bfloat16"
17
+ FP32 = "float32"
18
+
19
+
20
+ def fast_log2_ceil(x):
21
+ bits_x = T.reinterpret("uint32", x)
22
+ exp_x = (bits_x >> 23) & 0xFF
23
+ man_bits = bits_x & ((1 << 23) - 1)
24
+ return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
25
+
26
+
27
+ def fast_pow2(x):
28
+ bits_x = (x + 127) << 23
29
+ return T.reinterpret("float32", bits_x)
30
+
31
+
32
+ def fast_round_scale(amax, fp8_max_inv):
33
+ return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
34
+
35
+
36
+ @tilelang.jit(pass_configs=pass_configs)
37
+ def act_quant_kernel(
38
+ N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
39
+ ):
40
+ M = T.symbolic("M")
41
+ fp8_min = -448.0
42
+ fp8_max = 448.0
43
+ fp8_max_inv = 1 / fp8_max
44
+ num_stages = 0 if round_scale else 2
45
+ blk_m = 32
46
+ group_size = 128
47
+
48
+ @T.prim_func
49
+ def act_quant_kernel_(
50
+ X: T.Tensor[(M, N), in_dtype],
51
+ Y: T.Tensor[(M, N), out_dtype],
52
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
53
+ ):
54
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
55
+ pid_m,
56
+ pid_n,
57
+ ):
58
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
59
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
60
+ amax_local = T.alloc_fragment((blk_m,), scale_dtype)
61
+ s_local = T.alloc_fragment((blk_m,), scale_dtype)
62
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
63
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
64
+
65
+ for _ in T.Pipelined(1, num_stages=num_stages):
66
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
67
+ T.copy(x_shared, x_local)
68
+ T.reduce_absmax(x_local, amax_local, dim=1)
69
+ for i in T.Parallel(blk_m):
70
+ amax_local[i] = T.max(amax_local[i], 1e-4)
71
+ if round_scale:
72
+ s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
73
+ else:
74
+ s_local[i] = amax_local[i] * fp8_max_inv
75
+ for i, j in T.Parallel(blk_m, group_size):
76
+ y_local[i, j] = T.clamp(
77
+ x_local[i, j] / s_local[i], fp8_min, fp8_max
78
+ )
79
+ for i in T.Parallel(blk_m):
80
+ S[pid_m * blk_m + i, pid_n] = s_local[i]
81
+ T.copy(y_local, y_shared)
82
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
83
+
84
+ return act_quant_kernel_
85
+
86
+
87
+ def act_quant(
88
+ x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
89
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
90
+ """
91
+ Quantizes the input tensor `x` using block-wise quantization.
92
+
93
+ Args:
94
+ x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
95
+ block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
96
+ scale_fmt (Optional[str], optional): The format of the scale. Default is None.
97
+ Returns:
98
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
99
+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
100
+ - A tensor of scaling factors with dtype `torch.float32`.
101
+ """
102
+ assert x.is_contiguous(), "Input tensor must be contiguous"
103
+ assert x.size(-1) % block_size == 0, (
104
+ f"Last dimension size must be divisible by block_size (block_size={block_size})"
105
+ )
106
+ N = x.size(-1)
107
+ y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
108
+ s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
109
+ kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
110
+ kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
111
+ return y, s
112
+
113
+
114
+ @tilelang.jit(pass_configs=pass_configs)
115
+ def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
116
+ assert out_dtype in [BF16, "float32"]
117
+
118
+ M = T.symbolic("M")
119
+ group_size = 128
120
+ block_M = 32
121
+ block_N = 128
122
+ block_K = 128
123
+
124
+ @T.prim_func
125
+ def fp8_gemm_kernel_(
126
+ A: T.Tensor[(M, K), FP8],
127
+ B: T.Tensor[(N, K), FP8],
128
+ C: T.Tensor[(M, N), out_dtype],
129
+ scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
130
+ scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
131
+ ):
132
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
133
+ bx,
134
+ by,
135
+ ):
136
+ A_shared = T.alloc_shared((block_M, block_K), FP8)
137
+ B_shared = T.alloc_shared((block_N, block_K), FP8)
138
+ C_shared = T.alloc_shared((block_M, block_N), out_dtype)
139
+ Scale_C_shared = T.alloc_shared((block_M), FP32)
140
+ C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
141
+ C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
142
+
143
+ # Improve L2 Cache
144
+ T.use_swizzle(panel_size=10)
145
+
146
+ T.clear(C_local)
147
+ T.clear(C_local_accum)
148
+ K_iters = T.ceildiv(K, block_K)
149
+ for k in T.Pipelined(K_iters, num_stages=4):
150
+ # Load A into shared memory
151
+ T.copy(A[by * block_M, k * block_K], A_shared)
152
+ # Load B into shared memory
153
+ T.copy(B[bx * block_N, k * block_K], B_shared)
154
+ # Load scale into shared memory
155
+ Scale_B = scales_b[bx * block_N // group_size, k]
156
+ for i in T.Parallel(block_M):
157
+ Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
158
+
159
+ T.gemm(A_shared, B_shared, C_local, transpose_B=True)
160
+ # Promote to enable 2xAcc
161
+ for i, j in T.Parallel(block_M, block_N):
162
+ C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
163
+ T.clear(C_local)
164
+ # TMA store
165
+ T.copy(C_local_accum, C_shared)
166
+ T.copy(C_shared, C[by * block_M, bx * block_N])
167
+
168
+ return fp8_gemm_kernel_
169
+
170
+
171
+ def fp8_gemm(
172
+ a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor
173
+ ) -> torch.Tensor:
174
+ """
175
+ Perform a matrix multiplication using FP8 precision.
176
+
177
+ Args:
178
+ a (torch.Tensor): The first input matrix, must be contiguous.
179
+ a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
180
+ b (torch.Tensor): The second input matrix, must be contiguous.
181
+ b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
182
+
183
+ Returns:
184
+ torch.Tensor: The result of the matrix multiplication.
185
+ """
186
+ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
187
+ assert a_s.is_contiguous() and b_s.is_contiguous(), (
188
+ "Scaling factor tensors must be contiguous"
189
+ )
190
+ K = a.size(-1)
191
+ M = a.numel() // K
192
+ N = b.size(0)
193
+ c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
194
+ kernel = fp8_gemm_kernel(N, K)
195
+ kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
196
+ return c
197
+
198
+
199
+ @tilelang.jit(out_idx=[4], pass_configs=pass_configs)
200
+ def fp8_index_kernel(h: int, d: int):
201
+ b = T.symbolic("b")
202
+ m = T.symbolic("m")
203
+ n = T.symbolic("n")
204
+
205
+ blk_n1 = 512
206
+ blk_n2 = 128
207
+
208
+ @T.prim_func
209
+ def fp8_index_kernel_(
210
+ q: T.Tensor[(b, m, h, d), FP8],
211
+ q_s: T.Tensor[(b, m, h), FP32],
212
+ k: T.Tensor[(b, n, d), FP8],
213
+ k_s: T.Tensor[(b, n), FP32],
214
+ o: T.Tensor[(b, m, n), FP32],
215
+ ) -> None:
216
+ with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
217
+ q_smem = T.alloc_shared((h, d), FP8)
218
+ T.copy(q[i_b, i_m, 0, 0], q_smem)
219
+
220
+ q_s_frag = T.alloc_fragment(h, FP32)
221
+ T.copy(q_s[i_b, i_m, 0], q_s_frag)
222
+
223
+ for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
224
+ k_smem = T.alloc_shared((blk_n2, d), FP8)
225
+ T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
226
+
227
+ k_s_frag = T.alloc_fragment(blk_n2, FP32)
228
+ T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
229
+
230
+ logits = T.alloc_fragment((blk_n2, h), FP32)
231
+ T.gemm(
232
+ k_smem,
233
+ q_smem,
234
+ logits,
235
+ transpose_A=False,
236
+ transpose_B=True,
237
+ clear_accum=True,
238
+ )
239
+
240
+ for i_h, i3_n in T.Parallel(h, blk_n2):
241
+ logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
242
+
243
+ logits_sum = T.alloc_fragment(blk_n2, FP32)
244
+ T.reduce_sum(logits, logits_sum, dim=1)
245
+
246
+ for i3_n in T.Parallel(blk_n2):
247
+ logits_sum[i3_n] *= k_s_frag[i3_n]
248
+
249
+ T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
250
+
251
+ return fp8_index_kernel_
252
+
253
+
254
+ def fp8_index(
255
+ q: torch.Tensor,
256
+ q_s: torch.Tensor,
257
+ k: torch.Tensor,
258
+ k_s: torch.Tensor,
259
+ ) -> torch.Tensor:
260
+ """
261
+ Perform index score using FP8 precision.
262
+
263
+ Args:
264
+ q (torch.Tensor): The Q tensor, must be contiguous.
265
+ q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
266
+ k (torch.Tensor): The K tensor, must be contiguous.
267
+ k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
268
+
269
+ fp8 q @ fp8 k -> fp32 logits
270
+ relu(fp32 logits) * q_s (weights) -> fp32 logits
271
+ fp32 logits -> fp32 logits_sum
272
+ fp32 logits_sum * k_s (e8m0) -> fp32 index_score
273
+ """
274
+ return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
inference/model.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Tuple, Optional, Literal
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ import torch.distributed as dist
9
+
10
+ from kernel import act_quant, fp8_gemm, fp8_index
11
+
12
+
13
+ world_size = 1
14
+ rank = 0
15
+ block_size = 128
16
+
17
+ @dataclass
18
+ class ModelArgs:
19
+ """
20
+ Data class for defining model arguments and hyperparameters.
21
+
22
+ Attributes:
23
+ max_batch_size (int): Maximum batch size.
24
+ max_seq_len (int): Maximum sequence length.
25
+ dtype (Literal["bf16", "fp8"]): Data type for computations.
26
+ scale_fmt (Optional[str]): Format for quantization scale.
27
+ vocab_size (int): Vocabulary size.
28
+ dim (int): Model dimension.
29
+ inter_dim (int): Intermediate dimension for MLP layers.
30
+ moe_inter_dim (int): Intermediate dimension for MoE layers.
31
+ n_layers (int): Number of transformer layers.
32
+ n_dense_layers (int): Number of dense layers in the model.
33
+ n_heads (int): Number of attention heads.
34
+ n_routed_experts (int): Number of routed experts for MoE layers.
35
+ n_shared_experts (int): Number of shared experts for MoE layers.
36
+ n_activated_experts (int): Number of activated experts in MoE layers.
37
+ n_expert_groups (int): Number of expert groups.
38
+ n_limited_groups (int): Number of limited groups for MoE routing.
39
+ score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
40
+ route_scale (float): Scaling factor for routing scores.
41
+ q_lora_rank (int): LoRA rank for query projections.
42
+ kv_lora_rank (int): LoRA rank for key-value projections.
43
+ qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
44
+ qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
45
+ v_head_dim (int): Dimension for value projections.
46
+ original_seq_len (int): Original sequence length.
47
+ rope_theta (float): Base for rotary positional encoding.
48
+ rope_factor (float): Scaling factor for extended sequence lengths.
49
+ beta_fast (int): Fast beta correction factor.
50
+ beta_slow (int): Slow beta correction factor.
51
+ mscale (float): Scaling factor for extended attention.
52
+ index_head_dim (int): Dimension for index head.
53
+ index_topk (int): Top-k for index head.
54
+ """
55
+ max_batch_size: int = 8
56
+ max_seq_len: int = 4096 * 4
57
+ dtype: Literal["bf16", "fp8"] = "bf16"
58
+ scale_fmt: Optional[str] = None
59
+ vocab_size: int = 102400
60
+ dim: int = 2048
61
+ inter_dim: int = 10944
62
+ moe_inter_dim: int = 1408
63
+ n_layers: int = 27
64
+ n_dense_layers: int = 1
65
+ n_heads: int = 16
66
+ # moe
67
+ n_routed_experts: int = 64
68
+ n_shared_experts: int = 2
69
+ n_activated_experts: int = 6
70
+ n_expert_groups: int = 1
71
+ n_limited_groups: int = 1
72
+ score_func: Literal["softmax", "sigmoid"] = "softmax"
73
+ route_scale: float = 1.
74
+ # mla
75
+ q_lora_rank: int = 0
76
+ kv_lora_rank: int = 512
77
+ qk_nope_head_dim: int = 128
78
+ qk_rope_head_dim: int = 64
79
+ v_head_dim: int = 128
80
+ # yarn
81
+ original_seq_len: int = 4096
82
+ rope_theta: float = 10000.0
83
+ rope_factor: float = 40
84
+ beta_fast: int = 32
85
+ beta_slow: int = 1
86
+ mscale: float = 1.
87
+ # index
88
+ index_n_heads: int = 64
89
+ index_head_dim: int = 128
90
+ index_topk: int = 2048
91
+
92
+ class ParallelEmbedding(nn.Module):
93
+ """
94
+ Embedding layer with parallelism support across distributed processes.
95
+
96
+ Args:
97
+ vocab_size (int): Vocabulary size.
98
+ dim (int): Embedding dimension.
99
+ """
100
+ def __init__(self, vocab_size: int, dim: int):
101
+ super().__init__()
102
+ self.vocab_size = vocab_size
103
+ self.dim = dim
104
+ assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
105
+ self.part_vocab_size = (vocab_size // world_size)
106
+ self.vocab_start_idx = rank * self.part_vocab_size
107
+ self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
108
+ self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ """
112
+ Forward pass for parallel embedding layer.
113
+
114
+ Args:
115
+ x (torch.Tensor): Input tensor containing token indices.
116
+
117
+ Returns:
118
+ torch.Tensor: Embedded representations.
119
+
120
+ Raises:
121
+ ValueError: If `world_size` is not defined.
122
+ """
123
+ if world_size > 1:
124
+ mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
125
+ x = x - self.vocab_start_idx
126
+ x[mask] = 0
127
+ y = F.embedding(x, self.weight)
128
+ if world_size > 1:
129
+ y[mask] = 0
130
+ dist.all_reduce(y)
131
+ return y
132
+
133
+
134
+ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
135
+ scale_fmt: Optional[str] = None) -> torch.Tensor:
136
+ """
137
+ Applies a linear transformation to the incoming data: y = xA^T + b.
138
+ This function supports specialized implementations based on quantization
139
+ and tensor formats.
140
+
141
+ Args:
142
+ x (torch.Tensor): The input tensor.
143
+ weight (torch.Tensor): The weight tensor. It may be quantized and
144
+ requires dequantization for certain cases.
145
+ bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
146
+ scale_fmt (Optional[str]): The format of scaling factors.
147
+
148
+ Returns:
149
+ torch.Tensor: The result of the linear transformation, which may involve
150
+ quantization-aware computations depending on the input parameters.
151
+
152
+ Notes:
153
+ - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
154
+ is used for computation.
155
+ - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
156
+ """
157
+ assert bias is None
158
+
159
+ if weight.dtype != torch.float8_e4m3fn:
160
+ return F.linear(x, weight)
161
+ else:
162
+ x, scale = act_quant(x, block_size, scale_fmt)
163
+ return fp8_gemm(x, scale, weight, weight.scale)
164
+
165
+
166
+ class Linear(nn.Module):
167
+ """
168
+ Custom linear layer with support for quantized weights and optional bias.
169
+
170
+ Args:
171
+ in_features (int): Number of input features.
172
+ out_features (int): Number of output features.
173
+ bias (bool): Whether to include a bias term. Defaults to False.
174
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
175
+ """
176
+ dtype = torch.bfloat16
177
+ scale_fmt: Optional[str] = None
178
+
179
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
180
+ super().__init__()
181
+ self.in_features = in_features
182
+ self.out_features = out_features
183
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
184
+ if self.weight.element_size() == 1:
185
+ scale_out_features = (out_features + block_size - 1) // block_size
186
+ scale_in_features = (in_features + block_size - 1) // block_size
187
+ self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
188
+ else:
189
+ self.register_parameter("scale", None)
190
+ if bias:
191
+ self.bias = nn.Parameter(torch.empty(out_features))
192
+ else:
193
+ self.register_parameter("bias", None)
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ """
197
+ Forward pass for the custom linear layer.
198
+
199
+ Args:
200
+ x (torch.Tensor): Input tensor.
201
+
202
+ Returns:
203
+ torch.Tensor: Transformed tensor after linear computation.
204
+ """
205
+ return linear(x, self.weight, self.bias, self.scale_fmt)
206
+
207
+
208
+ class ColumnParallelLinear(Linear):
209
+ """
210
+ Linear layer with column parallelism, splitting output features across distributed processes.
211
+
212
+ Args:
213
+ in_features (int): Number of input features.
214
+ out_features (int): Total number of output features.
215
+ bias (bool): Whether to include a bias term. Defaults to False.
216
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
217
+ """
218
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
219
+ assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
220
+ self.part_out_features = out_features // world_size
221
+ super().__init__(in_features, self.part_out_features, bias, dtype)
222
+
223
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ """
225
+ Forward pass for column parallel linear layer.
226
+
227
+ Args:
228
+ x (torch.Tensor): Input tensor.
229
+
230
+ Returns:
231
+ torch.Tensor: Transformed tensor with column-parallel computation.
232
+ """
233
+ y = linear(x, self.weight, self.bias, self.scale_fmt)
234
+ return y
235
+
236
+
237
+ class RowParallelLinear(Linear):
238
+ """
239
+ Linear layer with row parallelism, splitting input features across distributed processes.
240
+
241
+ Args:
242
+ in_features (int): Total number of input features.
243
+ out_features (int): Number of output features.
244
+ bias (bool): Whether to include a bias term. Defaults to False.
245
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
246
+ """
247
+ def __init__(self, in_features: int, out_features: int, bias: bool = False, reduce_output = True, dtype = None):
248
+ assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
249
+ self.part_in_features = in_features // world_size
250
+ self.reduce_output = reduce_output
251
+ super().__init__(self.part_in_features, out_features, bias, dtype)
252
+
253
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
254
+ """
255
+ Forward pass for row parallel linear layer.
256
+
257
+ Args:
258
+ x (torch.Tensor): Input tensor.
259
+
260
+ Returns:
261
+ torch.Tensor: Transformed tensor with row-parallel computation.
262
+ """
263
+ y = linear(x, self.weight, None, self.scale_fmt)
264
+ if self.reduce_output and world_size > 1:
265
+ y = y.float()
266
+ dist.all_reduce(y)
267
+ if self.bias is not None:
268
+ y += self.bias
269
+ return y.type_as(x)
270
+
271
+
272
+ class RMSNorm(nn.Module):
273
+ """
274
+ Root Mean Square Layer Normalization (RMSNorm).
275
+
276
+ Args:
277
+ dim (int): Dimension of the input tensor.
278
+ eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
279
+ """
280
+ def __init__(self, dim: int, eps: float = 1e-6):
281
+ super().__init__()
282
+ self.dim = dim
283
+ self.eps = eps
284
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
285
+
286
+ def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
287
+ """
288
+ Forward pass for RMSNorm.
289
+
290
+ Args:
291
+ x (torch.Tensor): Input tensor.
292
+
293
+ Returns:
294
+ torch.Tensor: Normalized tensor with the same shape as input.
295
+ """
296
+ dtype = x.dtype
297
+ if residual is None:
298
+ x = x.float()
299
+ var = x.pow(2).mean(-1, keepdim=True)
300
+ x = x * torch.rsqrt(var + self.eps)
301
+ return (self.weight * x).to(dtype)
302
+ else:
303
+ x = residual = x.float() + residual.float()
304
+ var = x.pow(2).mean(-1, keepdim=True)
305
+ x = x * torch.rsqrt(var + self.eps)
306
+ return (self.weight * x).to(dtype), residual.to(dtype)
307
+
308
+
309
+ class LayerNorm(nn.Module):
310
+ """
311
+ Layer Normalization.
312
+ """
313
+ def __init__(self, dim: int, eps: float = 1e-6):
314
+ super().__init__()
315
+ self.dim = dim
316
+ self.eps = eps
317
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
318
+ self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
319
+
320
+ def forward(self, x: torch.Tensor):
321
+ return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
322
+
323
+
324
+ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
325
+ """
326
+ Precomputes frequency-based complex exponential values for rotary positional embeddings.
327
+
328
+ Args:
329
+ args (ModelArgs): Model arguments containing positional embedding parameters.
330
+
331
+ Returns:
332
+ torch.Tensor: Precomputed complex exponential values for positional embeddings.
333
+ """
334
+ dim = args.qk_rope_head_dim
335
+ seqlen = args.max_seq_len
336
+ beta_fast = args.beta_fast
337
+ beta_slow = args.beta_slow
338
+ base = args.rope_theta
339
+ factor = args.rope_factor
340
+
341
+ def find_correction_dim(num_rotations, dim, base, max_seq_len):
342
+ """
343
+ Computes the correction dimension for a given number of rotations in the rotary positional embedding.
344
+
345
+ Args:
346
+ num_rotations (float): Number of rotations to compute the correction for.
347
+ dim (int): Dimensionality of the embedding space.
348
+ base (float): Base value for the exponential computation.
349
+ max_seq_len (int): Maximum sequence length.
350
+
351
+ Returns:
352
+ float: The correction dimension based on the input parameters.
353
+ """
354
+ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
355
+
356
+ def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
357
+ """
358
+ Computes the range of correction dimensions for rotary positional embeddings.
359
+
360
+ Args:
361
+ low_rot (float): Lower bound for the number of rotations.
362
+ high_rot (float): Upper bound for the number of rotations.
363
+ dim (int): Dimensionality of the embedding space.
364
+ base (float): Base value for the exponential computation.
365
+ max_seq_len (int): Maximum sequence length.
366
+
367
+ Returns:
368
+ Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
369
+ """
370
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
371
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
372
+ return max(low, 0), min(high, dim-1)
373
+
374
+ def linear_ramp_factor(min, max, dim):
375
+ """
376
+ Computes a linear ramp function used to smooth values between a minimum and maximum range.
377
+
378
+ Args:
379
+ min (float): Minimum value for the ramp function.
380
+ max (float): Maximum value for the ramp function.
381
+ dim (int): Dimensionality of the ramp tensor.
382
+
383
+ Returns:
384
+ torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
385
+ clamped to the range [0, 1].
386
+ """
387
+ if min == max:
388
+ max += 0.001
389
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
390
+ ramp_func = torch.clamp(linear_func, 0, 1)
391
+ return ramp_func
392
+
393
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
394
+ if seqlen > args.original_seq_len:
395
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
396
+ smooth = 1 - linear_ramp_factor(low, high, dim // 2)
397
+ freqs = freqs / factor * (1 - smooth) + freqs * smooth
398
+
399
+ t = torch.arange(seqlen)
400
+ freqs = torch.outer(t, freqs)
401
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
402
+ return freqs_cis
403
+
404
+
405
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
406
+ """
407
+ Applies rotary positional embeddings to the input tensor.
408
+
409
+ Args:
410
+ x (torch.Tensor): Input tensor with positional embeddings to be applied.
411
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
412
+
413
+ Returns:
414
+ torch.Tensor: Tensor with rotary embeddings applied.
415
+ """
416
+ dtype = x.dtype
417
+ shape = x.shape
418
+ if not interleaved:
419
+ x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
420
+ x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
421
+ freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
422
+ y = torch.view_as_real(x * freqs_cis).flatten(3)
423
+ if not interleaved:
424
+ y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
425
+ return y.to(dtype)
426
+
427
+
428
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
429
+ assert x.dtype == torch.bfloat16
430
+ from fast_hadamard_transform import hadamard_transform
431
+ hidden_size = x.size(-1)
432
+ return hadamard_transform(x, scale=hidden_size ** -0.5)
433
+
434
+
435
+ class Indexer(torch.nn.Module):
436
+ def __init__(self, args: ModelArgs):
437
+ super().__init__()
438
+ self.dim: int = args.dim
439
+ self.n_heads: int = args.index_n_heads
440
+ self.n_local_heads = args.index_n_heads // world_size
441
+ self.head_dim: int = args.index_head_dim
442
+ self.rope_head_dim: int = args.qk_rope_head_dim
443
+ self.index_topk: int = args.index_topk
444
+ self.q_lora_rank: int = args.q_lora_rank
445
+ self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
446
+ self.wk = Linear(self.dim, self.head_dim)
447
+ self.k_norm = LayerNorm(self.head_dim)
448
+ # weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
449
+ self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
450
+ self.softmax_scale = self.head_dim ** -0.5
451
+ self.scale_fmt = args.scale_fmt
452
+
453
+ self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), persistent=False)
454
+ self.register_buffer("k_scale_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim // block_size, dtype=torch.float32), persistent=False)
455
+
456
+
457
+ def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
458
+ bsz, seqlen, _ = x.size()
459
+ end_pos = start_pos + seqlen
460
+ q = self.wq_b(qr)
461
+ q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
462
+ q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
463
+ # rope in indexer is not interleaved
464
+ q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
465
+ q = torch.cat([q_pe, q_nope], dim=-1)
466
+ k = self.wk(x)
467
+ k = self.k_norm(k)
468
+ k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
469
+ # rope in indexer is not interleaved
470
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
471
+ k = torch.cat([k_pe, k_nope], dim=-1)
472
+ q = rotate_activation(q)
473
+ k = rotate_activation(k)
474
+ q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
475
+ k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
476
+ self.k_cache[:bsz, start_pos:end_pos] = k_fp8
477
+ self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
478
+ weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
479
+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
480
+ index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
481
+ if mask is not None:
482
+ index_score += mask
483
+ topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
484
+ topk_indices_ = topk_indices.clone()
485
+ dist.broadcast(topk_indices_, src=0)
486
+ assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
487
+ return topk_indices
488
+
489
+
490
+ def weight_dequant(weight, scale):
491
+ shape = weight.shape
492
+ assert weight.dim() == 2
493
+ weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size, block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size)
494
+ weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(shape[0] // block_size, shape[1] // block_size, block_size, block_size).transpose(1, 2).contiguous().view(shape)
495
+ return weight
496
+
497
+
498
+ class MLA(nn.Module):
499
+ """
500
+ Multi-Head Latent Attention (MLA) Layer.
501
+
502
+ Attributes:
503
+ dim (int): Dimensionality of the input features.
504
+ n_heads (int): Number of attention heads.
505
+ n_local_heads (int): Number of local attention heads for distributed systems.
506
+ q_lora_rank (int): Rank for low-rank query projection.
507
+ kv_lora_rank (int): Rank for low-rank key/value projection.
508
+ qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
509
+ qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
510
+ qk_head_dim (int): Total dimensionality of query/key projections.
511
+ v_head_dim (int): Dimensionality of value projections.
512
+ softmax_scale (float): Scaling factor for softmax in attention computation.
513
+ """
514
+ def __init__(self, args: ModelArgs):
515
+ super().__init__()
516
+ self.dim = args.dim
517
+ self.n_heads = args.n_heads
518
+ self.n_local_heads = args.n_heads // world_size
519
+ self.q_lora_rank = args.q_lora_rank
520
+ self.kv_lora_rank = args.kv_lora_rank
521
+ self.qk_nope_head_dim = args.qk_nope_head_dim
522
+ self.qk_rope_head_dim = args.qk_rope_head_dim
523
+ self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
524
+ self.v_head_dim = args.v_head_dim
525
+
526
+ self.wq_a = Linear(self.dim, self.q_lora_rank)
527
+ self.q_norm = RMSNorm(self.q_lora_rank)
528
+ self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
529
+ self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
530
+ self.kv_norm = RMSNorm(self.kv_lora_rank)
531
+ self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
532
+ self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
533
+ self.softmax_scale = self.qk_head_dim ** -0.5
534
+ self.scale_fmt = args.scale_fmt
535
+ if args.max_seq_len > args.original_seq_len:
536
+ mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
537
+ self.softmax_scale = self.softmax_scale * mscale * mscale
538
+
539
+ self.indexer = Indexer(args)
540
+
541
+ self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
542
+ self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
543
+ self.dequant_wkv_b = None
544
+
545
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
546
+ """
547
+ Forward pass for the Multi-Head Latent Attention (MLA) Layer.
548
+
549
+ Args:
550
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
551
+ start_pos (int): Starting position in the sequence for caching.
552
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
553
+ mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
554
+
555
+ Returns:
556
+ torch.Tensor: Output tensor with the same shape as the input.
557
+ """
558
+ bsz, seqlen, _ = x.size()
559
+ end_pos = start_pos + seqlen
560
+ qr = self.q_norm(self.wq_a(x))
561
+ q = self.wq_b(qr)
562
+ q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
563
+ q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
564
+ q_pe = apply_rotary_emb(q_pe, freqs_cis)
565
+ kv = self.wkv_a(x)
566
+ kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
567
+ kv = self.kv_norm(kv)
568
+ k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
569
+ # we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
570
+ kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
571
+ kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
572
+ self.kv_cache[:bsz, start_pos:end_pos] = kv
573
+ self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
574
+ if mask is not None: # MHA prefill
575
+ q = torch.cat([q_nope, q_pe], dim=-1)
576
+ kv = self.wkv_b(kv)
577
+ kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
578
+ k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
579
+ k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
580
+ scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
581
+
582
+ # indexer
583
+ topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
584
+ index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
585
+ index_mask += mask
586
+ scores += index_mask.unsqueeze(2)
587
+
588
+ scores = scores.softmax(dim=-1)
589
+ x = torch.einsum("bsht,bthd->bshd", scores, v)
590
+ else: # MQA decode
591
+ if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
592
+ self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
593
+ wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
594
+ wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
595
+ q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
596
+ scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
597
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
598
+
599
+ # indexer
600
+ topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
601
+ index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
602
+ scores += index_mask.unsqueeze(2)
603
+
604
+ scores = scores.softmax(dim=-1)
605
+ x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
606
+ x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
607
+ x = self.wo(x.flatten(2))
608
+ return x
609
+
610
+
611
+ class MLP(nn.Module):
612
+ """
613
+ Multi-Layer Perceptron (MLP) used as a feed-forward layer.
614
+
615
+ Attributes:
616
+ w1 (nn.Module): Linear layer for input-to-hidden transformation.
617
+ w2 (nn.Module): Linear layer for hidden-to-output transformation.
618
+ w3 (nn.Module): Additional linear layer for feature transformation.
619
+ """
620
+ def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
621
+ """
622
+ Initializes the MLP layer.
623
+
624
+ Args:
625
+ dim (int): Input and output dimensionality.
626
+ inter_dim (int): Hidden layer dimensionality.
627
+ """
628
+ super().__init__()
629
+ self.w1 = ColumnParallelLinear(dim, inter_dim)
630
+ self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output)
631
+ self.w3 = ColumnParallelLinear(dim, inter_dim)
632
+
633
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
634
+ """
635
+ Forward pass for the MLP layer.
636
+
637
+ Args:
638
+ x (torch.Tensor): Input tensor.
639
+
640
+ Returns:
641
+ torch.Tensor: Output tensor after MLP computation.
642
+ """
643
+ return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
644
+
645
+
646
+ class Gate(nn.Module):
647
+ """
648
+ Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
649
+
650
+ Attributes:
651
+ dim (int): Dimensionality of input features.
652
+ topk (int): Number of top experts activated for each input.
653
+ n_groups (int): Number of groups for routing.
654
+ topk_groups (int): Number of groups to route inputs to.
655
+ score_func (str): Scoring function ('softmax' or 'sigmoid').
656
+ route_scale (float): Scaling factor for routing weights.
657
+ weight (torch.nn.Parameter): Learnable weights for the gate.
658
+ bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
659
+ """
660
+ def __init__(self, args: ModelArgs):
661
+ """
662
+ Initializes the Gate module.
663
+
664
+ Args:
665
+ args (ModelArgs): Model arguments containing gating parameters.
666
+ """
667
+ super().__init__()
668
+ self.dim = args.dim
669
+ self.topk = args.n_activated_experts
670
+ self.n_groups = args.n_expert_groups
671
+ self.topk_groups = args.n_limited_groups
672
+ self.score_func = args.score_func
673
+ self.route_scale = args.route_scale
674
+ self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
675
+ self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
676
+
677
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
678
+ """
679
+ Forward pass for the gating mechanism.
680
+
681
+ Args:
682
+ x (torch.Tensor): Input tensor.
683
+
684
+ Returns:
685
+ Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
686
+ """
687
+ scores = linear(x.float(), self.weight.float())
688
+ if self.score_func == "softmax":
689
+ scores = scores.softmax(dim=-1)
690
+ else:
691
+ scores = scores.sigmoid()
692
+ original_scores = scores
693
+ if self.bias is not None:
694
+ scores = scores + self.bias
695
+ if self.n_groups > 1:
696
+ scores = scores.view(x.size(0), self.n_groups, -1)
697
+ if self.bias is None:
698
+ group_scores = scores.amax(dim=-1)
699
+ else:
700
+ group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
701
+ indices = group_scores.topk(self.topk_groups, dim=-1)[1]
702
+ mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
703
+ scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
704
+ indices = scores.topk(self.topk, dim=-1)[1]
705
+ weights = original_scores.gather(1, indices)
706
+ if self.score_func == "sigmoid":
707
+ weights /= weights.sum(dim=-1, keepdim=True)
708
+ weights *= self.route_scale
709
+ return weights, indices
710
+
711
+
712
+ class Expert(nn.Module):
713
+ """
714
+ Expert layer for Mixture-of-Experts (MoE) models.
715
+
716
+ Attributes:
717
+ w1 (nn.Module): Linear layer for input-to-hidden transformation.
718
+ w2 (nn.Module): Linear layer for hidden-to-output transformation.
719
+ w3 (nn.Module): Additional linear layer for feature transformation.
720
+ """
721
+ def __init__(self, dim: int, inter_dim: int):
722
+ """
723
+ Initializes the Expert layer.
724
+
725
+ Args:
726
+ dim (int): Input and output dimensionality.
727
+ inter_dim (int): Hidden layer dimensionality.
728
+ """
729
+ super().__init__()
730
+ self.w1 = Linear(dim, inter_dim)
731
+ self.w2 = Linear(inter_dim, dim)
732
+ self.w3 = Linear(dim, inter_dim)
733
+
734
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
735
+ """
736
+ Forward pass for the Expert layer.
737
+
738
+ Args:
739
+ x (torch.Tensor): Input tensor.
740
+
741
+ Returns:
742
+ torch.Tensor: Output tensor after expert computation.
743
+ """
744
+ return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
745
+
746
+
747
+ class MoE(nn.Module):
748
+ """
749
+ Mixture-of-Experts (MoE) module.
750
+
751
+ Attributes:
752
+ dim (int): Dimensionality of input features.
753
+ n_routed_experts (int): Total number of experts in the model.
754
+ n_local_experts (int): Number of experts handled locally in distributed systems.
755
+ n_activated_experts (int): Number of experts activated for each input.
756
+ gate (nn.Module): Gating mechanism to route inputs to experts.
757
+ experts (nn.ModuleList): List of expert modules.
758
+ shared_experts (nn.Module): Shared experts applied to all inputs.
759
+ """
760
+ def __init__(self, args: ModelArgs):
761
+ """
762
+ Initializes the MoE module.
763
+
764
+ Args:
765
+ args (ModelArgs): Model arguments containing MoE parameters.
766
+ """
767
+ super().__init__()
768
+ self.dim = args.dim
769
+ assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
770
+ self.n_routed_experts = args.n_routed_experts
771
+ self.n_local_experts = args.n_routed_experts // world_size
772
+ self.n_activated_experts = args.n_activated_experts
773
+ self.experts_start_idx = rank * self.n_local_experts
774
+ self.experts_end_idx = self.experts_start_idx + self.n_local_experts
775
+ self.gate = Gate(args)
776
+ self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
777
+ for i in range(self.n_routed_experts)])
778
+ self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False)
779
+
780
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
781
+ """
782
+ Forward pass for the MoE module.
783
+
784
+ Args:
785
+ x (torch.Tensor): Input tensor.
786
+
787
+ Returns:
788
+ torch.Tensor: Output tensor after expert routing and computation.
789
+ """
790
+ shape = x.size()
791
+ x = x.view(-1, self.dim)
792
+ weights, indices = self.gate(x)
793
+ y = torch.zeros_like(x, dtype=torch.float32)
794
+ counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
795
+ for i in range(self.experts_start_idx, self.experts_end_idx):
796
+ if counts[i] == 0:
797
+ continue
798
+ expert = self.experts[i]
799
+ idx, top = torch.where(indices == i)
800
+ y[idx] += expert(x[idx]) * weights[idx, top, None]
801
+ y += self.shared_experts(x)
802
+ if world_size > 1:
803
+ dist.all_reduce(y)
804
+ return y.type_as(x).view(shape)
805
+
806
+
807
+ class Block(nn.Module):
808
+ """
809
+ Transformer block combining attention and feed-forward layers.
810
+
811
+ Attributes:
812
+ attn (nn.Module): Attention layer (MLA).
813
+ ffn (nn.Module): Feed-forward network (MLP or MoE).
814
+ attn_norm (nn.Module): Layer normalization for attention.
815
+ ffn_norm (nn.Module): Layer normalization for feed-forward network.
816
+ """
817
+ def __init__(self, layer_id: int, args: ModelArgs):
818
+ """
819
+ Initializes the Transformer block.
820
+
821
+ Args:
822
+ layer_id (int): Layer index in the transformer.
823
+ args (ModelArgs): Model arguments containing block parameters.
824
+ """
825
+ super().__init__()
826
+ self.attn = MLA(args)
827
+ self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
828
+ self.attn_norm = RMSNorm(args.dim)
829
+ self.ffn_norm = RMSNorm(args.dim)
830
+
831
+ def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
832
+ """
833
+ Forward pass for the Transformer block.
834
+
835
+ Args:
836
+ x (torch.Tensor): Input tensor.
837
+ start_pos (int): Starting position in the sequence.
838
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
839
+ mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
840
+
841
+ Returns:
842
+ torch.Tensor: Output tensor after block computation.
843
+ """
844
+ if residual is None:
845
+ x, residual = self.attn_norm(x), x
846
+ else:
847
+ x, residual = self.attn_norm(x, residual)
848
+ x = self.attn(x, start_pos, freqs_cis, mask)
849
+ x, residual = self.ffn_norm(x, residual)
850
+ x = self.ffn(x)
851
+ return x, residual
852
+
853
+
854
+ class Transformer(nn.Module):
855
+ """
856
+ Transformer model with positional embeddings, multiple layers, and output projection.
857
+
858
+ Attributes:
859
+ max_seq_len (int): Maximum sequence length for the transformer.
860
+ embed (nn.Module): Embedding layer for input tokens.
861
+ layers (torch.nn.ModuleList): List of transformer blocks.
862
+ norm (nn.Module): Layer normalization applied after all blocks.
863
+ head (nn.Module): Output projection layer mapping to vocabulary size.
864
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
865
+ """
866
+ def __init__(self, args: ModelArgs):
867
+ """
868
+ Initializes the Transformer model.
869
+
870
+ Args:
871
+ args (ModelArgs): Model arguments containing transformer parameters.
872
+ """
873
+ global world_size, rank
874
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
875
+ rank = dist.get_rank() if dist.is_initialized() else 0
876
+ Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
877
+ Linear.scale_fmt = args.scale_fmt
878
+ super().__init__()
879
+ self.max_seq_len = args.max_seq_len
880
+ self.embed = ParallelEmbedding(args.vocab_size, args.dim)
881
+ self.layers = torch.nn.ModuleList()
882
+ for layer_id in range(args.n_layers):
883
+ self.layers.append(Block(layer_id, args))
884
+ self.norm = RMSNorm(args.dim)
885
+ # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
886
+ self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
887
+ self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
888
+
889
+ @torch.inference_mode()
890
+ def forward(self, tokens: torch.Tensor, start_pos: int = 0):
891
+ """
892
+ Forward pass for the Transformer model.
893
+
894
+ Args:
895
+ tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
896
+ start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
897
+
898
+ Returns:
899
+ torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
900
+ """
901
+ seqlen = tokens.size(1)
902
+ freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
903
+ mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
904
+ h, residual = self.embed(tokens), None
905
+ for layer in self.layers:
906
+ h, residual = layer(h, residual, start_pos, freqs_cis, mask)
907
+ h, _ = self.norm(h, residual)
908
+ logits = self.head(h[:, -1].float())
909
+ if world_size > 1:
910
+ all_logits = [torch.empty_like(logits) for _ in range(world_size)]
911
+ dist.all_gather(all_logits, logits)
912
+ logits = torch.cat(all_logits, dim=-1)
913
+ return logits
914
+
915
+
916
+ if __name__ == "__main__":
917
+ torch.set_default_dtype(torch.bfloat16)
918
+ torch.set_default_device("cuda")
919
+ torch.manual_seed(0)
920
+ args = ModelArgs()
921
+ x = torch.randint(0, args.vocab_size, (2, 128))
922
+ model = Transformer(args)
923
+ print(model(x).size())
inference/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ safetensors
4
+ fast_hadamard_transform
5
+ tilelang==0.1.6
model-00001-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a20d4376cb0fef16425f38a2c819e957f48e83752c2ec8a747ec297a06460976
3
+ size 5233198531
model-00002-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca4cbcfcfbe0efc7ce703b3454e8f4d4985f1ad2b8a77b91cff539e68fafd07f
3
+ size 4302383956
model-00003-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9addd18a0a46128fe8c8fff100f9595b377772a4cff3e98fabb978ebffb4c14f
3
+ size 4302384377
model-00004-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce280c84088ede36cab620fe64d934c5a53d47d0ca03e7d6919b73ba1fb6b413
3
+ size 4302121967
model-00005-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0aeaa31d376fb3bb6285e8bb46825fe81cac90ddaf02ee6b4eeeee561dd78c1c
3
+ size 4302384146
model-00006-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2caca0cf47e65fef1cd778159deff8a9a7225eb5a7368c87c93e1f84fc627a8
3
+ size 4307162046
model-00007-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e8c95f758f736d14c7b2edf7acbfa6fe294bb26b4ab991bbeeee81e1889931f
3
+ size 4312028034
model-00008-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e588199641eace5be198076ee4711c45b39e4020d6da9d65cdca25fdf4a4b697
3
+ size 4302384334
model-00009-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d3cec332f932e3ca6c1ab25be36fe618de69b9073ce8275a507bbb0cc3089c6
3
+ size 4302122175
model-00010-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36c8ee43fba46da1f8c4ed768ba0afeca1a685e5052d5c30d5a08808e7a5319c
3
+ size 4302383938
model-00011-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3092be80417e33ff688063305c420a45f58ec83efa73e5102f784d5bedd8dd11
3
+ size 4302384377
model-00012-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b04e90caba836f79efd90b27ed6971559a048e068606c6479d93fb39a72f8b76
3
+ size 1483135583
model-00013-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abbe4de59b3410aeb922fc0d39830dc98e9ab7ad89f5020939b83d424d2d24f1
3
+ size 4302060527
model-00014-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51f32772258fe3d4fc8af313a6d6c545d238e37a294acbf3a3efccbb8480f77e
3
+ size 4302384328
model-00015-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f88eda8849a7b4eb65281b56ff651af2e10f443d1f3da8e7942bc38efb785fbc
3
+ size 4302122183
model-00016-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:877d653ea176c9d31650ca4d243e973fdf2852185dbac60423bebf53d728d152
3
+ size 4302383930
model-00017-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d33bbd747e7b4567567e6437f1f3bf23659de73bdca0a20809f23088f2291c66
3
+ size 4302384375
model-00018-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71e34312df5cbd21426344230370d1d207fcd6d4e58a47b35c8c26e04e918566
3
+ size 4302121995
model-00019-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0764e374a4ed477da741b7db81fe7a1cbdff94a3376c85a5b4280d31750ae0b
3
+ size 4302384118
model-00020-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5847909aca3a67d7018b8ae719169c7f70721d47331355ffe19383d68ac9a2b2
3
+ size 4302384377
model-00021-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42412fde8f987689cf13ccfeb54d8a72169142ec113d414dd5799d540744e9a3
3
+ size 4302122373
model-00022-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de909ed548c13a98e6811ef62491fe3a6d838c268c39c5e0cff2014ea97d55e5
3
+ size 4302384890
model-00023-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fd845c676f59da8d2ed229ab6036d1cb5166d120449a8c1bfd0627011647daa
3
+ size 4302122786
model-00024-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3743fa6abc0f0e35b8932640b0cdf8d94831e889e9511e52190c3e383a0ea2f1
3
+ size 4302384494
model-00025-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:167fe31e96d8d9820f127a5a06b9631e97db72dd9cc93ffadc06475248fc214c
3
+ size 4302384963
model-00149-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5619ea0af4a76e2acbb5d806a611cb9f8c6c325f634416e3f4aee1f89c037fd3
3
+ size 4302384963
model-00150-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c43edb1641b8c73ef535d3ee4034a509db314ea4f8c18196c573d6e8d4845056
3
+ size 4302122398
model-00151-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8420a702f53eeb2b5d919e3d5b66647bc1356e8523c8d1a2c880a0aa5a90cbfb
3
+ size 4302384890
model-00152-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2132b8c21e1bdd5c237b61c3497b22d40293d14e2945ddeaa355e0d11b3ad5ff
3
+ size 4302122786
model-00153-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab87c6c695bfa0807f487aec72f9f7532a04c913ec854a3606d364983fbf5770
3
+ size 4302384494
model-00154-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6968de6684a923bd388e21ba9f629ba8839badf60090303aa8b5d7878e57046
3
+ size 4302384963
model-00155-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5440080078866fa242cf10aba0f6f5d8491f15bc81f05960c5443accf3737c5
3
+ size 4302122598
model-00156-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e9b3ffac4fcdc5a720df75265ac61075e88f8cd21df9739453f3e49ebb4fbc7
3
+ size 4302384680
model-00157-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24d1fafd6e13f045807718f0f2d6496b52dc928b67ec51b6d07310777daada88
3
+ size 4302384963
model-00158-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95f1b28e15eca2910aa5c5fa5c9c025841e0d129dbf8d54cb49b88e167eb1e83
3
+ size 4302122420
model-00159-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08cfd9dc2a2cee511196621f934d8d1b1bf985e41daa3a2e04e8031bb4774622
3
+ size 4302384870
model-00160-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0e810a252ae59df23764d67c6cb0ff29d115d57cf07487d08bdcc8f13222a0b
3
+ size 5285723731
model-00161-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3bdc2e1a96d7cbfb78bdbc136d6f2412c7dbb306653238ee93e6376654397b7
3
+ size 4305819875
model-00162-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96ec087d8534e08fc67accb038bebd537cc2ffb949c686ecce14e02242f50a9c
3
+ size 4302384938
model-00163-of-000163.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dda8c6d06366716bfa1143ee1ddd060144a619120314ed3a5b2810cd29977a5d
3
+ size 6643591758
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|begin▁of▁sentence|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|end▁of▁sentence|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 131072,
23
+ "pad_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<|end▁of▁sentence|>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "sp_model_kwargs": {},
32
+ "unk_token": null,
33
+ "tokenizer_class": "LlamaTokenizerFast"
34
+ }