Pomilon
Deploy Aetheris to HF Space
1df0e33
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
"""Memory-efficient Feed-Forward Network expert with proper initialization."""
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.act = nn.GELU()
# Proper initialization to prevent NaN
nn.init.xavier_uniform_(self.w1.weight, gain=0.5)
nn.init.xavier_uniform_(self.w2.weight, gain=0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
# Force float32 for internal computation to prevent overflow in half precision
x = x.to(torch.float32)
# Cast weights to float32 for calculation
# This is necessary because the module weights might be float16
w1_weight = self.w1.weight.to(torch.float32)
w2_weight = self.w2.weight.to(torch.float32)
h = F.linear(x, w1_weight)
h = self.act(h)
out = F.linear(h, w2_weight)
# Clamp to avoid Inf when casting back to float16
if orig_dtype == torch.float16:
out = torch.clamp(out, min=-65500.0, max=65500.0)
return out.to(orig_dtype)