from __future__ import annotations from dataclasses import dataclass from typing import Callable, List, Optional, Sequence, TYPE_CHECKING import numpy as np import torch from ..generation import PrefixConfig from .audio_io import encode_audio_tokens, load_mono_audio from .state_machine import Entry if TYPE_CHECKING: # pragma: no cover from .context import RuntimeContext @dataclass class WhisperWord: text: str start: float end: float @dataclass class PrefixPlan: entries: List[Entry] new_word_steps: List[int] aligned_tokens: torch.Tensor aligned_frames: int def build_prefix_plan( runtime: "RuntimeContext", prefix: Optional[PrefixConfig], *, transcribe_fn: Optional[Callable[[str, torch.device], List[WhisperWord]]] = None, load_audio_fn: Optional[Callable[[str, int], np.ndarray]] = None, encode_fn: Optional[Callable[[np.ndarray], torch.Tensor]] = None, ) -> Optional[PrefixPlan]: if prefix is None: return None if not prefix.speaker_1: if prefix.speaker_2: raise ValueError("speaker_2 requires speaker_1 to be provided") return None transcribe = transcribe_fn or (lambda path, device: transcribe_words(path, device)) load_audio = load_audio_fn or (lambda path, sr: load_mono_audio(path, sr)) encode_audio = encode_fn or (lambda audio: encode_audio_tokens(runtime.mimi, audio)) entries1, steps1, tokens1 = _process_prefix_audio( runtime=runtime, audio_path=prefix.speaker_1, speaker_token=runtime.constants.spk1, transcribe=transcribe, load_audio=load_audio, encode_audio=encode_audio, ) offset = 3 # Match legacy BOS/PAD offset entries = list(entries1) new_word_steps = [step + offset for step in steps1] audio_tokens = tokens1.to(runtime.device) if prefix.speaker_2: entries2, steps2, tokens2 = _process_prefix_audio( runtime=runtime, audio_path=prefix.speaker_2, speaker_token=runtime.constants.spk2, transcribe=transcribe, load_audio=load_audio, encode_audio=encode_audio, ) spk1_frames = audio_tokens.shape[-1] new_word_steps.extend(step + spk1_frames for step in steps2) entries.extend(entries2) audio_tokens = torch.cat([audio_tokens, tokens2.to(runtime.device)], dim=1) return PrefixPlan( entries=entries, new_word_steps=new_word_steps, aligned_tokens=audio_tokens, aligned_frames=audio_tokens.shape[-1], ) def _process_prefix_audio( runtime: "RuntimeContext", audio_path: str, speaker_token: int, *, transcribe: Callable[[str, torch.device], List[WhisperWord]], load_audio: Callable[[str, int], np.ndarray], encode_audio: Callable[[np.ndarray], torch.Tensor], ) -> tuple[List[Entry], List[int], torch.Tensor]: words = transcribe(audio_path, runtime.device) entries, steps = words_to_entries( words=words, tokenizer=runtime.tokenizer, speaker_token=speaker_token, frame_rate=runtime.frame_rate, ) audio = load_audio(audio_path, runtime.mimi.sample_rate) tokens = encode_audio(audio) return entries, steps, tokens def transcribe_words( audio_path: str, device: torch.device, language: Optional[str] = None, ) -> List[WhisperWord]: import whisper_timestamped as wts # Imported lazily model = wts.load_model("openai/whisper-large-v3", device=str(device)) result = wts.transcribe(model, audio_path, language=language) words: List[WhisperWord] = [] for segment in result.get("segments", []): for word in segment.get("words", []): text = (word.get("text") or word.get("word") or "").strip() if not text: continue words.append( WhisperWord( text=text, start=float(word.get("start", 0.0)), end=float(word.get("end", 0.0)), ) ) return words def words_to_entries( *, words: Sequence[WhisperWord], tokenizer, speaker_token: int, frame_rate: float, ) -> tuple[List[Entry], List[int]]: entries: List[Entry] = [] new_word_steps: List[int] = [] if not words: return entries, new_word_steps convert = getattr(tokenizer, "convert_tokens_to_ids", None) speaker_prefix: Optional[str] = None if callable(convert): s1_id = convert("[S1]") s2_id = convert("[S2]") if speaker_token == s1_id: speaker_prefix = "[S1]" elif speaker_token == s2_id: speaker_prefix = "[S2]" pending_prefix: Optional[str] = speaker_prefix current_pos = 0 for idx, word in enumerate(words): tokens = _encode_word(word.text, tokenizer, pending_prefix) pending_prefix = None start_frame = max(current_pos + 1, int(round(word.start * frame_rate))) end_frame = start_frame + len(tokens) new_word_steps.append(start_frame - 1) if idx < len(words) - 1: next_start = int(round(words[idx + 1].start * frame_rate)) next_word_start = max(end_frame + 1, next_start) else: end_time = int(round(words[-1].end * frame_rate)) next_word_start = max(end_frame + 1, end_time) padding = max(0, next_word_start - start_frame - 1) entries.append(Entry(tokens=tokens, text=word.text, padding=padding)) current_pos = end_frame return entries, new_word_steps def _encode_word(text: str, tokenizer, prefix: Optional[str]) -> List[int]: if prefix: return tokenizer.encode(f"{prefix} {text}", add_special_tokens=False) return tokenizer.encode(text, add_special_tokens=False) __all__ = [ "PrefixPlan", "WhisperWord", "build_prefix_plan", "transcribe_words", "words_to_entries", ]