| import copy
|
| import os
|
| from typing import Optional
|
| import torch
|
| import torch.nn as nn
|
| from torch import Tensor
|
| import deepspeed
|
| from deepspeed import comm as dist
|
| from deepspeed.utils import groups, log_dist
|
| from deepspeed.utils.timer import SynchronizedWallClockTimer
|
| from deepspeed.moe.sharded_moe import FIRST_ALLTOALL_TIMER, MOE_TIMER, SECOND_ALLTOALL_TIMER, _AllToAll, einsum, gumbel_rsample
|
| from transformers.activations import ACT2FN
|
|
|
| def compress_matrix(A: torch.Tensor, mask: torch.Tensor, force_dim: int = None, allow_larger_dim=None) -> torch.Tensor:
|
| if A.shape[:2] != mask.shape:
|
| raise ValueError("First two dimensions of A and mask must match.")
|
| if mask.ndim != 2:
|
| raise ValueError("mask must be a 2D tensor.")
|
| if not ((mask == 0) | (mask == 1)).all():
|
| raise ValueError(
|
| f"mask must only contain 0s and 1s. dtype: {mask.dtype}. "
|
| f"Invalid elements found at indices: {((mask != 0) & (mask != 1)).nonzero().tolist()} "
|
| f"with corresponding values: {mask[((mask != 0) & (mask != 1))].tolist()}. "
|
| f"\nOriginal mask (showing up to first 20 elements if large):\n{mask.flatten()[:20]}{'...' if mask.numel() > 20 else ''}"
|
| )
|
|
|
| S, E = mask.shape
|
| trailing_dims_shape = A.shape[2:]
|
| num_trailing_dims = len(trailing_dims_shape)
|
| device = A.device
|
|
|
| ones_per_column = mask.sum(dim=0)
|
| X = ones_per_column.max().item() if force_dim is None else force_dim
|
|
|
| if X == 0:
|
| return torch.empty((0, E, *trailing_dims_shape), dtype=A.dtype, device=device)
|
|
|
| sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
|
| view_shape_for_indices = (S, E, *((1,) * num_trailing_dims))
|
| expanded_indices = sorted_row_indices_2d.view(view_shape_for_indices).expand_as(A)
|
|
|
| A_gathered = torch.gather(A, 0, expanded_indices)
|
|
|
| if X <= A_gathered.shape[0]:
|
| B_candidate = A_gathered[:X, ...]
|
| elif allow_larger_dim or allow_larger_dim is None:
|
| if allow_larger_dim is None:
|
| print(f"[Warning compress_matrix] Target dimension X ({X}) is larger than "
|
| f"A's original row count S ({S}). Padding B_candidate with zeros.")
|
| B_candidate = A_gathered
|
| zeros_shape = [X - A_gathered.shape[0]] + list(B_candidate.shape[1:])
|
| B_candidate = torch.cat((B_candidate, torch.zeros(zeros_shape, dtype=B_candidate.dtype, device=B_candidate.device)), dim=0)
|
| else:
|
| raise AssertionError(
|
| f"Target dimension X ({X}) is larger than A's original row count S ({S}) "
|
| f"and allow_larger_dim is False. Padding is disallowed."
|
| )
|
| row_indices_for_B = torch.arange(X, device=device).unsqueeze(1)
|
| b_mask_2d = row_indices_for_B < ones_per_column.unsqueeze(0)
|
| view_shape_for_b_mask = (X, E, *((1,) * num_trailing_dims))
|
| B = B_candidate * b_mask_2d.view(view_shape_for_b_mask).to(A.dtype)
|
|
|
| return B
|
|
|
|
|
| def decompress_matrix(B: torch.Tensor, mask: torch.Tensor, allow_larger_dim=None) -> torch.Tensor:
|
| if B.shape[1] != mask.shape[1]:
|
| raise ValueError("B's second dimension and mask's second dimension (E) must match.")
|
| if mask.ndim != 2:
|
| raise ValueError("mask must be a 2D tensor.")
|
| if not ((mask == 0) | (mask == 1)).all():
|
| raise ValueError("mask must only contain 0s and 1s.")
|
|
|
| S, E = mask.shape
|
| X = B.shape[0]
|
| trailing_dims_shape = B.shape[2:]
|
| num_trailing_dims = len(trailing_dims_shape)
|
| device = B.device
|
|
|
| if X == 0: return torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
|
| if X <= S: pass
|
| elif allow_larger_dim or allow_larger_dim is None:
|
| if allow_larger_dim is None:
|
| print(f"[Warning decompress_matrix] Input B.shape[0] ({X}) is larger than "
|
| f"target A's row count S ({S}). Truncating B to its first {S} rows.")
|
| B = B[:S, ...]
|
| X = S
|
| else:
|
| raise AssertionError(
|
| f"Input B.shape[0] ({X}) is larger than target A's row count S ({S}) "
|
| f"and allow_larger_dim is False. Truncation is disallowed."
|
| )
|
|
|
| sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
|
| target_A_row_indices_2d = sorted_row_indices_2d[:X, :]
|
| A_reconstructed = torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
|
| view_shape_for_target_indices = (X, E, *((1,) * num_trailing_dims))
|
| expanded_target_indices = target_A_row_indices_2d.view(view_shape_for_target_indices).expand_as(B)
|
| A_reconstructed.scatter_(dim=0, index=expanded_target_indices, src=B)
|
|
|
| return A_reconstructed
|
|
|
|
|
|
|
| class AudioSharedExpertMLP(nn.Module):
|
| """
|
| Shared expert MLP for UniMoE-Audio model.
|
| Handles common audio feature transformations across all tokens.
|
| """
|
| def __init__(self, config):
|
| super().__init__()
|
| self.hidden_size = config.hidden_size
|
| self.intermediate_size = config.shared_intermediate_size
|
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| self.act_fn = ACT2FN[config.hidden_act]
|
|
|
| def forward(self, hidden_state):
|
| return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
| class AudioDynamicExpertMLP(nn.Module):
|
| """
|
| Dynamic expert MLP for UniMoE-Audio model.
|
| Specialized for adaptive audio feature processing based on content.
|
| """
|
| def __init__(self, config):
|
| super().__init__()
|
| self.hidden_size = config.hidden_size
|
| self.intermediate_size = config.dynamic_intermediate_size
|
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| self.act_fn = ACT2FN[config.hidden_act]
|
|
|
| def forward(self, hidden_state):
|
| return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
| class AudioNullExpertMLP(nn.Module):
|
| """
|
| Null expert MLP for UniMoE-Audio model.
|
| Returns zero output for tokens that don't require expert processing.
|
| """
|
| def __init__(self, config):
|
| super().__init__()
|
|
|
| def forward(self, hidden_state):
|
| return torch.zeros_like(hidden_state, dtype=hidden_state.dtype, device=hidden_state.device)
|
|
|
|
|
| def audio_sparse_expert_mixer(scores, top_k, jitter_eps, training):
|
| """
|
| Sparse expert mixing function for UniMoE-Audio.
|
| Implements adaptive expert selection with noise injection for training.
|
| """
|
| masked_scores = scores
|
| multiplier_list = []
|
| selected_experts_list = []
|
|
|
| for _ in range(top_k):
|
| with torch.no_grad():
|
| mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
|
| factor = scores.abs().clamp(min=mask_logits_threshold.abs())
|
| mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
|
|
|
| masked_gates = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
|
|
|
| selected_experts = max_ind
|
|
|
| masked_gates = torch.softmax(masked_gates, dim=-1)
|
| multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
|
|
|
| multiplier = multiplier_o
|
|
|
| masked_scores = torch.scatter(
|
| masked_scores,
|
| -1,
|
| selected_experts,
|
| float("-inf"),
|
| )
|
|
|
| multiplier_list.append(multiplier)
|
| selected_experts_list.append(selected_experts)
|
|
|
| multiplier = torch.concat(multiplier_list, dim=-1)
|
| selected_experts = torch.concat(selected_experts_list, dim=-1)
|
| return (
|
| multiplier,
|
| selected_experts,
|
| )
|
|
|
|
|
| def audio_dynamic_expert_selection(logits, top_p):
|
| """
|
| Dynamic expert selection for UniMoE-Audio based on cumulative probability threshold.
|
| Adapts the number of experts based on audio content complexity.
|
| """
|
| dynamic_scores = torch.softmax(logits, dim=-1)
|
| dynamic_scores_sorted, _ = torch.sort(dynamic_scores, dim=-1, descending=True)
|
| dynamic_scores_cumsum = dynamic_scores_sorted.cumsum(dim=-1)
|
| dynamic_top_k = (~(dynamic_scores_cumsum >= top_p)).sum(dim=-1)
|
| dynamic_top_k = dynamic_top_k + 1
|
| return dynamic_top_k
|
|
|
|
|
| def _audio_expert_capacity(num_tokens, num_experts, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
|
| """Calculate expert capacity for UniMoE-Audio based on token distribution and capacity factor."""
|
| capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
|
| if capacity < min_capacity:
|
| capacity = min_capacity.to(torch.int64)
|
| return capacity
|
|
|
|
|
| def calculate_audio_global_routing_weight(
|
| expert_mask: torch.Tensor,
|
| full_router_logits: torch.Tensor,
|
| mlp_dynamic_expert_num: int,
|
| routing_weights: torch.Tensor,
|
| ):
|
| """
|
| Calculate global routing weights for UniMoE-Audio combining dynamic and fixed expert weights.
|
| Optimized for audio generation tasks.
|
| """
|
| global_weight = torch.softmax(full_router_logits.masked_fill(expert_mask == 0, float("-inf")), dim=-1)
|
| global_dynamic_weight = global_weight[:, :mlp_dynamic_expert_num]
|
| global_fixed_weight = global_weight[:, mlp_dynamic_expert_num:]
|
| global_dynamic_weight = routing_weights * global_dynamic_weight.sum(-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1])
|
| global_weight = torch.cat((global_dynamic_weight, global_fixed_weight), dim=-1)
|
| return global_weight
|
|
|
|
|
| class UniMoEAudioSparseMoeBlock(nn.Module):
|
| """
|
| UniMoE-Audio Sparse Mixture of Experts block with dynamic routing and expert selection.
|
| Optimized for audio generation tasks with efficient sparse operations and capacity management.
|
| """
|
|
|
| def __init__(self, config):
|
| super().__init__()
|
| self.hidden_dim = config.hidden_size
|
| self.mlp_dynamic_expert_num = config.mlp_dynamic_expert_num + config.mlp_dynamic_null_expert_num
|
| self.mlp_dynamic_real_expert_num = config.mlp_dynamic_expert_num
|
| self.mlp_dynamic_null_expert_num = config.mlp_dynamic_null_expert_num
|
| self.mlp_dynamic_top_p = config.mlp_dynamic_top_p
|
| self.mlp_dynamic_top_k = config.mlp_dynamic_top_k
|
| self.mlp_fixed_expert_num = config.mlp_fixed_expert_num
|
| self.num_experts = self.mlp_dynamic_expert_num + self.mlp_fixed_expert_num
|
|
|
| if self.mlp_dynamic_top_p == 0:
|
| print(f"mlp_dynamic_top_p is 0, will use mlp_dynamic_top_k={self.mlp_dynamic_top_k} instead !!!")
|
|
|
| self.ignore_differentiable_router = config.ignore_differentiable_router
|
| if self.ignore_differentiable_router:
|
| print("ignore_differentiable_router is True, will not use router_logits !!!")
|
|
|
| self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
| self.fixed_real_moe = nn.ModuleList([AudioSharedExpertMLP(config) for _ in range(self.mlp_fixed_expert_num)])
|
| self.dynamic_real_moe = UniMoEAudioMoE(config, AudioDynamicExpertMLP(config), self.mlp_dynamic_real_expert_num, config.ep_size)
|
|
|
| self.router_jitter_noise = config.router_jitter_noise
|
| self.input_jitter_noise = config.input_jitter_noise
|
|
|
| self.min_capacity = config.min_capacity
|
| self.capacity_factor = config.capacity_factor
|
| self.token_drop = config.token_drop
|
| self.drop_policy = config.drop_policy
|
|
|
| self.avg_hidden_states_last = config.avg_hidden_states_last
|
| self.drop_token_num_print = config.drop_token_num_print
|
| self.fp32_gate = config.fp32_gate
|
|
|
| def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, aux_balance_weight: torch.Tensor=None):
|
| batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| original_hidden_states = hidden_states
|
|
|
| if self.training and self.fp32_gate:
|
| hidden_states = hidden_states.float()
|
|
|
| if self.training and self.input_jitter_noise > 0:
|
| hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise)
|
|
|
| hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
| if self.training and self.fp32_gate:
|
| full_router_logits = torch.nn.functional.linear(hidden_states, weight=self.gate.weight.float(), bias=None)
|
| else:
|
| full_router_logits = self.gate(hidden_states)
|
| dynamic_router_logits = full_router_logits[:, : self.mlp_dynamic_expert_num]
|
|
|
| if self.mlp_dynamic_top_p != 0:
|
| dynamic_top_k = audio_dynamic_expert_selection(dynamic_router_logits, self.mlp_dynamic_top_p)
|
| else:
|
| dynamic_top_k = torch.full((dynamic_router_logits.shape[0],), self.mlp_dynamic_top_k, dtype=torch.int, device=dynamic_router_logits.device)
|
|
|
| expert_mask = torch.zeros((batch_size * sequence_length, self.num_experts), dtype=torch.int, device=hidden_states.device)
|
|
|
| routing_weights = torch.zeros((batch_size * sequence_length, self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
|
| for top_k in range(1, self.mlp_dynamic_expert_num + 1):
|
| group_idx = torch.nonzero(dynamic_top_k == top_k, as_tuple=True)[0]
|
| if len(group_idx) == 0:
|
| continue
|
|
|
| dynamic_group_logits = dynamic_router_logits[group_idx]
|
| group_routing_weights, group_selected_experts = audio_sparse_expert_mixer(
|
| dynamic_group_logits,
|
| top_k=top_k,
|
| jitter_eps=self.router_jitter_noise,
|
| training=self.training and not self.ignore_differentiable_router,
|
| )
|
|
|
| group_expert_mask = torch.nn.functional.one_hot(group_selected_experts, num_classes=self.num_experts)
|
| group_expert_mask = group_expert_mask.sum(dim=1)
|
|
|
| group_weight = torch.zeros((len(group_idx), self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
|
| group_weight.scatter_(dim=-1, index=group_selected_experts, src=group_routing_weights)
|
| routing_weights.index_add_(0, group_idx, group_weight)
|
|
|
| expert_mask.index_add_(0, group_idx, group_expert_mask.to(expert_mask.dtype))
|
|
|
| routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
|
|
|
| if attention_mask is not None:
|
| attention_mask = attention_mask.to(expert_mask.dtype).view(-1).unsqueeze(-1).expand(-1, self.num_experts)
|
| expert_mask = expert_mask * attention_mask
|
|
|
| if self.mlp_dynamic_expert_num < self.num_experts:
|
| expert_mask[:, self.mlp_dynamic_expert_num :] = 1
|
|
|
| aux_loss = audio_load_balancing_loss_func(
|
| expert_mask=expert_mask,
|
| mlp_dynamic_expert_num=self.mlp_dynamic_expert_num,
|
| global_weight=None,
|
| full_router_logits=full_router_logits,
|
| routing_weights=routing_weights,
|
| aux_balance_weight=aux_balance_weight,
|
| )
|
|
|
| if self.token_drop:
|
| expert_mask_dtype = expert_mask.dtype
|
| capacity = _audio_expert_capacity(batch_size * sequence_length, self.mlp_dynamic_expert_num, torch.tensor(self.capacity_factor), torch.tensor(self.min_capacity))
|
| if self.drop_policy == "probs":
|
| if capacity > dynamic_router_logits.shape[0]:
|
| print(f"[warning] token capacity({capacity}) > token num({dynamic_router_logits.shape[0]}), setting capacity=token num")
|
| capacity = dynamic_router_logits.shape[0]
|
| dynamic_expert_mask = expert_mask[:, : self.mlp_dynamic_expert_num].bool()
|
| token_drop_router_logits = torch.masked_fill(dynamic_router_logits, ~dynamic_expert_mask, torch.finfo(dynamic_router_logits.dtype).min)
|
| capacity_probs, capacity_indices = torch.topk(token_drop_router_logits, k=capacity, dim=0, sorted=False)
|
| capacity_mask = torch.zeros_like(expert_mask).scatter(0, capacity_indices, 1)
|
| capacity_mask[:, self.mlp_dynamic_expert_num :] = 1
|
| expert_mask = torch.logical_and(expert_mask, capacity_mask)
|
|
|
| ori_token_num = dynamic_expert_mask.sum().item()
|
| cur_token_num = expert_mask[:, : self.mlp_dynamic_expert_num].sum().item()
|
| if self.drop_token_num_print and ("RANK" not in os.environ or int(os.environ["RANK"]) == 0):
|
| print(f"drop {ori_token_num - cur_token_num} tokens from total {ori_token_num} tokens")
|
|
|
| elif self.drop_policy == "position":
|
| locations = torch.cumsum(expert_mask, dim=0) - 1
|
| expert_mask *= torch.lt(locations, capacity)
|
| else:
|
| raise ValueError(f"Invalid drop_policy: {self.drop_policy}")
|
| expert_mask = expert_mask.to(expert_mask_dtype)
|
|
|
| routing_weights = routing_weights.masked_fill(~(expert_mask[:, : self.mlp_dynamic_expert_num].bool()), 0.0)
|
| routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
|
|
|
| if self.mlp_dynamic_expert_num < self.num_experts:
|
| global_weight = calculate_audio_global_routing_weight(expert_mask, full_router_logits, self.mlp_dynamic_expert_num, routing_weights)
|
| else:
|
| global_weight = routing_weights
|
|
|
| hidden_states = original_hidden_states.view(-1, hidden_dim)
|
|
|
| final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
|
| global_weight = global_weight.to(hidden_states.dtype)
|
|
|
| current_hidden_states = self.dynamic_real_moe(hidden_states, expert_mask=expert_mask[:, : self.mlp_dynamic_real_expert_num], router_weight=global_weight[:, : self.mlp_dynamic_real_expert_num])
|
| final_hidden_states = final_hidden_states + current_hidden_states
|
|
|
| for expert_idx in range(self.mlp_fixed_expert_num):
|
| expert_layer = self.fixed_real_moe[expert_idx]
|
|
|
| current_state = hidden_states
|
| current_global_weight = global_weight[:, self.mlp_dynamic_expert_num + expert_idx].unsqueeze(-1)
|
| current_hidden_states = expert_layer(current_state) * current_global_weight
|
|
|
| final_hidden_states = final_hidden_states + current_hidden_states
|
|
|
| final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
|
| if not self.training and self.avg_hidden_states_last:
|
| dist.all_reduce(final_hidden_states, op=dist.ReduceOp.AVG, group=self.dynamic_real_moe.deepspeed_moe.ep_group)
|
|
|
| return final_hidden_states, full_router_logits, dynamic_top_k, expert_mask, global_weight, aux_loss
|
|
|
|
|
| def audio_load_balancing_loss_func(
|
| expert_mask: torch.Tensor,
|
| mlp_dynamic_expert_num: int,
|
| global_weight: Optional[torch.Tensor] = None,
|
| full_router_logits: Optional[torch.Tensor] = None,
|
| routing_weights: Optional[torch.Tensor] = None,
|
| aux_balance_weight: Optional[torch.Tensor] = None,
|
| ) -> float:
|
| """Calculate load balancing loss for UniMoE-Audio expert routing to encourage balanced usage."""
|
| min_dtype = torch.finfo(full_router_logits.dtype).min
|
| global_weight = full_router_logits.masked_fill(expert_mask == 0, min_dtype)
|
| global_weight = global_weight[:, :mlp_dynamic_expert_num]
|
| global_weight = torch.softmax(global_weight, dim=-1)
|
| expert_mask = expert_mask[:, :mlp_dynamic_expert_num]
|
|
|
| num_experts = expert_mask.shape[-1]
|
| if aux_balance_weight is None:
|
| tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| router_prob_per_expert = torch.mean(global_weight, dim=0)
|
| else:
|
| batch_size, sequence_length = aux_balance_weight.shape
|
| num_hidden_layers = global_weight.shape[0] // (batch_size * sequence_length)
|
| expert_attention_mask = aux_balance_weight[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(global_weight.device)
|
| tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
|
| router_prob_per_expert = torch.sum(global_weight * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
|
|
|
| overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
|
|
|
| return overall_loss * num_experts
|
|
|
|
|
| class AudioExperts(deepspeed.moe.experts.Experts):
|
| """Custom Audio experts class extending DeepSpeed MoE experts with additional functionality."""
|
|
|
| def __init__(self, expert, num_local_experts=1, expert_group_name=None):
|
| super(deepspeed.moe.experts.Experts, self).__init__()
|
|
|
| self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
| self.num_local_experts = num_local_experts
|
|
|
| for expert in self.deepspeed_experts:
|
| for name, param in expert.named_parameters():
|
| param.allreduce = False
|
| param.group_name = expert_group_name
|
|
|
| def forward(self, inputs):
|
| chunks = inputs.chunk(self.num_local_experts, dim=1)
|
| expert_outputs = []
|
| for chunk, expert in zip(chunks, self.deepspeed_experts):
|
| out = expert(chunk)
|
| if type(out) is tuple:
|
| out = out[0]
|
| expert_outputs += [out]
|
|
|
| expert_output = torch.cat(expert_outputs, dim=1)
|
| return expert_output
|
|
|
|
|
| class AudioMOELayer(deepspeed.moe.sharded_moe.MOELayer):
|
| """Custom Audio MoE layer extending DeepSpeed MOELayer with matrix compression optimization."""
|
|
|
| def __init__(
|
| self,
|
| experts: nn.Module,
|
| ep_group_name,
|
| ep_size,
|
| num_local_experts: int,
|
| use_tutel: bool = False,
|
| ) -> None:
|
| super(deepspeed.moe.sharded_moe.MOELayer, self).__init__()
|
|
|
| self.experts = experts
|
| self.ep_group = None
|
| self.ep_size = ep_size
|
| self.ep_group_name = ep_group_name
|
| self.num_local_experts = num_local_experts
|
| self.time_falltoall = 0.0
|
| self.time_salltoall = 0.0
|
| self.time_moe = 0.0
|
| self.timers = SynchronizedWallClockTimer()
|
| self.wall_clock_breakdown = False
|
|
|
| def _set_ep_group(self, ep_group):
|
| self.ep_group = ep_group
|
|
|
| def forward(self, hidden_states: Tensor, expert_mask: Tensor, router_weight: Tensor) -> Tensor:
|
| router_weight = router_weight * expert_mask
|
|
|
| if self.wall_clock_breakdown:
|
| self.timers(MOE_TIMER).start()
|
|
|
| d_model = hidden_states.shape[-1]
|
| seq_len = hidden_states.shape[0]
|
| expert_num = expert_mask.shape[-1]
|
| capacity = expert_mask.sum(dim=0).max()
|
| if self.ep_group is not None:
|
| dist.all_reduce(capacity, op=dist.ReduceOp.MAX, group=self.ep_group)
|
|
|
| compres_hidden_states = hidden_states.unsqueeze(1).expand(seq_len, expert_num, d_model)
|
| compres_hidden_states = compress_matrix(compres_hidden_states, expert_mask, force_dim=capacity, allow_larger_dim=True)
|
| compres_expert_mask = compress_matrix(expert_mask, expert_mask, force_dim=capacity, allow_larger_dim=True)
|
| dispatched_input = einsum("ce,cem->ecm", compres_expert_mask, compres_hidden_states)
|
|
|
| if self.wall_clock_breakdown:
|
| self.timers(FIRST_ALLTOALL_TIMER).start()
|
|
|
| dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
|
|
|
| if self.wall_clock_breakdown:
|
| self.timers(FIRST_ALLTOALL_TIMER).stop()
|
| self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
|
|
|
| dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
|
|
| expert_output = self.experts(dispatched_input)
|
|
|
| if self.wall_clock_breakdown:
|
| self.timers(SECOND_ALLTOALL_TIMER).start()
|
|
|
| expert_output = _AllToAll.apply(self.ep_group, expert_output)
|
|
|
| if self.wall_clock_breakdown:
|
| self.timers(SECOND_ALLTOALL_TIMER).stop()
|
| self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
|
|
|
| expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
| expert_output = decompress_matrix(expert_output.transpose(0, 1), expert_mask, allow_larger_dim=True)
|
| combined_output = einsum("se,sem->sm", router_weight, expert_output)
|
| if self.wall_clock_breakdown:
|
| self.timers(MOE_TIMER).stop()
|
| self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)
|
|
|
| return combined_output
|
|
|
|
|
| class UniMoEAudioMoE(deepspeed.moe.layer.MoE):
|
| """Custom Audio MoE class extending DeepSpeed MoE with configuration and parallelism setup."""
|
|
|
| def __init__(self, config, expert, num_experts, ep_size, moe_name_prefix="ep_size"):
|
| super(deepspeed.moe.layer.MoE, self).__init__()
|
| self.enable_expert_tensor_parallelism = config.enable_expert_tensor_parallelism
|
| self.ep_size = ep_size
|
| self.num_experts = num_experts
|
| self.expert_group_name = f"{moe_name_prefix}_{self.ep_size}"
|
| self.num_local_experts = self.num_experts // self.ep_size
|
| log_dist(f"Creating MoE layer with num_experts: {self.num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}", [0])
|
| experts = AudioExperts(expert, self.num_local_experts, self.expert_group_name)
|
| self.deepspeed_moe = AudioMOELayer(experts, self.expert_group_name, self.ep_size, self.num_local_experts)
|
|
|
| def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False):
|
| self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
|
|
| def _create_process_groups(self, use_data_before_expert_parallel_=False):
|
| if self.expert_group_name not in groups._get_expert_parallel_group_dict():
|
| print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
|
| if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
|
| groups._create_expert_and_data_parallel(self.ep_size, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
| else:
|
| groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
| self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
|
|
|
| def forward(self, *input_args, **input_kwargs):
|
| return self.deepspeed_moe(*input_args, **input_kwargs)
|
|
|