""" WavCoch: waveform-to-cochleagram encoder with an LFQ bottleneck. Transforming waveforms to cochleagrams ("Transformation Imitation"). """ import math from math import log2, ceil import tqdm from transformers.tokenization_utils import BatchEncoding from transformers import PreTrainedModel from functools import partial, cache from collections import namedtuple from contextlib import nullcontext import torch import torch.nn as nn import torch.distributed as dist from torch.distributed import nn as dist_nn from torch import nn, einsum import torch.nn.functional as F from torch.nn import Module from torch.amp import autocast from .configuration_wavcoch import WavCochConfig ######################################## ### Cochleagram Transform ### ######################################## class CochleagramTransform: def __init__( self, sr: int = 16000, signal_size: int = 16000 * 5, # set default signal size to 5 sec @ 16khz device: str = 'cpu', batch_mode: bool = False, return_on_cpu: bool = True, ): # try: # import chcochleagram # except: # print("""The cochleagram library is required to perform inversion, please instlal it with: # pip install git+https://github.com/jenellefeather/chcochleagram.git""") # return None self.sr = sr self.device = device self.batch_mode = batch_mode self.return_on_cpu = return_on_cpu self.cochleagram_fn = self._init_cochleagram_fn(signal_size=signal_size) def cochleagram(self, audio: torch.Tensor) -> torch.Tensor: """ Compute the cochleagram of the audio waveform. From Jenelle Feather: chcochleagram """ # move audio to specified device audio = audio.to(self.device) cochleagram = self.cochleagram_fn(audio) # (batch, n_channels, n_timesteps) # Transpose the chochleagram such that n_timesteps x n_channels cochleagram = cochleagram.permute(0, 2, 1) # Check for nan values if torch.isnan(cochleagram).any(): raise ValueError('Cochleagram contains nan values') # Move cochleagram back to cpu to match the semantics of previous dataloader # Maybe this can be improved in the future but it does not seem to make a big # difference in terms of performance so far if self.return_on_cpu: cochleagram = cochleagram.to('cpu') # This is a bit silly, but if the cochleagram has batch size of 1 we squeeze it # in order to match the semantics of the previous dataloaders if cochleagram.shape[0] == 1 and not self.batch_mode: cochleagram = cochleagram.squeeze(0) return cochleagram def __call__(self, audio: torch.Tensor) -> torch.Tensor: return self.cochleagram(audio) def _init_cochleagram_fn( self, pad_factor: int = 1.5, use_rfft: bool = True, signal_size: int = 16000 * 5, # set default signal size to 5 sec @ 16khz ): ### Define the cochlear filters using ERBCosFilters. # These are the arguments used for filter construction of ERBCosFilters. See helpers/erb_filters.py for # more documentation. half_cos_filter_kwargs = { 'n': 50, # Number of filters to evenly tile the space 'low_lim': 50, # Lowest center frequency for full filter (if lowpass filters are used they can be centered lower) 'high_lim': 8000, # Highest center frequency 'sample_factor': 4, # Positive integer that determines how densely ERB function will be sampled 'full_filter': False, # Whether to use the full-filter. Must be False if rFFT is true. } coch_filter_kwargs = { 'use_rfft': use_rfft, # Whether to use rFFT or not 'pad_factor': pad_factor, # How much to pad the signal 'filter_kwargs': half_cos_filter_kwargs} ### Define an envelope extraction operation # Use the analytic amplitude of the hilbert transform here. Other types of envelope extraction # are also implemented in envelope_extraction.py. Can use Identity if want the raw subbands. envelope_extraction = chcochleagram.envelope_extraction.HilbertEnvelopeExtraction(signal_size=signal_size, sr=self.sr, use_rfft=use_rfft, pad_factor=pad_factor) # This (and most) cochleagrams use ERBCosFilters, however other types of filterbanks can be # constructed for linear spaced filters or different shapes. Make a new CochlearFilter class for # these. filters = chcochleagram.cochlear_filters.ERBCosFilters(signal_size=signal_size, sr=self.sr, **coch_filter_kwargs) ### Define a downsampling operation # Downsample the extracted envelopes. Can use Identity if want the raw subbands. env_sr = 200 # Sampling rate after downsampling downsampling_kwargs = {'window_size': 1001} # Parameters for the downsampling filter (see downsampling.py) downsampling_op = chcochleagram.downsampling.SincWithKaiserWindow(sr=self.sr, env_sr=env_sr, **downsampling_kwargs) ### Define a compression operation. compression_kwargs = {'power': 0.3, # Power compression of 0.3 'offset': 1e-8, # Offset for numerical stability in backwards pass 'scale': 1, # Optional multiplicative value applied to the envelopes before compression 'clip_value': 100} # Clip the gradients for this compression for stability compression = chcochleagram.compression.ClippedGradPowerCompression(**compression_kwargs) cochleagram_fn = chcochleagram.cochleagram.Cochleagram(filter_object=filters, envelope_extraction=envelope_extraction, downsampling=downsampling_op, compression=compression) # Move cochleagram_fn to the specified device cochleagram_fn = cochleagram_fn.to(self.device) return cochleagram_fn ######################################## ### LFQ Definition ### ######################################## """ Lookup Free Quantization Proposed in https://arxiv.org/abs/2310.05737 Adapted from vector-quantize-pytorch https://github.com/lucidrains/vector-quantize-pytorch In the simplest setup, each dimension is quantized into {-1, 1}. An entropy penalty is used to encourage utilization. """ # constants Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss']) LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment']) # distributed helpers @cache def is_distributed(): return dist.is_initialized() and dist.get_world_size() > 1 def maybe_distributed_mean(t): if not is_distributed(): return t dist_nn.all_reduce(t) t = t / dist.get_world_size() return t # helper functions def exists(v): return v is not None def identity(t): return t def default(*args): for arg in args: if exists(arg): return arg() if callable(arg) else arg return None def pack_one(tensor: torch.Tensor, pattern: str): """ Packs a single tensor by flattening all axes matched by '*' into one. Returns (packed_tensor, packed_shapes), where packed_shapes is a list of one tuple describing the original wildcard dims. """ tokens = pattern.split() if '*' not in tokens: raise ValueError("Pattern must contain a '*' wildcard axis") idx = tokens.index('*') n_before = idx n_after = len(tokens) - idx - 1 shape = tensor.shape # split original shape into before / wildcard / after if n_after: before = shape[:n_before] wildcard = shape[n_before:-n_after] after = shape[-n_after:] else: before = shape[:n_before] wildcard = shape[n_before:] after = () # compute flattened size and reshape flat = 1 for d in wildcard: flat *= d new_shape = before + (flat,) + after packed = tensor.reshape(new_shape) # return list-of-shapes so unpack_one can use the same interface return packed, [tuple(wildcard)] def unpack_one(packed: torch.Tensor, ps: list, pattern: str): """ Reverses pack_one on a single tensor. `ps` should be the list-of-shapes returned by pack_one. """ tokens = pattern.split() if '*' not in tokens: raise ValueError("Pattern must contain a '*' wildcard axis") idx = tokens.index('*') n_before = idx n_after = len(tokens) - idx - 1 shape = packed.shape # extract the wildcard shape that was saved wildcard = tuple(ps[0]) # split packed shape into before/flat/after if n_after: before = shape[:n_before] after = shape[-n_after:] else: before = shape[:n_before] after = () orig_shape = before + wildcard + after return packed.reshape(orig_shape) def l2norm(t): return F.normalize(t, dim = -1) # entropy def log(t, eps = 1e-5): return t.clamp(min = eps).log() def entropy(prob): return (-prob * log(prob)).sum(dim=-1) # cosine sim linear class CosineSimLinear(Module): def __init__( self, dim_in, dim_out, scale = 1. ): super().__init__() self.scale = scale self.weight = nn.Parameter(torch.randn(dim_in, dim_out)) def forward(self, x): x = F.normalize(x, dim = -1) w = F.normalize(self.weight, dim = 0) return (x @ w) * self.scale # class class LFQ(Module): def __init__( self, *, dim = None, codebook_size = None, entropy_loss_weight = 0.1, commitment_loss_weight = 0., diversity_gamma = 1., straight_through_activation = nn.Identity(), num_codebooks = 1, keep_num_codebooks_dim = None, codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy has_projections = None, projection_has_bias = True, soft_clamp_input_value = None, cosine_sim_project_in = False, cosine_sim_project_in_scale = None, channel_first = None, experimental_softplus_entropy_loss = False, entropy_loss_offset = 5., # how much to shift the loss before softplus spherical = False, # from https://arxiv.org/abs/2406.07548 force_quantization_f32 = True # will force the quantization step to be full precision ): super().__init__() # some assert validations assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ' assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})' codebook_size = default(codebook_size, lambda: 2 ** dim) self.codebook_size = codebook_size codebook_dim = int(log2(codebook_size)) codebook_dims = codebook_dim * num_codebooks dim = default(dim, codebook_dims) has_projections = default(has_projections, dim != codebook_dims) if cosine_sim_project_in: cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale) project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in) else: project_in_klass = partial(nn.Linear, bias = projection_has_bias) self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity() self.has_projections = has_projections self.dim = dim self.codebook_dim = codebook_dim self.num_codebooks = num_codebooks keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) assert not (num_codebooks > 1 and not keep_num_codebooks_dim) self.keep_num_codebooks_dim = keep_num_codebooks_dim # channel first self.channel_first = channel_first # straight through activation self.activation = straight_through_activation # whether to use BSQ (binary spherical quantization) self.spherical = spherical self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity # entropy aux loss related weights assert 0 < frac_per_sample_entropy <= 1. self.frac_per_sample_entropy = frac_per_sample_entropy self.diversity_gamma = diversity_gamma self.entropy_loss_weight = entropy_loss_weight # codebook scale self.codebook_scale = codebook_scale # commitment loss self.commitment_loss_weight = commitment_loss_weight # whether to soft clamp the input value from -value to value self.soft_clamp_input_value = soft_clamp_input_value assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions) self.entropy_loss_offset = entropy_loss_offset self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss # for no auxiliary loss, during inference self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1)) self.register_buffer('zero', torch.tensor(0.), persistent = False) # whether to force quantization step to be f32 self.force_quantization_f32 = force_quantization_f32 # codes all_codes = torch.arange(codebook_size) bits = ((all_codes[..., None].int() & self.mask) != 0).float() codebook = self.bits_to_codes(bits) self.register_buffer('codebook', codebook.float(), persistent = False) def bits_to_codes(self, bits): return bits * self.codebook_scale * 2 - self.codebook_scale @property def dtype(self): return self.codebook.dtype def indices_to_codes( self, indices, project_out = True ): is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) should_transpose = default(self.channel_first, is_img_or_video) if not self.keep_num_codebooks_dim: # append a singleton dimension at the end indices = indices.unsqueeze(-1) # indices to codes, which are bits of either -1 or 1 bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) codes = self.bits_to_codes(bits) codes = self.maybe_l2norm(codes) codes = codes.flatten(-2, -1) # whether to project codes out to original dimensions # if the input feature dimensions were not log2(codebook size) if project_out: codes = self.project_out(codes) # move codes back to original shape if should_transpose: codes = codes.movedim(-1, 1) return codes def forward( self, x, inv_temperature = 100., return_loss_breakdown = False, mask = None, ): """ einstein notation b - batch n - sequence (or flattened spatial dimensions) d - feature dimension, which is also log2(codebook size) c - number of codebook dim """ is_img_or_video = x.ndim >= 4 should_transpose = default(self.channel_first, is_img_or_video) # standardize image or video into (batch, seq, dimension) if should_transpose: x = x.movedim(1, -1) x, ps = pack_one(x, 'b * d') assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}' x = self.project_in(x) # maybe soft clamp if exists(self.soft_clamp_input_value): clamp_value = self.soft_clamp_input_value x = (x / clamp_value).tanh() * clamp_value # split out number of codebooks x = x.reshape(*x.shape[:2], self.num_codebooks, -1) # maybe l2norm x = self.maybe_l2norm(x) # whether to force quantization step to be full precision or not force_f32 = self.force_quantization_f32 quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext with quantization_context(): if force_f32: orig_dtype = x.dtype x = x.float() # quantize by eq 3. original_input = x codebook_value = torch.ones_like(x) * self.codebook_scale quantized = torch.where(x > 0, codebook_value, -codebook_value) # calculate indices t = (quantized > 0).int() * self.mask.int() indices = t.sum(dim=-1) quantized = self.maybe_l2norm(quantized) # use straight-through gradients (optionally with custom activation fn) if training if self.training: x = self.activation(x) x = x + (quantized - x).detach() else: x = quantized # entropy aux loss if self.training: if force_f32: codebook = self.codebook.float() codebook = self.maybe_l2norm(codebook) # whether to only use a fraction of probs, for reducing memory input_for_entropy = original_input if exists(mask): input_for_entropy = original_input[mask] input_for_entropy = input_for_entropy.flatten(0, 1) if self.frac_per_sample_entropy < 1.: # account for mask num_tokens = input_for_entropy.size(0) num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy) rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens sampled_input = input_for_entropy[rand_mask] sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook) sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1) per_sample_probs = sampled_prob else: # the same as euclidean distance up to a constant distance = -2 * einsum('... i d, j d -> ... i j', input_for_entropy, codebook) prob = (-distance * inv_temperature).softmax(dim = -1) per_sample_probs = prob # calculate per sample entropy per_sample_entropy = entropy(per_sample_probs).mean() # distribution over all available tokens in the batch avg_prob = (per_sample_probs .flatten(start_dim=0, end_dim=-3) .mean(dim=0)) avg_prob = maybe_distributed_mean(avg_prob) codebook_entropy = entropy(avg_prob).mean() # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy else: # if not training, just return dummy 0 entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero # whether to make the entropy loss positive or not through a (shifted) softplus if self.training and self.experimental_softplus_entropy_loss: entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset) # commit loss if self.training and self.commitment_loss_weight > 0.: commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none') if exists(mask): commit_loss = commit_loss[mask] commit_loss = commit_loss.mean() else: commit_loss = self.zero # input back to original dtype if needed if force_f32: x = x.type(orig_dtype) # merge back codebook dim x = x.flatten(2, 3) # project out to feature dimension if needed x = self.project_out(x) # reconstitute image or video dimensions if should_transpose: x = unpack_one(x, ps, 'b * d') x = x.movedim(-1, 1) indices = unpack_one(indices, ps, 'b * c') # whether to remove single codebook dim if not self.keep_num_codebooks_dim: indices = indices.squeeze(-1) # complete aux loss aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight # returns ret = Return(x, indices, aux_loss) if not return_loss_breakdown: return ret return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss) #################$$$$$$$################ ### Quantizer Model ### ################$$$$$$################## class WavCoch(PreTrainedModel): config_class = WavCochConfig def __init__(self, config): super().__init__(config) self.N = config.window_size self.hop_length = config.hop_length # Initial frequency transform convolutions self.conv_real_filters = nn.Conv1d(1, self.N // 2 + 1, kernel_size=self.N, stride=self.hop_length) self.conv_imag_filters = nn.Conv1d(1, self.N // 2 + 1, kernel_size=self.N, stride=self.hop_length) self._initialize_conv_filters() # Configurable encoder and decoder layers self.encoder = self._build_conv_block( in_channels=self.N // 2 + 1, out_channels=config.encoder_dim, num_layers=config.encoder_layers, kernel_size=config.encoder_kernel_size ) self.quantizer = LFQ( codebook_size=config.codebook_size, dim=config.encoder_dim, num_codebooks=1, entropy_loss_weight=config.entropy_loss_weight, commitment_loss_weight=config.commit_loss_weight, diversity_gamma=config.diversity_gamma, ) self.decoder = self._build_conv_block( in_channels=config.decoder_dim, out_channels=211, num_layers=config.decoder_layers, kernel_size=config.decoder_kernel_size ) def _build_conv_block(self, in_channels, out_channels, num_layers, kernel_size=9): """Creates a block of convolutional layers with residual connections.""" layers = [] for i in range(num_layers): conv_layer = nn.Conv1d( in_channels if i == 0 else out_channels, out_channels, kernel_size=kernel_size, stride=1, padding='same' ) layers.extend([ conv_layer, nn.ReLU(), ]) return nn.Sequential(*layers) def _compute_twiddle_factors(self): n = torch.arange(self.N).unsqueeze(1) k = torch.arange(self.N).unsqueeze(0) angles = -2 * math.pi * n * k / self.N return torch.cos(angles), torch.sin(angles) # Real and imaginary parts def _initialize_conv_filters(self): twiddle_factors_real, twiddle_factors_imag = self._compute_twiddle_factors() twiddle_factors_real = twiddle_factors_real[:self.N // 2 + 1, :] twiddle_factors_imag = twiddle_factors_imag[:self.N // 2 + 1, :] window = torch.hann_window(self.N).view(1, 1, -1) conv_real_filters = twiddle_factors_real.unsqueeze(1) * window conv_imag_filters = twiddle_factors_imag.unsqueeze(1) * window self.conv_real_filters.weight = nn.Parameter(conv_real_filters) self.conv_imag_filters.weight = nn.Parameter(conv_imag_filters) @property def vocab_size(self): return 8192 def forward(self, wav, coch=None, return_tensors="pt", sample_rate=16000, pad=True): if coch is None: # # if coch is a 1D input # if len(wav.shape) == 1: # wav = wav.unsqueeze(0).unsqueeze(0) # Handle all input formats if isinstance(wav, list): # List[Tensor[T]] → pad to [B, T], then unsqueeze to [B, 1, T] wav = [w.unsqueeze(0) if w.ndim == 1 else w for w in wav] # make [1, T] wav = torch.nn.utils.rnn.pad_sequence(wav, batch_first=True) # [B, T] wav = wav.unsqueeze(1) # [B, 1, T] elif isinstance(wav, torch.Tensor): if wav.ndim == 1: wav = wav.unsqueeze(0).unsqueeze(0) # [1, 1, T] elif wav.ndim == 2: wav = wav.unsqueeze(1) # [B, T] → [B, 1, T] elif wav.ndim != 3: raise ValueError(f"Unexpected tensor shape {wav.shape}, expected 1D, 2D or 3D.") else: raise TypeError(f"Unsupported input type: {type(wav)}") # pad input waveform to correct for cutoff performed by cochleagram if pad: wav = F.pad(wav, (self.N - self.hop_length, 0), mode='constant', value=0) # quantize audio codes = self.quantize(wav) return BatchEncoding({ "input_values": codes, "input_ids": codes, }) with torch.no_grad(): real_part = self.conv_real_filters(wav) imag_part = self.conv_imag_filters(wav) x = real_part + imag_part x = self.encoder(x) x = x.permute(0, 2, 1) quantized, indices, entropy_aux_loss = self.quantizer(x) mel_spectrogram = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1) loss = F.mse_loss(mel_spectrogram, coch) return mel_spectrogram, loss, entropy_aux_loss def quantize(self, wav): with torch.no_grad(): real_part = self.conv_real_filters(wav) imag_part = self.conv_imag_filters(wav) x = real_part + imag_part x = self.encoder(x) x = x.permute(0, 2, 1) quantized, indices, _ = self.quantizer(x) return indices # def quantize(self, wav, pad=True): # # Pad the audio waveform if necessary # if pad: # wav = F.pad(wav, (self.N - self.hop_length, 0), mode='constant', value=0) # with torch.no_grad(): # real_part = self.conv_real_filters(wav) # imag_part = self.conv_imag_filters(wav) # x = real_part + imag_part # x = self.encoder(x) # x = x.permute(0, 2, 1) # # Quantization # quantized, indices, entropy_aux_loss = self.quantizer(x) # return indices def decode(self, indices): emb = self.quantizer.indices_to_codes(indices) mel_spectrogram = self.decoder(emb.permute(0, 2, 1)).permute(0, 2, 1) return mel_spectrogram def wav2coch(self, wav): with torch.no_grad(): real_part = self.conv_real_filters(wav) imag_part = self.conv_imag_filters(wav) x = real_part + imag_part x = self.encoder(x) x = x.permute(0, 2, 1) quantized, indices, _ = self.quantizer(x) mel_spectrogram = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1) return mel_spectrogram def invert_cochleagram_to_audio( self, cochleagram, device, num_optim_steps=1000, lr=1e-2, transform_cls=CochleagramTransform ): """ Function to invert a cochleagram back to audio using gradient descent """ # Initialize the transform function transform = transform_cls(sr=16000, signal_size=16000*5, device=device, return_on_cpu=False) # Initialize the audio to be optimized audio = torch.randn(1, 1, 16000*5).to(device).requires_grad_() # Define the optimizer optimizer = torch.optim.Adam([audio], lr=lr) # Define the loss function criterion = torch.nn.MSELoss() # Initialize tqdm progress bar with tqdm.tqdm(total=num_optim_steps, desc="Inverting the cochleagram") as pbar: # Invert the cochleagram for _ in range(num_optim_steps): optimizer.zero_grad() # Compute the cochleagram from the audio pred_coch = transform(audio[0]) # Compute the loss loss = criterion(pred_coch, cochleagram) # Backpropagate the loss loss.backward() # Update the audio optimizer.step() # Update the progress bar pbar.set_postfix(loss=loss.item()) pbar.update(1) return audio