import torch import torch.nn as nn import math from .optimized_diffattn import MultiheadDiffAttn # --- Tokenizer Definition --- # Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + IM_START_TOKEN = "<|im_start|>" IM_END_TOKEN = "<|im_end|>" PAD_TOKEN = "" SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN] VOCAB_SIZE = 256 + len(SPECIAL_TOKENS) # Create token to id mapping token_to_id = {} id_to_token = {} for i in range(256): token_to_id[bytes([i])] = i id_to_token[i] = bytes([i]) for i, token_str in enumerate(SPECIAL_TOKENS): token_id = 256 + i token_to_id[token_str] = token_id id_to_token[token_id] = token_str PAD_ID = token_to_id[PAD_TOKEN] IM_START_ID = token_to_id[IM_START_TOKEN] IM_END_ID = token_to_id[IM_END_TOKEN] class ByteTokenizer: def __init__(self): self.token_to_id = token_to_id self.id_to_token = id_to_token self.vocab_size = VOCAB_SIZE self.pad_id = PAD_ID self.im_start_id = IM_START_ID self.im_end_id = IM_END_ID def encode(self, text_bytes: bytes, add_special_tokens=True): ids = [self.token_to_id[bytes([b])] for b in text_bytes] if add_special_tokens: return [self.im_start_id] + ids + [self.im_end_id] return ids def decode(self, ids: list[int]): tokens = [] for i in ids: token = self.id_to_token.get(i) if token is None: # Handle unknown token ID if necessary, or raise error tokens.append(b"?") # Placeholder for unknown elif isinstance(token, bytes): tokens.append(token) # Ignore special tokens for decoding to raw text, or handle as needed return b"".join(tokens) # --- RoPE Embeddings --- (Reused from previous script) def get_rotary_embeddings(seq_len, dim_model, theta=10000.0): if dim_model % 2 != 0: raise ValueError(f"dim_model must be even, got {dim_model}") position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model) ) angles = position * div_term cos_emb = torch.cos(angles) sin_emb = torch.sin(angles) return cos_emb, sin_emb # --- Model Definition --- class FeedForward(nn.Module): def __init__(self, embed_dim, hidden_dim, dropout=0.1): super().__init__() self.fc1 = nn.Linear(embed_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, embed_dim) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() def forward(self, x): return self.fc2(self.dropout(self.act(self.fc1(x)))) class DiffTransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1): super().__init__() self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout) self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, rel_pos, attn_mask=None): # Pre-norm attn_out = self.attn(self.norm1(x), rel_pos, attn_mask) x = x + self.dropout(attn_out) ffn_out = self.ffn(self.norm2(x)) x = x + self.dropout(ffn_out) return x class DiffTransformerLLM(nn.Module): def __init__( self, vocab_size, embed_dim, num_layers, num_heads, ffn_hidden_dim, max_seq_len, dropout=0.1, ): super().__init__() self.embed_dim = embed_dim self.max_seq_len = max_seq_len self.token_embeddings = nn.Embedding(vocab_size, embed_dim) # Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions self.dropout = nn.Dropout(dropout) self.layers = nn.ModuleList( [ DiffTransformerBlock( embed_dim, num_heads, depth, ffn_hidden_dim, dropout ) for depth in range(num_layers) ] ) self.norm_out = nn.LayerNorm(embed_dim) self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False) # Tie weights self.token_embeddings.weight = self.lm_head.weight # RoPE precomputation # The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2 self.rope_head_dim = embed_dim // num_heads // 2 cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim) self.register_buffer("cos_emb", cos_emb, persistent=False) self.register_buffer("sin_emb", sin_emb, persistent=False) def forward(self, input_ids, attn_mask=None): batch_size, seq_len = input_ids.shape x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim) x = self.dropout(x) # Ensure RoPE embeddings are on the same device *and* dtype as activations rel_pos = ( self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype), self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype), ) # Create causal attention mask if not provided if attn_mask is None: # Standard causal mask for autoregressive decoding # MultiheadDiffAttn expects a mask where -inf indicates masked positions causal_mask = torch.triu( torch.ones(seq_len, seq_len, device=x.device) * float("-inf"), diagonal=1, ) else: # If a custom mask is provided (e.g., for padding), ensure it's correctly formatted # For MultiheadDiffAttn, 0 means attend, -inf means mask. # Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face) # We need to convert it: (1 - attn_mask) * -inf # However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding. # For simplicity, let's assume the provided attn_mask is already in the correct format if not None. # If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it. # Let's stick to causal mask for now, padding handled by loss_fn ignore_index. causal_mask = torch.triu( torch.ones(seq_len, seq_len, device=x.device) * float("-inf"), diagonal=1, ) for layer in self.layers: x = layer(x, rel_pos, attn_mask=causal_mask) x = self.norm_out(x) logits = self.lm_head(x) return logits def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad)