Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from typing import Optional, Tuple | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from ..config import DiaConfig | |
| from .cache import KVCache | |
| from .precision import Precision | |
| from .layers import ( | |
| AttentionShape, | |
| MultiStreamEmbedding, | |
| Mlp, | |
| Attention, | |
| ) | |
| class TransformerDecoder(nn.Module): | |
| """Inference-time port of dia_v2.model.Transformer.""" | |
| def __init__(self, config: DiaConfig, precision: Precision): | |
| super().__init__() | |
| self.config = config | |
| self.precision = precision | |
| data_cfg = config.data | |
| dec_cfg = config.model.decoder | |
| self.audio_embeds = nn.ModuleList( | |
| [ | |
| nn.Embedding( | |
| data_cfg.audio_vocab_size, | |
| dec_cfg.n_embd, | |
| ) | |
| for _ in range(max(0, data_cfg.channels - 2)) | |
| ] | |
| ) | |
| self.text_embed = MultiStreamEmbedding( | |
| data_cfg.text_vocab_size, | |
| dec_cfg.n_embd, | |
| pad_id=data_cfg.text_pad_token_id, | |
| output_dtype=self.precision.compute, | |
| low_rank_dim=dec_cfg.low_rank_dim, | |
| ) | |
| self.layers = nn.ModuleList([DecoderLayer(config, precision) for _ in range(dec_cfg.n_layer)]) | |
| self.norm = nn.RMSNorm(dec_cfg.n_embd, eps=config.model.normalization_layer_epsilon, dtype=torch.float32) | |
| self.action_head = nn.Linear(dec_cfg.n_embd, data_cfg.action_vocab_size, bias=False) | |
| self.cb0_head = nn.Linear(dec_cfg.n_embd, data_cfg.audio_vocab_size, bias=False) | |
| def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache: | |
| heads = self.layers[0].attn.num_kv_heads | |
| head_dim = self.layers[0].attn.head_dim | |
| return KVCache.allocate( | |
| num_layers=len(self.layers), | |
| batch_size=batch_size, | |
| heads=heads, | |
| max_steps=max_steps, | |
| head_dim=head_dim, | |
| device=device, | |
| dtype=self.precision.compute, | |
| ) | |
| def forward_step( | |
| self, | |
| tokens: torch.Tensor, | |
| positions: torch.Tensor, | |
| cache: KVCache, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, KVCache]: | |
| if cache is None: | |
| raise ValueError("Transformer cache must be initialized") | |
| B, C, T1 = tokens.shape | |
| if T1 != 1: | |
| raise ValueError("forward_step expects sequence length 1") | |
| num_audio_channels = max(0, C - 2) | |
| hidden_t = self.text_embed(tokens[:, 0, :], tokens[:, 1, :]) | |
| for idx in range(num_audio_channels): | |
| audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :]) | |
| hidden_t.add_(audio_emb) | |
| hidden_t = hidden_t.to(self.precision.compute) | |
| x = hidden_t | |
| for idx, layer in enumerate(self.layers): | |
| slot = cache.get_slot(idx) | |
| x, _ = layer.decode_step(x, positions, slot) | |
| hidden_norm = self.norm(x) | |
| action_logits = self.action_head(hidden_norm.to(torch.float32)).to(self.precision.logits) | |
| cb0_logits = self.cb0_head(hidden_norm.to(torch.float32)).to(self.precision.logits) | |
| return hidden_norm, action_logits, cb0_logits, cache | |
| def _embed(self, tokens: torch.Tensor) -> torch.Tensor: | |
| B, C, T1 = tokens.shape | |
| if T1 != 1: | |
| raise ValueError("_embed expects sequence length 1") | |
| num_audio_channels = max(0, C - 2) | |
| text_hidden = self.text_embed(tokens[:, 0, :], tokens[:, 1, :]) | |
| audio_terms: list[torch.Tensor] = [] | |
| for idx in range(num_audio_channels): | |
| audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :]) | |
| audio_terms.append(audio_emb) | |
| hidden = text_hidden | |
| for term in audio_terms: | |
| hidden = hidden + term | |
| final = hidden.to(self.precision.compute) | |
| return final | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, config: DiaConfig, precision: Precision): | |
| super().__init__() | |
| dec = config.model.decoder | |
| eps = config.model.normalization_layer_epsilon | |
| self.pre_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32) | |
| self.attn = Attention(config, dec.n_embd, precision.compute) | |
| self.post_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32) | |
| self.mlp = Mlp( | |
| dec.n_embd, | |
| dec.n_hidden, | |
| precision.compute, | |
| tuple(config.model.linear.mlp_activations), | |
| ) | |
| def decode_step( | |
| self, | |
| x: torch.Tensor, | |
| pos: torch.Tensor, | |
| cache_slot, | |
| ) -> Tuple[torch.Tensor, object]: | |
| residual = x | |
| x_norm = self.pre_norm(x) | |
| attn_out, _ = self.attn(x_norm, pos, cache_slot) | |
| x = residual + attn_out | |
| residual2 = x | |
| x_norm2 = self.post_norm(x) | |
| mlp_out = self.mlp(x_norm2) | |
| return residual2 + mlp_out, cache_slot | |