WavCochV8192 / modeling_wavcoch.py
klemenk's picture
Update modeling_wavcoch.py
4785b04 verified
"""
WavCoch: waveform-to-cochleagram encoder with an LFQ bottleneck.
Transforming waveforms to cochleagrams ("Transformation Imitation").
"""
import math
from math import log2, ceil
import tqdm
from transformers.tokenization_utils import BatchEncoding
from transformers import PreTrainedModel
from functools import partial, cache
from collections import namedtuple
from contextlib import nullcontext
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed import nn as dist_nn
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module
from torch.amp import autocast
from .configuration_wavcoch import WavCochConfig
########################################
### Cochleagram Transform ###
########################################
class CochleagramTransform:
def __init__(
self,
sr: int = 16000,
signal_size: int = 16000 * 5, # set default signal size to 5 sec @ 16khz
device: str = 'cpu',
batch_mode: bool = False,
return_on_cpu: bool = True,
):
# try:
# import chcochleagram
# except:
# print("""The cochleagram library is required to perform inversion, please instlal it with:
# pip install git+https://github.com/jenellefeather/chcochleagram.git""")
# return None
self.sr = sr
self.device = device
self.batch_mode = batch_mode
self.return_on_cpu = return_on_cpu
self.cochleagram_fn = self._init_cochleagram_fn(signal_size=signal_size)
def cochleagram(self, audio: torch.Tensor) -> torch.Tensor:
"""
Compute the cochleagram of the audio waveform.
From Jenelle Feather: chcochleagram
"""
# move audio to specified device
audio = audio.to(self.device)
cochleagram = self.cochleagram_fn(audio) # (batch, n_channels, n_timesteps)
# Transpose the chochleagram such that n_timesteps x n_channels
cochleagram = cochleagram.permute(0, 2, 1)
# Check for nan values
if torch.isnan(cochleagram).any():
raise ValueError('Cochleagram contains nan values')
# Move cochleagram back to cpu to match the semantics of previous dataloader
# Maybe this can be improved in the future but it does not seem to make a big
# difference in terms of performance so far
if self.return_on_cpu:
cochleagram = cochleagram.to('cpu')
# This is a bit silly, but if the cochleagram has batch size of 1 we squeeze it
# in order to match the semantics of the previous dataloaders
if cochleagram.shape[0] == 1 and not self.batch_mode:
cochleagram = cochleagram.squeeze(0)
return cochleagram
def __call__(self, audio: torch.Tensor) -> torch.Tensor:
return self.cochleagram(audio)
def _init_cochleagram_fn(
self,
pad_factor: int = 1.5,
use_rfft: bool = True,
signal_size: int = 16000 * 5, # set default signal size to 5 sec @ 16khz
):
### Define the cochlear filters using ERBCosFilters.
# These are the arguments used for filter construction of ERBCosFilters. See helpers/erb_filters.py for
# more documentation.
half_cos_filter_kwargs = {
'n': 50, # Number of filters to evenly tile the space
'low_lim': 50,
# Lowest center frequency for full filter (if lowpass filters are used they can be centered lower)
'high_lim': 8000, # Highest center frequency
'sample_factor': 4, # Positive integer that determines how densely ERB function will be sampled
'full_filter': False, # Whether to use the full-filter. Must be False if rFFT is true.
}
coch_filter_kwargs = {
'use_rfft': use_rfft, # Whether to use rFFT or not
'pad_factor': pad_factor, # How much to pad the signal
'filter_kwargs': half_cos_filter_kwargs}
### Define an envelope extraction operation
# Use the analytic amplitude of the hilbert transform here. Other types of envelope extraction
# are also implemented in envelope_extraction.py. Can use Identity if want the raw subbands.
envelope_extraction = chcochleagram.envelope_extraction.HilbertEnvelopeExtraction(signal_size=signal_size,
sr=self.sr,
use_rfft=use_rfft,
pad_factor=pad_factor)
# This (and most) cochleagrams use ERBCosFilters, however other types of filterbanks can be
# constructed for linear spaced filters or different shapes. Make a new CochlearFilter class for
# these.
filters = chcochleagram.cochlear_filters.ERBCosFilters(signal_size=signal_size,
sr=self.sr,
**coch_filter_kwargs)
### Define a downsampling operation
# Downsample the extracted envelopes. Can use Identity if want the raw subbands.
env_sr = 200 # Sampling rate after downsampling
downsampling_kwargs = {'window_size': 1001} # Parameters for the downsampling filter (see downsampling.py)
downsampling_op = chcochleagram.downsampling.SincWithKaiserWindow(sr=self.sr, env_sr=env_sr, **downsampling_kwargs)
### Define a compression operation.
compression_kwargs = {'power': 0.3, # Power compression of 0.3
'offset': 1e-8, # Offset for numerical stability in backwards pass
'scale': 1, # Optional multiplicative value applied to the envelopes before compression
'clip_value': 100} # Clip the gradients for this compression for stability
compression = chcochleagram.compression.ClippedGradPowerCompression(**compression_kwargs)
cochleagram_fn = chcochleagram.cochleagram.Cochleagram(filter_object=filters,
envelope_extraction=envelope_extraction,
downsampling=downsampling_op,
compression=compression)
# Move cochleagram_fn to the specified device
cochleagram_fn = cochleagram_fn.to(self.device)
return cochleagram_fn
########################################
### LFQ Definition ###
########################################
"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737
Adapted from vector-quantize-pytorch https://github.com/lucidrains/vector-quantize-pytorch
In the simplest setup, each dimension is quantized into {-1, 1}.
An entropy penalty is used to encourage utilization.
"""
# constants
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
# distributed helpers
@cache
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1
def maybe_distributed_mean(t):
if not is_distributed():
return t
dist_nn.all_reduce(t)
t = t / dist.get_world_size()
return t
# helper functions
def exists(v):
return v is not None
def identity(t):
return t
def default(*args):
for arg in args:
if exists(arg):
return arg() if callable(arg) else arg
return None
def pack_one(tensor: torch.Tensor, pattern: str):
"""
Packs a single tensor by flattening all axes matched by '*' into one.
Returns (packed_tensor, packed_shapes), where packed_shapes is a list
of one tuple describing the original wildcard dims.
"""
tokens = pattern.split()
if '*' not in tokens:
raise ValueError("Pattern must contain a '*' wildcard axis")
idx = tokens.index('*')
n_before = idx
n_after = len(tokens) - idx - 1
shape = tensor.shape
# split original shape into before / wildcard / after
if n_after:
before = shape[:n_before]
wildcard = shape[n_before:-n_after]
after = shape[-n_after:]
else:
before = shape[:n_before]
wildcard = shape[n_before:]
after = ()
# compute flattened size and reshape
flat = 1
for d in wildcard:
flat *= d
new_shape = before + (flat,) + after
packed = tensor.reshape(new_shape)
# return list-of-shapes so unpack_one can use the same interface
return packed, [tuple(wildcard)]
def unpack_one(packed: torch.Tensor, ps: list, pattern: str):
"""
Reverses pack_one on a single tensor.
`ps` should be the list-of-shapes returned by pack_one.
"""
tokens = pattern.split()
if '*' not in tokens:
raise ValueError("Pattern must contain a '*' wildcard axis")
idx = tokens.index('*')
n_before = idx
n_after = len(tokens) - idx - 1
shape = packed.shape
# extract the wildcard shape that was saved
wildcard = tuple(ps[0])
# split packed shape into before/flat/after
if n_after:
before = shape[:n_before]
after = shape[-n_after:]
else:
before = shape[:n_before]
after = ()
orig_shape = before + wildcard + after
return packed.reshape(orig_shape)
def l2norm(t):
return F.normalize(t, dim = -1)
# entropy
def log(t, eps = 1e-5):
return t.clamp(min = eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)
# cosine sim linear
class CosineSimLinear(Module):
def __init__(
self,
dim_in,
dim_out,
scale = 1.
):
super().__init__()
self.scale = scale
self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
def forward(self, x):
x = F.normalize(x, dim = -1)
w = F.normalize(self.weight, dim = 0)
return (x @ w) * self.scale
# class
class LFQ(Module):
def __init__(
self,
*,
dim = None,
codebook_size = None,
entropy_loss_weight = 0.1,
commitment_loss_weight = 0.,
diversity_gamma = 1.,
straight_through_activation = nn.Identity(),
num_codebooks = 1,
keep_num_codebooks_dim = None,
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
has_projections = None,
projection_has_bias = True,
soft_clamp_input_value = None,
cosine_sim_project_in = False,
cosine_sim_project_in_scale = None,
channel_first = None,
experimental_softplus_entropy_loss = False,
entropy_loss_offset = 5., # how much to shift the loss before softplus
spherical = False, # from https://arxiv.org/abs/2406.07548
force_quantization_f32 = True # will force the quantization step to be full precision
):
super().__init__()
# some assert validations
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
codebook_size = default(codebook_size, lambda: 2 ** dim)
self.codebook_size = codebook_size
codebook_dim = int(log2(codebook_size))
codebook_dims = codebook_dim * num_codebooks
dim = default(dim, codebook_dims)
has_projections = default(has_projections, dim != codebook_dims)
if cosine_sim_project_in:
cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
else:
project_in_klass = partial(nn.Linear, bias = projection_has_bias)
self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
self.has_projections = has_projections
self.dim = dim
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
# channel first
self.channel_first = channel_first
# straight through activation
self.activation = straight_through_activation
# whether to use BSQ (binary spherical quantization)
self.spherical = spherical
self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
# entropy aux loss related weights
assert 0 < frac_per_sample_entropy <= 1.
self.frac_per_sample_entropy = frac_per_sample_entropy
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight
# codebook scale
self.codebook_scale = codebook_scale
# commitment loss
self.commitment_loss_weight = commitment_loss_weight
# whether to soft clamp the input value from -value to value
self.soft_clamp_input_value = soft_clamp_input_value
assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
# whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
self.entropy_loss_offset = entropy_loss_offset
self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
# for no auxiliary loss, during inference
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
self.register_buffer('zero', torch.tensor(0.), persistent = False)
# whether to force quantization step to be f32
self.force_quantization_f32 = force_quantization_f32
# codes
all_codes = torch.arange(codebook_size)
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
codebook = self.bits_to_codes(bits)
self.register_buffer('codebook', codebook.float(), persistent = False)
def bits_to_codes(self, bits):
return bits * self.codebook_scale * 2 - self.codebook_scale
@property
def dtype(self):
return self.codebook.dtype
def indices_to_codes(
self,
indices,
project_out = True
):
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
should_transpose = default(self.channel_first, is_img_or_video)
if not self.keep_num_codebooks_dim:
# append a singleton dimension at the end
indices = indices.unsqueeze(-1)
# indices to codes, which are bits of either -1 or 1
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
codes = self.bits_to_codes(bits)
codes = self.maybe_l2norm(codes)
codes = codes.flatten(-2, -1)
# whether to project codes out to original dimensions
# if the input feature dimensions were not log2(codebook size)
if project_out:
codes = self.project_out(codes)
# move codes back to original shape
if should_transpose:
codes = codes.movedim(-1, 1)
return codes
def forward(
self,
x,
inv_temperature = 100.,
return_loss_breakdown = False,
mask = None,
):
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
"""
is_img_or_video = x.ndim >= 4
should_transpose = default(self.channel_first, is_img_or_video)
# standardize image or video into (batch, seq, dimension)
if should_transpose:
x = x.movedim(1, -1)
x, ps = pack_one(x, 'b * d')
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
x = self.project_in(x)
# maybe soft clamp
if exists(self.soft_clamp_input_value):
clamp_value = self.soft_clamp_input_value
x = (x / clamp_value).tanh() * clamp_value
# split out number of codebooks
x = x.reshape(*x.shape[:2], self.num_codebooks, -1)
# maybe l2norm
x = self.maybe_l2norm(x)
# whether to force quantization step to be full precision or not
force_f32 = self.force_quantization_f32
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
with quantization_context():
if force_f32:
orig_dtype = x.dtype
x = x.float()
# quantize by eq 3.
original_input = x
codebook_value = torch.ones_like(x) * self.codebook_scale
quantized = torch.where(x > 0, codebook_value, -codebook_value)
# calculate indices
t = (quantized > 0).int() * self.mask.int()
indices = t.sum(dim=-1)
quantized = self.maybe_l2norm(quantized)
# use straight-through gradients (optionally with custom activation fn) if training
if self.training:
x = self.activation(x)
x = x + (quantized - x).detach()
else:
x = quantized
# entropy aux loss
if self.training:
if force_f32:
codebook = self.codebook.float()
codebook = self.maybe_l2norm(codebook)
# whether to only use a fraction of probs, for reducing memory
input_for_entropy = original_input
if exists(mask):
input_for_entropy = original_input[mask]
input_for_entropy = input_for_entropy.flatten(0, 1)
if self.frac_per_sample_entropy < 1.:
# account for mask
num_tokens = input_for_entropy.size(0)
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
sampled_input = input_for_entropy[rand_mask]
sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
per_sample_probs = sampled_prob
else:
# the same as euclidean distance up to a constant
distance = -2 * einsum('... i d, j d -> ... i j', input_for_entropy, codebook)
prob = (-distance * inv_temperature).softmax(dim = -1)
per_sample_probs = prob
# calculate per sample entropy
per_sample_entropy = entropy(per_sample_probs).mean()
# distribution over all available tokens in the batch
avg_prob = (per_sample_probs
.flatten(start_dim=0, end_dim=-3)
.mean(dim=0))
avg_prob = maybe_distributed_mean(avg_prob)
codebook_entropy = entropy(avg_prob).mean()
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
else:
# if not training, just return dummy 0
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
# whether to make the entropy loss positive or not through a (shifted) softplus
if self.training and self.experimental_softplus_entropy_loss:
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
# commit loss
if self.training and self.commitment_loss_weight > 0.:
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
if exists(mask):
commit_loss = commit_loss[mask]
commit_loss = commit_loss.mean()
else:
commit_loss = self.zero
# input back to original dtype if needed
if force_f32:
x = x.type(orig_dtype)
# merge back codebook dim
x = x.flatten(2, 3)
# project out to feature dimension if needed
x = self.project_out(x)
# reconstitute image or video dimensions
if should_transpose:
x = unpack_one(x, ps, 'b * d')
x = x.movedim(-1, 1)
indices = unpack_one(indices, ps, 'b * c')
# whether to remove single codebook dim
if not self.keep_num_codebooks_dim:
indices = indices.squeeze(-1)
# complete aux loss
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
# returns
ret = Return(x, indices, aux_loss)
if not return_loss_breakdown:
return ret
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
#################$$$$$$$################
### Quantizer Model ###
################$$$$$$##################
class WavCoch(PreTrainedModel):
config_class = WavCochConfig
def __init__(self, config):
super().__init__(config)
self.N = config.window_size
self.hop_length = config.hop_length
# Initial frequency transform convolutions
self.conv_real_filters = nn.Conv1d(1, self.N // 2 + 1, kernel_size=self.N, stride=self.hop_length)
self.conv_imag_filters = nn.Conv1d(1, self.N // 2 + 1, kernel_size=self.N, stride=self.hop_length)
self._initialize_conv_filters()
# Configurable encoder and decoder layers
self.encoder = self._build_conv_block(
in_channels=self.N // 2 + 1,
out_channels=config.encoder_dim,
num_layers=config.encoder_layers,
kernel_size=config.encoder_kernel_size
)
self.quantizer = LFQ(
codebook_size=config.codebook_size,
dim=config.encoder_dim,
num_codebooks=1,
entropy_loss_weight=config.entropy_loss_weight,
commitment_loss_weight=config.commit_loss_weight,
diversity_gamma=config.diversity_gamma,
)
self.decoder = self._build_conv_block(
in_channels=config.decoder_dim,
out_channels=211,
num_layers=config.decoder_layers,
kernel_size=config.decoder_kernel_size
)
def _build_conv_block(self, in_channels, out_channels, num_layers, kernel_size=9):
"""Creates a block of convolutional layers with residual connections."""
layers = []
for i in range(num_layers):
conv_layer = nn.Conv1d(
in_channels if i == 0 else out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding='same'
)
layers.extend([
conv_layer,
nn.ReLU(),
])
return nn.Sequential(*layers)
def _compute_twiddle_factors(self):
n = torch.arange(self.N).unsqueeze(1)
k = torch.arange(self.N).unsqueeze(0)
angles = -2 * math.pi * n * k / self.N
return torch.cos(angles), torch.sin(angles) # Real and imaginary parts
def _initialize_conv_filters(self):
twiddle_factors_real, twiddle_factors_imag = self._compute_twiddle_factors()
twiddle_factors_real = twiddle_factors_real[:self.N // 2 + 1, :]
twiddle_factors_imag = twiddle_factors_imag[:self.N // 2 + 1, :]
window = torch.hann_window(self.N).view(1, 1, -1)
conv_real_filters = twiddle_factors_real.unsqueeze(1) * window
conv_imag_filters = twiddle_factors_imag.unsqueeze(1) * window
self.conv_real_filters.weight = nn.Parameter(conv_real_filters)
self.conv_imag_filters.weight = nn.Parameter(conv_imag_filters)
@property
def vocab_size(self):
return 8192
def forward(self, wav, coch=None, return_tensors="pt", sample_rate=16000, pad=True):
if coch is None:
# # if coch is a 1D input
# if len(wav.shape) == 1:
# wav = wav.unsqueeze(0).unsqueeze(0)
# Handle all input formats
if isinstance(wav, list):
# List[Tensor[T]] → pad to [B, T], then unsqueeze to [B, 1, T]
wav = [w.unsqueeze(0) if w.ndim == 1 else w for w in wav] # make [1, T]
wav = torch.nn.utils.rnn.pad_sequence(wav, batch_first=True) # [B, T]
wav = wav.unsqueeze(1) # [B, 1, T]
elif isinstance(wav, torch.Tensor):
if wav.ndim == 1:
wav = wav.unsqueeze(0).unsqueeze(0) # [1, 1, T]
elif wav.ndim == 2:
wav = wav.unsqueeze(1) # [B, T] → [B, 1, T]
elif wav.ndim != 3:
raise ValueError(f"Unexpected tensor shape {wav.shape}, expected 1D, 2D or 3D.")
else:
raise TypeError(f"Unsupported input type: {type(wav)}")
# pad input waveform to correct for cutoff performed by cochleagram
if pad:
wav = F.pad(wav, (self.N - self.hop_length, 0), mode='constant', value=0)
# quantize audio
codes = self.quantize(wav)
return BatchEncoding({
"input_values": codes,
"input_ids": codes,
})
with torch.no_grad():
real_part = self.conv_real_filters(wav)
imag_part = self.conv_imag_filters(wav)
x = real_part + imag_part
x = self.encoder(x)
x = x.permute(0, 2, 1)
quantized, indices, entropy_aux_loss = self.quantizer(x)
mel_spectrogram = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1)
loss = F.mse_loss(mel_spectrogram, coch)
return mel_spectrogram, loss, entropy_aux_loss
def quantize(self, wav):
with torch.no_grad():
real_part = self.conv_real_filters(wav)
imag_part = self.conv_imag_filters(wav)
x = real_part + imag_part
x = self.encoder(x)
x = x.permute(0, 2, 1)
quantized, indices, _ = self.quantizer(x)
return indices
# def quantize(self, wav, pad=True):
# # Pad the audio waveform if necessary
# if pad:
# wav = F.pad(wav, (self.N - self.hop_length, 0), mode='constant', value=0)
# with torch.no_grad():
# real_part = self.conv_real_filters(wav)
# imag_part = self.conv_imag_filters(wav)
# x = real_part + imag_part
# x = self.encoder(x)
# x = x.permute(0, 2, 1)
# # Quantization
# quantized, indices, entropy_aux_loss = self.quantizer(x)
# return indices
def decode(self, indices):
emb = self.quantizer.indices_to_codes(indices)
mel_spectrogram = self.decoder(emb.permute(0, 2, 1)).permute(0, 2, 1)
return mel_spectrogram
def wav2coch(self, wav):
with torch.no_grad():
real_part = self.conv_real_filters(wav)
imag_part = self.conv_imag_filters(wav)
x = real_part + imag_part
x = self.encoder(x)
x = x.permute(0, 2, 1)
quantized, indices, _ = self.quantizer(x)
mel_spectrogram = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1)
return mel_spectrogram
def invert_cochleagram_to_audio(
self,
cochleagram,
device,
num_optim_steps=1000,
lr=1e-2,
transform_cls=CochleagramTransform
):
"""
Function to invert a cochleagram back to audio using gradient descent
"""
# Initialize the transform function
transform = transform_cls(sr=16000, signal_size=16000*5, device=device, return_on_cpu=False)
# Initialize the audio to be optimized
audio = torch.randn(1, 1, 16000*5).to(device).requires_grad_()
# Define the optimizer
optimizer = torch.optim.Adam([audio], lr=lr)
# Define the loss function
criterion = torch.nn.MSELoss()
# Initialize tqdm progress bar
with tqdm.tqdm(total=num_optim_steps, desc="Inverting the cochleagram") as pbar:
# Invert the cochleagram
for _ in range(num_optim_steps):
optimizer.zero_grad()
# Compute the cochleagram from the audio
pred_coch = transform(audio[0])
# Compute the loss
loss = criterion(pred_coch, cochleagram)
# Backpropagate the loss
loss.backward()
# Update the audio
optimizer.step()
# Update the progress bar
pbar.set_postfix(loss=loss.item())
pbar.update(1)
return audio