import math import json import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Dict from torch.nn.utils.rnn import pad_sequence from fireredtts2.codec.rvq import ResidualVQ from fireredtts2.codec.decoder import AcousticDecoder from fireredtts2.codec.utils import make_nonpad_mask from fireredtts2.codec.whisper import ( WhisperEncoderLayer, PretrainedWhisperEncoder, WhisperAcousticEncoder, ) class SslAdaptor(nn.Module): def __init__( self, in_dim: int, embed_dim: int, out_dim: int, num_layers: int, num_heads: int, ffn_dim: int = None, attn_dropout: float = 0.0, dropout: float = 0.0, ): super().__init__() self.in_dim = in_dim self.embed_dim = embed_dim self.dropout = dropout # Input Projection self.in_proj = nn.Linear(in_dim, embed_dim) # Transformer self.layers = nn.ModuleList( [ WhisperEncoderLayer( embed_dim, num_heads, ffn_dim, attn_dropout, dropout ) for _ in range(num_layers) ] ) # Output norm self.layer_norm = nn.LayerNorm(embed_dim) # Output projection self.out_proj = nn.Linear(embed_dim, out_dim) # Init weight self.apply(self._init_weights) def forward( self, hidden_states: torch.Tensor, hidden_length: torch.Tensor, ): # Downsampling hidden_states = self.in_proj(hidden_states) # Transformer attention_mask = make_nonpad_mask(hidden_length).unsqueeze(1) # (b, 1, t) for layer in self.layers: hidden_states = layer(hidden_states, attention_mask) hidden_states = self.layer_norm(hidden_states) hidden_states = self.out_proj(hidden_states) return hidden_states, hidden_length def _init_weights(self, module): std = 0.02 if isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class ResidualDownConv(nn.Module): def __init__( self, embed_dim: int = 768, avg_pooler=4, ): super().__init__() self.embed_dim = embed_dim self.avg_pooler = avg_pooler self.intermediate_dim = embed_dim * avg_pooler # Convolution layer for downsampling self.gate_proj = nn.Conv1d( embed_dim, self.intermediate_dim, avg_pooler, avg_pooler, bias=False ) self.up_proj = nn.Conv1d( embed_dim, self.intermediate_dim, avg_pooler, avg_pooler, bias=False ) # Downsampled linear projection self.down_proj = nn.Linear( self.intermediate_dim, self.intermediate_dim, bias=False ) # Activation function and layer normalization self.act_fn = nn.SiLU() self.layer_norm = nn.LayerNorm(self.intermediate_dim) # Final output projection self.out_proj = nn.Linear(self.intermediate_dim, embed_dim) def forward(self, x: torch.Tensor, input_length: torch.Tensor): output_length = input_length // self.avg_pooler batch_size, seq_len, _ = x.shape # (B, T, D) xt = x.permute(0, 2, 1) # (B, D, T) g = self.gate_proj(xt).permute(0, 2, 1) # (B, T//4, D*4) u = self.up_proj(xt).permute(0, 2, 1) # (B, T//4, D*4) x = x.reshape(batch_size, -1, self.intermediate_dim) # (B, T//4, D*4) c = self.down_proj(self.act_fn(g) * u) # (B, T//4, D*4) res = self.layer_norm(c + x) # (B, T//4, D*4) res = self.out_proj(res) return res, output_length class UpConv(nn.Module): def __init__( self, embed_dim: int = 768, stride: int = 4, ): super().__init__() self.embed_dim = embed_dim self.stride = stride self.in_proj = nn.Linear(embed_dim, self.stride * embed_dim) # Simple transpose convolution layer to keep channel number consistent self.up_conv = nn.ConvTranspose1d( self.stride * embed_dim, embed_dim, kernel_size=stride, stride=stride, bias=False, ) def forward(self, x: torch.Tensor, input_length: torch.Tensor): x = self.in_proj(x) x = x.transpose(1, 2) res = self.up_conv(x) res = res.transpose(1, 2) output_length = input_length * self.stride return res, output_length class RedCodec(nn.Module): def __init__( self, ssl: PretrainedWhisperEncoder, ssl_adaptor: SslAdaptor, acoustic_encoder: WhisperAcousticEncoder, downsample: ResidualDownConv, rvq: ResidualVQ, upsample: UpConv, semantic_decoder: SslAdaptor, acoustic_decoder: AcousticDecoder, ): super().__init__() self.ssl = ssl self.ssl_adaptor = ssl_adaptor self.acoustic_encoder = acoustic_encoder self.downsample = downsample self.rvq = rvq self.upsample = upsample self.semantic_decoder = semantic_decoder self.acoustic_decoder = acoustic_decoder @classmethod def from_config(cls, config_json: str) -> "RedCodec": with open(config_json, "rb") as f: config = json.load(f)["codec"] ssl = PretrainedWhisperEncoder.from_pretrained() ssl_adaptor = SslAdaptor(**config["ssl_adaptor"]) acoustic_encoder = WhisperAcousticEncoder(**config["acoustic_encoder"]) downsample = ResidualDownConv(**config["downsample"]) rvq = ResidualVQ(**config["rvq"]) upsample = UpConv(**config["upsample"]) semantic_decoder = SslAdaptor(**config["semantic_decoder"]) acoustic_decoder = AcousticDecoder(**config["acoustic_decoder"]) return cls( ssl, ssl_adaptor, acoustic_encoder, downsample, rvq, upsample, semantic_decoder, acoustic_decoder, ) class RedCodecInfer(RedCodec): def __init__(self, codec: RedCodec): super().__init__( codec.ssl, codec.ssl_adaptor, codec.acoustic_encoder, codec.downsample, codec.rvq, codec.upsample, codec.semantic_decoder, codec.acoustic_decoder, ) @classmethod def from_pretrained(cls, conf_path: str, ckpt_path: str) -> "RedCodecInfer": with open(conf_path, "r") as f: codec = RedCodec.from_config(conf_path) ckpt = torch.load(ckpt_path)["generator"] codec.load_state_dict(ckpt) return cls(codec) def _encode_one_batch(self, audio16k: torch.Tensor): B, T = audio16k.shape audio16k_length = torch.tensor( [T] * B, dtype=torch.long, device=audio16k.device ) # Semantic ssl, ssl_length = self.ssl.forward(audio16k, audio16k_length) ssl = ssl.clone() # For onnx export sem_feats, sem_length = self.ssl_adaptor(ssl, ssl_length) # Acoustic aco_feats, aco_length = self.acoustic_encoder(audio16k, audio16k_length) # VQ vq_in_feats = torch.cat([sem_feats, aco_feats], dim=2) vq_in_feats, vq_in_length = self.downsample(vq_in_feats, aco_length) # RVQ, indices = self.rvq.encode_codes(vq_in_feats.transpose(1, 2)) # (nq, B, L) indices = indices.permute(1, 0, 2) return indices # (B, nq, L) @staticmethod def _pad_and_chunk(audio: torch.Tensor, chunk_size: int) -> List[torch.Tensor]: pad_len = math.ceil(audio.shape[1] / chunk_size) * chunk_size - audio.shape[1] audio = F.pad(audio, (0, pad_len), mode="constant", value=0) audio_chunks = audio.split(chunk_size, dim=1) return audio_chunks @torch.inference_mode() def encode( self, audio16k: torch.Tensor, audio16k_length: torch.Tensor = None, batch_size: int = 96, ): """ Args: audio16k: shape (b, t) audio16k_length: (b,) Returns: token: shape (b, nq, l) token_length: (b,) """ if audio16k_length is None: assert audio16k.shape[0] == 1 audio16k_length = torch.tensor( [audio16k.shape[1]], dtype=torch.long, device=audio16k.device ) CHUNK_SIZE = 6 * 16000 B, T = audio16k.shape # Pad, chunk, and batch audio16k_batch = [] batch_size_list = [] for i in range(B): # Remove extra paddings one_audio_chunks = self._pad_and_chunk( audio16k[i : (i + 1), : audio16k_length[i]], CHUNK_SIZE ) audio16k_batch += one_audio_chunks batch_size_list.append(len(one_audio_chunks)) audio16k_batch = torch.cat(audio16k_batch, dim=0) # Batch encode token_batch = [] for i in range(0, audio16k_batch.shape[0], batch_size): one_audio_batch = audio16k_batch[i : (i + batch_size)] one_token_batch = self._encode_one_batch(one_audio_batch) token_batch.append(one_token_batch) token_batch = torch.cat(token_batch, dim=0) # Recover & concat token_list = torch.split( token_batch, batch_size_list, dim=0 ) # [(B=1, nq, l), (B=3, nq, l), ...] token_list = [ torch.cat(token_ts.split(1, dim=0), dim=-1) # (B=1, nq, l) for token_ts in token_list ] # Pad tokens token = pad_sequence( [ts.squeeze(0).transpose(1, 0) for ts in token_list], batch_first=True, padding_value=0, ).transpose( 1, 2 ) # (B, nq, L) token_length = (audio16k_length / 1280).ceil().long() token = token[ ..., : token_length.max() ] # Remove extra paddings (we pad to multiples of 6s) return token, token_length @torch.inference_mode() def decode(self, tokens: torch.Tensor): """ Args: tokens: (B=1, nq, L) Returns: audio: (B=1, t) """ tokens = tokens.permute(1, 0, 2) # (B, nq, L) -> (nq, B, L) vq_out_feats = self.rvq.decode_codes(tokens) vq_out_feats = vq_out_feats.transpose(1, 2) vq_out_length = torch.tensor( [vq_out_feats.shape[1]], dtype=torch.long, device=vq_out_feats.device ) vq_out_feats, vq_out_length = self.upsample(vq_out_feats, vq_out_length) # audio: (b, t) audio, audio_length = self.acoustic_decoder(vq_out_feats, vq_out_length) return audio @torch.inference_mode() def decode_one_token( self, token: torch.Tensor, cache_dict: Dict[str, torch.Tensor], last_token: bool ): """Decode one single token to audio. Args: token: (B=1, nq, L=1) Returns: audio: (B=1, t) """ # token->latent->upsample, (naturally causal) token = token.permute(1, 0, 2) # (B, nq, L) -> (nq, B, L) vq_out_feats = self.rvq.decode_codes(token) vq_out_feats = vq_out_feats.transpose(1, 2) vq_out_length = torch.tensor( [vq_out_feats.shape[1]], dtype=torch.long, device=vq_out_feats.device ) vq_out_feats, vq_out_length = self.upsample(vq_out_feats, vq_out_length) # acoustic decoder up_conv_cache = cache_dict.get("up_conv_cache", None) bb_conv_cache1 = cache_dict.get("bb_conv_cache1", None) bb_conv_cache2 = cache_dict.get("bb_conv_cache2", None) bb_kv_cache = cache_dict.get("bb_kv_cache", None) is_cache = cache_dict.get("is_cache", None) ( audio, new_up_conv_cache, new_bb_conv_cache1, new_bb_conv_cache2, new_bb_kv_cache, new_is_cache, ) = self.acoustic_decoder.forward_chunk( vq_out_feats, up_conv_cache, bb_conv_cache1, bb_conv_cache2, bb_kv_cache, is_cache, last_token, ) new_cache_dict = { "up_conv_cache": new_up_conv_cache, "bb_conv_cache1": new_bb_conv_cache1, "bb_conv_cache2": new_bb_conv_cache2, "bb_kv_cache": new_bb_kv_cache, "is_cache": new_is_cache, } return audio, new_cache_dict