from __future__ import annotations from pathlib import Path from typing import Optional, Sequence from .assets import resolve_assets from .runtime.context import RuntimeContext, build_runtime from .runtime.generator import ( build_initial_state, decode_audio, run_generation_loop, warmup_with_prefix, ) from .runtime.script_parser import parse_script from .audio.grid import undelay_frames, write_wav from .runtime.voice_clone import build_prefix_plan from .generation import ( GenerationConfig, GenerationResult, merge_generation_config, normalize_script, ) from .runtime.logger import RuntimeLogger class Dia2: def __init__( self, *, repo: Optional[str] = None, config_path: Optional[str | Path] = None, weights_path: Optional[str | Path] = None, tokenizer_id: Optional[str | Path] = None, mimi_id: Optional[str] = None, device: str = "cuda", dtype: str = "auto", default_config: Optional[GenerationConfig] = None, ) -> None: bundle = resolve_assets( repo=repo, config_path=config_path, weights_path=weights_path, ) self._config_path = bundle.config_path self._weights_path = bundle.weights_path self._tokenizer_id = (str(tokenizer_id) if tokenizer_id else None) or bundle.tokenizer_id self._repo_id = bundle.repo_id self._mimi_id = mimi_id or bundle.mimi_id self.device = device self._dtype_pref = dtype or "auto" self.default_config = default_config or GenerationConfig() self._runtime: Optional[RuntimeContext] = None @classmethod def from_repo( cls, repo: str, *, device: str = "cuda", dtype: str = "auto", tokenizer_id: Optional[str] = None, mimi_id: Optional[str] = None, ) -> "Dia2": return cls(repo=repo, device=device, dtype=dtype, tokenizer_id=tokenizer_id, mimi_id=mimi_id) @classmethod def from_local( cls, config_path: str | Path, weights_path: str | Path, *, device: str = "cuda", dtype: str = "auto", tokenizer_id: Optional[str | Path] = None, mimi_id: Optional[str] = None, ) -> "Dia2": return cls( config_path=config_path, weights_path=weights_path, tokenizer_id=tokenizer_id, device=device, dtype=dtype, mimi_id=mimi_id, ) def set_device(self, device: str, *, dtype: Optional[str] = None) -> None: desired_dtype = dtype or self._dtype_pref if self.device == device and desired_dtype == self._dtype_pref: return self.device = device self._dtype_pref = desired_dtype self._runtime = None def close(self) -> None: self._runtime = None def _ensure_runtime(self) -> RuntimeContext: if self._runtime is None: self._runtime = self._build_runtime() return self._runtime def generate( self, script: str | Sequence[str], *, config: Optional[GenerationConfig] = None, output_wav: Optional[str | Path] = None, prefix_speaker_1: Optional[str] = None, prefix_speaker_2: Optional[str] = None, include_prefix: Optional[bool] = None, verbose: bool = False, **overrides, ): runtime = self._ensure_runtime() logger = RuntimeLogger(verbose) merged_overrides = dict(overrides) if prefix_speaker_1 is not None: merged_overrides["prefix_speaker_1"] = prefix_speaker_1 if prefix_speaker_2 is not None: merged_overrides["prefix_speaker_2"] = prefix_speaker_2 if include_prefix is not None: merged_overrides["include_prefix"] = include_prefix merged = merge_generation_config(base=config or self.default_config, overrides=merged_overrides) max_context = runtime.config.runtime.max_context_steps text = normalize_script(script) prefix_plan = build_prefix_plan(runtime, merged.prefix) entries = [] if prefix_plan is not None: entries.extend(prefix_plan.entries) entries.extend(parse_script([text], runtime.tokenizer, runtime.constants, runtime.frame_rate)) runtime.machine.initial_padding = merged.initial_padding logger.event( f"starting generation: max_context={max_context} cfg_scale={merged.cfg_scale:.2f} " f"device={self.device} dtype={self._dtype_pref}" ) state = runtime.machine.new_state(entries) cfg_active = merged.cfg_scale != 1.0 if cfg_active: logger.event(f"classifier-free guidance enabled (scale={merged.cfg_scale:.2f})") else: logger.event("classifier-free guidance disabled (scale=1.0)") gen_state = build_initial_state( runtime, prefix=prefix_plan, ) include_prefix_audio = bool(prefix_plan and merged.prefix and merged.prefix.include_audio) start_step = 0 if prefix_plan is not None: logger.event(f"warming up with prefix ({prefix_plan.aligned_frames} frames)") start_step = warmup_with_prefix(runtime, prefix_plan, state, gen_state) if include_prefix_audio: logger.event("prefix audio will be kept in output") else: logger.event("prefix audio trimmed from output") first_word_frame, audio_buf = run_generation_loop( runtime, state=state, generation=gen_state, config=merged, start_step=start_step, logger=logger, ) aligned = undelay_frames(audio_buf[0], runtime.audio_delays, runtime.constants.audio_pad).unsqueeze(0) crop = 0 if include_prefix_audio else max(first_word_frame, 0) if crop > 0 and crop < aligned.shape[-1]: aligned = aligned[:, :, crop:] elif crop >= aligned.shape[-1]: crop = 0 logger.event(f"decoding {aligned.shape[-1]} Mimi frames") waveform = decode_audio(runtime, aligned) if output_wav is not None: write_wav(str(output_wav), waveform.detach().cpu().numpy(), runtime.mimi.sample_rate) duration = waveform.shape[-1] / max(runtime.mimi.sample_rate, 1) logger.event(f"saved {output_wav} ({duration:.2f}s)") frame_rate = max(runtime.frame_rate, 1.0) prefix_entry_count = len(prefix_plan.entries) if prefix_plan is not None else 0 transcript_entries = state.transcript if prefix_plan is not None and not include_prefix_audio: if len(transcript_entries) > prefix_entry_count: transcript_entries = transcript_entries[prefix_entry_count:] else: transcript_entries = [] timestamps = [] for word, step in transcript_entries: adj = step - crop if adj < 0: continue timestamps.append((word, adj / frame_rate)) logger.event(f"generation finished in {logger.elapsed():.2f}s") return GenerationResult(aligned, waveform, runtime.mimi.sample_rate, timestamps) def save_wav(self, script: str | Sequence[str], path: str | Path, **kwargs): return self.generate(script, output_wav=path, **kwargs) @property def sample_rate(self) -> int: return self._ensure_runtime().mimi.sample_rate @property def tokenizer_id(self) -> Optional[str]: if self._tokenizer_id: return self._tokenizer_id if self._runtime is not None: return getattr(self._runtime.tokenizer, "name_or_path", None) return self._repo_id @property def dtype(self) -> str: return self._dtype_pref @property def max_context_steps(self) -> int: return self._ensure_runtime().config.runtime.max_context_steps @property def repo(self) -> Optional[str]: return self._repo_id def _build_runtime(self) -> RuntimeContext: runtime, tokenizer_ref, mimi_ref = build_runtime( config_path=self._config_path, weights_path=self._weights_path, tokenizer_id=self._tokenizer_id, repo_id=self._repo_id, mimi_id=self._mimi_id, device=self.device, dtype_pref=self._dtype_pref, ) self._tokenizer_id = tokenizer_ref self._mimi_id = mimi_ref return runtime