Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,974 Bytes
1315cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, List, Optional, Sequence, TYPE_CHECKING
import numpy as np
import torch
from ..generation import PrefixConfig
from .audio_io import encode_audio_tokens, load_mono_audio
from .state_machine import Entry
if TYPE_CHECKING: # pragma: no cover
from .context import RuntimeContext
@dataclass
class WhisperWord:
text: str
start: float
end: float
@dataclass
class PrefixPlan:
entries: List[Entry]
new_word_steps: List[int]
aligned_tokens: torch.Tensor
aligned_frames: int
def build_prefix_plan(
runtime: "RuntimeContext",
prefix: Optional[PrefixConfig],
*,
transcribe_fn: Optional[Callable[[str, torch.device], List[WhisperWord]]] = None,
load_audio_fn: Optional[Callable[[str, int], np.ndarray]] = None,
encode_fn: Optional[Callable[[np.ndarray], torch.Tensor]] = None,
) -> Optional[PrefixPlan]:
if prefix is None:
return None
if not prefix.speaker_1:
if prefix.speaker_2:
raise ValueError("speaker_2 requires speaker_1 to be provided")
return None
transcribe = transcribe_fn or (lambda path, device: transcribe_words(path, device))
load_audio = load_audio_fn or (lambda path, sr: load_mono_audio(path, sr))
encode_audio = encode_fn or (lambda audio: encode_audio_tokens(runtime.mimi, audio))
entries1, steps1, tokens1 = _process_prefix_audio(
runtime=runtime,
audio_path=prefix.speaker_1,
speaker_token=runtime.constants.spk1,
transcribe=transcribe,
load_audio=load_audio,
encode_audio=encode_audio,
)
offset = 3 # Match legacy BOS/PAD offset
entries = list(entries1)
new_word_steps = [step + offset for step in steps1]
audio_tokens = tokens1.to(runtime.device)
if prefix.speaker_2:
entries2, steps2, tokens2 = _process_prefix_audio(
runtime=runtime,
audio_path=prefix.speaker_2,
speaker_token=runtime.constants.spk2,
transcribe=transcribe,
load_audio=load_audio,
encode_audio=encode_audio,
)
spk1_frames = audio_tokens.shape[-1]
new_word_steps.extend(step + spk1_frames for step in steps2)
entries.extend(entries2)
audio_tokens = torch.cat([audio_tokens, tokens2.to(runtime.device)], dim=1)
return PrefixPlan(
entries=entries,
new_word_steps=new_word_steps,
aligned_tokens=audio_tokens,
aligned_frames=audio_tokens.shape[-1],
)
def _process_prefix_audio(
runtime: "RuntimeContext",
audio_path: str,
speaker_token: int,
*,
transcribe: Callable[[str, torch.device], List[WhisperWord]],
load_audio: Callable[[str, int], np.ndarray],
encode_audio: Callable[[np.ndarray], torch.Tensor],
) -> tuple[List[Entry], List[int], torch.Tensor]:
words = transcribe(audio_path, runtime.device)
entries, steps = words_to_entries(
words=words,
tokenizer=runtime.tokenizer,
speaker_token=speaker_token,
frame_rate=runtime.frame_rate,
)
audio = load_audio(audio_path, runtime.mimi.sample_rate)
tokens = encode_audio(audio)
return entries, steps, tokens
def transcribe_words(
audio_path: str,
device: torch.device,
language: Optional[str] = None,
) -> List[WhisperWord]:
import whisper_timestamped as wts # Imported lazily
model = wts.load_model("openai/whisper-large-v3", device=str(device))
result = wts.transcribe(model, audio_path, language=language)
words: List[WhisperWord] = []
for segment in result.get("segments", []):
for word in segment.get("words", []):
text = (word.get("text") or word.get("word") or "").strip()
if not text:
continue
words.append(
WhisperWord(
text=text,
start=float(word.get("start", 0.0)),
end=float(word.get("end", 0.0)),
)
)
return words
def words_to_entries(
*,
words: Sequence[WhisperWord],
tokenizer,
speaker_token: int,
frame_rate: float,
) -> tuple[List[Entry], List[int]]:
entries: List[Entry] = []
new_word_steps: List[int] = []
if not words:
return entries, new_word_steps
convert = getattr(tokenizer, "convert_tokens_to_ids", None)
speaker_prefix: Optional[str] = None
if callable(convert):
s1_id = convert("[S1]")
s2_id = convert("[S2]")
if speaker_token == s1_id:
speaker_prefix = "[S1]"
elif speaker_token == s2_id:
speaker_prefix = "[S2]"
pending_prefix: Optional[str] = speaker_prefix
current_pos = 0
for idx, word in enumerate(words):
tokens = _encode_word(word.text, tokenizer, pending_prefix)
pending_prefix = None
start_frame = max(current_pos + 1, int(round(word.start * frame_rate)))
end_frame = start_frame + len(tokens)
new_word_steps.append(start_frame - 1)
if idx < len(words) - 1:
next_start = int(round(words[idx + 1].start * frame_rate))
next_word_start = max(end_frame + 1, next_start)
else:
end_time = int(round(words[-1].end * frame_rate))
next_word_start = max(end_frame + 1, end_time)
padding = max(0, next_word_start - start_frame - 1)
entries.append(Entry(tokens=tokens, text=word.text, padding=padding))
current_pos = end_frame
return entries, new_word_steps
def _encode_word(text: str, tokenizer, prefix: Optional[str]) -> List[int]:
if prefix:
return tokenizer.encode(f"{prefix} {text}", add_special_tokens=False)
return tokenizer.encode(text, add_special_tokens=False)
__all__ = [
"PrefixPlan",
"WhisperWord",
"build_prefix_plan",
"transcribe_words",
"words_to_entries",
]
|