import torch import torch.nn as nn import torch.nn.functional as F from fireredtts2.codec.whisper import WhisperEncoderLayer from fireredtts2.codec.utils import make_nonpad_mask, make_block_causal_mask class ResnetBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int = None, conv_shortcut: bool = False, dropout: float = 0.0, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.block1 = nn.Sequential( nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ), nn.SiLU(), nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), ) self.block2 = nn.Sequential( nn.GroupNorm( num_groups=32, num_channels=out_channels, eps=1e-6, affine=True ), nn.SiLU(), nn.Dropout(dropout), nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv1d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv1d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x: torch.Tensor): """ Args: x: shape (b, c, t) """ h = x h = self.block1(h) h = self.block2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class Transpose(torch.nn.Module): def __init__(self, dim0: int, dim1: int): super().__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x: torch.Tensor): x = torch.transpose(x, self.dim0, self.dim1) return x # A causal variant of Conv1d class CausalConv1d(torch.nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, ) -> None: super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size) self.causal_padding = (kernel_size - 1, 0) def forward(self, x: torch.Tensor): x = F.pad(x, self.causal_padding) x = super(CausalConv1d, self).forward(x) return x def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor = None): if cnn_cache is None: cnn_cache = x.new_zeros( (x.shape[0], self.in_channels, self.causal_padding[0]) ) x = torch.cat([cnn_cache, x], dim=2) new_cnn_cache = x[..., -self.causal_padding[0] :] x = super(CausalConv1d, self).forward(x) return x, new_cnn_cache # A causal variant of ResnetBlock class CausalResnetBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int = None, dropout: float = 0.0, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.block1 = nn.Sequential( Transpose(1, 2), nn.LayerNorm(in_channels), Transpose(1, 2), nn.SiLU(), CausalConv1d(in_channels, out_channels, kernel_size=3), ) self.block2 = nn.Sequential( Transpose(1, 2), nn.LayerNorm(out_channels), Transpose(1, 2), nn.SiLU(), nn.Dropout(dropout), CausalConv1d(out_channels, out_channels, kernel_size=3), ) if self.in_channels != self.out_channels: self.nin_shortcut = torch.nn.Conv1d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x: torch.Tensor): """ Args: x: shape (b, c, t) """ h = x h = self.block1(h) h = self.block2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h def forward_chunk(self, x: torch.Tensor, cache: torch.Tensor = None): """ Args: x: shape (b, c, t) cache: shape (b, c_in+c_out, t=2) """ cache1, cache2 = ( (None, None) if cache is None else cache.split((self.in_channels, self.out_channels), dim=1) ) h = x # block1 h = self.block1[:4](h) h, new_cache1 = self.block1[4].forward_chunk(h, cache1) # block2 h = self.block2[:5](h) h, new_cache2 = self.block2[5].forward_chunk(h, cache2) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) new_cache = torch.cat([new_cache1, new_cache2], dim=1) return x + h, new_cache # Nonstreaming Vocos backbone based on Transformer layers class VocosBackbone(nn.Module): def __init__( self, embed_dim: int = 1024, num_layers: int = 12, num_heads: int = 16, dropout: float = 0.1, ): super().__init__() self.in_proj = nn.Conv1d(embed_dim, embed_dim, kernel_size=7, padding=3) self.prior_net = nn.Sequential( ResnetBlock(embed_dim, embed_dim, dropout=dropout), ResnetBlock(embed_dim, embed_dim, dropout=dropout), ) self.transformers = nn.ModuleList( [WhisperEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)] ) self.post_net = nn.Sequential( ResnetBlock(embed_dim, embed_dim, dropout=dropout), ResnetBlock(embed_dim, embed_dim, dropout=dropout), ) self.final_norm = nn.LayerNorm(embed_dim, eps=1e-6) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, ): """ Args: x: shape (b, t, c) x_lens: shape (b,) """ x = x.transpose(1, 2) x = self.in_proj(x) x = self.prior_net(x) x = x.transpose(1, 2) attention_mask = make_nonpad_mask(x_lens).unsqueeze(1) # (b, 1, t) # NOTE(sfy): I think positional embedding is unnecessary for layer in self.transformers: x = layer(x, attention_mask) x = x.transpose(1, 2) x = self.post_net(x) x = x.transpose(1, 2) x = self.final_norm(x) return x # Streaming Vocos backbone based on Transformer layers class CausalVocosBackbone(nn.Module): def __init__( self, embed_dim: int = 1024, num_layers: int = 12, num_heads: int = 16, dropout: float = 0.1, ): super().__init__() self.in_proj = CausalConv1d(embed_dim, embed_dim, kernel_size=7) self.prior_net = nn.Sequential( CausalResnetBlock(embed_dim, embed_dim, dropout=dropout), CausalResnetBlock(embed_dim, embed_dim, dropout=dropout), ) self.transformers = nn.ModuleList( [WhisperEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)] ) self.post_net = nn.Sequential( CausalResnetBlock(embed_dim, embed_dim, dropout=dropout), CausalResnetBlock(embed_dim, embed_dim, dropout=dropout), ) self.final_norm = nn.LayerNorm(embed_dim, eps=1e-6) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, ): """ Args: x: shape (b, t, c) x_lens: shape (b,) """ x = x.transpose(1, 2) x = self.in_proj(x) x = self.prior_net(x) x = x.transpose(1, 2) # NOTE(sfy): We have no padding in training, so safe for sdpa attention, no Nan. # Also, 1 token(12.5Hz) -> 4 latents(50Hz) -> 8 latents(100Hz), # so we design a 8 block causal attention mask instead of fully causal to improve performance attention_mask = make_block_causal_mask(x_lens, chunk_size=8) for layer in self.transformers: x = layer(x, attention_mask) x = x.transpose(1, 2) x = self.post_net(x) x = x.transpose(1, 2) x = self.final_norm(x) return x def forward_chunk( self, x: torch.Tensor, conv_cache1: torch.Tensor = None, conv_cache2: torch.Tensor = None, kv_cache: torch.Tensor = None, ): # Unpack cache cache1 = conv_cache1 cache2, cache3, cache4, cache5 = ( (None, None, None, None) if conv_cache2 is None else conv_cache2.chunk(4, dim=1) ) # cache1: shape (b, c=embed_dim, t=6) x = x.transpose(1, 2) x, new_cache1 = self.in_proj.forward_chunk(x, cache1) # cache2: shape (b, c=embed_dim*2, t=2) x, new_cache2 = self.prior_net[0].forward_chunk(x, cache2) # cache3: shape (b, c=embed_dim*2, t=2) x, new_cache3 = self.prior_net[1].forward_chunk(x, cache3) x = x.transpose(1, 2) # k,v-cache: shape (b, nlayer, nh, t, c*2) new_kv_cache = [] for idx, layer in enumerate(self.transformers): kv_cache_i = None if kv_cache is None else kv_cache[:, idx] x, new_kv_cache_i = layer.forward_chunk(x, kv_cache=kv_cache_i) new_kv_cache.append(new_kv_cache_i) new_kv_cache = torch.stack(new_kv_cache, dim=1) x = x.transpose(1, 2) # cache4: shape (b, c=embed_dim*2, t=2) x, new_cache4 = self.post_net[0].forward_chunk(x, cache4) # cache5: shape (b, c=embed_dim*2, t=2) x, new_cache5 = self.post_net[1].forward_chunk(x, cache5) x = x.transpose(1, 2) x = self.final_norm(x) new_conv_cache1 = new_cache1 new_conv_cache2 = torch.cat( [new_cache2, new_cache3, new_cache4, new_cache5], dim=1 ) return x, new_conv_cache1, new_conv_cache2, new_kv_cache class ISTFT(nn.Module): """ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. See issue: https://github.com/pytorch/pytorch/issues/62323 Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. The NOLA constraint is met as we trim padded samples anyway. Args: n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames. win_length (int): The size of window frame and STFT filter. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__( self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" ): super().__init__() assert padding in ["center", "same"], "Padding must be 'center' or 'same'." self.padding = padding self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length window = torch.hann_window(win_length) self.register_buffer("window", window) def forward(self, spec: torch.Tensor) -> torch.Tensor: """ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. Args: spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, N is the number of frequency bins, and T is the number of time frames. Returns: Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. """ if self.padding == "center": # Fallback to pytorch native implementation return torch.istft( spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True, ) elif self.padding == "same": pad = (self.win_length - self.hop_length) // 2 else: raise ValueError("Padding must be 'center' or 'same'.") assert spec.dim() == 3, "Expected a 3D tensor as input" B, N, T = spec.shape # Inverse FFT ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") ifft = ifft * self.window[None, :, None] # Overlap and Add output_size = (T - 1) * self.hop_length + self.win_length y = torch.nn.functional.fold( ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), )[:, 0, 0, pad:-pad] # Window envelope window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) window_envelope = torch.nn.functional.fold( window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), ).squeeze()[pad:-pad] # Normalize assert (window_envelope > 1e-11).all() y = y / window_envelope return y def forward_chunk( self, spec: torch.Tensor, cache: torch.Tensor = None, last_chunk: bool = False ): """Forward only one frame. Args: spec: shape (B, N, T=chunk_size) cache: previous chunk's last ifft frame, shape (B, N, T=3) last_chunk: if last_chunk, will not trim the last (win-hop) segment Returns: y: shape (B, T=effective_length) """ assert self.padding == "same", "Padding must be same." assert ( self.win_length % self.hop_length == 0 ), f"{self.win_length} {self.hop_length}" pad = (self.win_length - self.hop_length) // 2 # Inverse FFT ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") ifft = ifft * self.window[None, :, None] # (B, N, T=chunk_size) # Append previous cache if cache is not None: ifft = torch.cat([cache, ifft], dim=-1) new_cache_t = self.win_length // self.hop_length - 1 new_cache = ifft[..., -new_cache_t:] # Overlap and Add output_size = (ifft.shape[-1] - 1) * self.hop_length + self.win_length y = torch.nn.functional.fold( ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), )[:, 0, 0, :] # Window envelope window_sq = ( self.window.square().expand(1, ifft.shape[-1], -1).transpose(1, 2) ) # (B=1, N, T) window_envelope = torch.nn.functional.fold( window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), ).squeeze() # Normalize # assert (window_envelope > 1e-11).all() y = y / window_envelope # Only take effective part if cache is None: y = y[:, pad:] else: y = y[:, (self.win_length - self.hop_length) :] if last_chunk: y = y[:, :-pad] else: y = y[:, : -(self.win_length - self.hop_length)] return y, new_cache class ISTFTHead(nn.Module): """ ISTFT Head module for predicting STFT complex coefficients. Args: dim (int): Hidden dimension of the model. n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames, which should align with the resolution of the input features. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): super().__init__() self.hop_length = hop_length out_dim = n_fft + 2 self.out = torch.nn.Linear(dim, out_dim) self.istft = ISTFT( n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding ) def forward(self, x: torch.Tensor, x_len: torch.Tensor) -> torch.Tensor: """ Forward pass of the ISTFTHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x_pred = self.out(x) x_pred = x_pred.transpose(1, 2) mag, p = x_pred.chunk(2, dim=1) mag = torch.exp(mag) mag = torch.clip( mag, max=1e2 ) # safeguard to prevent excessively large magnitudes # wrapping happens here. These two lines produce real and imaginary value x = torch.cos(p) y = torch.sin(p) # recalculating phase here does not produce anything new # only costs time # phase = torch.atan2(y, x) # S = mag * torch.exp(phase * 1j) # better directly produce the complex value S = mag * (x + 1j * y) audio = self.istft(S) audio_length = x_len * self.hop_length return audio, audio_length def forward_chunk( self, x: torch.Tensor, cache: torch.Tensor = None, last_chunk: bool = False ): """ISTFTHead can be adapted in streaming inference without retraining. Args: x: shape (B, T, C) cache: shape (B, N, T=3), istft cache Returns: audio: shape (B, t) """ x_pred = self.out(x) x_pred = x_pred.transpose(1, 2) mag, p = x_pred.chunk(2, dim=1) mag = torch.exp(mag) # (B, C, T) mag = torch.clip( mag, max=1e2 ) # safeguard to prevent excessively large magnitudes # wrapping happens here. These two lines produce real and imaginary value x = torch.cos(p) y = torch.sin(p) S = mag * (x + 1j * y) # (B, C, T) audio, new_cache = self.istft.forward_chunk(S, cache, last_chunk) return audio, new_cache # UpsampleConv(50->100Hz) + VocosBackbone + ISTFTHead class AcousticDecoder(nn.Module): def __init__( self, # Transformer embed_dim: int, num_layers: int, num_heads: int, dropout: float = 0.0, # iSTFT hop_length: int = 240, # Causal causal: bool = False, ): super().__init__() self.embed_dim = embed_dim self.num_layers = num_layers self.num_heads = num_heads self.hop_length = hop_length self.causal = causal # Output upsample self.upsample_conv = nn.Sequential( nn.ConvTranspose1d( embed_dim, embed_dim, kernel_size=3, stride=2, padding=0, # Do not fill input side output_padding=0, # Can be adjusted to precisely control length ), nn.GELU(), nn.ConvTranspose1d( embed_dim, embed_dim, kernel_size=3, stride=1, padding=0, # Do not fill input side ), nn.GELU(), ) self.backbone = ( CausalVocosBackbone(embed_dim, num_layers, num_heads, dropout) if causal else VocosBackbone(embed_dim, num_layers, num_heads, dropout) ) self.isift = ISTFTHead(embed_dim, hop_length * 4, hop_length, padding="same") # Init weights self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor, x_lens: torch.Tensor): """ Args: x: shape (b, t, c) x_lens: shape (b,) """ # Upsample target_length = x.shape[1] * 2 x = x.transpose(1, 2) x = self.upsample_conv(x) x = x.transpose(1, 2) # NOTE strict upsampling, trim the last 3 elements x = x[:, :target_length] x_lens = x_lens * 2 # Backbone x = self.backbone(x, x_lens) # iSTFT y, y_lens = self.isift(x, x_lens) return y, y_lens def forward_upsample_conv_chunk(self, x: torch.Tensor, cache: torch.Tensor = None): """Stream forward upsample_conv module with previous block cache. Args: x: shape (B, C, T) cache: shape (B, C, 3), where 3 denotes 1 history state for 1st conv and 2 for the rest conv. """ # Unpack cache cache1, cache2 = ( (None, None) if cache is None else torch.split(cache, [1, 2], dim=2) ) # 1st conv cache if cache1 is not None: x = torch.cat([cache1, x], dim=2) new_cache1 = x[..., -1:] # 1st conv x = self.upsample_conv[0](x)[..., :-1] # remove extra 1 frame if cache1 is not None: x = x[..., 2:] # remove cache1 part x = self.upsample_conv[1](x) # 2nd conv cache if cache2 is not None: x = torch.cat([cache2, x], dim=2) new_cache2 = x[..., -2:] # 2nd conv x = self.upsample_conv[2](x)[..., :-2] # remove extra 2 frame if cache2 is not None: x = x[..., 2:] # remove cache2 part x = self.upsample_conv[3](x) new_cache = torch.cat([new_cache1, new_cache2], dim=2) return x, new_cache def forward_chunk( self, x: torch.Tensor, # Upsample conv cache up_conv_cache: torch.Tensor = None, # Backbone conv cache bb_conv_cache1: torch.Tensor = None, bb_conv_cache2: torch.Tensor = None, # Backbone attention cache bb_kv_cache: torch.Tensor = None, # iSTFT cache is_cache: torch.Tensor = None, last_chunk: bool = False, ): """ Args: x: input sequence at 50Hz, length should be multiples of 4 """ assert ( self.causal ), "Only AcousticDecoder with causal=True supports forward_chunk method." x = x.transpose(1, 2) x, new_up_conv_cache = self.forward_upsample_conv_chunk(x, up_conv_cache) x = x.transpose(1, 2) # Backbone x, new_bb_conv_cache1, new_bb_conv_cache2, new_bb_kv_cache = ( self.backbone.forward_chunk( x, bb_conv_cache1, bb_conv_cache2, bb_kv_cache, ) ) # iSTFT y, new_is_cache = self.isift.forward_chunk(x, is_cache, last_chunk) return ( y, new_up_conv_cache, new_bb_conv_cache1, new_bb_conv_cache2, new_bb_kv_cache, new_is_cache, )