import torch import torch.nn.functional as F from fireredtts.modules.acoustic_llm import AcousticLLM from fireredtts.modules.acoustic_codec import AcousticCodec from fireredtts.modules.flowmatching import FlowToken2Mel from fireredtts.modules.bigvgan import BigVGAN, MelExtractor class TwoStageCodec(torch.nn.Module): def __init__(self, config): super().__init__() self.acoustic_llm = AcousticLLM(**config["acoustic_llm"]) self.acoustic_codec = AcousticCodec(**config["acoustic_codec"]) def load_model(self, acoustic_llm_path, acoustic_codec_path): self.acoustic_llm.load_state_dict( torch.load(acoustic_llm_path, map_location="cpu"), strict=True ) self.acoustic_codec.load_state_dict( torch.load(acoustic_codec_path, map_location="cpu"), strict=True ) @torch.inference_mode() def forward( self, semantic_token, prompt_semantic_token, prompt_acoustic_token, spk_gpt ): # print('Before: ', semantic_token.shape) token_pred = torch.cat((prompt_semantic_token, semantic_token), dim=1) # Fine LLM inference token_pred = self.acoustic_llm.inference_speech( speech_conditioning_latent=spk_gpt, text_inputs=token_pred, num_return_sequences=1, input_tokens=prompt_acoustic_token, )[0] if isinstance(token_pred, (tuple, list)): token_pred = [x.unsqueeze(0) for x in token_pred] else: token_pred = token_pred.unsqueeze(0) acoustic_outputs = self.acoustic_codec.reconstruct_wav(token=token_pred) wav = acoustic_outputs["wav_pred"].squeeze(1) return wav def extract(self, wavs, wav_lengths, spk): if torch.cuda.is_available(): wavs = wavs.cuda() cond_tok = self.acoustic_codec.extract_speech_tokens(wavs, wav_lengths)[ "token" ][0] spk_gpt = self.acoustic_llm.get_conditioning(spk) return cond_tok, spk_gpt """For FlowToken2Audio, keep interface consistant with TwoStageCodec to minimize code changes. prompt_acoustic_token alias to prompt_mel spk_gpt alias to spk_embeddings """ class FlowToken2Audio(torch.nn.Module): def __init__(self, config): super().__init__() self.flow = FlowToken2Mel(config['flow']) self.bigvgan = BigVGAN(**config['bigvgan']) self.mel_extractor = MelExtractor(**config['mel']) def load_model(self, flow_path, bigvgan_path): self.flow.load_state_dict( torch.load(flow_path, map_location="cpu"), strict=True ) self.bigvgan.load_state_dict( torch.load(bigvgan_path, map_location="cpu")['generator'], strict=True ) self.bigvgan.remove_weight_norm() @torch.inference_mode() def forward( self, semantic_token, prompt_semantic_token, prompt_acoustic_token, spk_gpt ): # Align prompt token & prompt_mel target_mel_length = prompt_semantic_token.shape[1] * 2 if target_mel_length > prompt_acoustic_token.shape[1]: prompt_acoustic_token = F.pad( prompt_acoustic_token, (0, 0, 0, target_mel_length-prompt_acoustic_token.shape[1]), mode='constant', value=-11.5 ) elif target_mel_length < prompt_acoustic_token.shape[1]: prompt_acoustic_token = prompt_acoustic_token[:, :target_mel_length] # prompt_acoustic_token = F.interpolate( # prompt_acoustic_token.transpose(1, 2), # size=prompt_semantic_token.shape[1] * 2, mode='nearest' # ).transpose(1, 2) mel_pred = self.flow.inference( prompt_token=prompt_semantic_token, prompt_xvec=spk_gpt, prompt_feat=prompt_acoustic_token, token=semantic_token ) wav = self.bigvgan(mel_pred.transpose(1, 2)).squeeze(1) return wav def extract(self, wavs, wav_lengths, spk): mel = self.mel_extractor(wavs, 24000).transpose(1, 2) if torch.cuda.is_available(): mel = mel.cuda() return mel, spk.squeeze(0)