Marti Umbert commited on
Commit
1db6194
·
1 Parent(s): 72d3f66

whisperlivekit/core.py: print args, load model_cascaded_translation from input parameter

Browse files
Files changed (1) hide show
  1. whisperlivekit/core.py +16 -4
whisperlivekit/core.py CHANGED
@@ -7,6 +7,7 @@ from argparse import Namespace, ArgumentParser
7
  import ctranslate2
8
  import pyonmttok
9
  from huggingface_hub import snapshot_download
 
10
 
11
  def parse_args():
12
  parser = ArgumentParser(description="Whisper FastAPI Online Server")
@@ -64,6 +65,13 @@ def parse_args():
64
  help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
65
  )
66
 
 
 
 
 
 
 
 
67
  parser.add_argument(
68
  "--model_cache_dir",
69
  type=str,
@@ -141,7 +149,7 @@ def parse_args():
141
  args.vad = not args.no_vad
142
  delattr(args, 'no_transcription')
143
  delattr(args, 'no_vad')
144
-
145
  return args
146
 
147
  class WhisperLiveKit:
@@ -162,6 +170,8 @@ class WhisperLiveKit:
162
  merged_args = {**default_args, **kwargs}
163
 
164
  self.args = Namespace(**merged_args)
 
 
165
 
166
  self.asr = None
167
  self.tokenizer = None
@@ -172,9 +182,11 @@ class WhisperLiveKit:
172
  warmup_asr(self.asr, self.args.warmup_file)
173
 
174
  # translate from transcription
175
- model_dir = snapshot_download(repo_id="projecte-aina/aina-translator-ca-es", revision="main")
176
- self.translation_tokenizer = pyonmttok.Tokenizer(mode="none", sp_model_path=model_dir + "/spm.model")
177
- self.translator = ctranslate2.Translator(model_dir)
 
 
178
 
179
  if self.args.diarization:
180
  from whisperlivekit.diarization.diarization_online import DiartDiarization
 
7
  import ctranslate2
8
  import pyonmttok
9
  from huggingface_hub import snapshot_download
10
+ from pprint import pprint
11
 
12
  def parse_args():
13
  parser = ArgumentParser(description="Whisper FastAPI Online Server")
 
65
  help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
66
  )
67
 
68
+ parser.add_argument(
69
+ "--model_cascaded_translation",
70
+ type=str,
71
+ default=None,
72
+ help="Name of the model for cascaded translation from transcription output. Tested values: projecte-aina/aina-translator-ca-es. The model is automatically downloaded from the model hub if not present in model cache dir.",
73
+ )
74
+
75
  parser.add_argument(
76
  "--model_cache_dir",
77
  type=str,
 
149
  args.vad = not args.no_vad
150
  delattr(args, 'no_transcription')
151
  delattr(args, 'no_vad')
152
+
153
  return args
154
 
155
  class WhisperLiveKit:
 
170
  merged_args = {**default_args, **kwargs}
171
 
172
  self.args = Namespace(**merged_args)
173
+
174
+ pprint(vars(self.args))
175
 
176
  self.asr = None
177
  self.tokenizer = None
 
182
  warmup_asr(self.asr, self.args.warmup_file)
183
 
184
  # translate from transcription
185
+ if self.args.model_cascaded_translation:
186
+ print(f"Loading translation model: {self.args.model_cascaded_translation}")
187
+ model_dir = snapshot_download(repo_id=self.args.model_cascaded_translation, revision="main")
188
+ self.translation_tokenizer = pyonmttok.Tokenizer(mode="none", sp_model_path=model_dir + "/spm.model")
189
+ self.translator = ctranslate2.Translator(model_dir)
190
 
191
  if self.args.diarization:
192
  from whisperlivekit.diarization.diarization_online import DiartDiarization