Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| from .optimized_diffattn import MultiheadDiffAttn | |
| # --- Tokenizer Definition --- | |
| # Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + <pad> | |
| IM_START_TOKEN = "<|im_start|>" | |
| IM_END_TOKEN = "<|im_end|>" | |
| PAD_TOKEN = "<pad>" | |
| SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN] | |
| VOCAB_SIZE = 256 + len(SPECIAL_TOKENS) | |
| # Create token to id mapping | |
| token_to_id = {} | |
| id_to_token = {} | |
| for i in range(256): | |
| token_to_id[bytes([i])] = i | |
| id_to_token[i] = bytes([i]) | |
| for i, token_str in enumerate(SPECIAL_TOKENS): | |
| token_id = 256 + i | |
| token_to_id[token_str] = token_id | |
| id_to_token[token_id] = token_str | |
| PAD_ID = token_to_id[PAD_TOKEN] | |
| IM_START_ID = token_to_id[IM_START_TOKEN] | |
| IM_END_ID = token_to_id[IM_END_TOKEN] | |
| class ByteTokenizer: | |
| def __init__(self): | |
| self.token_to_id = token_to_id | |
| self.id_to_token = id_to_token | |
| self.vocab_size = VOCAB_SIZE | |
| self.pad_id = PAD_ID | |
| self.im_start_id = IM_START_ID | |
| self.im_end_id = IM_END_ID | |
| def encode(self, text_bytes: bytes, add_special_tokens=True): | |
| ids = [self.token_to_id[bytes([b])] for b in text_bytes] | |
| if add_special_tokens: | |
| return [self.im_start_id] + ids + [self.im_end_id] | |
| return ids | |
| def decode(self, ids: list[int]): | |
| tokens = [] | |
| for i in ids: | |
| token = self.id_to_token.get(i) | |
| if token is None: | |
| # Handle unknown token ID if necessary, or raise error | |
| tokens.append(b"?") # Placeholder for unknown | |
| elif isinstance(token, bytes): | |
| tokens.append(token) | |
| # Ignore special tokens for decoding to raw text, or handle as needed | |
| return b"".join(tokens) | |
| # --- RoPE Embeddings --- (Reused from previous script) | |
| def get_rotary_embeddings(seq_len, dim_model, theta=10000.0): | |
| if dim_model % 2 != 0: | |
| raise ValueError(f"dim_model must be even, got {dim_model}") | |
| position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model) | |
| ) | |
| angles = position * div_term | |
| cos_emb = torch.cos(angles) | |
| sin_emb = torch.sin(angles) | |
| return cos_emb, sin_emb | |
| # --- Model Definition --- | |
| class FeedForward(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, dropout=0.1): | |
| super().__init__() | |
| self.fc1 = nn.Linear(embed_dim, hidden_dim) | |
| self.fc2 = nn.Linear(hidden_dim, embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.act = nn.GELU() | |
| def forward(self, x): | |
| return self.fc2(self.dropout(self.act(self.fc1(x)))) | |
| class DiffTransformerBlock(nn.Module): | |
| def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1): | |
| super().__init__() | |
| self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout) | |
| self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout) | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, rel_pos, attn_mask=None): | |
| # Pre-norm | |
| attn_out = self.attn(self.norm1(x), rel_pos, attn_mask) | |
| x = x + self.dropout(attn_out) | |
| ffn_out = self.ffn(self.norm2(x)) | |
| x = x + self.dropout(ffn_out) | |
| return x | |
| class DiffTransformerLLM(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size, | |
| embed_dim, | |
| num_layers, | |
| num_heads, | |
| ffn_hidden_dim, | |
| max_seq_len, | |
| dropout=0.1, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.max_seq_len = max_seq_len | |
| self.token_embeddings = nn.Embedding(vocab_size, embed_dim) | |
| # Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions | |
| self.dropout = nn.Dropout(dropout) | |
| self.layers = nn.ModuleList( | |
| [ | |
| DiffTransformerBlock( | |
| embed_dim, num_heads, depth, ffn_hidden_dim, dropout | |
| ) | |
| for depth in range(num_layers) | |
| ] | |
| ) | |
| self.norm_out = nn.LayerNorm(embed_dim) | |
| self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False) | |
| # Tie weights | |
| self.token_embeddings.weight = self.lm_head.weight | |
| # RoPE precomputation | |
| # The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2 | |
| self.rope_head_dim = embed_dim // num_heads // 2 | |
| cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim) | |
| self.register_buffer("cos_emb", cos_emb, persistent=False) | |
| self.register_buffer("sin_emb", sin_emb, persistent=False) | |
| def forward(self, input_ids, attn_mask=None): | |
| batch_size, seq_len = input_ids.shape | |
| x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim) | |
| x = self.dropout(x) | |
| # Ensure RoPE embeddings are on the same device *and* dtype as activations | |
| rel_pos = ( | |
| self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype), | |
| self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype), | |
| ) | |
| # Create causal attention mask if not provided | |
| if attn_mask is None: | |
| # Standard causal mask for autoregressive decoding | |
| # MultiheadDiffAttn expects a mask where -inf indicates masked positions | |
| causal_mask = torch.triu( | |
| torch.ones(seq_len, seq_len, device=x.device) * float("-inf"), | |
| diagonal=1, | |
| ) | |
| else: | |
| # If a custom mask is provided (e.g., for padding), ensure it's correctly formatted | |
| # For MultiheadDiffAttn, 0 means attend, -inf means mask. | |
| # Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face) | |
| # We need to convert it: (1 - attn_mask) * -inf | |
| # However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding. | |
| # For simplicity, let's assume the provided attn_mask is already in the correct format if not None. | |
| # If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it. | |
| # Let's stick to causal mask for now, padding handled by loss_fn ignore_index. | |
| causal_mask = torch.triu( | |
| torch.ones(seq_len, seq_len, device=x.device) * float("-inf"), | |
| diagonal=1, | |
| ) | |
| for layer in self.layers: | |
| x = layer(x, rel_pos, attn_mask=causal_mask) | |
| x = self.norm_out(x) | |
| logits = self.lm_head(x) | |
| return logits | |
| def count_parameters(self): | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |