Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import sys | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import List, Mapping, Optional, Sequence, Tuple | |
| import torch | |
| class SamplingConfig: | |
| temperature: float = 0.8 | |
| top_k: int = 50 | |
| def _default_text_sampling() -> SamplingConfig: | |
| return SamplingConfig(temperature=0.6, top_k=50) | |
| def _default_audio_sampling() -> SamplingConfig: | |
| return SamplingConfig(temperature=0.8, top_k=50) | |
| class PrefixConfig: | |
| speaker_1: Optional[str] = None | |
| speaker_2: Optional[str] = None | |
| include_audio: bool = False | |
| class GenerationConfig: | |
| text: SamplingConfig = field(default_factory=_default_text_sampling) | |
| audio: SamplingConfig = field(default_factory=_default_audio_sampling) | |
| cfg_scale: float = 2.0 | |
| cfg_filter_k: int = 50 | |
| initial_padding: int = 2 | |
| prefix: Optional["PrefixConfig"] = None | |
| use_cuda_graph: bool = False | |
| class GenerationResult: | |
| audio_tokens: torch.Tensor | |
| waveform: torch.Tensor | |
| sample_rate: int | |
| timestamps: List[Tuple[str, float]] | |
| def normalize_script(script: str | Sequence[str]) -> str: | |
| if isinstance(script, str): | |
| return script.strip() | |
| return "\n".join(line.strip() for line in script) | |
| def load_script_text(path: str | Path) -> str: | |
| if path == "-": | |
| return sys.stdin.read().strip() | |
| path_obj = Path(path) | |
| if path_obj.exists(): | |
| return path_obj.read_text().strip() | |
| return str(path).strip() | |
| def validate_generation_params( | |
| *, | |
| temperature: float, | |
| top_k: int, | |
| cfg_scale: float, | |
| ) -> tuple[float, int, float]: | |
| if temperature <= 0: | |
| raise ValueError("temperature must be positive") | |
| if top_k <= 0: | |
| raise ValueError("top_k must be positive") | |
| if cfg_scale <= 0: | |
| raise ValueError("cfg_scale must be positive") | |
| return temperature, top_k, cfg_scale | |
| def build_generation_config( | |
| *, | |
| temperature: float, | |
| top_k: int, | |
| cfg_scale: float, | |
| ) -> GenerationConfig: | |
| sampling = SamplingConfig(temperature=temperature, top_k=top_k) | |
| return GenerationConfig( | |
| text=sampling, | |
| audio=sampling, | |
| cfg_scale=cfg_scale, | |
| ) | |
| def merge_generation_config( | |
| *, | |
| base: GenerationConfig, | |
| overrides: Mapping[str, object], | |
| ) -> GenerationConfig: | |
| clean_overrides = {k: v for k, v in overrides.items() if v is not None} | |
| text_temp = clean_overrides.pop("temp_text", None) | |
| text_topk = clean_overrides.pop("topk_text", None) | |
| audio_temp = clean_overrides.pop("temp_audio", None) | |
| audio_topk = clean_overrides.pop("topk_audio", None) | |
| prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None) | |
| prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None) | |
| include_prefix = clean_overrides.pop("include_prefix", None) | |
| text_sampling = base.text | |
| if text_temp is not None or text_topk is not None: | |
| text_sampling = SamplingConfig( | |
| temperature=text_temp if text_temp is not None else text_sampling.temperature, | |
| top_k=text_topk if text_topk is not None else text_sampling.top_k, | |
| ) | |
| audio_sampling = base.audio | |
| if audio_temp is not None or audio_topk is not None: | |
| audio_sampling = SamplingConfig( | |
| temperature=audio_temp if audio_temp is not None else audio_sampling.temperature, | |
| top_k=audio_topk if audio_topk is not None else audio_sampling.top_k, | |
| ) | |
| prefix_cfg = base.prefix | |
| if ( | |
| prefix_speaker_1 is not None | |
| or prefix_speaker_2 is not None | |
| or include_prefix is not None | |
| or prefix_cfg is not None | |
| ): | |
| prefix_cfg = prefix_cfg or PrefixConfig() | |
| prefix_cfg = PrefixConfig( | |
| speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1, | |
| speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2, | |
| include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio, | |
| ) | |
| return GenerationConfig( | |
| text=text_sampling, | |
| audio=audio_sampling, | |
| cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale), | |
| cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k), | |
| initial_padding=clean_overrides.pop("initial_padding", base.initial_padding), | |
| prefix=prefix_cfg, | |
| use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph), | |
| ) | |
| __all__ = [ | |
| "SamplingConfig", | |
| "GenerationConfig", | |
| "GenerationResult", | |
| "PrefixConfig", | |
| "normalize_script", | |
| "load_script_text", | |
| "validate_generation_params", | |
| "build_generation_config", | |
| "merge_generation_config", | |
| ] | |