import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import librosa from transformers import PreTrainedModel, PretrainedConfig from huggingface_hub import PyTorchModelHubMixin from typing import Optional, List, Union, Dict class MuleConfig(PretrainedConfig): model_type = "mule" def __init__( self, sample_rate: int = 16000, n_mels: int = 128, n_fft: int = 1024, hop_length: int = 256, win_length: int = 512, fmin: float = 40.0, fmax: float = 8000.0, alpha: float = 0.2, scaled_activation_type: str = "gelu", f_value: int = 0, projector_layers: list = None, include_fc: bool = True, num_labels: int = 50, **kwargs ): super().__init__(**kwargs) self.sample_rate = sample_rate self.n_mels = n_mels self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.fmin = fmin self.fmax = fmax self.alpha = alpha self.scaled_activation_type = scaled_activation_type self.f_value = f_value self.projector_layers = projector_layers if projector_layers is not None else [] self.include_fc = include_fc self.num_labels = num_labels def get_same_padding(kernel_size, stride): if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) # This is a simplification. For stride > 1, 'same' padding depends on input size. # However, for many architectures, (kernel_size - 1) // 2 works if stride=1. # For strided convs, we often use manual padding. return (kernel_size[0] // 2, kernel_size[1] // 2) class ScaledActivation(nn.Module): def __init__(self, activation_type="gelu"): super().__init__() self.activation_type = activation_type if activation_type == "gelu": self.scale = 1.7015043497085571 self.act = F.gelu elif activation_type == "relu": self.scale = 1.7139588594436646 self.act = F.relu else: raise ValueError(f"Unsupported activation: {activation_type}") def forward(self, x): return self.act(x) * self.scale class WeightStandardizedConv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): if padding == "same": # Calculate padding for stride 1. For stride > 1, we'll pad manually in forward. if isinstance(stride, int) and stride > 1: padding = 0 self.manual_padding = True elif isinstance(stride, (list, tuple)) and any(s > 1 for s in stride): padding = 0 self.manual_padding = True else: padding = get_same_padding(kernel_size, stride) self.manual_padding = False else: self.manual_padding = False super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) self.eps = 1e-4 def forward(self, x): if self.manual_padding: # TF 'same' padding: # p_h = max(0, (o_h - 1) * s_h + k_h - i_h) # p_w = max(0, (o_w - 1) * s_w + k_w - i_w) # We'll use a simpler version that works for most cases: h, w = x.shape[-2:] k_h, k_w = self.kernel_size s_h, s_w = self.stride pad_h = max(0, (np.ceil(h / s_h) - 1) * s_h + k_h - h) pad_w = max(0, (np.ceil(w / s_w) - 1) * s_w + k_w - w) if pad_h > 0 or pad_w > 0: x = F.pad(x, [int(pad_w // 2), int(pad_w - pad_w // 2), int(pad_h // 2), int(pad_h - pad_h // 2)]) weight = self.weight mean = weight.mean(dim=(1, 2, 3), keepdim=True) var = weight.var(dim=(1, 2, 3), keepdim=True, unbiased=False) fan_in = weight[0].numel() scale = torch.rsqrt(torch.clamp(var * fan_in, min=self.eps)) * self.gain shift = mean * scale standardized_weight = weight * scale - shift return F.conv2d(x, standardized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) class ScalarMultiply(nn.Module): def __init__(self, init_gain, learnable=False): super().__init__() self.init_gain = init_gain self.learnable = learnable if learnable: self.gain = nn.Parameter(torch.tensor(float(init_gain))) else: self.register_buffer("gain", torch.tensor(float(init_gain))) def forward(self, x): return x * self.gain class SqueezeExcite(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Linear(in_channels, out_channels // 2) self.fc2 = nn.Linear(out_channels // 2, out_channels) def forward(self, x): b, c, h, w = x.shape y = self.avg_pool(x).view(b, c) y = F.relu(self.fc1(y)) y = torch.sigmoid(self.fc2(y)).view(b, c, 1, 1) return x * y * 2.0 class StochDepth(nn.Module): def __init__(self, survival_probability=0.5): super().__init__() self.survival_probability = survival_probability def forward(self, x_list): shortcut, residual = x_list if not self.training: return shortcut + residual b = shortcut.shape[0] mask = torch.bernoulli(torch.full((b, 1, 1, 1), self.survival_probability, device=shortcut.device)) return shortcut + mask * residual class NFNetBlock(nn.Module): def __init__(self, kernels, freq_downsample, in_ch, out_ch, group_size, alpha, beta, stoch_depth_prob, is_transition=False): super().__init__() self.alpha = alpha self.beta = beta self.is_transition = (freq_downsample > 1) or (in_ch != out_ch) or is_transition self.act = ScaledActivation() self.scalar_beta = ScalarMultiply(beta) mid_ch = out_ch // 2 groups = [1, mid_ch // group_size, mid_ch // group_size, 1] strides = [(1, 1), (freq_downsample, 1), (1, 1), (1, 1)] res_layers = [] res_layers.append(WeightStandardizedConv2d(in_ch, mid_ch, kernel_size=kernels[0], stride=strides[0], padding="same")) res_layers.append(ScaledActivation()) res_layers.append(WeightStandardizedConv2d(mid_ch, mid_ch, kernel_size=kernels[1], stride=strides[1], padding="same")) res_layers.append(ScaledActivation()) res_layers.append(WeightStandardizedConv2d(mid_ch, mid_ch, kernel_size=kernels[2], stride=strides[2], padding="same")) res_layers.append(ScaledActivation()) res_layers.append(WeightStandardizedConv2d(mid_ch, out_ch, kernel_size=kernels[3], stride=strides[3], padding="same")) self.residual_convs = nn.Sequential(*res_layers) self.se = SqueezeExcite(out_ch, out_ch) self.gain = ScalarMultiply(0.0, learnable=True) self.scalar_alpha = ScalarMultiply(alpha) self.skip = nn.Identity() if freq_downsample > 1: self.skip = nn.Sequential( nn.AvgPool2d(kernel_size=(freq_downsample, 1), stride=(freq_downsample, 1), padding=0), self.skip ) if self.is_transition: self.skip = nn.Sequential( self.skip, WeightStandardizedConv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0) ) self.stoch_depth = StochDepth(1.0 - stoch_depth_prob) def forward(self, x): if self.is_transition: z = self.act(x) z = self.scalar_beta(z) shortcut = self.skip(z) residual = self.residual_convs(z) else: shortcut = x z = self.act(x) z = self.scalar_beta(z) residual = self.residual_convs(z) residual = self.se(residual) residual = self.gain(residual) residual = self.scalar_alpha(residual) return self.stoch_depth([shortcut, residual]) class NFNetStage(nn.Module): def __init__(self, kernels, freq_downsample, in_ch, out_ch, group_size, alpha, input_expected_var, stoch_depths, num_blocks): super().__init__() self.blocks = nn.ModuleList() self.blocks.append(NFNetBlock( kernels, freq_downsample, in_ch, out_ch, group_size, alpha, 1.0/input_expected_var, stoch_depths[0], is_transition=True )) expected_std = (input_expected_var**2.0 + alpha**2.0)**0.5 for i in range(1, num_blocks): self.blocks.append(NFNetBlock( kernels, 1, out_ch, out_ch, group_size, alpha, 1.0/expected_std, stoch_depths[i] )) expected_std = (expected_std**2.0 + alpha**2.0)**0.5 def forward(self, x): for block in self.blocks: x = block(x) return x class FusionLayer(nn.Module): def __init__(self, time_kernel_length, time_stride, in_ch, out_ch): super().__init__() self.conv1 = WeightStandardizedConv2d(in_ch, in_ch, kernel_size=(1, time_kernel_length), stride=(1, time_stride), padding="same") self.conv2 = WeightStandardizedConv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0) def forward(self, slow, fast): fast_fused = self.conv1(fast) fast_fused = self.conv2(fast_fused) return torch.cat([slow, fast_fused], dim=1) class Mule(PreTrainedModel, PyTorchModelHubMixin): config_class = MuleConfig def __init__(self, config: MuleConfig): super().__init__(config) self.slow_stem = self._make_stem([1, 16, 32, 64, 128], [(3, 1), (3, 1), (3, 1), (3, 3)], [(2, 8), (1, 1), (1, 1), (2, 2)]) self.fast_stem = self._make_stem([1, 2, 4, 8, 16], [(3, 3), (3, 3), (3, 3), (3, 3)], [(2, 2), (1, 1), (1, 1), (2, 2)]) nfnet_stage_depths = [x * (config.f_value + 1) for x in (1, 2, 6, 3)] cumulative_stage_depths = np.concatenate(([0], np.cumsum(nfnet_stage_depths))) stoch_depth_probs = 0.1 * np.arange(cumulative_stage_depths[-1]) / (cumulative_stage_depths[-1]) stoch_depth_probs_split = [stoch_depth_probs[st:end] for st, end in zip(cumulative_stage_depths[:-1], cumulative_stage_depths[1:])] stage_expected_vars = [1.0, (1.0 + config.alpha**2)**0.5, (1.0 + config.alpha**2)**0.5, (1.0 + config.alpha**2)**0.5] stage_downsamples = [1, 2, 2, 2] slow_in_channels = [256, 512, 1024, 4608] # Updated to match fusion expansion slow_out_channels = [256, 512, 1536, 1536] slow_kernels = [[(1, 1), (1, 3), (3, 1), (1, 1)]] * 4 self.slow_stages = nn.ModuleList([ NFNetStage(slow_kernels[i], stage_downsamples[i], slow_in_channels[i], slow_out_channels[i], 128, config.alpha, stage_expected_vars[i], stoch_depth_probs_split[i], nfnet_stage_depths[i]) for i in range(4) ]) fast_in_channels = [16, 32, 64, 192] fast_out_channels = [32, 64, 192, 192] fast_kernels = [[(1, 1), (1, 3), (3, 1), (1, 1)]] * 4 self.fast_stages = nn.ModuleList([ NFNetStage(fast_kernels[i], stage_downsamples[i], fast_in_channels[i], fast_out_channels[i], 16, config.alpha, stage_expected_vars[i], stoch_depth_probs_split[i], nfnet_stage_depths[i]) for i in range(4) ]) self.fusion_layers = nn.ModuleList([ FusionLayer(7, 4, 16, 128), FusionLayer(7, 4, 32, 256), FusionLayer(7, 4, 64, 512), FusionLayer(7, 4, 192, 3072) ]) self.final_act = ScaledActivation(config.scaled_activation_type) projectors = [] in_dim = 1536 + 192 for dim in config.projector_layers: projectors.append(nn.Linear(in_dim, dim)) projectors.append(ScaledActivation(config.scaled_activation_type)) in_dim = dim if config.include_fc: self.classifier = nn.Linear(in_dim, config.num_labels, bias=False) else: self.classifier = nn.Identity() self.projectors = nn.Sequential(*projectors) def _make_stem(self, channels, kernels, strides): layers = nn.ModuleList() for i in range(len(kernels)): layers.append(WeightStandardizedConv2d(channels[i], channels[i+1], kernel_size=kernels[i], stride=strides[i], padding="same")) if i < len(kernels) - 1: layers.append(ScaledActivation()) return nn.Sequential(*layers) def forward(self, x): if x.dim() == 3: x = x.unsqueeze(1) slow = self.slow_stem(x) fast = self.fast_stem(x) for i in range(4): slow = self.fusion_layers[i](slow, fast) slow = self.slow_stages[i](slow) fast = self.fast_stages[i](fast) slow_out = slow.mean(dim=(2, 3)) fast_out = fast.mean(dim=(2, 3)) out = torch.cat([slow_out, fast_out], dim=1) out = self.final_act(out) out = self.projectors(out) logits = self.classifier(out) return logits def preprocess(self, audio_file): audio, sr = librosa.load(audio_file, sr=self.config.sample_rate) mel = librosa.feature.melspectrogram( y=audio, sr=sr, n_fft=self.config.n_fft, hop_length=self.config.hop_length, win_length=self.config.win_length, n_mels=self.config.n_mels, fmin=self.config.fmin, fmax=self.config.fmax ) mel = np.log10(10000.0 * mel + 1.0) return torch.from_numpy(mel).float() def predict(self, audio_file): # Auto-detect device and move model to it device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.to(device) mel = self.preprocess(audio_file).unsqueeze(0).to(device) with torch.no_grad(): logits = self(mel) probs = torch.sigmoid(logits) return probs def extract_embeddings(self, audio_file): # Auto-detect device and move model to it device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.to(device) mel = self.preprocess(audio_file).unsqueeze(0).to(device) return self.extract_embeddings_from_spec(mel) def extract_embeddings_from_spec(self, mel): with torch.no_grad(): # Ensure input is on the same device as model if mel.device != next(self.parameters()).device: mel = mel.to(next(self.parameters()).device) slow = self.slow_stem(mel.unsqueeze(1) if mel.dim()==3 else mel) fast = self.fast_stem(mel.unsqueeze(1) if mel.dim()==3 else mel) for i in range(4): slow = self.fusion_layers[i](slow, fast) slow = self.slow_stages[i](slow) fast = self.fast_stages[i](fast) slow_out = slow.mean(dim=(2, 3)) fast_out = fast.mean(dim=(2, 3)) out = torch.cat([slow_out, fast_out], dim=1) out = self.final_act(out) emb = self.projectors(out) return emb