File size: 1,318 Bytes
1df0e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    """Memory-efficient Feed-Forward Network expert with proper initialization."""
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.act = nn.GELU()

        # Proper initialization to prevent NaN
        nn.init.xavier_uniform_(self.w1.weight, gain=0.5)
        nn.init.xavier_uniform_(self.w2.weight, gain=0.5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_dtype = x.dtype
        # Force float32 for internal computation to prevent overflow in half precision
        x = x.to(torch.float32)
        
        # Cast weights to float32 for calculation
        # This is necessary because the module weights might be float16
        w1_weight = self.w1.weight.to(torch.float32)
        w2_weight = self.w2.weight.to(torch.float32)
        
        h = F.linear(x, w1_weight)
        h = self.act(h)
        out = F.linear(h, w2_weight)
        
        # Clamp to avoid Inf when casting back to float16
        if orig_dtype == torch.float16:
            out = torch.clamp(out, min=-65500.0, max=65500.0)
            
        return out.to(orig_dtype)