Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| import torch | |
| from ..core.cache import KVCache | |
| from ..core.model import DecodeState | |
| from ..generation import GenerationConfig | |
| from ..audio.grid import delay_frames, mask_audio_logits, undelay_frames | |
| from .context import RuntimeContext | |
| from .state_machine import State, TokenIds | |
| from .guidance import apply_classifier_guidance, sample_audio_logits | |
| from .sampler import sample_token | |
| from .voice_clone import PrefixPlan | |
| from .logger import RuntimeLogger | |
| _GRAPH_CUBLAS_READY = False | |
| def _ensure_graph_cublas_ready(device: torch.device) -> None: | |
| global _GRAPH_CUBLAS_READY | |
| if _GRAPH_CUBLAS_READY or device.type != "cuda": | |
| return | |
| tmp = torch.empty((1, 1), device=device, dtype=torch.float32) | |
| torch.matmul(tmp, tmp) | |
| torch.cuda.synchronize() | |
| _GRAPH_CUBLAS_READY = True | |
| class GenerationState: | |
| decode: DecodeState | |
| step_tokens: torch.Tensor | |
| audio_buf: torch.Tensor | |
| def trim_audio(self, limit: int, pad_token: int, ungenerated: int) -> torch.Tensor: | |
| trimmed = self.audio_buf[:, :, :limit] | |
| pad = torch.full_like(trimmed, pad_token) | |
| trimmed = torch.where(trimmed == ungenerated, pad, trimmed) | |
| self.audio_buf = trimmed | |
| return trimmed | |
| def transformer_cache(self) -> KVCache: | |
| return self.decode.transformer | |
| def transformer_cache(self, cache: KVCache) -> None: | |
| self.decode.transformer = cache | |
| def depformer_cache(self) -> KVCache: | |
| return self.decode.depformer | |
| def depformer_cache(self, cache: KVCache) -> None: | |
| self.decode.depformer = cache | |
| def reset_dep_cache(self) -> None: | |
| self.decode.depformer.reset() | |
| class NetworkBuffers: | |
| text: torch.Tensor | |
| cb0: torch.Tensor | |
| dep: list[torch.Tensor] | |
| def _allocate_network_buffers(runtime: RuntimeContext, branches: int) -> NetworkBuffers: | |
| device = runtime.device | |
| logits_dtype = runtime.precision.logits | |
| data_cfg = runtime.config.data | |
| text_logits = torch.empty((branches, 1, data_cfg.action_vocab_size), dtype=logits_dtype, device=device) | |
| cb0_logits = torch.empty((branches, 1, data_cfg.audio_vocab_size), dtype=logits_dtype, device=device) | |
| dep_vocab = runtime.model.depformer.audio_vocab_limit or data_cfg.audio_vocab_size | |
| dep_logits = [ | |
| torch.empty((branches, 1, 1, dep_vocab), dtype=logits_dtype, device=device) | |
| for _ in range(runtime.model.depformer.num_depth) | |
| ] | |
| return NetworkBuffers(text=text_logits, cb0=cb0_logits, dep=dep_logits) | |
| def build_initial_state( | |
| runtime: RuntimeContext, | |
| *, | |
| prefix: PrefixPlan | None = None, | |
| ) -> GenerationState: | |
| dep_q = runtime.model.depformer.num_audio_channels | |
| channels = 2 + dep_q | |
| branches = 2 | |
| token_ids = runtime.constants | |
| step_tokens = torch.full( | |
| (branches, channels, 1), | |
| token_ids.pad, | |
| dtype=torch.long, | |
| device=runtime.device, | |
| ) | |
| step_tokens[0, 0, 0] = token_ids.bos | |
| step_tokens[0, 1, 0] = token_ids.pad | |
| step_tokens[1, 0, 0] = token_ids.zero | |
| step_tokens[1, 1, 0] = token_ids.pad | |
| prefix_len = 0 | |
| if prefix is not None: | |
| delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad) | |
| prefix_len = delayed.shape[1] | |
| limit = runtime.config.runtime.max_context_steps | |
| total_steps = max(limit + prefix_len + 1, limit) | |
| decode_state = runtime.model.init_state(branches, runtime.device, total_steps) | |
| audio_buf = torch.full( | |
| (branches, dep_q, total_steps), | |
| token_ids.ungenerated, | |
| dtype=torch.long, | |
| device=runtime.device, | |
| ) | |
| if prefix is not None: | |
| delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad).to(runtime.device) | |
| audio_buf[0, :, : delayed.shape[1]] = delayed | |
| if branches > 1: | |
| audio_buf[1:, :, : delayed.shape[1]] = delayed | |
| return GenerationState(decode_state, step_tokens, audio_buf) | |
| def _fill_audio_channels( | |
| step_tokens: torch.Tensor, | |
| audio_buf: torch.Tensor, | |
| delays: torch.Tensor, | |
| step: int, | |
| bos_token: int, | |
| ) -> None: | |
| channels = delays.numel() | |
| if channels == 0: | |
| return | |
| target = step_tokens[:, 2 : 2 + channels, 0] | |
| if step < audio_buf.shape[-1]: | |
| target.copy_(audio_buf[:, :channels, step]) | |
| else: | |
| target.fill_(bos_token) | |
| mask = delays > step | |
| if mask.any().item(): | |
| target[:, mask] = bos_token | |
| def _execute_transformer_step( | |
| step_tokens: torch.Tensor, | |
| positions_view: torch.Tensor, | |
| generation: GenerationState, | |
| transformer_step, | |
| buffers: NetworkBuffers, | |
| ) -> torch.Tensor: | |
| hidden_t, text_logits_t, cb0_logits_t, present = transformer_step( | |
| step_tokens, | |
| positions_view, | |
| generation.transformer_cache, | |
| ) | |
| buffers.text.copy_(text_logits_t) | |
| buffers.cb0.copy_(cb0_logits_t) | |
| generation.transformer_cache = present | |
| return hidden_t | |
| def _execute_depformer_stage( | |
| stage_index: int, | |
| prev_audio: torch.Tensor, | |
| hidden_t: torch.Tensor, | |
| generation: GenerationState, | |
| depformer_step, | |
| main_tokens: Optional[torch.Tensor], | |
| second_tokens: Optional[torch.Tensor], | |
| buffers: NetworkBuffers, | |
| ) -> None: | |
| logits_stage, dep_present = depformer_step( | |
| prev_audio=prev_audio, | |
| transformer_out=hidden_t, | |
| stage_index=stage_index, | |
| cache=generation.depformer_cache, | |
| main_text=main_tokens if stage_index == 0 else None, | |
| second_text=second_tokens if stage_index == 0 else None, | |
| ) | |
| target = buffers.dep[stage_index] | |
| if logits_stage.shape != target.shape: | |
| raise RuntimeError( | |
| f"depformer logits shape mismatch: {logits_stage.shape} vs {target.shape}" | |
| ) | |
| target.copy_(logits_stage) | |
| generation.depformer_cache = dep_present | |
| def run_generation_loop( | |
| runtime: RuntimeContext, | |
| *, | |
| state: State, | |
| generation: GenerationState, | |
| config: GenerationConfig, | |
| start_step: int = 0, | |
| logger: RuntimeLogger | None = None, | |
| ) -> tuple[Optional[int], torch.Tensor]: | |
| step_tokens = generation.step_tokens | |
| audio_buf = generation.audio_buf | |
| branches = step_tokens.shape[0] | |
| max_context = runtime.config.runtime.max_context_steps | |
| if max_context <= 0: | |
| raise ValueError("Runtime configuration must specify a positive max_context_steps") | |
| positions = torch.empty(1, 1, dtype=torch.long, device=runtime.device) | |
| main_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device) | |
| aux_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device) | |
| cfg_active = config.cfg_scale != 1.0 | |
| token_ids = runtime.constants | |
| delay_tensor = runtime.audio_delay_tensor | |
| max_delay = int(delay_tensor.max().item()) if delay_tensor.numel() else 0 | |
| flush_tail = max_delay + getattr(runtime.machine, "max_padding", 0) | |
| first_word_frame: Optional[int] = None | |
| eos_cutoff: Optional[int] = None | |
| last_step = start_step - 1 | |
| use_graph = bool(config.use_cuda_graph and runtime.device.type == "cuda") | |
| transformer_step = runtime.transformer_step | |
| depformer_step = runtime.depformer_step | |
| buffers = _allocate_network_buffers(runtime, branches) | |
| positions_view = positions.expand(branches, -1) | |
| transformer_capture = None | |
| dep_captures: list[dict] | None = None | |
| if use_graph: | |
| _ensure_graph_cublas_ready(runtime.device) | |
| processed_steps = 0 | |
| report_interval = 12 | |
| with torch.inference_mode(): | |
| for offset in range(max_context): | |
| t = start_step + offset | |
| if eos_cutoff is not None and t >= eos_cutoff: | |
| break | |
| if t + 1 >= audio_buf.shape[-1]: | |
| break | |
| generation.reset_dep_cache() | |
| positions.fill_(t) | |
| _fill_audio_channels(step_tokens, audio_buf, delay_tensor, t, token_ids.audio_bos) | |
| if branches > 1: | |
| step_tokens[1:, 0, 0] = token_ids.zero | |
| step_tokens[1:, 1, 0] = token_ids.pad | |
| if use_graph: | |
| if transformer_capture is None: | |
| torch.cuda.synchronize() | |
| graph = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(graph): | |
| hidden_ref = _execute_transformer_step( | |
| step_tokens, | |
| positions_view, | |
| generation, | |
| transformer_step, | |
| buffers, | |
| ) | |
| transformer_capture = (graph, hidden_ref) | |
| if runtime.model.depformer.num_depth > 0: | |
| dep_captures = [] | |
| for idx in range(runtime.model.depformer.num_depth): | |
| capture = { | |
| "graph": torch.cuda.CUDAGraph(), | |
| "captured": False, | |
| "prev_audio": torch.empty((branches,), dtype=torch.long, device=runtime.device), | |
| "main_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None, | |
| "second_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None, | |
| } | |
| dep_captures.append(capture) | |
| else: | |
| transformer_capture[0].replay() | |
| hidden_t = transformer_capture[1] | |
| else: | |
| hidden_t = _execute_transformer_step( | |
| step_tokens, | |
| positions_view, | |
| generation, | |
| transformer_step, | |
| buffers, | |
| ) | |
| guided_text = apply_classifier_guidance(buffers.text, cfg_active, config.cfg_scale, config.cfg_filter_k) | |
| if guided_text.shape[0] > 1: | |
| guided_text = guided_text[:1] | |
| text_token = sample_token( | |
| guided_text, | |
| temp=config.text.temperature, | |
| top_k=config.text.top_k, | |
| ).item() | |
| main_token, aux_token, _ = runtime.machine.process(t, state, text_token) | |
| second_token = aux_token if aux_token != -1 else token_ids.pad | |
| if first_word_frame is None and main_token == token_ids.new_word: | |
| first_word_frame = t - config.initial_padding | |
| step_tokens[:, 0, 0] = main_token | |
| step_tokens[:, 1, 0] = second_token | |
| guided_cb0 = apply_classifier_guidance(buffers.cb0, cfg_active, config.cfg_scale, config.cfg_filter_k) | |
| if guided_cb0.shape[0] > 1: | |
| guided_cb0 = guided_cb0[:1] | |
| masked_cb0 = mask_audio_logits(guided_cb0, token_ids.audio_pad, token_ids.audio_bos) | |
| codebook_token = sample_audio_logits(masked_cb0, config.audio.temperature, config.audio.top_k) | |
| audio_buf[:, 0, t + 1] = codebook_token | |
| prev_audio = codebook_token.expand(branches) | |
| main_tokens.fill_(main_token) | |
| aux_tokens.fill_(second_token) | |
| for stage in range(runtime.model.depformer.num_depth): | |
| if use_graph and dep_captures is not None: | |
| capture = dep_captures[stage] | |
| capture["prev_audio"].copy_(prev_audio) | |
| if capture["main_tokens"] is not None and stage == 0: | |
| capture["main_tokens"].copy_(main_tokens) | |
| capture["second_tokens"].copy_(aux_tokens) | |
| if not capture["captured"]: | |
| torch.cuda.synchronize() | |
| with torch.cuda.graph(capture["graph"]): | |
| _execute_depformer_stage( | |
| stage_index=stage, | |
| prev_audio=capture["prev_audio"], | |
| hidden_t=hidden_t, | |
| generation=generation, | |
| depformer_step=depformer_step, | |
| main_tokens=capture["main_tokens"], | |
| second_tokens=capture["second_tokens"], | |
| buffers=buffers, | |
| ) | |
| capture["captured"] = True | |
| else: | |
| capture["graph"].replay() | |
| else: | |
| _execute_depformer_stage( | |
| stage_index=stage, | |
| prev_audio=prev_audio, | |
| hidden_t=hidden_t, | |
| generation=generation, | |
| depformer_step=depformer_step, | |
| main_tokens=main_tokens, | |
| second_tokens=aux_tokens, | |
| buffers=buffers, | |
| ) | |
| dep_logits = apply_classifier_guidance(buffers.dep[stage], cfg_active, config.cfg_scale, config.cfg_filter_k) | |
| if dep_logits.shape[0] > 1: | |
| dep_logits = dep_logits[:1] | |
| stage_token = sample_audio_logits( | |
| dep_logits, | |
| config.audio.temperature, | |
| config.audio.top_k, | |
| ) | |
| audio_buf[:, stage + 1, t + 1] = stage_token | |
| prev_audio = stage_token.expand(branches) | |
| last_step = t | |
| if eos_cutoff is None and state.end_step is not None: | |
| eos_cutoff = state.end_step + flush_tail | |
| processed_steps = offset + 1 | |
| if logger and processed_steps % report_interval == 0: | |
| logger.progress(processed_steps, max_context) | |
| if logger and processed_steps and processed_steps % report_interval != 0: | |
| logger.progress(processed_steps, max_context) | |
| if first_word_frame is None: | |
| first_word_frame = start_step | |
| if last_step < start_step: | |
| limit = min(start_step + 1, audio_buf.shape[-1]) | |
| else: | |
| limit = min(last_step + 2, audio_buf.shape[-1]) | |
| trimmed = generation.trim_audio(limit, token_ids.audio_pad, token_ids.ungenerated) | |
| return first_word_frame, trimmed | |
| def decode_audio(runtime: RuntimeContext, tokens: torch.Tensor) -> torch.Tensor: | |
| if tokens.shape[-1] == 0: | |
| return torch.zeros(0, device=runtime.device) | |
| with torch.inference_mode(): | |
| pcm = runtime.mimi.decode(tokens.to(runtime.device)) | |
| return pcm[0, 0] | |
| def warmup_with_prefix( | |
| runtime: RuntimeContext, | |
| plan: PrefixPlan, | |
| state: State, | |
| generation: GenerationState, | |
| ) -> int: | |
| step_tokens = generation.step_tokens | |
| model_state = generation.decode | |
| branches = step_tokens.shape[0] | |
| device = runtime.device | |
| tokens = plan.aligned_tokens.to(device) | |
| new_word_steps = set(plan.new_word_steps) | |
| positions = torch.empty(1, 1, dtype=torch.long, device=device) | |
| with torch.inference_mode(): | |
| for t in range(plan.aligned_frames): | |
| positions.fill_(t) | |
| channels = tokens.shape[0] | |
| for cb in range(channels): | |
| delay = runtime.audio_delays[cb] if cb < len(runtime.audio_delays) else 0 | |
| idx = t - delay | |
| value = tokens[cb, idx] if idx >= 0 else runtime.constants.audio_bos | |
| step_tokens[:, 2 + cb, 0] = value | |
| hidden, text_logits, cb0_logits, present = runtime.model.transformer.forward_step( | |
| step_tokens, | |
| positions.expand(branches, -1), | |
| model_state.transformer, | |
| ) | |
| model_state.transformer = present | |
| forced = runtime.constants.new_word if t in new_word_steps else runtime.constants.pad | |
| main_token, aux_token, _ = runtime.machine.process(t, state, forced, is_forced=True) | |
| second_token = runtime.constants.pad if aux_token == -1 else aux_token | |
| step_tokens[0, 0, 0] = main_token | |
| step_tokens[0, 1, 0] = second_token | |
| if branches > 1: | |
| step_tokens[1:, 0, 0] = runtime.constants.zero | |
| step_tokens[1:, 1, 0] = runtime.constants.pad | |
| return max(plan.aligned_frames - 1, 0) | |
| __all__ = [ | |
| "build_initial_state", | |
| "run_generation_loop", | |
| "decode_audio", | |
| "warmup_with_prefix", | |
| "GenerationState", | |
| ] | |