import torch import torch.nn as nn import torch.nn.functional as F from ..config import AetherisConfig def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor) -> torch.Tensor: """Memory-efficient scan with reduced intermediate tensors.""" B_size, L, D_inner = u.shape D_state = A.shape[-1] # Use in-place operations where possible h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=u.dtype) ys = [] for l in range(L): dt = delta[:, l, :].unsqueeze(-1) dA = torch.exp(dt * A) B_l = B[:, l, :].unsqueeze(1) dB = dt * B_l u_t = u[:, l, :].unsqueeze(-1) h = dA * h + dB * u_t C_l = C[:, l, :].unsqueeze(1) y_t = torch.sum(h * C_l, dim=-1) ys.append(y_t) y = torch.stack(ys, dim=1) return y + u * D class SSMBlock(nn.Module): """Memory-optimized State Space Model with stability improvements.""" def __init__(self, config: AetherisConfig): super().__init__() self.d_model = config.d_model self.d_state = config.ssm_d_state self.d_inner = config.d_inner self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=False) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False) self.conv_d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=3, padding=2, groups=self.d_inner, bias=False) self.gate_proj = nn.Linear(self.d_model, self.d_inner, bias=False) self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False) self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False) self.delta_proj = nn.Linear(self.d_inner, self.d_inner, bias=False) # Initialize A to be more stable (closer to -1) self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state) * 0.1 - 4.0) self.D = nn.Parameter(torch.ones(self.d_inner) * 0.1) self.act = nn.SiLU() self.norm = nn.LayerNorm(config.d_model) # Proper initialization nn.init.xavier_uniform_(self.in_proj.weight, gain=0.5) nn.init.xavier_uniform_(self.out_proj.weight, gain=0.5) nn.init.xavier_uniform_(self.gate_proj.weight, gain=0.5) nn.init.xavier_uniform_(self.B_proj.weight, gain=0.5) nn.init.xavier_uniform_(self.C_proj.weight, gain=0.5) nn.init.xavier_uniform_(self.delta_proj.weight, gain=0.5) def forward(self, x: torch.Tensor) -> torch.Tensor: B, L, D = x.shape x_norm = self.norm(x) xz = self.in_proj(x_norm) x_in, z_gate = xz.chunk(2, dim=-1) x_conv = self.conv_d(x_in.transpose(1, 2)) # Slice off the last 2 elements (the "future" leakage) x_conv = x_conv[:, :, :-2].transpose(1, 2) x_conv = self.act(x_conv) # Add small epsilon to prevent numerical issues and clamp max value delta = torch.clamp(F.softplus(self.delta_proj(x_conv)), max=5.0) + 1e-4 B_ssm = self.B_proj(x_conv) C_ssm = self.C_proj(x_conv) # Clamp A to prevent extreme values A_fixed = -torch.exp(torch.clamp(self.A_log, min=-10.0, max=2.0)) A_batched = A_fixed.unsqueeze(0).expand(B, -1, -1) y_ssm = selective_scan_native(x_conv, delta, A_batched, B_ssm, C_ssm, self.D) y_gate = F.silu(self.gate_proj(x_norm)) * y_ssm output = self.out_proj(y_gate) return x + output