#reference: https://github.com/NVlabs/AFNO-transformer import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .attention import TemporalAttention class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class AFNO2D(nn.Module): def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1): super().__init__() assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" self.hidden_size = hidden_size self.sparsity_threshold = sparsity_threshold self.num_blocks = num_blocks self.block_size = self.hidden_size // self.num_blocks self.hard_thresholding_fraction = hard_thresholding_fraction self.hidden_size_factor = hidden_size_factor self.scale = 0.02 self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) def forward(self, x): bias = x dtype = x.dtype x = x.float() B, H, W, C = x.shape x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho") x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) o2_real = torch.zeros(x.shape, device=x.device) o2_imag = torch.zeros(x.shape, device=x.device) total_modes = H // 2 + 1 kept_modes = int(total_modes * self.hard_thresholding_fraction) o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ self.b1[0] ) o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ self.b1[1] ) o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ self.b2[0] ) o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ self.b2[1] ) x = torch.stack([o2_real, o2_imag], dim=-1) x = F.softshrink(x, lambd=self.sparsity_threshold) x = torch.view_as_complex(x) x = x.reshape(B, H, W // 2 + 1, C) x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho") x = x.type(dtype) return x + bias class Block(nn.Module): def __init__( self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, double_skip=True, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1.0 ): super().__init__() self.norm1 = norm_layer(dim) self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.double_skip = double_skip def forward(self, x): residual = x x = self.norm1(x) x = self.filter(x) if self.double_skip: x = x + residual residual = x x = self.norm2(x) x = self.mlp(x) x = x + residual return x class AFNO3D(nn.Module): def __init__( self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1 ): super().__init__() assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" self.hidden_size = hidden_size self.sparsity_threshold = sparsity_threshold self.num_blocks = num_blocks self.block_size = self.hidden_size // self.num_blocks self.hard_thresholding_fraction = hard_thresholding_fraction self.hidden_size_factor = hidden_size_factor self.scale = 0.02 self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) def forward(self, x): bias = x dtype = x.dtype x = x.float() B, D, H, W, C = x.shape x = torch.fft.rfftn(x, dim=(1, 2, 3), norm="ortho") x = x.reshape(B, D, H, W // 2 + 1, self.num_blocks, self.block_size) o1_real = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) o1_imag = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) o2_real = torch.zeros(x.shape, device=x.device) o2_imag = torch.zeros(x.shape, device=x.device) total_modes = H // 2 + 1 kept_modes = int(total_modes * self.hard_thresholding_fraction) o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ self.b1[0] ) o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ self.b1[1] ) o2_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ self.b2[0] ) o2_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ self.b2[1] ) x = torch.stack([o2_real, o2_imag], dim=-1) x = F.softshrink(x, lambd=self.sparsity_threshold) x = torch.view_as_complex(x) x = x.reshape(B, D, H, W // 2 + 1, C) x = torch.fft.irfftn(x, s=(D, H, W), dim=(1,2,3), norm="ortho") x = x.type(dtype) return x + bias class AFNOBlock3d(nn.Module): def __init__( self, dim, mlp_ratio=4., drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, double_skip=True, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1.0, data_format="channels_last", mlp_out_features=None, ): super().__init__() self.norm_layer = norm_layer self.norm1 = norm_layer(dim) self.filter = AFNO3D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, out_features=mlp_out_features, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) self.double_skip = double_skip self.channels_first = (data_format == "channels_first") def forward(self, x): if self.channels_first: # AFNO natively uses a channels-last data format x = x.permute(0,2,3,4,1) residual = x x = self.norm1(x) x = self.filter(x) if self.double_skip: x = x + residual residual = x x = self.norm2(x) x = self.mlp(x) x = x + residual if self.channels_first: x = x.permute(0,4,1,2,3) return x class PatchEmbed3d(nn.Module): def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=256): super().__init__() self.patch_size = patch_size self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) x = x.permute(0,2,3,4,1) # convert to BHWC return x class PatchExpand3d(nn.Module): def __init__(self, patch_size=(4,4,4), out_chans=1, embed_dim=256): super().__init__() self.patch_size = patch_size self.proj = nn.Linear(embed_dim, out_chans*np.prod(patch_size)) def forward(self, x): x = self.proj(x) x = rearrange( x, "b d h w (p0 p1 p2 c_out) -> b c_out (d p0) (h p1) (w p2)", p0=self.patch_size[0], p1=self.patch_size[1], p2=self.patch_size[2], d=x.shape[1], h=x.shape[2], w=x.shape[3], ) return x class AFNOCrossAttentionBlock3d(nn.Module): """ AFNO 3D Block with channel mixing from two sources. """ def __init__( self, dim, context_dim, mlp_ratio=2., drop=0., act_layer=nn.GELU, norm_layer=nn.Identity, double_skip=True, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1.0, data_format="channels_last", timesteps=None ): super().__init__() self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim+context_dim) mlp_hidden_dim = int((dim+context_dim) * mlp_ratio) self.pre_proj = nn.Linear(dim+context_dim, dim+context_dim) self.filter = AFNO3D(dim+context_dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) self.mlp = Mlp( in_features=dim+context_dim, out_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) self.channels_first = (data_format == "channels_first") def forward(self, x, y): if self.channels_first: # AFNO natively uses a channels-last order x = x.permute(0,2,3,4,1) y = y.permute(0,2,3,4,1) xy = torch.concat((self.norm1(x),y), axis=-1) xy = self.pre_proj(xy) + xy xy = self.filter(self.norm2(xy)) + xy # AFNO filter x = self.mlp(xy) + x # feed-forward if self.channels_first: x = x.permute(0,4,1,2,3) return x