Marti Umbert
commited on
Commit
·
1db6194
1
Parent(s):
72d3f66
whisperlivekit/core.py: print args, load model_cascaded_translation from input parameter
Browse files- whisperlivekit/core.py +16 -4
whisperlivekit/core.py
CHANGED
|
@@ -7,6 +7,7 @@ from argparse import Namespace, ArgumentParser
|
|
| 7 |
import ctranslate2
|
| 8 |
import pyonmttok
|
| 9 |
from huggingface_hub import snapshot_download
|
|
|
|
| 10 |
|
| 11 |
def parse_args():
|
| 12 |
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
|
@@ -64,6 +65,13 @@ def parse_args():
|
|
| 64 |
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
| 65 |
)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
parser.add_argument(
|
| 68 |
"--model_cache_dir",
|
| 69 |
type=str,
|
|
@@ -141,7 +149,7 @@ def parse_args():
|
|
| 141 |
args.vad = not args.no_vad
|
| 142 |
delattr(args, 'no_transcription')
|
| 143 |
delattr(args, 'no_vad')
|
| 144 |
-
|
| 145 |
return args
|
| 146 |
|
| 147 |
class WhisperLiveKit:
|
|
@@ -162,6 +170,8 @@ class WhisperLiveKit:
|
|
| 162 |
merged_args = {**default_args, **kwargs}
|
| 163 |
|
| 164 |
self.args = Namespace(**merged_args)
|
|
|
|
|
|
|
| 165 |
|
| 166 |
self.asr = None
|
| 167 |
self.tokenizer = None
|
|
@@ -172,9 +182,11 @@ class WhisperLiveKit:
|
|
| 172 |
warmup_asr(self.asr, self.args.warmup_file)
|
| 173 |
|
| 174 |
# translate from transcription
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
|
| 179 |
if self.args.diarization:
|
| 180 |
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
|
|
|
| 7 |
import ctranslate2
|
| 8 |
import pyonmttok
|
| 9 |
from huggingface_hub import snapshot_download
|
| 10 |
+
from pprint import pprint
|
| 11 |
|
| 12 |
def parse_args():
|
| 13 |
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
|
|
|
| 65 |
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
| 66 |
)
|
| 67 |
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--model_cascaded_translation",
|
| 70 |
+
type=str,
|
| 71 |
+
default=None,
|
| 72 |
+
help="Name of the model for cascaded translation from transcription output. Tested values: projecte-aina/aina-translator-ca-es. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
parser.add_argument(
|
| 76 |
"--model_cache_dir",
|
| 77 |
type=str,
|
|
|
|
| 149 |
args.vad = not args.no_vad
|
| 150 |
delattr(args, 'no_transcription')
|
| 151 |
delattr(args, 'no_vad')
|
| 152 |
+
|
| 153 |
return args
|
| 154 |
|
| 155 |
class WhisperLiveKit:
|
|
|
|
| 170 |
merged_args = {**default_args, **kwargs}
|
| 171 |
|
| 172 |
self.args = Namespace(**merged_args)
|
| 173 |
+
|
| 174 |
+
pprint(vars(self.args))
|
| 175 |
|
| 176 |
self.asr = None
|
| 177 |
self.tokenizer = None
|
|
|
|
| 182 |
warmup_asr(self.asr, self.args.warmup_file)
|
| 183 |
|
| 184 |
# translate from transcription
|
| 185 |
+
if self.args.model_cascaded_translation:
|
| 186 |
+
print(f"Loading translation model: {self.args.model_cascaded_translation}")
|
| 187 |
+
model_dir = snapshot_download(repo_id=self.args.model_cascaded_translation, revision="main")
|
| 188 |
+
self.translation_tokenizer = pyonmttok.Tokenizer(mode="none", sp_model_path=model_dir + "/spm.model")
|
| 189 |
+
self.translator = ctranslate2.Translator(model_dir)
|
| 190 |
|
| 191 |
if self.args.diarization:
|
| 192 |
from whisperlivekit.diarization.diarization_online import DiartDiarization
|