| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import librosa |
| from transformers import PreTrainedModel, PretrainedConfig |
| from huggingface_hub import PyTorchModelHubMixin |
| from typing import Optional, List, Union, Dict |
|
|
| class MuleConfig(PretrainedConfig): |
| model_type = "mule" |
| |
| def __init__( |
| self, |
| sample_rate: int = 16000, |
| n_mels: int = 128, |
| n_fft: int = 1024, |
| hop_length: int = 256, |
| win_length: int = 512, |
| fmin: float = 40.0, |
| fmax: float = 8000.0, |
| alpha: float = 0.2, |
| scaled_activation_type: str = "gelu", |
| f_value: int = 0, |
| projector_layers: list = None, |
| include_fc: bool = True, |
| num_labels: int = 50, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.sample_rate = sample_rate |
| self.n_mels = n_mels |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.win_length = win_length |
| self.fmin = fmin |
| self.fmax = fmax |
| self.alpha = alpha |
| self.scaled_activation_type = scaled_activation_type |
| self.f_value = f_value |
| self.projector_layers = projector_layers if projector_layers is not None else [] |
| self.include_fc = include_fc |
| self.num_labels = num_labels |
|
|
| def get_same_padding(kernel_size, stride): |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size) |
| if isinstance(stride, int): |
| stride = (stride, stride) |
| |
| |
| |
| |
| return (kernel_size[0] // 2, kernel_size[1] // 2) |
|
|
| class ScaledActivation(nn.Module): |
| def __init__(self, activation_type="gelu"): |
| super().__init__() |
| self.activation_type = activation_type |
| if activation_type == "gelu": |
| self.scale = 1.7015043497085571 |
| self.act = F.gelu |
| elif activation_type == "relu": |
| self.scale = 1.7139588594436646 |
| self.act = F.relu |
| else: |
| raise ValueError(f"Unsupported activation: {activation_type}") |
|
|
| def forward(self, x): |
| return self.act(x) * self.scale |
|
|
| class WeightStandardizedConv2d(nn.Conv2d): |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): |
| if padding == "same": |
| |
| if isinstance(stride, int) and stride > 1: |
| padding = 0 |
| self.manual_padding = True |
| elif isinstance(stride, (list, tuple)) and any(s > 1 for s in stride): |
| padding = 0 |
| self.manual_padding = True |
| else: |
| padding = get_same_padding(kernel_size, stride) |
| self.manual_padding = False |
| else: |
| self.manual_padding = False |
| |
| super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) |
| self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) |
| self.eps = 1e-4 |
|
|
| def forward(self, x): |
| if self.manual_padding: |
| |
| |
| |
| |
| h, w = x.shape[-2:] |
| k_h, k_w = self.kernel_size |
| s_h, s_w = self.stride |
| |
| pad_h = max(0, (np.ceil(h / s_h) - 1) * s_h + k_h - h) |
| pad_w = max(0, (np.ceil(w / s_w) - 1) * s_w + k_w - w) |
| |
| if pad_h > 0 or pad_w > 0: |
| x = F.pad(x, [int(pad_w // 2), int(pad_w - pad_w // 2), int(pad_h // 2), int(pad_h - pad_h // 2)]) |
|
|
| weight = self.weight |
| mean = weight.mean(dim=(1, 2, 3), keepdim=True) |
| var = weight.var(dim=(1, 2, 3), keepdim=True, unbiased=False) |
| fan_in = weight[0].numel() |
| scale = torch.rsqrt(torch.clamp(var * fan_in, min=self.eps)) * self.gain |
| shift = mean * scale |
| standardized_weight = weight * scale - shift |
| return F.conv2d(x, standardized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
| class ScalarMultiply(nn.Module): |
| def __init__(self, init_gain, learnable=False): |
| super().__init__() |
| self.init_gain = init_gain |
| self.learnable = learnable |
| if learnable: |
| self.gain = nn.Parameter(torch.tensor(float(init_gain))) |
| else: |
| self.register_buffer("gain", torch.tensor(float(init_gain))) |
|
|
| def forward(self, x): |
| return x * self.gain |
|
|
| class SqueezeExcite(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.fc1 = nn.Linear(in_channels, out_channels // 2) |
| self.fc2 = nn.Linear(out_channels // 2, out_channels) |
|
|
| def forward(self, x): |
| b, c, h, w = x.shape |
| y = self.avg_pool(x).view(b, c) |
| y = F.relu(self.fc1(y)) |
| y = torch.sigmoid(self.fc2(y)).view(b, c, 1, 1) |
| return x * y * 2.0 |
|
|
| class StochDepth(nn.Module): |
| def __init__(self, survival_probability=0.5): |
| super().__init__() |
| self.survival_probability = survival_probability |
|
|
| def forward(self, x_list): |
| shortcut, residual = x_list |
| if not self.training: |
| return shortcut + residual |
| b = shortcut.shape[0] |
| mask = torch.bernoulli(torch.full((b, 1, 1, 1), self.survival_probability, device=shortcut.device)) |
| return shortcut + mask * residual |
|
|
| class NFNetBlock(nn.Module): |
| def __init__(self, kernels, freq_downsample, in_ch, out_ch, group_size, alpha, beta, stoch_depth_prob, is_transition=False): |
| super().__init__() |
| self.alpha = alpha |
| self.beta = beta |
| self.is_transition = (freq_downsample > 1) or (in_ch != out_ch) or is_transition |
| |
| self.act = ScaledActivation() |
| self.scalar_beta = ScalarMultiply(beta) |
| |
| mid_ch = out_ch // 2 |
| groups = [1, mid_ch // group_size, mid_ch // group_size, 1] |
| strides = [(1, 1), (freq_downsample, 1), (1, 1), (1, 1)] |
| |
| res_layers = [] |
| res_layers.append(WeightStandardizedConv2d(in_ch, mid_ch, kernel_size=kernels[0], stride=strides[0], padding="same")) |
| res_layers.append(ScaledActivation()) |
| res_layers.append(WeightStandardizedConv2d(mid_ch, mid_ch, kernel_size=kernels[1], stride=strides[1], padding="same")) |
| res_layers.append(ScaledActivation()) |
| res_layers.append(WeightStandardizedConv2d(mid_ch, mid_ch, kernel_size=kernels[2], stride=strides[2], padding="same")) |
| res_layers.append(ScaledActivation()) |
| res_layers.append(WeightStandardizedConv2d(mid_ch, out_ch, kernel_size=kernels[3], stride=strides[3], padding="same")) |
| |
| self.residual_convs = nn.Sequential(*res_layers) |
| self.se = SqueezeExcite(out_ch, out_ch) |
| self.gain = ScalarMultiply(0.0, learnable=True) |
| self.scalar_alpha = ScalarMultiply(alpha) |
| |
| self.skip = nn.Identity() |
| if freq_downsample > 1: |
| self.skip = nn.Sequential( |
| nn.AvgPool2d(kernel_size=(freq_downsample, 1), stride=(freq_downsample, 1), padding=0), |
| self.skip |
| ) |
| if self.is_transition: |
| self.skip = nn.Sequential( |
| self.skip, |
| WeightStandardizedConv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0) |
| ) |
| |
| self.stoch_depth = StochDepth(1.0 - stoch_depth_prob) |
|
|
| def forward(self, x): |
| if self.is_transition: |
| z = self.act(x) |
| z = self.scalar_beta(z) |
| shortcut = self.skip(z) |
| residual = self.residual_convs(z) |
| else: |
| shortcut = x |
| z = self.act(x) |
| z = self.scalar_beta(z) |
| residual = self.residual_convs(z) |
| |
| residual = self.se(residual) |
| residual = self.gain(residual) |
| residual = self.scalar_alpha(residual) |
| |
| return self.stoch_depth([shortcut, residual]) |
|
|
| class NFNetStage(nn.Module): |
| def __init__(self, kernels, freq_downsample, in_ch, out_ch, group_size, alpha, input_expected_var, stoch_depths, num_blocks): |
| super().__init__() |
| self.blocks = nn.ModuleList() |
| self.blocks.append(NFNetBlock( |
| kernels, freq_downsample, in_ch, out_ch, group_size, alpha, 1.0/input_expected_var, stoch_depths[0], is_transition=True |
| )) |
| |
| expected_std = (input_expected_var**2.0 + alpha**2.0)**0.5 |
| for i in range(1, num_blocks): |
| self.blocks.append(NFNetBlock( |
| kernels, 1, out_ch, out_ch, group_size, alpha, 1.0/expected_std, stoch_depths[i] |
| )) |
| expected_std = (expected_std**2.0 + alpha**2.0)**0.5 |
|
|
| def forward(self, x): |
| for block in self.blocks: |
| x = block(x) |
| return x |
|
|
| class FusionLayer(nn.Module): |
| def __init__(self, time_kernel_length, time_stride, in_ch, out_ch): |
| super().__init__() |
| self.conv1 = WeightStandardizedConv2d(in_ch, in_ch, kernel_size=(1, time_kernel_length), stride=(1, time_stride), padding="same") |
| self.conv2 = WeightStandardizedConv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0) |
|
|
| def forward(self, slow, fast): |
| fast_fused = self.conv1(fast) |
| fast_fused = self.conv2(fast_fused) |
| return torch.cat([slow, fast_fused], dim=1) |
|
|
| class Mule(PreTrainedModel, PyTorchModelHubMixin): |
| config_class = MuleConfig |
| |
| def __init__(self, config: MuleConfig): |
| super().__init__(config) |
| |
| self.slow_stem = self._make_stem([1, 16, 32, 64, 128], [(3, 1), (3, 1), (3, 1), (3, 3)], [(2, 8), (1, 1), (1, 1), (2, 2)]) |
| self.fast_stem = self._make_stem([1, 2, 4, 8, 16], [(3, 3), (3, 3), (3, 3), (3, 3)], [(2, 2), (1, 1), (1, 1), (2, 2)]) |
| |
| nfnet_stage_depths = [x * (config.f_value + 1) for x in (1, 2, 6, 3)] |
| cumulative_stage_depths = np.concatenate(([0], np.cumsum(nfnet_stage_depths))) |
| stoch_depth_probs = 0.1 * np.arange(cumulative_stage_depths[-1]) / (cumulative_stage_depths[-1]) |
| stoch_depth_probs_split = [stoch_depth_probs[st:end] for st, end in zip(cumulative_stage_depths[:-1], cumulative_stage_depths[1:])] |
| |
| stage_expected_vars = [1.0, (1.0 + config.alpha**2)**0.5, (1.0 + config.alpha**2)**0.5, (1.0 + config.alpha**2)**0.5] |
| stage_downsamples = [1, 2, 2, 2] |
| |
| slow_in_channels = [256, 512, 1024, 4608] |
| slow_out_channels = [256, 512, 1536, 1536] |
| slow_kernels = [[(1, 1), (1, 3), (3, 1), (1, 1)]] * 4 |
| |
| self.slow_stages = nn.ModuleList([ |
| NFNetStage(slow_kernels[i], stage_downsamples[i], slow_in_channels[i], slow_out_channels[i], 128, config.alpha, stage_expected_vars[i], stoch_depth_probs_split[i], nfnet_stage_depths[i]) |
| for i in range(4) |
| ]) |
| |
| fast_in_channels = [16, 32, 64, 192] |
| fast_out_channels = [32, 64, 192, 192] |
| fast_kernels = [[(1, 1), (1, 3), (3, 1), (1, 1)]] * 4 |
| |
| self.fast_stages = nn.ModuleList([ |
| NFNetStage(fast_kernels[i], stage_downsamples[i], fast_in_channels[i], fast_out_channels[i], 16, config.alpha, stage_expected_vars[i], stoch_depth_probs_split[i], nfnet_stage_depths[i]) |
| for i in range(4) |
| ]) |
| |
| self.fusion_layers = nn.ModuleList([ |
| FusionLayer(7, 4, 16, 128), |
| FusionLayer(7, 4, 32, 256), |
| FusionLayer(7, 4, 64, 512), |
| FusionLayer(7, 4, 192, 3072) |
| ]) |
| |
| self.final_act = ScaledActivation(config.scaled_activation_type) |
| |
| projectors = [] |
| in_dim = 1536 + 192 |
| for dim in config.projector_layers: |
| projectors.append(nn.Linear(in_dim, dim)) |
| projectors.append(ScaledActivation(config.scaled_activation_type)) |
| in_dim = dim |
| |
| if config.include_fc: |
| self.classifier = nn.Linear(in_dim, config.num_labels, bias=False) |
| else: |
| self.classifier = nn.Identity() |
| |
| self.projectors = nn.Sequential(*projectors) |
|
|
| def _make_stem(self, channels, kernels, strides): |
| layers = nn.ModuleList() |
| for i in range(len(kernels)): |
| layers.append(WeightStandardizedConv2d(channels[i], channels[i+1], kernel_size=kernels[i], stride=strides[i], padding="same")) |
| if i < len(kernels) - 1: |
| layers.append(ScaledActivation()) |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| if x.dim() == 3: |
| x = x.unsqueeze(1) |
| |
| slow = self.slow_stem(x) |
| fast = self.fast_stem(x) |
| |
| for i in range(4): |
| slow = self.fusion_layers[i](slow, fast) |
| slow = self.slow_stages[i](slow) |
| fast = self.fast_stages[i](fast) |
| |
| slow_out = slow.mean(dim=(2, 3)) |
| fast_out = fast.mean(dim=(2, 3)) |
| |
| out = torch.cat([slow_out, fast_out], dim=1) |
| out = self.final_act(out) |
| |
| out = self.projectors(out) |
| logits = self.classifier(out) |
| |
| return logits |
|
|
| def preprocess(self, audio_file): |
| audio, sr = librosa.load(audio_file, sr=self.config.sample_rate) |
| mel = librosa.feature.melspectrogram( |
| y=audio, sr=sr, n_fft=self.config.n_fft, hop_length=self.config.hop_length, |
| win_length=self.config.win_length, n_mels=self.config.n_mels, |
| fmin=self.config.fmin, fmax=self.config.fmax |
| ) |
| mel = np.log10(10000.0 * mel + 1.0) |
| return torch.from_numpy(mel).float() |
|
|
| def predict(self, audio_file): |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.to(device) |
|
|
| mel = self.preprocess(audio_file).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| logits = self(mel) |
| probs = torch.sigmoid(logits) |
| return probs |
|
|
| def extract_embeddings(self, audio_file): |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.to(device) |
|
|
| mel = self.preprocess(audio_file).unsqueeze(0).to(device) |
| return self.extract_embeddings_from_spec(mel) |
|
|
| def extract_embeddings_from_spec(self, mel): |
| with torch.no_grad(): |
| |
| if mel.device != next(self.parameters()).device: |
| mel = mel.to(next(self.parameters()).device) |
|
|
| slow = self.slow_stem(mel.unsqueeze(1) if mel.dim()==3 else mel) |
| fast = self.fast_stem(mel.unsqueeze(1) if mel.dim()==3 else mel) |
| for i in range(4): |
| slow = self.fusion_layers[i](slow, fast) |
| slow = self.slow_stages[i](slow) |
| fast = self.fast_stages[i](fast) |
| slow_out = slow.mean(dim=(2, 3)) |
| fast_out = fast.mean(dim=(2, 3)) |
| out = torch.cat([slow_out, fast_out], dim=1) |
| out = self.final_act(out) |
| emb = self.projectors(out) |
| return emb |
|
|