| | import copy |
| | from typing import Optional, Tuple |
| | import random |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present |
| |
|
| | class Hubert(nn.Module): |
| | def __init__(self, num_label_embeddings: int = 100, mask: bool = True): |
| | super().__init__() |
| | self._mask = mask |
| | self.feature_extractor = FeatureExtractor() |
| | self.feature_projection = FeatureProjection() |
| | self.positional_embedding = PositionalConvEmbedding() |
| | self.norm = nn.LayerNorm(768) |
| | self.dropout = nn.Dropout(0.1) |
| | self.encoder = TransformerEncoder( |
| | nn.TransformerEncoderLayer( |
| | 768, 12, 3072, activation="gelu", batch_first=True |
| | ), |
| | 12, |
| | ) |
| | self.proj = nn.Linear(768, 256) |
| |
|
| | self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) |
| | self.label_embedding = nn.Embedding(num_label_embeddings, 256) |
| |
|
| | def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | mask = None |
| | if self.training and self._mask: |
| | mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) |
| | x[mask] = self.masked_spec_embed.to(x.dtype) |
| | return x, mask |
| |
|
| | def encode( |
| | self, x: torch.Tensor, layer: Optional[int] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | x = self.feature_extractor(x) |
| | x = self.feature_projection(x.transpose(1, 2)) |
| | x, mask = self.mask(x) |
| | x = x + self.positional_embedding(x) |
| | x = self.dropout(self.norm(x)) |
| | x = self.encoder(x, output_layer=layer) |
| | return x, mask |
| |
|
| | def logits(self, x: torch.Tensor) -> torch.Tensor: |
| | logits = torch.cosine_similarity( |
| | x.unsqueeze(2), |
| | self.label_embedding.weight.unsqueeze(0).unsqueeze(0), |
| | dim=-1, |
| | ) |
| | return logits / 0.1 |
| |
|
| | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | x, mask = self.encode(x) |
| | x = self.proj(x) |
| | logits = self.logits(x) |
| | return logits, mask |
| |
|
| |
|
| | class HubertSoft(Hubert): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | @torch.inference_mode() |
| | def units(self, wav: torch.Tensor) -> torch.Tensor: |
| | wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) |
| | x, _ = self.encode(wav) |
| | return self.proj(x) |
| |
|
| |
|
| | class FeatureExtractor(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) |
| | self.norm0 = nn.GroupNorm(512, 512) |
| | self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) |
| | self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) |
| | self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) |
| | self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) |
| | self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) |
| | self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = F.gelu(self.norm0(self.conv0(x))) |
| | x = F.gelu(self.conv1(x)) |
| | x = F.gelu(self.conv2(x)) |
| | x = F.gelu(self.conv3(x)) |
| | x = F.gelu(self.conv4(x)) |
| | x = F.gelu(self.conv5(x)) |
| | x = F.gelu(self.conv6(x)) |
| | return x |
| |
|
| |
|
| | class FeatureProjection(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.norm = nn.LayerNorm(512) |
| | self.projection = nn.Linear(512, 768) |
| | self.dropout = nn.Dropout(0.1) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.norm(x) |
| | x = self.projection(x) |
| | x = self.dropout(x) |
| | return x |
| |
|
| |
|
| | class PositionalConvEmbedding(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.conv = nn.Conv1d( |
| | 768, |
| | 768, |
| | kernel_size=128, |
| | padding=128 // 2, |
| | groups=16, |
| | ) |
| | self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.conv(x.transpose(1, 2)) |
| | x = F.gelu(x[:, :, :-1]) |
| | return x.transpose(1, 2) |
| |
|
| |
|
| | class TransformerEncoder(nn.Module): |
| | def __init__( |
| | self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int |
| | ) -> None: |
| | super(TransformerEncoder, self).__init__() |
| | self.layers = nn.ModuleList( |
| | [copy.deepcopy(encoder_layer) for _ in range(num_layers)] |
| | ) |
| | self.num_layers = num_layers |
| |
|
| | def forward( |
| | self, |
| | src: torch.Tensor, |
| | mask: torch.Tensor = None, |
| | src_key_padding_mask: torch.Tensor = None, |
| | output_layer: Optional[int] = None, |
| | ) -> torch.Tensor: |
| | output = src |
| | for layer in self.layers[:output_layer]: |
| | output = layer( |
| | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask |
| | ) |
| | return output |
| |
|
| |
|
| | def _compute_mask( |
| | shape: Tuple[int, int], |
| | mask_prob: float, |
| | mask_length: int, |
| | device: torch.device, |
| | min_masks: int = 0, |
| | ) -> torch.Tensor: |
| | batch_size, sequence_length = shape |
| |
|
| | if mask_length < 1: |
| | raise ValueError("`mask_length` has to be bigger than 0.") |
| |
|
| | if mask_length > sequence_length: |
| | raise ValueError( |
| | f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" |
| | ) |
| |
|
| | |
| | num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) |
| | num_masked_spans = max(num_masked_spans, min_masks) |
| |
|
| | |
| | if num_masked_spans * mask_length > sequence_length: |
| | num_masked_spans = sequence_length // mask_length |
| |
|
| | |
| | mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) |
| |
|
| | |
| | uniform_dist = torch.ones( |
| | (batch_size, sequence_length - (mask_length - 1)), device=device |
| | ) |
| |
|
| | |
| | mask_indices = torch.multinomial(uniform_dist, num_masked_spans) |
| |
|
| | |
| | mask_indices = ( |
| | mask_indices.unsqueeze(dim=-1) |
| | .expand((batch_size, num_masked_spans, mask_length)) |
| | .reshape(batch_size, num_masked_spans * mask_length) |
| | ) |
| | offsets = ( |
| | torch.arange(mask_length, device=device)[None, None, :] |
| | .expand((batch_size, num_masked_spans, mask_length)) |
| | .reshape(batch_size, num_masked_spans * mask_length) |
| | ) |
| | mask_idxs = mask_indices + offsets |
| |
|
| | |
| | mask = mask.scatter(1, mask_idxs, True) |
| |
|
| | return mask |
| |
|
| |
|
| | def hubert_soft( |
| | path: str |
| | ) -> HubertSoft: |
| | r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. |
| | Args: |
| | path (str): path of a pretrained model |
| | """ |
| | hubert = HubertSoft() |
| | checkpoint = torch.load(path) |
| | consume_prefix_in_state_dict_if_present(checkpoint, "module.") |
| | hubert.load_state_dict(checkpoint) |
| | hubert.eval() |
| | return hubert |
| |
|