from __future__ import annotations from typing import Optional, Tuple import torch from torch import nn import torch.nn.functional as F from ..config import DiaConfig from .cache import KVCache from .layers import MultiStreamEmbedding, Mlp, RotaryEmbedding from .precision import Precision class ScheduleAttention(nn.Module): """Depformer attention that mirrors dia_v2 ScheduleAttention.""" def __init__(self, config: DiaConfig, compute_dtype: torch.dtype) -> None: super().__init__() dep_cfg = config.model.depformer runtime = config.runtime self.schedule = runtime.weights_schedule self.num_query_heads = dep_cfg.gqa_query_heads self.num_kv_heads = dep_cfg.kv_heads self.head_dim = dep_cfg.gqa_head_dim self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1) self.apply_rope = dep_cfg.apply_rope self.used_ids = sorted(set(self.schedule)) self.compute_dtype = compute_dtype self.in_proj = nn.ModuleDict( { str(i): nn.Linear( dep_cfg.n_embd, 3 * self.num_query_heads * self.head_dim, bias=False, ) for i in self.used_ids } ) self.out_proj = nn.ModuleDict( { str(i): nn.Linear( self.num_query_heads * self.head_dim, dep_cfg.n_embd, bias=False, ) for i in self.used_ids } ) eps = config.model.normalization_layer_epsilon self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32) self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32) if self.apply_rope: self.rotary = RotaryEmbedding( self.head_dim, config.model.rope_min_timescale, config.model.rope_max_timescale, ) stage_count = max(len(self.schedule), 1) self.register_buffer( "stage_positions", torch.arange(stage_count, dtype=torch.long).view(stage_count, 1), persistent=False, ) else: self.rotary = None self.register_buffer( "stage_positions", torch.zeros(0, 1, dtype=torch.long), persistent=False, ) def forward_incremental( self, x_t: torch.Tensor, stage_index: int, cache_slot, ) -> Tuple[torch.Tensor, object]: bsz, seq, _ = x_t.shape if seq != 1: raise ValueError("ScheduleAttention expects seq len 1 during decoding") orig_dtype = x_t.dtype module_index = self.schedule[stage_index] proj = self.in_proj[str(module_index)](x_t.to(torch.float32)) proj = proj.view(bsz, seq, 3, self.num_query_heads, self.head_dim).to(self.compute_dtype) q_proj = self.q_norm(proj[:, :, 0]) k_proj = self.k_norm(proj[:, :, 1]) v_proj = proj[:, :, 2] if self.apply_rope: pos_ids = self.stage_positions[stage_index : stage_index + 1] if pos_ids.device != x_t.device: pos_ids = pos_ids.to(x_t.device) q_proj = self.rotary(q_proj, pos_ids) k_proj = self.rotary(k_proj, pos_ids) q = q_proj.transpose(1, 2) k = k_proj.transpose(1, 2) v = v_proj.transpose(1, 2) if cache_slot is not None: k, v, attn_mask = cache_slot.write_and_view(k, v) else: attn_mask = None attn = F.scaled_dot_product_attention( q, k, v, scale=1.0, attn_mask=attn_mask, enable_gqa=self.num_gqa_groups > 1, ) attn = attn.transpose(1, 2).contiguous() flat = attn.reshape(bsz, seq, self.num_query_heads * self.head_dim) out = self.out_proj[str(module_index)](flat.to(torch.float32)) return out.to(orig_dtype), cache_slot class DepformerLayer(nn.Module): def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): super().__init__() dep_cfg = config.model.depformer eps = config.model.normalization_layer_epsilon self.pre_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32) self.post_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32) self.self_attention = ScheduleAttention(config, compute_dtype) self.mlp = Mlp( dep_cfg.n_embd, dep_cfg.n_hidden, compute_dtype, tuple(config.model.depformer.mlp_activations), ) def decode_step( self, x_t: torch.Tensor, stage_index: int, cache_slot, ) -> Tuple[torch.Tensor, object]: residual = x_t x_norm = self.pre_norm(x_t) sa_out, _ = self.self_attention.forward_incremental(x_norm, stage_index, cache_slot) x = residual + sa_out residual2 = x x_norm2 = self.post_norm(x) mlp_out = self.mlp(x_norm2) return residual2 + mlp_out, cache_slot class Depformer(nn.Module): def __init__(self, config: DiaConfig, precision: Precision): super().__init__() self.config = config self.precision = precision dep_cfg = config.model.depformer data_cfg = config.data runtime = config.runtime self.num_audio_channels = max(0, data_cfg.channels - 2) self.num_depth = max(self.num_audio_channels - 1, 0) self.weights_schedule = runtime.weights_schedule self.audio_embeds = nn.ModuleList( [nn.Embedding(data_cfg.audio_vocab_size, dep_cfg.n_embd) for _ in range(self.num_depth)] ) if dep_cfg.text_embedding: self.text_embed = MultiStreamEmbedding( data_cfg.text_vocab_size, dep_cfg.n_embd, pad_id=data_cfg.text_pad_token_id, output_dtype=precision.compute, ) else: self.text_embed = None used_ids = sorted(set(self.weights_schedule)) self.depformer_in = nn.ModuleDict( { str(i): nn.Linear( config.model.decoder.n_embd, dep_cfg.n_embd, bias=False, ) for i in used_ids } ) self.layers = nn.ModuleList([DepformerLayer(config, precision.compute) for _ in range(dep_cfg.n_layer)]) self.norm = nn.RMSNorm(dep_cfg.n_embd, eps=config.model.normalization_layer_epsilon) self.logits_dtype = precision.logits self.logits = nn.ModuleList( [ nn.Linear(dep_cfg.n_embd, data_cfg.audio_vocab_size, bias=False) for _ in range(self.num_depth) ] ) self.audio_vocab_limit = min(data_cfg.audio_pad_token_id, data_cfg.audio_bos_token_id) def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache: heads = self.layers[0].self_attention.num_kv_heads head_dim = self.layers[0].self_attention.head_dim return KVCache.allocate( num_layers=len(self.layers), batch_size=batch_size, heads=heads, max_steps=max_steps, head_dim=head_dim, device=device, dtype=self.precision.compute, ) def forward_step( self, prev_audio: torch.Tensor, transformer_out: torch.Tensor, stage_index: int, cache: KVCache, main_text: Optional[torch.Tensor], second_text: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, KVCache]: self._validate_inputs(stage_index, cache) return self._forward_stage(stage_index, prev_audio, transformer_out, cache, main_text, second_text) def _forward_stage( self, stage_index: int, prev_audio: torch.Tensor, transformer_out: torch.Tensor, cache: KVCache, main_text: Optional[torch.Tensor], second_text: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, KVCache]: prev_audio = prev_audio.long() weight_idx = self.weights_schedule[stage_index] token_emb = self.audio_embeds[stage_index](prev_audio[:, None]).to(self.precision.compute) if stage_index == 0 and self.text_embed is not None: if main_text is None or second_text is None: raise ValueError("stage 0 requires text tokens") token_emb = token_emb + self.text_embed(main_text[:, None], second_text[:, None]) dep_in = self.depformer_in[str(weight_idx)](transformer_out.to(torch.float32)) dep_in = dep_in.to(self.precision.compute) dep_in = dep_in + token_emb.to(dep_in.dtype) x = dep_in for idx, layer in enumerate(self.layers): slot = cache.get_slot(idx) x, _ = layer.decode_step(x, stage_index, slot) hidden = self.norm(x) logits = self.logits[stage_index](hidden.to(torch.float32)) logits = logits.to(self.logits_dtype) logits = logits.unsqueeze(1) logits = logits[..., : self.audio_vocab_limit] return logits, cache def _validate_inputs(self, stage_index: int, cache: KVCache | None) -> None: if stage_index < 0 or stage_index >= self.num_depth: raise ValueError(f"stage_index {stage_index} out of range (depth={self.num_depth})") if cache is None: raise ValueError("depformer cache must be initialized")