mule-pytorch / mule.py
oriyonay's picture
Upload mule.py with huggingface_hub
dd84e41 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
from transformers import PreTrainedModel, PretrainedConfig
from huggingface_hub import PyTorchModelHubMixin
from typing import Optional, List, Union, Dict
class MuleConfig(PretrainedConfig):
model_type = "mule"
def __init__(
self,
sample_rate: int = 16000,
n_mels: int = 128,
n_fft: int = 1024,
hop_length: int = 256,
win_length: int = 512,
fmin: float = 40.0,
fmax: float = 8000.0,
alpha: float = 0.2,
scaled_activation_type: str = "gelu",
f_value: int = 0,
projector_layers: list = None,
include_fc: bool = True,
num_labels: int = 50,
**kwargs
):
super().__init__(**kwargs)
self.sample_rate = sample_rate
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.fmin = fmin
self.fmax = fmax
self.alpha = alpha
self.scaled_activation_type = scaled_activation_type
self.f_value = f_value
self.projector_layers = projector_layers if projector_layers is not None else []
self.include_fc = include_fc
self.num_labels = num_labels
def get_same_padding(kernel_size, stride):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
# This is a simplification. For stride > 1, 'same' padding depends on input size.
# However, for many architectures, (kernel_size - 1) // 2 works if stride=1.
# For strided convs, we often use manual padding.
return (kernel_size[0] // 2, kernel_size[1] // 2)
class ScaledActivation(nn.Module):
def __init__(self, activation_type="gelu"):
super().__init__()
self.activation_type = activation_type
if activation_type == "gelu":
self.scale = 1.7015043497085571
self.act = F.gelu
elif activation_type == "relu":
self.scale = 1.7139588594436646
self.act = F.relu
else:
raise ValueError(f"Unsupported activation: {activation_type}")
def forward(self, x):
return self.act(x) * self.scale
class WeightStandardizedConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
if padding == "same":
# Calculate padding for stride 1. For stride > 1, we'll pad manually in forward.
if isinstance(stride, int) and stride > 1:
padding = 0
self.manual_padding = True
elif isinstance(stride, (list, tuple)) and any(s > 1 for s in stride):
padding = 0
self.manual_padding = True
else:
padding = get_same_padding(kernel_size, stride)
self.manual_padding = False
else:
self.manual_padding = False
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.eps = 1e-4
def forward(self, x):
if self.manual_padding:
# TF 'same' padding:
# p_h = max(0, (o_h - 1) * s_h + k_h - i_h)
# p_w = max(0, (o_w - 1) * s_w + k_w - i_w)
# We'll use a simpler version that works for most cases:
h, w = x.shape[-2:]
k_h, k_w = self.kernel_size
s_h, s_w = self.stride
pad_h = max(0, (np.ceil(h / s_h) - 1) * s_h + k_h - h)
pad_w = max(0, (np.ceil(w / s_w) - 1) * s_w + k_w - w)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [int(pad_w // 2), int(pad_w - pad_w // 2), int(pad_h // 2), int(pad_h - pad_h // 2)])
weight = self.weight
mean = weight.mean(dim=(1, 2, 3), keepdim=True)
var = weight.var(dim=(1, 2, 3), keepdim=True, unbiased=False)
fan_in = weight[0].numel()
scale = torch.rsqrt(torch.clamp(var * fan_in, min=self.eps)) * self.gain
shift = mean * scale
standardized_weight = weight * scale - shift
return F.conv2d(x, standardized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class ScalarMultiply(nn.Module):
def __init__(self, init_gain, learnable=False):
super().__init__()
self.init_gain = init_gain
self.learnable = learnable
if learnable:
self.gain = nn.Parameter(torch.tensor(float(init_gain)))
else:
self.register_buffer("gain", torch.tensor(float(init_gain)))
def forward(self, x):
return x * self.gain
class SqueezeExcite(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(in_channels, out_channels // 2)
self.fc2 = nn.Linear(out_channels // 2, out_channels)
def forward(self, x):
b, c, h, w = x.shape
y = self.avg_pool(x).view(b, c)
y = F.relu(self.fc1(y))
y = torch.sigmoid(self.fc2(y)).view(b, c, 1, 1)
return x * y * 2.0
class StochDepth(nn.Module):
def __init__(self, survival_probability=0.5):
super().__init__()
self.survival_probability = survival_probability
def forward(self, x_list):
shortcut, residual = x_list
if not self.training:
return shortcut + residual
b = shortcut.shape[0]
mask = torch.bernoulli(torch.full((b, 1, 1, 1), self.survival_probability, device=shortcut.device))
return shortcut + mask * residual
class NFNetBlock(nn.Module):
def __init__(self, kernels, freq_downsample, in_ch, out_ch, group_size, alpha, beta, stoch_depth_prob, is_transition=False):
super().__init__()
self.alpha = alpha
self.beta = beta
self.is_transition = (freq_downsample > 1) or (in_ch != out_ch) or is_transition
self.act = ScaledActivation()
self.scalar_beta = ScalarMultiply(beta)
mid_ch = out_ch // 2
groups = [1, mid_ch // group_size, mid_ch // group_size, 1]
strides = [(1, 1), (freq_downsample, 1), (1, 1), (1, 1)]
res_layers = []
res_layers.append(WeightStandardizedConv2d(in_ch, mid_ch, kernel_size=kernels[0], stride=strides[0], padding="same"))
res_layers.append(ScaledActivation())
res_layers.append(WeightStandardizedConv2d(mid_ch, mid_ch, kernel_size=kernels[1], stride=strides[1], padding="same"))
res_layers.append(ScaledActivation())
res_layers.append(WeightStandardizedConv2d(mid_ch, mid_ch, kernel_size=kernels[2], stride=strides[2], padding="same"))
res_layers.append(ScaledActivation())
res_layers.append(WeightStandardizedConv2d(mid_ch, out_ch, kernel_size=kernels[3], stride=strides[3], padding="same"))
self.residual_convs = nn.Sequential(*res_layers)
self.se = SqueezeExcite(out_ch, out_ch)
self.gain = ScalarMultiply(0.0, learnable=True)
self.scalar_alpha = ScalarMultiply(alpha)
self.skip = nn.Identity()
if freq_downsample > 1:
self.skip = nn.Sequential(
nn.AvgPool2d(kernel_size=(freq_downsample, 1), stride=(freq_downsample, 1), padding=0),
self.skip
)
if self.is_transition:
self.skip = nn.Sequential(
self.skip,
WeightStandardizedConv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
)
self.stoch_depth = StochDepth(1.0 - stoch_depth_prob)
def forward(self, x):
if self.is_transition:
z = self.act(x)
z = self.scalar_beta(z)
shortcut = self.skip(z)
residual = self.residual_convs(z)
else:
shortcut = x
z = self.act(x)
z = self.scalar_beta(z)
residual = self.residual_convs(z)
residual = self.se(residual)
residual = self.gain(residual)
residual = self.scalar_alpha(residual)
return self.stoch_depth([shortcut, residual])
class NFNetStage(nn.Module):
def __init__(self, kernels, freq_downsample, in_ch, out_ch, group_size, alpha, input_expected_var, stoch_depths, num_blocks):
super().__init__()
self.blocks = nn.ModuleList()
self.blocks.append(NFNetBlock(
kernels, freq_downsample, in_ch, out_ch, group_size, alpha, 1.0/input_expected_var, stoch_depths[0], is_transition=True
))
expected_std = (input_expected_var**2.0 + alpha**2.0)**0.5
for i in range(1, num_blocks):
self.blocks.append(NFNetBlock(
kernels, 1, out_ch, out_ch, group_size, alpha, 1.0/expected_std, stoch_depths[i]
))
expected_std = (expected_std**2.0 + alpha**2.0)**0.5
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class FusionLayer(nn.Module):
def __init__(self, time_kernel_length, time_stride, in_ch, out_ch):
super().__init__()
self.conv1 = WeightStandardizedConv2d(in_ch, in_ch, kernel_size=(1, time_kernel_length), stride=(1, time_stride), padding="same")
self.conv2 = WeightStandardizedConv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
def forward(self, slow, fast):
fast_fused = self.conv1(fast)
fast_fused = self.conv2(fast_fused)
return torch.cat([slow, fast_fused], dim=1)
class Mule(PreTrainedModel, PyTorchModelHubMixin):
config_class = MuleConfig
def __init__(self, config: MuleConfig):
super().__init__(config)
self.slow_stem = self._make_stem([1, 16, 32, 64, 128], [(3, 1), (3, 1), (3, 1), (3, 3)], [(2, 8), (1, 1), (1, 1), (2, 2)])
self.fast_stem = self._make_stem([1, 2, 4, 8, 16], [(3, 3), (3, 3), (3, 3), (3, 3)], [(2, 2), (1, 1), (1, 1), (2, 2)])
nfnet_stage_depths = [x * (config.f_value + 1) for x in (1, 2, 6, 3)]
cumulative_stage_depths = np.concatenate(([0], np.cumsum(nfnet_stage_depths)))
stoch_depth_probs = 0.1 * np.arange(cumulative_stage_depths[-1]) / (cumulative_stage_depths[-1])
stoch_depth_probs_split = [stoch_depth_probs[st:end] for st, end in zip(cumulative_stage_depths[:-1], cumulative_stage_depths[1:])]
stage_expected_vars = [1.0, (1.0 + config.alpha**2)**0.5, (1.0 + config.alpha**2)**0.5, (1.0 + config.alpha**2)**0.5]
stage_downsamples = [1, 2, 2, 2]
slow_in_channels = [256, 512, 1024, 4608] # Updated to match fusion expansion
slow_out_channels = [256, 512, 1536, 1536]
slow_kernels = [[(1, 1), (1, 3), (3, 1), (1, 1)]] * 4
self.slow_stages = nn.ModuleList([
NFNetStage(slow_kernels[i], stage_downsamples[i], slow_in_channels[i], slow_out_channels[i], 128, config.alpha, stage_expected_vars[i], stoch_depth_probs_split[i], nfnet_stage_depths[i])
for i in range(4)
])
fast_in_channels = [16, 32, 64, 192]
fast_out_channels = [32, 64, 192, 192]
fast_kernels = [[(1, 1), (1, 3), (3, 1), (1, 1)]] * 4
self.fast_stages = nn.ModuleList([
NFNetStage(fast_kernels[i], stage_downsamples[i], fast_in_channels[i], fast_out_channels[i], 16, config.alpha, stage_expected_vars[i], stoch_depth_probs_split[i], nfnet_stage_depths[i])
for i in range(4)
])
self.fusion_layers = nn.ModuleList([
FusionLayer(7, 4, 16, 128),
FusionLayer(7, 4, 32, 256),
FusionLayer(7, 4, 64, 512),
FusionLayer(7, 4, 192, 3072)
])
self.final_act = ScaledActivation(config.scaled_activation_type)
projectors = []
in_dim = 1536 + 192
for dim in config.projector_layers:
projectors.append(nn.Linear(in_dim, dim))
projectors.append(ScaledActivation(config.scaled_activation_type))
in_dim = dim
if config.include_fc:
self.classifier = nn.Linear(in_dim, config.num_labels, bias=False)
else:
self.classifier = nn.Identity()
self.projectors = nn.Sequential(*projectors)
def _make_stem(self, channels, kernels, strides):
layers = nn.ModuleList()
for i in range(len(kernels)):
layers.append(WeightStandardizedConv2d(channels[i], channels[i+1], kernel_size=kernels[i], stride=strides[i], padding="same"))
if i < len(kernels) - 1:
layers.append(ScaledActivation())
return nn.Sequential(*layers)
def forward(self, x):
if x.dim() == 3:
x = x.unsqueeze(1)
slow = self.slow_stem(x)
fast = self.fast_stem(x)
for i in range(4):
slow = self.fusion_layers[i](slow, fast)
slow = self.slow_stages[i](slow)
fast = self.fast_stages[i](fast)
slow_out = slow.mean(dim=(2, 3))
fast_out = fast.mean(dim=(2, 3))
out = torch.cat([slow_out, fast_out], dim=1)
out = self.final_act(out)
out = self.projectors(out)
logits = self.classifier(out)
return logits
def preprocess(self, audio_file):
audio, sr = librosa.load(audio_file, sr=self.config.sample_rate)
mel = librosa.feature.melspectrogram(
y=audio, sr=sr, n_fft=self.config.n_fft, hop_length=self.config.hop_length,
win_length=self.config.win_length, n_mels=self.config.n_mels,
fmin=self.config.fmin, fmax=self.config.fmax
)
mel = np.log10(10000.0 * mel + 1.0)
return torch.from_numpy(mel).float()
def predict(self, audio_file):
# Auto-detect device and move model to it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(device)
mel = self.preprocess(audio_file).unsqueeze(0).to(device)
with torch.no_grad():
logits = self(mel)
probs = torch.sigmoid(logits)
return probs
def extract_embeddings(self, audio_file):
# Auto-detect device and move model to it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(device)
mel = self.preprocess(audio_file).unsqueeze(0).to(device)
return self.extract_embeddings_from_spec(mel)
def extract_embeddings_from_spec(self, mel):
with torch.no_grad():
# Ensure input is on the same device as model
if mel.device != next(self.parameters()).device:
mel = mel.to(next(self.parameters()).device)
slow = self.slow_stem(mel.unsqueeze(1) if mel.dim()==3 else mel)
fast = self.fast_stem(mel.unsqueeze(1) if mel.dim()==3 else mel)
for i in range(4):
slow = self.fusion_layers[i](slow, fast)
slow = self.slow_stages[i](slow)
fast = self.fast_stages[i](fast)
slow_out = slow.mean(dim=(2, 3))
fast_out = fast.mean(dim=(2, 3))
out = torch.cat([slow_out, fast_out], dim=1)
out = self.final_act(out)
emb = self.projectors(out)
return emb