|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Union |
|
|
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.generation.utils import GenerationMixin |
|
|
| from .shared_space_config import SharedSpaceDecoderConfig |
| from .shared_space_decoder import ( |
| SharedSpaceDecoderPreTrainedModel, |
| SharedSpaceDecoderModel, |
| DeepseekV3RMSNorm |
| ) |
|
|
| def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module: |
| """ |
| Create a normalization layer based on the config norm_type. |
| |
| Args: |
| hidden_size: The dimension to normalize over |
| config: Configuration containing norm_type and epsilon values |
| |
| Returns: |
| Either a LayerNorm or RMSNorm layer |
| """ |
| if config.norm_type == "layernorm": |
| return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) |
| elif config.norm_type == "rmsnorm": |
| from .shared_space_decoder import DeepseekV3RMSNorm |
| return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps) |
| else: |
| |
| raise ValueError(f"Unknown norm_type: {config.norm_type}") |
|
|
|
|
| class SharedSpaceDecoderForCausalLM(GenerationMixin, SharedSpaceDecoderPreTrainedModel): |
| """ |
| Subspace Decoder model with a causal language modeling head. |
| |
| This model extends the SharedSpaceDecoderModel with: |
| - A language modeling head that projects hidden states to vocabulary logits |
| - Support for computing cross-entropy loss for language modeling |
| - Proper HuggingFace compatibility for causal language modeling tasks |
| - Decoder-specific initialization strategies |
| |
| The model can be used for: |
| - Text generation |
| - Language modeling pretraining |
| - Fine-tuning on downstream tasks |
| """ |
|
|
| def __init__(self, config: SharedSpaceDecoderConfig) -> None: |
| super().__init__(config) |
| |
| |
| self.model = SharedSpaceDecoderModel(config) |
| |
| |
| self.norm = create_norm_layer(config.hidden_size, config) |
| |
| |
| |
| self.lm_head = nn.Linear( |
| config.hidden_size, |
| config.vocab_size, |
| bias=False |
| ) |
| |
| |
| |
| self.post_init() |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| """ |
| Decoder-specific weight initialization with special handling for language modeling head. |
| |
| Key differences from encoder initialization: |
| - Language modeling head gets specialized initialization for stability |
| - Configurable normalization layers (LayerNorm or RMSNorm) are properly handled |
| - Weight tying considerations for embedding/lm_head relationship |
| """ |
| |
| |
| super()._init_weights(module) |
| |
| |
| if module is self.lm_head: |
| |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| |
| |
| if self.model.vocab_proj is not None: |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range * 0.5) |
|
|
| def get_input_embeddings(self): |
| """Return the input embedding layer for compatibility with HuggingFace.""" |
| return self.model.vocab_embed |
|
|
| def set_input_embeddings(self, value): |
| """Set the input embedding layer for compatibility with HuggingFace.""" |
| self.model.vocab_embed = value |
|
|
| def get_output_embeddings(self): |
| """Return the output embedding layer (lm_head) for compatibility.""" |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| """Set the output embedding layer for compatibility.""" |
| self.lm_head = new_embeddings |
|
|
| def tie_weights(self): |
| """ |
| Tie the input and output embedding weights. |
| |
| This method sets the language modeling head's weight to be the same as |
| the input embedding weight. This reduces the number of parameters and |
| is a common practice in modern language models. |
| |
| Note: For vocab subspace models, we need to handle the case where |
| input embeddings go through a projection layer. |
| """ |
| |
| if getattr(self.model, "vocab_proj", None) is None: |
| |
| self._tie_or_clone_weights(self.lm_head, self.model.vocab_embed) |
| |
|
|
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> Union[CausalLMOutputWithPast, tuple]: |
| """ |
| Forward pass for causal language modeling. |
| |
| Args: |
| input_ids: Token ids of shape [batch_size, seq_len] |
| attention_mask: Attention mask of shape [batch_size, seq_len] |
| (1 for real tokens, 0 for padding) |
| labels: Ground truth token ids for computing loss. Same shape as input_ids. |
| If provided, loss will be computed. Typically input_ids shifted by 1. |
| |
| Returns: |
| CausalLMOutputWithPast containing: |
| - logits: Prediction logits of shape [batch_size, seq_len, vocab_size] |
| - loss: Cross-entropy loss if labels provided, else None |
| - hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size] |
| """ |
|
|
| |
| |
| |
| |
| if attention_mask is None and input_ids is not None: |
| |
| |
| attention_mask = torch.ones( |
| (input_ids.size(0), input_ids.size(1)), |
| dtype=torch.long, |
| device=input_ids.device, |
| ) |
| |
| |
| |
| |
| hidden_states = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **kwargs |
| ) |
| |
| |
| |
| hidden_states = self.norm(hidden_states) |
| |
| |
| |
| logits = self.lm_head(hidden_states) |
| |
| |
| |
| |
| loss = None |
| if labels is not None: |
| |
| loss = self.loss_function( |
| logits, |
| labels, |
| vocab_size=self.config.vocab_size, |
| **kwargs, |
| ) |
| |
| |
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| |
| hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None, |
| attentions=None, |
| ) |
| |
| |
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| **kwargs, |
| ): |
| |
| return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
| |
| def _reorder_cache(self, past_key_values, beam_idx): |
| return past_key_values |
|
|
|
|