from dataclasses import dataclass import time import sentencepiece import torch import textwrap import sounddevice as sd import torchaudio from scipy.io import wavfile as wav from moshi.models import loaders, MimiModel, LMModel, LMGen # -------------------------- # Variables audio FILENAME = "my_audio.wav" DURATION = 5 # durée de l'enregistrement en secondes TARGET_SR = 24000 # fréquence d'échantillonnage compatible Moshi # -------------------------- # Fonctions pour enregistrer et charger l'audio def record_audio(filename=FILENAME, duration=DURATION, samplerate=TARGET_SR): print(f"\n🎙️ Enregistrement audio de {duration}s...") recording = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='int16') sd.wait() wav.write(filename, samplerate, recording) print("✅ Audio enregistré.") def load_audio(filename, target_sr=TARGET_SR): waveform, sr = torchaudio.load(filename) if sr != target_sr: waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) waveform = waveform.squeeze().float() / 32768.0 # int16 -> float32 [-1,1] return waveform, target_sr # -------------------------- # Classe pour l'inférence Moshi @dataclass class InferenceState: mimi: MimiModel text_tokenizer: sentencepiece.SentencePieceProcessor lm_gen: LMGen def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, lm: LMModel, batch_size: int, device: str | torch.device): self.mimi = mimi self.text_tokenizer = text_tokenizer self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False) self.device = device self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) self.batch_size = batch_size self.mimi.streaming_forever(batch_size) self.lm_gen.streaming_forever(batch_size) def run(self, in_pcms: torch.Tensor): ntokens = 0 first_frame = True chunks = [c for c in in_pcms.split(self.frame_size, dim=2) if c.shape[-1] == self.frame_size] start_time = time.time() all_text = [] for chunk in chunks: codes = self.mimi.encode(chunk) if first_frame: tokens = self.lm_gen.step(codes) first_frame = False tokens = self.lm_gen.step(codes) if tokens is None: continue assert tokens.shape[1] == 1 one_text = tokens[0, 0].cpu() if one_text.item() not in [0, 3]: text = self.text_tokenizer.id_to_piece(one_text.item()) text = text.replace("▁", " ") all_text.append(text) ntokens += 1 dt = time.time() - start_time print(f"processed {ntokens} steps in {dt:.0f}s, {1000*dt/ntokens:.2f}ms/step") return "".join(all_text) # -------------------------- # Enregistrement de l'audio record_audio(FILENAME, DURATION, TARGET_SR) # -------------------------- # Chargement du modèle Moshi device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint_info = loaders.CheckpointInfo.from_hf_repo("kyutai/stt-1b-en_fr") mimi = checkpoint_info.get_mimi(device=device) text_tokenizer = checkpoint_info.get_text_tokenizer() lm = checkpoint_info.get_moshi(device=device) # -------------------------- # Chargement de l'audio enregistré waveform, sr = torchaudio.load(FILENAME) if waveform.dtype != torch.float32: waveform = waveform.float() / 32768.0 # seulement si int16 waveform = waveform.squeeze() # shape [N] in_pcms = waveform[None, None, :].float().to(device) stt_config = checkpoint_info.stt_config pad_left = int(stt_config.get("audio_silence_prefix_seconds", 0.0) * mimi.sample_rate) pad_right = int((stt_config.get("audio_delay_seconds", 0.0) + 1.0) * mimi.sample_rate) in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode="constant") # -------------------------- # Transcription state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device) text = state.run(in_pcms) print("\n📄 Transcription :\n") print(text) # -------------------------- # Lecture de l'audio enregistré from IPython.display import Audio Audio(FILENAME)