Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| from typing import Callable, List, Tuple | |
| import torch | |
| import safetensors.torch as st | |
| from huggingface_hub import hf_hub_download | |
| from model import EchoDiT | |
| from autoencoder import build_ae, DAC | |
| import torchaudio | |
| from torchcodec.decoders import AudioDecoder | |
| # from samplers import Sampler | |
| SampleFn = Callable[ | |
| [EchoDiT, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int], | |
| torch.Tensor | |
| ] | |
| ### Loading | |
| def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cuda', dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None) -> EchoDiT: | |
| with torch.device('meta'): | |
| model = EchoDiT( | |
| latent_size=80, model_size=2048, num_layers=24, num_heads=16, | |
| intermediate_size=5888, norm_eps=1e-5, max_seq_len=640, | |
| text_vocab_size=256, text_model_size=1280, text_num_layers=14, | |
| text_num_heads=10, text_intermediate_size=3328, text_max_seq_len=768, | |
| speaker_patch_size=4, speaker_model_size=1280, speaker_num_layers=14, | |
| speaker_num_heads=10, speaker_intermediate_size=3328, | |
| speaker_max_patched_seq_len=640, timestep_embed_size=512, adaln_rank=256, | |
| ) | |
| w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token) | |
| # Load to CPU first | |
| state = st.load_file(w_path, device='cpu') | |
| # Convert dtype on CPU if needed | |
| if dtype is not None: | |
| state = {k: v.to(dtype=dtype) for k, v in state.items()} | |
| # Now move to device | |
| state = {k: v.to(device=device) for k, v in state.items()} | |
| model.load_state_dict(state, strict=False, assign=True) | |
| model = model.eval() | |
| if compile: | |
| model = torch.compile(model) | |
| model.get_kv_cache = torch.compile(model.get_kv_cache) | |
| return model | |
| def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = 'cuda', dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC: | |
| # have not tested lower precisions with fish AE yet | |
| with torch.device('meta'): | |
| fish_ae = build_ae() | |
| w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token) | |
| if dtype is not None and dtype != torch.float32: | |
| state = st.load_file(w_path, device='cpu') | |
| state = {k: v.to(dtype=dtype) for k, v in state.items()} | |
| state = {k: v.to(device=device) for k, v in state.items()} | |
| fish_ae.load_state_dict(state, strict=False, assign=True) | |
| else: | |
| state = st.load_file(w_path, device=device) | |
| fish_ae.load_state_dict(state, strict=False, assign=True) | |
| fish_ae = fish_ae.eval().to(device) | |
| if compile: | |
| fish_ae.encoder = torch.compile(fish_ae.encoder) | |
| fish_ae.decoder = torch.compile(fish_ae.decoder) | |
| return fish_ae | |
| class PCAState: | |
| pca_components: torch.Tensor | |
| pca_mean: torch.Tensor | |
| latent_scale: float | |
| def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts', device: str = 'cuda', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState: | |
| p_path = hf_hub_download(repo_id, filename, token=token) | |
| t = st.load_file(p_path, device=device) | |
| return PCAState( | |
| pca_components=t["pca_components"], | |
| pca_mean=t["pca_mean"], | |
| latent_scale=float(t["latent_scale"].item()), | |
| ) | |
| ### default load audio | |
| def load_audio(path: str) -> torch.Tensor: | |
| decoder = AudioDecoder(path) | |
| sr = decoder.metadata.sample_rate | |
| audio = decoder.get_samples_played_in_range(0, 120) | |
| audio = audio.data.mean(dim=0).unsqueeze(0) | |
| audio = torchaudio.functional.resample(audio, sr, 44_100) | |
| audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.)) | |
| # TODO is this better than clipping? should we target a specific energy level? | |
| return audio | |
| ### Text helpers | |
| def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True) -> torch.Tensor: | |
| if normalize: | |
| text = text.replace('…', '...') | |
| text = text.replace('“', '"') | |
| text = text.replace('”', '"') | |
| text = text.replace('’', "'") | |
| text = text.replace('\n', " ") | |
| text = text.replace(':', ',') | |
| text = text.replace(';', ',') | |
| b = list(text.encode('utf-8')) | |
| if append_bos: | |
| b.insert(0, 0) | |
| return torch.tensor(b) | |
| def get_text_input_ids_and_mask(text_arr: List[str], max_length: int | None, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]: | |
| batch_size = len(text_arr) | |
| if max_length is None: | |
| max_length = max(len(tokenizer_encode(text)) for text in text_arr) # obviously bad... | |
| tokens = torch.zeros((batch_size, max_length), dtype=torch.int32) | |
| mask = torch.zeros((batch_size, max_length), dtype=torch.bool) | |
| for i, text in enumerate(text_arr): | |
| encoded = tokenizer_encode(text) | |
| length = min(len(encoded), max_length) | |
| tokens[i, :length] = encoded[:length] | |
| mask[i, :length] = 1 | |
| if device is not None: | |
| tokens = tokens.to(device) | |
| mask = mask.to(device) | |
| return tokens, mask | |
| ### Autoencoder Inference | |
| def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor: | |
| assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length) | |
| z_q = fish_ae.encode_zq(audio).float() | |
| z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T | |
| z_q = z_q * pca_state.latent_scale | |
| return z_q | |
| def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor: | |
| z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean | |
| return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float() | |
| def ae_reconstruct(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor: | |
| # (audio is (b, 1, length)) | |
| z_q = ae_encode(fish_ae, pca_state, audio.to(fish_ae.dtype)) | |
| return ae_decode(fish_ae, pca_state, z_q) | |
| def get_speaker_latent_and_mask( | |
| fish_ae: DAC, | |
| pca_state: PCAState, | |
| audio: torch.Tensor, # (1, length) | |
| max_speaker_latent_len: int = 2560, # pretrained max length | |
| audio_chunk_size: int = 640 * 2048 # (~30 seconds, 1/4 max speaker condition size) | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| # gets speaker latent and mask from audio, computes in chunks and concatenates (similar to pretraining setup) | |
| AE_DOWNSAMPLE_FACTOR = 2048 | |
| max_audio_len = max_speaker_latent_len * AE_DOWNSAMPLE_FACTOR | |
| assert audio.ndim == 2 and audio.shape[0] == 1 # (1, length) | |
| audio = audio[:, :max_audio_len] | |
| audio_len = audio.shape[1] | |
| latent_arr = [] | |
| for i in range(0, audio_len, audio_chunk_size): | |
| audio_chunk = audio[:, i:i + audio_chunk_size] | |
| if audio_chunk.shape[1] < audio_chunk_size: | |
| audio_chunk = torch.nn.functional.pad(audio_chunk, (0, audio_chunk_size - audio_chunk.shape[1])) | |
| latent_chunk = ae_encode(fish_ae, pca_state, audio_chunk.unsqueeze(0)) | |
| latent_arr.append(latent_chunk) | |
| speaker_latent = torch.cat(latent_arr, dim=1) | |
| actual_latent_len = audio_len // AE_DOWNSAMPLE_FACTOR | |
| speaker_mask = (torch.arange(speaker_latent.shape[1], device=speaker_latent.device) < actual_latent_len).unsqueeze(0) | |
| if speaker_latent.shape[1] < max_speaker_latent_len: | |
| speaker_latent = torch.nn.functional.pad(speaker_latent, (0, 0, 0, max_speaker_latent_len - speaker_latent.shape[1])) | |
| speaker_mask = torch.nn.functional.pad(speaker_mask, (0, max_speaker_latent_len - speaker_mask.shape[1])) | |
| return speaker_latent, speaker_mask | |
| ### Full sample pipeline | |
| def find_flattening_point(data, target_value=0.0, window_size=20, std_threshold=0.05): | |
| padded_data = torch.cat([data, torch.zeros(window_size, *data.shape[1:], device=data.device, dtype=data.dtype)]) | |
| for i in range(len(padded_data) - window_size): | |
| window = padded_data[i:i + window_size] | |
| if window.std() < std_threshold and abs(window.mean() - target_value) < 0.1: | |
| return i | |
| return len(data) | |
| def sample_pipeline( | |
| model: EchoDiT, | |
| fish_ae: DAC, | |
| pca_state: PCAState, | |
| sample_fn: SampleFn, | |
| text_prompt: str, | |
| speaker_audio: torch.Tensor | None, | |
| rng_seed: int, | |
| pad_to_max_speaker_latent_len: int | None = 2560, | |
| pad_to_max_text_seq_len: int | None = 768, | |
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | |
| MAX_SPEAKER_LATENT_LEN = 2560 | |
| MAX_TEXT_SEQ_LEN = 768 | |
| device, dtype = model.device, model.dtype | |
| text_input_ids, text_mask = get_text_input_ids_and_mask([text_prompt], min(pad_to_max_text_seq_len or MAX_TEXT_SEQ_LEN, MAX_TEXT_SEQ_LEN), device=device) | |
| # print('initial text input ids length: ', text_input_ids.shape[1]) | |
| # torch.cuda.synchronize() | |
| # import time | |
| # t0 = time.time() | |
| if speaker_audio is None: | |
| # No speaker prompt - use zero speaker latent and mask | |
| speaker_latent = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN, 80), device=device, dtype=dtype) | |
| speaker_mask = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN), device=device, dtype=torch.bool) | |
| # print("Using zero speaker latent and mask (no speaker prompt)") | |
| else: | |
| speaker_latent, speaker_mask = get_speaker_latent_and_mask( | |
| fish_ae, | |
| pca_state, | |
| speaker_audio.to(fish_ae.dtype), | |
| max_speaker_latent_len=pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN | |
| ) | |
| speaker_latent = speaker_latent.to(device) | |
| speaker_mask = speaker_mask.to(device) | |
| # print('speaker latent shape: ', speaker_latent.shape) | |
| # print('speaker mask shape: ', speaker_mask.shape) | |
| # torch.cuda.synchronize() | |
| # t1 = time.time() | |
| # print(f"Time taken encode: {t1 - t0} seconds") | |
| latent_out = sample_fn(model, speaker_latent, speaker_mask, text_input_ids, text_mask, rng_seed) | |
| # torch.cuda.synchronize() | |
| # t2 = time.time() | |
| # print(f"Time taken sample: {t2 - t1} seconds") | |
| audio_out = ae_decode(fish_ae, pca_state, latent_out) | |
| # torch.cuda.synchronize() | |
| # t3 = time.time() | |
| # print(f"Time taken decode: {t3 - t2} seconds") | |
| flattening_point = find_flattening_point(latent_out[0]) | |
| audio_out = audio_out[..., :flattening_point * 2048] | |
| # print(f"\nTime taken total: {t3 - t0} seconds") | |
| # peak_mem = torch.cuda.max_memory_allocated() | |
| # print(f"Peak memory: {peak_mem / 1024**2:.2f} MB") | |
| # print(torch.cuda.memory_summary(abbreviated=True)) | |
| return audio_out | |