Dia2-2B / dia2 /generation.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
from __future__ import annotations
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Mapping, Optional, Sequence, Tuple
import torch
@dataclass(frozen=True)
class SamplingConfig:
temperature: float = 0.8
top_k: int = 50
def _default_text_sampling() -> SamplingConfig:
return SamplingConfig(temperature=0.6, top_k=50)
def _default_audio_sampling() -> SamplingConfig:
return SamplingConfig(temperature=0.8, top_k=50)
@dataclass(frozen=True)
class PrefixConfig:
speaker_1: Optional[str] = None
speaker_2: Optional[str] = None
include_audio: bool = False
@dataclass(frozen=True)
class GenerationConfig:
text: SamplingConfig = field(default_factory=_default_text_sampling)
audio: SamplingConfig = field(default_factory=_default_audio_sampling)
cfg_scale: float = 2.0
cfg_filter_k: int = 50
initial_padding: int = 2
prefix: Optional["PrefixConfig"] = None
use_cuda_graph: bool = False
@dataclass(frozen=True)
class GenerationResult:
audio_tokens: torch.Tensor
waveform: torch.Tensor
sample_rate: int
timestamps: List[Tuple[str, float]]
def normalize_script(script: str | Sequence[str]) -> str:
if isinstance(script, str):
return script.strip()
return "\n".join(line.strip() for line in script)
def load_script_text(path: str | Path) -> str:
if path == "-":
return sys.stdin.read().strip()
path_obj = Path(path)
if path_obj.exists():
return path_obj.read_text().strip()
return str(path).strip()
def validate_generation_params(
*,
temperature: float,
top_k: int,
cfg_scale: float,
) -> tuple[float, int, float]:
if temperature <= 0:
raise ValueError("temperature must be positive")
if top_k <= 0:
raise ValueError("top_k must be positive")
if cfg_scale <= 0:
raise ValueError("cfg_scale must be positive")
return temperature, top_k, cfg_scale
def build_generation_config(
*,
temperature: float,
top_k: int,
cfg_scale: float,
) -> GenerationConfig:
sampling = SamplingConfig(temperature=temperature, top_k=top_k)
return GenerationConfig(
text=sampling,
audio=sampling,
cfg_scale=cfg_scale,
)
def merge_generation_config(
*,
base: GenerationConfig,
overrides: Mapping[str, object],
) -> GenerationConfig:
clean_overrides = {k: v for k, v in overrides.items() if v is not None}
text_temp = clean_overrides.pop("temp_text", None)
text_topk = clean_overrides.pop("topk_text", None)
audio_temp = clean_overrides.pop("temp_audio", None)
audio_topk = clean_overrides.pop("topk_audio", None)
prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None)
prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None)
include_prefix = clean_overrides.pop("include_prefix", None)
text_sampling = base.text
if text_temp is not None or text_topk is not None:
text_sampling = SamplingConfig(
temperature=text_temp if text_temp is not None else text_sampling.temperature,
top_k=text_topk if text_topk is not None else text_sampling.top_k,
)
audio_sampling = base.audio
if audio_temp is not None or audio_topk is not None:
audio_sampling = SamplingConfig(
temperature=audio_temp if audio_temp is not None else audio_sampling.temperature,
top_k=audio_topk if audio_topk is not None else audio_sampling.top_k,
)
prefix_cfg = base.prefix
if (
prefix_speaker_1 is not None
or prefix_speaker_2 is not None
or include_prefix is not None
or prefix_cfg is not None
):
prefix_cfg = prefix_cfg or PrefixConfig()
prefix_cfg = PrefixConfig(
speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1,
speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2,
include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio,
)
return GenerationConfig(
text=text_sampling,
audio=audio_sampling,
cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale),
cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k),
initial_padding=clean_overrides.pop("initial_padding", base.initial_padding),
prefix=prefix_cfg,
use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph),
)
__all__ = [
"SamplingConfig",
"GenerationConfig",
"GenerationResult",
"PrefixConfig",
"normalize_script",
"load_script_text",
"validate_generation_params",
"build_generation_config",
"merge_generation_config",
]