Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Callable, List, Optional, Sequence, TYPE_CHECKING | |
| import numpy as np | |
| import torch | |
| from ..generation import PrefixConfig | |
| from .audio_io import encode_audio_tokens, load_mono_audio | |
| from .state_machine import Entry | |
| if TYPE_CHECKING: # pragma: no cover | |
| from .context import RuntimeContext | |
| class WhisperWord: | |
| text: str | |
| start: float | |
| end: float | |
| class PrefixPlan: | |
| entries: List[Entry] | |
| new_word_steps: List[int] | |
| aligned_tokens: torch.Tensor | |
| aligned_frames: int | |
| def build_prefix_plan( | |
| runtime: "RuntimeContext", | |
| prefix: Optional[PrefixConfig], | |
| *, | |
| transcribe_fn: Optional[Callable[[str, torch.device], List[WhisperWord]]] = None, | |
| load_audio_fn: Optional[Callable[[str, int], np.ndarray]] = None, | |
| encode_fn: Optional[Callable[[np.ndarray], torch.Tensor]] = None, | |
| ) -> Optional[PrefixPlan]: | |
| if prefix is None: | |
| return None | |
| if not prefix.speaker_1: | |
| if prefix.speaker_2: | |
| raise ValueError("speaker_2 requires speaker_1 to be provided") | |
| return None | |
| transcribe = transcribe_fn or (lambda path, device: transcribe_words(path, device)) | |
| load_audio = load_audio_fn or (lambda path, sr: load_mono_audio(path, sr)) | |
| encode_audio = encode_fn or (lambda audio: encode_audio_tokens(runtime.mimi, audio)) | |
| entries1, steps1, tokens1 = _process_prefix_audio( | |
| runtime=runtime, | |
| audio_path=prefix.speaker_1, | |
| speaker_token=runtime.constants.spk1, | |
| transcribe=transcribe, | |
| load_audio=load_audio, | |
| encode_audio=encode_audio, | |
| ) | |
| offset = 3 # Match legacy BOS/PAD offset | |
| entries = list(entries1) | |
| new_word_steps = [step + offset for step in steps1] | |
| audio_tokens = tokens1.to(runtime.device) | |
| if prefix.speaker_2: | |
| entries2, steps2, tokens2 = _process_prefix_audio( | |
| runtime=runtime, | |
| audio_path=prefix.speaker_2, | |
| speaker_token=runtime.constants.spk2, | |
| transcribe=transcribe, | |
| load_audio=load_audio, | |
| encode_audio=encode_audio, | |
| ) | |
| spk1_frames = audio_tokens.shape[-1] | |
| new_word_steps.extend(step + spk1_frames for step in steps2) | |
| entries.extend(entries2) | |
| audio_tokens = torch.cat([audio_tokens, tokens2.to(runtime.device)], dim=1) | |
| return PrefixPlan( | |
| entries=entries, | |
| new_word_steps=new_word_steps, | |
| aligned_tokens=audio_tokens, | |
| aligned_frames=audio_tokens.shape[-1], | |
| ) | |
| def _process_prefix_audio( | |
| runtime: "RuntimeContext", | |
| audio_path: str, | |
| speaker_token: int, | |
| *, | |
| transcribe: Callable[[str, torch.device], List[WhisperWord]], | |
| load_audio: Callable[[str, int], np.ndarray], | |
| encode_audio: Callable[[np.ndarray], torch.Tensor], | |
| ) -> tuple[List[Entry], List[int], torch.Tensor]: | |
| words = transcribe(audio_path, runtime.device) | |
| entries, steps = words_to_entries( | |
| words=words, | |
| tokenizer=runtime.tokenizer, | |
| speaker_token=speaker_token, | |
| frame_rate=runtime.frame_rate, | |
| ) | |
| audio = load_audio(audio_path, runtime.mimi.sample_rate) | |
| tokens = encode_audio(audio) | |
| return entries, steps, tokens | |
| def transcribe_words( | |
| audio_path: str, | |
| device: torch.device, | |
| language: Optional[str] = None, | |
| ) -> List[WhisperWord]: | |
| import whisper_timestamped as wts # Imported lazily | |
| model = wts.load_model("openai/whisper-large-v3", device=str(device)) | |
| result = wts.transcribe(model, audio_path, language=language) | |
| words: List[WhisperWord] = [] | |
| for segment in result.get("segments", []): | |
| for word in segment.get("words", []): | |
| text = (word.get("text") or word.get("word") or "").strip() | |
| if not text: | |
| continue | |
| words.append( | |
| WhisperWord( | |
| text=text, | |
| start=float(word.get("start", 0.0)), | |
| end=float(word.get("end", 0.0)), | |
| ) | |
| ) | |
| return words | |
| def words_to_entries( | |
| *, | |
| words: Sequence[WhisperWord], | |
| tokenizer, | |
| speaker_token: int, | |
| frame_rate: float, | |
| ) -> tuple[List[Entry], List[int]]: | |
| entries: List[Entry] = [] | |
| new_word_steps: List[int] = [] | |
| if not words: | |
| return entries, new_word_steps | |
| convert = getattr(tokenizer, "convert_tokens_to_ids", None) | |
| speaker_prefix: Optional[str] = None | |
| if callable(convert): | |
| s1_id = convert("[S1]") | |
| s2_id = convert("[S2]") | |
| if speaker_token == s1_id: | |
| speaker_prefix = "[S1]" | |
| elif speaker_token == s2_id: | |
| speaker_prefix = "[S2]" | |
| pending_prefix: Optional[str] = speaker_prefix | |
| current_pos = 0 | |
| for idx, word in enumerate(words): | |
| tokens = _encode_word(word.text, tokenizer, pending_prefix) | |
| pending_prefix = None | |
| start_frame = max(current_pos + 1, int(round(word.start * frame_rate))) | |
| end_frame = start_frame + len(tokens) | |
| new_word_steps.append(start_frame - 1) | |
| if idx < len(words) - 1: | |
| next_start = int(round(words[idx + 1].start * frame_rate)) | |
| next_word_start = max(end_frame + 1, next_start) | |
| else: | |
| end_time = int(round(words[-1].end * frame_rate)) | |
| next_word_start = max(end_frame + 1, end_time) | |
| padding = max(0, next_word_start - start_frame - 1) | |
| entries.append(Entry(tokens=tokens, text=word.text, padding=padding)) | |
| current_pos = end_frame | |
| return entries, new_word_steps | |
| def _encode_word(text: str, tokenizer, prefix: Optional[str]) -> List[int]: | |
| if prefix: | |
| return tokenizer.encode(f"{prefix} {text}", add_special_tokens=False) | |
| return tokenizer.encode(text, add_special_tokens=False) | |
| __all__ = [ | |
| "PrefixPlan", | |
| "WhisperWord", | |
| "build_prefix_plan", | |
| "transcribe_words", | |
| "words_to_entries", | |
| ] | |