Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import time | |
| from traceback import format_exc | |
| import torch | |
| from transformers import AutoTokenizer | |
| from fireredtts.utils.utils import load_audio | |
| from fireredtts.modules.text_normalizer.utils import text_split | |
| from fireredtts.utils.spliter import clean_text | |
| from fireredtts.modules.text_normalizer.normalize import TextNormalizer | |
| from fireredtts.modules.semantic_tokenizer import SemanticTokenizer | |
| from fireredtts.modules.semantic_llm.llm_gpt2 import Speech_LLM_GPT2 | |
| from fireredtts.models.token2audio import TwoStageCodec, FlowToken2Audio | |
| class FireRedTTS: | |
| def __init__(self, config_path, pretrained_path, device="cuda"): | |
| self.device = device | |
| self.config = json.load(open(config_path)) | |
| self.EOS_TOKEN = self.config["semantic_llm"]["EOS_TOKEN"] | |
| # pretrained models | |
| self.tokenizer_path = os.path.join(pretrained_path, "tokenizer") | |
| self.speech_tokenizer_path = os.path.join(pretrained_path, "speech_tokenizer") | |
| self.semantic_llm_path = os.path.join(pretrained_path, "semantic_llm.pt") | |
| assert os.path.exists(self.tokenizer_path) | |
| assert os.path.exists(self.speech_tokenizer_path) | |
| assert os.path.exists(self.semantic_llm_path) | |
| if 'acoustic_llm' in self.config: | |
| self.acoustic_llm_path = os.path.join(pretrained_path, "acoustic_llm.bin") | |
| self.acoustic_codec_path = os.path.join(pretrained_path, "acoustic_codec.bin") | |
| assert os.path.exists(self.acoustic_llm_path) | |
| assert os.path.exists(self.acoustic_codec_path) | |
| else: | |
| self.flow_path = os.path.join(pretrained_path, "flow.pt") | |
| self.bigvgan_path = os.path.join(pretrained_path, "bigvgan.pt") | |
| assert os.path.exists(self.flow_path) | |
| assert os.path.exists(self.bigvgan_path) | |
| # text normalizer | |
| self.text_normalizer = TextNormalizer() | |
| # text tokenizer | |
| self.text_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) | |
| # semantic llm | |
| self.semantic_llm = Speech_LLM_GPT2( | |
| start_text_token=self.config["semantic_llm"]["start_text_token"], | |
| stop_text_token=self.config["semantic_llm"]["stop_text_token"], | |
| num_text_tokens=self.config["semantic_llm"]["num_text_tokens"], | |
| start_audio_token=self.config["semantic_llm"]["start_audio_token"], | |
| stop_audio_token=self.config["semantic_llm"]["stop_audio_token"], | |
| num_audio_tokens=self.config["semantic_llm"]["num_audio_tokens"], | |
| llm_hidden_size=self.config["semantic_llm"]["llm_hidden_size"], | |
| llm_intermediate_size=self.config["semantic_llm"]["llm_intermediate_size"], | |
| llm_num_layers=self.config["semantic_llm"]["llm_num_layers"], | |
| llm_num_heads=self.config["semantic_llm"]["llm_num_heads"], | |
| llm_max_audio_seq_len=self.config["semantic_llm"]["llm_max_audio_seq_len"], | |
| llm_max_text_seq_len=self.config["semantic_llm"]["llm_max_text_seq_len"], | |
| llm_max_prompt_len=self.config["semantic_llm"]["llm_max_prompt_len"], | |
| code_stride_len=self.config["semantic_llm"]["code_stride_len"], | |
| ) | |
| sd = torch.load(self.semantic_llm_path, map_location=device)["model"] | |
| self.semantic_llm.load_state_dict(sd, strict=True) | |
| self.semantic_llm = self.semantic_llm.to(device=device) | |
| self.semantic_llm.eval() | |
| self.semantic_llm.init_gpt_for_inference(kv_cache=True) | |
| # Speech tokenizer | |
| self.speech_tokenizer = SemanticTokenizer( | |
| config=self.config["semantic_tokenizer"], path=self.speech_tokenizer_path | |
| ) | |
| # Acoustic decoder | |
| if 'acoustic_llm' in self.config: | |
| self.acoustic_decoder = TwoStageCodec(self.config) | |
| self.acoustic_decoder.load_model(self.acoustic_llm_path, self.acoustic_codec_path) | |
| else: | |
| self.acoustic_decoder = FlowToken2Audio(self.config) | |
| self.acoustic_decoder.load_model(self.flow_path, self.bigvgan_path) | |
| self.acoustic_decoder.eval() | |
| self.acoustic_decoder = self.acoustic_decoder.to(device) | |
| def extract_spk_embeddings(self, prompt_wav): | |
| audio, lsr, audio_resampled = load_audio( | |
| audiopath=prompt_wav, | |
| sampling_rate=16000, | |
| ) | |
| _, _, audio_resampled24k = load_audio( | |
| audiopath=prompt_wav, | |
| sampling_rate=24000, | |
| ) | |
| audio_resampled = audio_resampled.to(self.device) | |
| audio_len = torch.tensor( | |
| data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False | |
| ) | |
| # spk_embeddings:[1, 512] | |
| prompt_tokens, token_lengths, spk_embeddings = self.speech_tokenizer( | |
| audio_resampled, audio_len | |
| ) | |
| prompt_acoustic_tokens, acoustic_llm_spk = self.acoustic_decoder.extract( | |
| audio_resampled if isinstance(self.acoustic_decoder, TwoStageCodec) else audio_resampled24k, | |
| audio_len, spk_embeddings.unsqueeze(0) | |
| ) | |
| return prompt_tokens, spk_embeddings, prompt_acoustic_tokens, acoustic_llm_spk | |
| def synthesize_base( | |
| self, | |
| prompt_semantic_tokens, | |
| prompt_acoustic_tokens, | |
| spk_semantic_llm, | |
| spk_acoustic_llm, | |
| prompt_text, | |
| text, | |
| lang="auto", | |
| ): | |
| """_summary_ | |
| Args: | |
| prompt_wav (_type_): _description_ | |
| prompt_text (_type_): _description_ | |
| text (_type_): _description_ | |
| lang (str, optional): _description_. Defaults to "auto". | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| if lang == "en": | |
| text = prompt_text + " " + text | |
| else: | |
| text = prompt_text + text | |
| print("---text:\n", text) | |
| # Pre-process prompt tokens | |
| # text to tokens | |
| text_tokens = self.text_tokenizer.encode( | |
| text=text, | |
| add_special_tokens=False, | |
| max_length=10**6, | |
| truncation=False, | |
| ) | |
| # print("---decode", [self.text_tokenizer.decode([c]) for c in text_tokens]) | |
| text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device) | |
| assert text_tokens.shape[-1] < 200 | |
| with torch.no_grad(): | |
| gpt_codes = self.semantic_llm.generate_ic( | |
| cond_latents=spk_semantic_llm, | |
| text_inputs=text_tokens, | |
| prompt_tokens=prompt_semantic_tokens[:, :-3], | |
| do_sample=True, | |
| top_p=0.85, | |
| top_k=30, | |
| temperature=0.75, | |
| num_return_sequences=7, | |
| num_beams=1, | |
| length_penalty=2.0, | |
| repetition_penalty=5.0, | |
| output_attentions=False, | |
| ) | |
| seqs = [] | |
| for seq in gpt_codes: | |
| index = (seq == self.EOS_TOKEN).nonzero(as_tuple=True)[0][0] | |
| seq = seq[:index] | |
| seqs.append(seq) | |
| sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False) | |
| sorted_len = [len(l) for l in sorted_seqs] | |
| gpt_codes = sorted_seqs[2].unsqueeze(0) | |
| # Acoustic decoder | |
| rec_wavs = self.acoustic_decoder( | |
| gpt_codes, prompt_semantic_tokens, prompt_acoustic_tokens, spk_acoustic_llm | |
| ) | |
| rec_wavs = rec_wavs.detach().cpu() | |
| return rec_wavs | |
| def synthesize(self, prompt_wav, prompt_text, text, lang="auto", use_tn=False): | |
| """audio synthesize | |
| Args: | |
| prompt_wav (_type_): _description_ | |
| prompt_text (_type_): _description_ | |
| text (_type_): _description_ | |
| lang (str, optional): _description_. Defaults to "auto". | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| assert lang in ["zh", "en", "auto"] | |
| assert os.path.exists(prompt_wav) | |
| ( | |
| prompt_semantic_tokens, | |
| spk_embeddings, | |
| prompt_acoustic_tokens, | |
| spk_acoustic_llm, | |
| ) = self.extract_spk_embeddings(prompt_wav=prompt_wav) | |
| spk_embeddings = spk_embeddings.unsqueeze(0) | |
| spk_semantic_llm = self.semantic_llm.reference_embedding(spk_embeddings) | |
| # print("---prompt_semantic_tokens:\n", prompt_semantic_tokens) | |
| # print("---spk_embeddings:\n", spk_embeddings) | |
| # clean text | |
| prompt_text = clean_text(prompt_text) | |
| text = clean_text(text=text) | |
| if use_tn: | |
| substrings = text_split(text=text) | |
| out_wavs = [] | |
| try: | |
| for sub in substrings: | |
| res_lang = self.text_normalizer.tn(text=sub)[1] | |
| chunk = self.synthesize_base( | |
| prompt_semantic_tokens=prompt_semantic_tokens, | |
| prompt_acoustic_tokens=prompt_acoustic_tokens, | |
| spk_semantic_llm=spk_semantic_llm, | |
| spk_acoustic_llm=spk_acoustic_llm, | |
| prompt_text=prompt_text, | |
| text=sub, | |
| lang=res_lang, | |
| ) | |
| out_wavs.append(chunk) | |
| out_wav = torch.concat(out_wavs, axis=-1) | |
| return out_wav | |
| except: | |
| print('[ERROR] ', format_exc()) | |
| return None | |
| else: | |
| out_wavs = [] | |
| try: | |
| res_lang = self.text_normalizer.tn(text=text)[1] | |
| chunk = self.synthesize_base( | |
| prompt_semantic_tokens=prompt_semantic_tokens, | |
| prompt_acoustic_tokens=prompt_acoustic_tokens, | |
| spk_semantic_llm=spk_semantic_llm, | |
| spk_acoustic_llm=spk_acoustic_llm, | |
| prompt_text=prompt_text, | |
| text=text, | |
| lang=res_lang, | |
| ) | |
| out_wavs.append(chunk) | |
| out_wav = torch.concat(out_wavs, axis=-1) | |
| return out_wav | |
| except: | |
| return None | |