import torch from src.Modules import commons from src import utils from src.Voice_Synthesizer import SynthesizerTrn from src.Text.Symbols import symbols from src.Text import text_to_sequence from scipy.io.wavfile import write import logging import os import onnxruntime import numpy as np # Desactiva los molestos logs de matplotlib si lo usas en otro lado logging.getLogger('matplotlib').setLevel(logging.WARNING) class TTS: """ Clase unificada para Texto a Voz (TTS) que soporta tanto modelos PyTorch (.pth) como ONNX (.onnx). """ def __init__(self, config_path, model_path, device="cuda"): """ Inicializa el motor TTS. Detecta automáticamente el tipo de modelo. Args: config_path (str): Ruta al archivo de configuración JSON. model_path (str): Ruta al archivo del modelo (.pth o .onnx). device (str): Dispositivo a usar ("cuda" o "cpu"). """ self.device = device self.hps = utils.get_hparams_from_file(config_path) self.model_path = model_path self.model_type = "onnx" if model_path.endswith(".onnx") else "pytorch" self.net_g = None self.onnx_session = None if self.model_type == "pytorch": self._init_pytorch_model() else: self._init_onnx_model() print(f"Motor TTS inicializado en modo: {self.model_type.upper()}") def _init_pytorch_model(self): """Inicializa el modelo usando PyTorch.""" if ( "use_mel_posterior_encoder" in self.hps.model and self.hps.model.use_mel_posterior_encoder ): posterior_channels = 80 else: posterior_channels = self.hps.data.filter_length // 2 + 1 self.net_g = SynthesizerTrn( len(symbols), posterior_channels, self.hps.train.segment_size // self.hps.data.hop_length, **self.hps.model, ).to(self.device) _ = self.net_g.eval() _ = utils.load_checkpoint(self.model_path, self.net_g, None) def _init_onnx_model(self): """Inicializa el motor de inferencia usando ONNX Runtime.""" providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if self.device == "cuda" else ["CPUExecutionProvider"] self.onnx_session = onnxruntime.InferenceSession( self.model_path, providers=providers ) def _get_text(self, text): """Convierte texto plano a una secuencia de IDs de fonemas.""" text_norm = text_to_sequence(text, self.hps.data.text_cleaners) if self.hps.data.add_blank: text_norm = commons.intersperse(text_norm, 0) return np.array(text_norm, dtype=np.int64) def text_to_speech(self, text, output_path="sample.wav", noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0, sid=None): """Sintetiza audio a partir de un texto y lo guarda en un archivo WAV.""" phoneme_ids = self._get_text(text) if self.model_type == "pytorch": # Inferencia con PyTorch stn_tst = torch.LongTensor(phoneme_ids).to(self.device).unsqueeze(0) x_tst_lengths = torch.LongTensor([stn_tst.size(1)]).to(self.device) sid_tensor = torch.LongTensor([sid]).to(self.device) if sid is not None else None with torch.no_grad(): audio = self.net_g.infer( stn_tst, x_tst_lengths, sid=sid_tensor, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, )[0][0, 0].data.cpu().float().numpy() else: # Inferencia con ONNX text_input = np.expand_dims(phoneme_ids, 0) text_lengths = np.array([text_input.shape[1]], dtype=np.int64) scales = np.array([noise_scale, length_scale, noise_scale_w], dtype=np.float32) sid_input = np.array([sid], dtype=np.int64) if sid is not None else None audio = self.onnx_session.run( None, { "input": text_input, "input_lengths": text_lengths, "scales": scales, "sid": sid_input, }, )[0].squeeze((0, 1)) write(data=audio, rate=self.hps.data.sampling_rate, filename=output_path) print(f"Audio guardado exitosamente en: {output_path}") # --- Ejemplo de Uso --- if __name__ == "__main__": # Rutas de configuración CONFIG_PATH = "./configs/config.json" # --- PRUEBA CON MODELO ONNX CUANTIZADO (RECOMENDADO) --- ONNX_MODEL_PATH = "./models/LJspeech_quantized.onnx" if os.path.exists(ONNX_MODEL_PATH): print("\n--- Probando con el modelo ONNX ---") tts_onnx_engine = TTS(config_path=CONFIG_PATH, model_path=ONNX_MODEL_PATH) tts_onnx_engine.text_to_speech( "This is a test using the optimized ONNX model. It should be very fast.", "sample_onnx.wav" ) else: print(f"No se encontró el modelo ONNX en {ONNX_MODEL_PATH}. Saltando prueba.") # --- PRUEBA CON MODELO PYTORCH ORIGINAL --- PTH_MODEL_PATH = "./models/LJspeech.pth" if os.path.exists(PTH_MODEL_PATH): print("\n--- Probando con el modelo PyTorch ---") tts_pytorch_engine = TTS(config_path=CONFIG_PATH, model_path=PTH_MODEL_PATH) tts_pytorch_engine.text_to_speech( "This is a test using the original PyTorch model.", "sample_pytorch.wav" ) else: print(f"No se encontró el modelo PyTorch en {PTH_MODEL_PATH}. Saltando prueba.")