Dia2-2B / dia2 /runtime /generator.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from ..core.cache import KVCache
from ..core.model import DecodeState
from ..generation import GenerationConfig
from ..audio.grid import delay_frames, mask_audio_logits, undelay_frames
from .context import RuntimeContext
from .state_machine import State, TokenIds
from .guidance import apply_classifier_guidance, sample_audio_logits
from .sampler import sample_token
from .voice_clone import PrefixPlan
from .logger import RuntimeLogger
_GRAPH_CUBLAS_READY = False
def _ensure_graph_cublas_ready(device: torch.device) -> None:
global _GRAPH_CUBLAS_READY
if _GRAPH_CUBLAS_READY or device.type != "cuda":
return
tmp = torch.empty((1, 1), device=device, dtype=torch.float32)
torch.matmul(tmp, tmp)
torch.cuda.synchronize()
_GRAPH_CUBLAS_READY = True
@dataclass
class GenerationState:
decode: DecodeState
step_tokens: torch.Tensor
audio_buf: torch.Tensor
def trim_audio(self, limit: int, pad_token: int, ungenerated: int) -> torch.Tensor:
trimmed = self.audio_buf[:, :, :limit]
pad = torch.full_like(trimmed, pad_token)
trimmed = torch.where(trimmed == ungenerated, pad, trimmed)
self.audio_buf = trimmed
return trimmed
@property
def transformer_cache(self) -> KVCache:
return self.decode.transformer
@transformer_cache.setter
def transformer_cache(self, cache: KVCache) -> None:
self.decode.transformer = cache
@property
def depformer_cache(self) -> KVCache:
return self.decode.depformer
@depformer_cache.setter
def depformer_cache(self, cache: KVCache) -> None:
self.decode.depformer = cache
def reset_dep_cache(self) -> None:
self.decode.depformer.reset()
@dataclass
class NetworkBuffers:
text: torch.Tensor
cb0: torch.Tensor
dep: list[torch.Tensor]
def _allocate_network_buffers(runtime: RuntimeContext, branches: int) -> NetworkBuffers:
device = runtime.device
logits_dtype = runtime.precision.logits
data_cfg = runtime.config.data
text_logits = torch.empty((branches, 1, data_cfg.action_vocab_size), dtype=logits_dtype, device=device)
cb0_logits = torch.empty((branches, 1, data_cfg.audio_vocab_size), dtype=logits_dtype, device=device)
dep_vocab = runtime.model.depformer.audio_vocab_limit or data_cfg.audio_vocab_size
dep_logits = [
torch.empty((branches, 1, 1, dep_vocab), dtype=logits_dtype, device=device)
for _ in range(runtime.model.depformer.num_depth)
]
return NetworkBuffers(text=text_logits, cb0=cb0_logits, dep=dep_logits)
def build_initial_state(
runtime: RuntimeContext,
*,
prefix: PrefixPlan | None = None,
) -> GenerationState:
dep_q = runtime.model.depformer.num_audio_channels
channels = 2 + dep_q
branches = 2
token_ids = runtime.constants
step_tokens = torch.full(
(branches, channels, 1),
token_ids.pad,
dtype=torch.long,
device=runtime.device,
)
step_tokens[0, 0, 0] = token_ids.bos
step_tokens[0, 1, 0] = token_ids.pad
step_tokens[1, 0, 0] = token_ids.zero
step_tokens[1, 1, 0] = token_ids.pad
prefix_len = 0
if prefix is not None:
delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad)
prefix_len = delayed.shape[1]
limit = runtime.config.runtime.max_context_steps
total_steps = max(limit + prefix_len + 1, limit)
decode_state = runtime.model.init_state(branches, runtime.device, total_steps)
audio_buf = torch.full(
(branches, dep_q, total_steps),
token_ids.ungenerated,
dtype=torch.long,
device=runtime.device,
)
if prefix is not None:
delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad).to(runtime.device)
audio_buf[0, :, : delayed.shape[1]] = delayed
if branches > 1:
audio_buf[1:, :, : delayed.shape[1]] = delayed
return GenerationState(decode_state, step_tokens, audio_buf)
def _fill_audio_channels(
step_tokens: torch.Tensor,
audio_buf: torch.Tensor,
delays: torch.Tensor,
step: int,
bos_token: int,
) -> None:
channels = delays.numel()
if channels == 0:
return
target = step_tokens[:, 2 : 2 + channels, 0]
if step < audio_buf.shape[-1]:
target.copy_(audio_buf[:, :channels, step])
else:
target.fill_(bos_token)
mask = delays > step
if mask.any().item():
target[:, mask] = bos_token
def _execute_transformer_step(
step_tokens: torch.Tensor,
positions_view: torch.Tensor,
generation: GenerationState,
transformer_step,
buffers: NetworkBuffers,
) -> torch.Tensor:
hidden_t, text_logits_t, cb0_logits_t, present = transformer_step(
step_tokens,
positions_view,
generation.transformer_cache,
)
buffers.text.copy_(text_logits_t)
buffers.cb0.copy_(cb0_logits_t)
generation.transformer_cache = present
return hidden_t
def _execute_depformer_stage(
stage_index: int,
prev_audio: torch.Tensor,
hidden_t: torch.Tensor,
generation: GenerationState,
depformer_step,
main_tokens: Optional[torch.Tensor],
second_tokens: Optional[torch.Tensor],
buffers: NetworkBuffers,
) -> None:
logits_stage, dep_present = depformer_step(
prev_audio=prev_audio,
transformer_out=hidden_t,
stage_index=stage_index,
cache=generation.depformer_cache,
main_text=main_tokens if stage_index == 0 else None,
second_text=second_tokens if stage_index == 0 else None,
)
target = buffers.dep[stage_index]
if logits_stage.shape != target.shape:
raise RuntimeError(
f"depformer logits shape mismatch: {logits_stage.shape} vs {target.shape}"
)
target.copy_(logits_stage)
generation.depformer_cache = dep_present
def run_generation_loop(
runtime: RuntimeContext,
*,
state: State,
generation: GenerationState,
config: GenerationConfig,
start_step: int = 0,
logger: RuntimeLogger | None = None,
) -> tuple[Optional[int], torch.Tensor]:
step_tokens = generation.step_tokens
audio_buf = generation.audio_buf
branches = step_tokens.shape[0]
max_context = runtime.config.runtime.max_context_steps
if max_context <= 0:
raise ValueError("Runtime configuration must specify a positive max_context_steps")
positions = torch.empty(1, 1, dtype=torch.long, device=runtime.device)
main_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
aux_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
cfg_active = config.cfg_scale != 1.0
token_ids = runtime.constants
delay_tensor = runtime.audio_delay_tensor
max_delay = int(delay_tensor.max().item()) if delay_tensor.numel() else 0
flush_tail = max_delay + getattr(runtime.machine, "max_padding", 0)
first_word_frame: Optional[int] = None
eos_cutoff: Optional[int] = None
last_step = start_step - 1
use_graph = bool(config.use_cuda_graph and runtime.device.type == "cuda")
transformer_step = runtime.transformer_step
depformer_step = runtime.depformer_step
buffers = _allocate_network_buffers(runtime, branches)
positions_view = positions.expand(branches, -1)
transformer_capture = None
dep_captures: list[dict] | None = None
if use_graph:
_ensure_graph_cublas_ready(runtime.device)
processed_steps = 0
report_interval = 12
with torch.inference_mode():
for offset in range(max_context):
t = start_step + offset
if eos_cutoff is not None and t >= eos_cutoff:
break
if t + 1 >= audio_buf.shape[-1]:
break
generation.reset_dep_cache()
positions.fill_(t)
_fill_audio_channels(step_tokens, audio_buf, delay_tensor, t, token_ids.audio_bos)
if branches > 1:
step_tokens[1:, 0, 0] = token_ids.zero
step_tokens[1:, 1, 0] = token_ids.pad
if use_graph:
if transformer_capture is None:
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
hidden_ref = _execute_transformer_step(
step_tokens,
positions_view,
generation,
transformer_step,
buffers,
)
transformer_capture = (graph, hidden_ref)
if runtime.model.depformer.num_depth > 0:
dep_captures = []
for idx in range(runtime.model.depformer.num_depth):
capture = {
"graph": torch.cuda.CUDAGraph(),
"captured": False,
"prev_audio": torch.empty((branches,), dtype=torch.long, device=runtime.device),
"main_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
"second_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
}
dep_captures.append(capture)
else:
transformer_capture[0].replay()
hidden_t = transformer_capture[1]
else:
hidden_t = _execute_transformer_step(
step_tokens,
positions_view,
generation,
transformer_step,
buffers,
)
guided_text = apply_classifier_guidance(buffers.text, cfg_active, config.cfg_scale, config.cfg_filter_k)
if guided_text.shape[0] > 1:
guided_text = guided_text[:1]
text_token = sample_token(
guided_text,
temp=config.text.temperature,
top_k=config.text.top_k,
).item()
main_token, aux_token, _ = runtime.machine.process(t, state, text_token)
second_token = aux_token if aux_token != -1 else token_ids.pad
if first_word_frame is None and main_token == token_ids.new_word:
first_word_frame = t - config.initial_padding
step_tokens[:, 0, 0] = main_token
step_tokens[:, 1, 0] = second_token
guided_cb0 = apply_classifier_guidance(buffers.cb0, cfg_active, config.cfg_scale, config.cfg_filter_k)
if guided_cb0.shape[0] > 1:
guided_cb0 = guided_cb0[:1]
masked_cb0 = mask_audio_logits(guided_cb0, token_ids.audio_pad, token_ids.audio_bos)
codebook_token = sample_audio_logits(masked_cb0, config.audio.temperature, config.audio.top_k)
audio_buf[:, 0, t + 1] = codebook_token
prev_audio = codebook_token.expand(branches)
main_tokens.fill_(main_token)
aux_tokens.fill_(second_token)
for stage in range(runtime.model.depformer.num_depth):
if use_graph and dep_captures is not None:
capture = dep_captures[stage]
capture["prev_audio"].copy_(prev_audio)
if capture["main_tokens"] is not None and stage == 0:
capture["main_tokens"].copy_(main_tokens)
capture["second_tokens"].copy_(aux_tokens)
if not capture["captured"]:
torch.cuda.synchronize()
with torch.cuda.graph(capture["graph"]):
_execute_depformer_stage(
stage_index=stage,
prev_audio=capture["prev_audio"],
hidden_t=hidden_t,
generation=generation,
depformer_step=depformer_step,
main_tokens=capture["main_tokens"],
second_tokens=capture["second_tokens"],
buffers=buffers,
)
capture["captured"] = True
else:
capture["graph"].replay()
else:
_execute_depformer_stage(
stage_index=stage,
prev_audio=prev_audio,
hidden_t=hidden_t,
generation=generation,
depformer_step=depformer_step,
main_tokens=main_tokens,
second_tokens=aux_tokens,
buffers=buffers,
)
dep_logits = apply_classifier_guidance(buffers.dep[stage], cfg_active, config.cfg_scale, config.cfg_filter_k)
if dep_logits.shape[0] > 1:
dep_logits = dep_logits[:1]
stage_token = sample_audio_logits(
dep_logits,
config.audio.temperature,
config.audio.top_k,
)
audio_buf[:, stage + 1, t + 1] = stage_token
prev_audio = stage_token.expand(branches)
last_step = t
if eos_cutoff is None and state.end_step is not None:
eos_cutoff = state.end_step + flush_tail
processed_steps = offset + 1
if logger and processed_steps % report_interval == 0:
logger.progress(processed_steps, max_context)
if logger and processed_steps and processed_steps % report_interval != 0:
logger.progress(processed_steps, max_context)
if first_word_frame is None:
first_word_frame = start_step
if last_step < start_step:
limit = min(start_step + 1, audio_buf.shape[-1])
else:
limit = min(last_step + 2, audio_buf.shape[-1])
trimmed = generation.trim_audio(limit, token_ids.audio_pad, token_ids.ungenerated)
return first_word_frame, trimmed
def decode_audio(runtime: RuntimeContext, tokens: torch.Tensor) -> torch.Tensor:
if tokens.shape[-1] == 0:
return torch.zeros(0, device=runtime.device)
with torch.inference_mode():
pcm = runtime.mimi.decode(tokens.to(runtime.device))
return pcm[0, 0]
def warmup_with_prefix(
runtime: RuntimeContext,
plan: PrefixPlan,
state: State,
generation: GenerationState,
) -> int:
step_tokens = generation.step_tokens
model_state = generation.decode
branches = step_tokens.shape[0]
device = runtime.device
tokens = plan.aligned_tokens.to(device)
new_word_steps = set(plan.new_word_steps)
positions = torch.empty(1, 1, dtype=torch.long, device=device)
with torch.inference_mode():
for t in range(plan.aligned_frames):
positions.fill_(t)
channels = tokens.shape[0]
for cb in range(channels):
delay = runtime.audio_delays[cb] if cb < len(runtime.audio_delays) else 0
idx = t - delay
value = tokens[cb, idx] if idx >= 0 else runtime.constants.audio_bos
step_tokens[:, 2 + cb, 0] = value
hidden, text_logits, cb0_logits, present = runtime.model.transformer.forward_step(
step_tokens,
positions.expand(branches, -1),
model_state.transformer,
)
model_state.transformer = present
forced = runtime.constants.new_word if t in new_word_steps else runtime.constants.pad
main_token, aux_token, _ = runtime.machine.process(t, state, forced, is_forced=True)
second_token = runtime.constants.pad if aux_token == -1 else aux_token
step_tokens[0, 0, 0] = main_token
step_tokens[0, 1, 0] = second_token
if branches > 1:
step_tokens[1:, 0, 0] = runtime.constants.zero
step_tokens[1:, 1, 0] = runtime.constants.pad
return max(plan.aligned_frames - 1, 0)
__all__ = [
"build_initial_state",
"run_generation_loop",
"decode_audio",
"warmup_with_prefix",
"GenerationState",
]