echo-tts-preview / inference.py
jordand's picture
Update inference.py
0fe8b05 verified
raw
history blame
10.9 kB
from dataclasses import dataclass
from typing import Callable, List, Tuple
import torch
import safetensors.torch as st
from huggingface_hub import hf_hub_download
from model import EchoDiT
from autoencoder import build_ae, DAC
import torchaudio
from torchcodec.decoders import AudioDecoder
# from samplers import Sampler
SampleFn = Callable[
[EchoDiT, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int],
torch.Tensor
]
### Loading
def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cuda', dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None) -> EchoDiT:
with torch.device('meta'):
model = EchoDiT(
latent_size=80, model_size=2048, num_layers=24, num_heads=16,
intermediate_size=5888, norm_eps=1e-5, max_seq_len=640,
text_vocab_size=256, text_model_size=1280, text_num_layers=14,
text_num_heads=10, text_intermediate_size=3328, text_max_seq_len=768,
speaker_patch_size=4, speaker_model_size=1280, speaker_num_layers=14,
speaker_num_heads=10, speaker_intermediate_size=3328,
speaker_max_patched_seq_len=640, timestep_embed_size=512, adaln_rank=256,
)
w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
# Load to CPU first
state = st.load_file(w_path, device='cpu')
# Convert dtype on CPU if needed
if dtype is not None:
state = {k: v.to(dtype=dtype) for k, v in state.items()}
# Now move to device
state = {k: v.to(device=device) for k, v in state.items()}
model.load_state_dict(state, strict=False, assign=True)
model = model.eval()
if compile:
model = torch.compile(model)
model.get_kv_cache = torch.compile(model.get_kv_cache)
return model
def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = 'cuda', dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC:
# have not tested lower precisions with fish AE yet
with torch.device('meta'):
fish_ae = build_ae()
w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token)
if dtype is not None and dtype != torch.float32:
state = st.load_file(w_path, device='cpu')
state = {k: v.to(dtype=dtype) for k, v in state.items()}
state = {k: v.to(device=device) for k, v in state.items()}
fish_ae.load_state_dict(state, strict=False, assign=True)
else:
state = st.load_file(w_path, device=device)
fish_ae.load_state_dict(state, strict=False, assign=True)
fish_ae = fish_ae.eval().to(device)
if compile:
fish_ae.encoder = torch.compile(fish_ae.encoder)
fish_ae.decoder = torch.compile(fish_ae.decoder)
return fish_ae
@dataclass
class PCAState:
pca_components: torch.Tensor
pca_mean: torch.Tensor
latent_scale: float
def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts', device: str = 'cuda', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState:
p_path = hf_hub_download(repo_id, filename, token=token)
t = st.load_file(p_path, device=device)
return PCAState(
pca_components=t["pca_components"],
pca_mean=t["pca_mean"],
latent_scale=float(t["latent_scale"].item()),
)
### default load audio
def load_audio(path: str) -> torch.Tensor:
decoder = AudioDecoder(path)
sr = decoder.metadata.sample_rate
audio = decoder.get_samples_played_in_range(0, 120)
audio = audio.data.mean(dim=0).unsqueeze(0)
audio = torchaudio.functional.resample(audio, sr, 44_100)
audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.))
# TODO is this better than clipping? should we target a specific energy level?
return audio
### Text helpers
def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True) -> torch.Tensor:
if normalize:
text = text.replace('…', '...')
text = text.replace('“', '"')
text = text.replace('”', '"')
text = text.replace('’', "'")
text = text.replace('\n', " ")
text = text.replace(':', ',')
text = text.replace(';', ',')
b = list(text.encode('utf-8'))
if append_bos:
b.insert(0, 0)
return torch.tensor(b)
def get_text_input_ids_and_mask(text_arr: List[str], max_length: int | None, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = len(text_arr)
if max_length is None:
max_length = max(len(tokenizer_encode(text)) for text in text_arr) # obviously bad...
tokens = torch.zeros((batch_size, max_length), dtype=torch.int32)
mask = torch.zeros((batch_size, max_length), dtype=torch.bool)
for i, text in enumerate(text_arr):
encoded = tokenizer_encode(text)
length = min(len(encoded), max_length)
tokens[i, :length] = encoded[:length]
mask[i, :length] = 1
if device is not None:
tokens = tokens.to(device)
mask = mask.to(device)
return tokens, mask
### Autoencoder Inference
@torch.inference_mode()
def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length)
z_q = fish_ae.encode_zq(audio).float()
z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T
z_q = z_q * pca_state.latent_scale
return z_q
@torch.inference_mode()
def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor:
z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean
return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float()
@torch.inference_mode()
def ae_reconstruct(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
# (audio is (b, 1, length))
z_q = ae_encode(fish_ae, pca_state, audio.to(fish_ae.dtype))
return ae_decode(fish_ae, pca_state, z_q)
@torch.inference_mode()
def get_speaker_latent_and_mask(
fish_ae: DAC,
pca_state: PCAState,
audio: torch.Tensor, # (1, length)
max_speaker_latent_len: int = 2560, # pretrained max length
audio_chunk_size: int = 640 * 2048 # (~30 seconds, 1/4 max speaker condition size)
) -> tuple[torch.Tensor, torch.Tensor]:
# gets speaker latent and mask from audio, computes in chunks and concatenates (similar to pretraining setup)
AE_DOWNSAMPLE_FACTOR = 2048
max_audio_len = max_speaker_latent_len * AE_DOWNSAMPLE_FACTOR
assert audio.ndim == 2 and audio.shape[0] == 1 # (1, length)
audio = audio[:, :max_audio_len]
audio_len = audio.shape[1]
latent_arr = []
for i in range(0, audio_len, audio_chunk_size):
audio_chunk = audio[:, i:i + audio_chunk_size]
if audio_chunk.shape[1] < audio_chunk_size:
audio_chunk = torch.nn.functional.pad(audio_chunk, (0, audio_chunk_size - audio_chunk.shape[1]))
latent_chunk = ae_encode(fish_ae, pca_state, audio_chunk.unsqueeze(0))
latent_arr.append(latent_chunk)
speaker_latent = torch.cat(latent_arr, dim=1)
actual_latent_len = audio_len // AE_DOWNSAMPLE_FACTOR
speaker_mask = (torch.arange(speaker_latent.shape[1], device=speaker_latent.device) < actual_latent_len).unsqueeze(0)
if speaker_latent.shape[1] < max_speaker_latent_len:
speaker_latent = torch.nn.functional.pad(speaker_latent, (0, 0, 0, max_speaker_latent_len - speaker_latent.shape[1]))
speaker_mask = torch.nn.functional.pad(speaker_mask, (0, max_speaker_latent_len - speaker_mask.shape[1]))
return speaker_latent, speaker_mask
### Full sample pipeline
def find_flattening_point(data, target_value=0.0, window_size=20, std_threshold=0.05):
padded_data = torch.cat([data, torch.zeros(window_size, *data.shape[1:], device=data.device, dtype=data.dtype)])
for i in range(len(padded_data) - window_size):
window = padded_data[i:i + window_size]
if window.std() < std_threshold and abs(window.mean() - target_value) < 0.1:
return i
return len(data)
@torch.inference_mode()
def sample_pipeline(
model: EchoDiT,
fish_ae: DAC,
pca_state: PCAState,
sample_fn: SampleFn,
text_prompt: str,
speaker_audio: torch.Tensor | None,
rng_seed: int,
pad_to_max_speaker_latent_len: int | None = 2560,
pad_to_max_text_seq_len: int | None = 768,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
MAX_SPEAKER_LATENT_LEN = 2560
MAX_TEXT_SEQ_LEN = 768
device, dtype = model.device, model.dtype
text_input_ids, text_mask = get_text_input_ids_and_mask([text_prompt], min(pad_to_max_text_seq_len or MAX_TEXT_SEQ_LEN, MAX_TEXT_SEQ_LEN), device=device)
# print('initial text input ids length: ', text_input_ids.shape[1])
# torch.cuda.synchronize()
# import time
# t0 = time.time()
if speaker_audio is None:
# No speaker prompt - use zero speaker latent and mask
speaker_latent = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN, 80), device=device, dtype=dtype)
speaker_mask = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN), device=device, dtype=torch.bool)
# print("Using zero speaker latent and mask (no speaker prompt)")
else:
speaker_latent, speaker_mask = get_speaker_latent_and_mask(
fish_ae,
pca_state,
speaker_audio.to(fish_ae.dtype),
max_speaker_latent_len=pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN
)
speaker_latent = speaker_latent.to(device)
speaker_mask = speaker_mask.to(device)
# print('speaker latent shape: ', speaker_latent.shape)
# print('speaker mask shape: ', speaker_mask.shape)
# torch.cuda.synchronize()
# t1 = time.time()
# print(f"Time taken encode: {t1 - t0} seconds")
latent_out = sample_fn(model, speaker_latent, speaker_mask, text_input_ids, text_mask, rng_seed)
# torch.cuda.synchronize()
# t2 = time.time()
# print(f"Time taken sample: {t2 - t1} seconds")
audio_out = ae_decode(fish_ae, pca_state, latent_out)
# torch.cuda.synchronize()
# t3 = time.time()
# print(f"Time taken decode: {t3 - t2} seconds")
flattening_point = find_flattening_point(latent_out[0])
audio_out = audio_out[..., :flattening_point * 2048]
# print(f"\nTime taken total: {t3 - t0} seconds")
# peak_mem = torch.cuda.max_memory_allocated()
# print(f"Peak memory: {peak_mem / 1024**2:.2f} MB")
# print(torch.cuda.memory_summary(abbreviated=True))
return audio_out