from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Optional import warnings import torch from safetensors.torch import load_file from transformers import AutoTokenizer, PreTrainedTokenizerBase from ..config import DiaConfig, load_config from ..core.model import Dia2Model from ..core.precision import Precision, resolve_precision from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID from .state_machine import StateMachine, TokenIds @dataclass class RuntimeContext: config: DiaConfig model: Dia2Model precision: Precision tokenizer: PreTrainedTokenizerBase mimi: MimiCodec device: torch.device machine: StateMachine transformer_step: callable depformer_step: callable constants: TokenIds audio_delays: list[int] audio_delay_tensor: torch.Tensor frame_rate: float def build_runtime( *, config_path: str | Path, weights_path: str | Path, tokenizer_id: Optional[str], repo_id: Optional[str], mimi_id: Optional[str], device: str, dtype_pref: str, ) -> tuple[RuntimeContext, str, str]: device_obj = torch.device(device) if device_obj.type == "cuda": cuda_matmul = torch.backends.cuda.matmul if hasattr(cuda_matmul, "fp32_precision"): cuda_matmul.fp32_precision = "tf32" with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="Please use the new API settings", ) torch.backends.cuda.matmul.allow_tf32 = True else: # pragma: no cover - compatibility with older PyTorch torch.backends.cuda.matmul.allow_tf32 = True # Handle cuDNN conv TF32 settings (check if conv attribute exists first) if hasattr(torch.backends.cudnn, "conv"): cudnn_conv = torch.backends.cudnn.conv if hasattr(cudnn_conv, "fp32_precision"): cudnn_conv.fp32_precision = "tf32" with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="Please use the new API settings", ) torch.backends.cudnn.allow_tf32 = True else: torch.backends.cudnn.allow_tf32 = True else: # For older PyTorch versions without the conv attribute torch.backends.cudnn.allow_tf32 = True precision = resolve_precision(dtype_pref, device_obj) config = load_config(config_path) model = Dia2Model(config, precision) state = load_file(str(weights_path)) model.load_state_dict(state) model = model.to(device_obj) tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id if tokenizer_ref is None: raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.") tokenizer = AutoTokenizer.from_pretrained( tokenizer_ref, use_fast=False, trust_remote_code=True, ) mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj) data_cfg = config.data constants = TokenIds( card=data_cfg.text_vocab_size, new_word=data_cfg.text_new_word_token_id, pad=data_cfg.text_pad_token_id, bos=getattr(tokenizer, "bos_token_id", 1) or 1, zero=data_cfg.text_zero_token_id, spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id, spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id, audio_pad=data_cfg.audio_pad_token_id, audio_bos=data_cfg.audio_bos_token_id, ) machine = StateMachine( token_ids=constants, second_stream_ahead=data_cfg.second_stream_ahead, max_padding=6, initial_padding=0, ) audio_delays = list(data_cfg.delay_pattern) audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj) frame_rate = getattr(mimi, "frame_rate", 75.0) runtime = RuntimeContext( config=config, precision=precision, model=model, tokenizer=tokenizer, mimi=mimi, device=device_obj, machine=machine, constants=constants, audio_delays=audio_delays, audio_delay_tensor=audio_delay_tensor, frame_rate=frame_rate, transformer_step=model.transformer.forward_step, depformer_step=model.depformer.forward_step, ) return runtime, tokenizer_ref, mimi_ref __all__ = [ "RuntimeContext", "build_runtime", ]