Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Expert(nn.Module): | |
| """Memory-efficient Feed-Forward Network expert with proper initialization.""" | |
| def __init__(self, d_model: int, d_ff: int): | |
| super().__init__() | |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) | |
| self.w2 = nn.Linear(d_ff, d_model, bias=False) | |
| self.act = nn.GELU() | |
| # Proper initialization to prevent NaN | |
| nn.init.xavier_uniform_(self.w1.weight, gain=0.5) | |
| nn.init.xavier_uniform_(self.w2.weight, gain=0.5) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| orig_dtype = x.dtype | |
| # Force float32 for internal computation to prevent overflow in half precision | |
| x = x.to(torch.float32) | |
| # Cast weights to float32 for calculation | |
| # This is necessary because the module weights might be float16 | |
| w1_weight = self.w1.weight.to(torch.float32) | |
| w2_weight = self.w2.weight.to(torch.float32) | |
| h = F.linear(x, w1_weight) | |
| h = self.act(h) | |
| out = F.linear(h, w2_weight) | |
| # Clamp to avoid Inf when casting back to float16 | |
| if orig_dtype == torch.float16: | |
| out = torch.clamp(out, min=-65500.0, max=65500.0) | |
| return out.to(orig_dtype) | |