import torch
import torch.nn as nn
from .modules.swin import SwinTransformerBlock, RoPESwinTransformerBlock
from .modules.sinusoidal import SinusoidalEmbedding
from torch.utils.checkpoint import checkpoint

class SwinBlock(nn.Module):
    def __init__(self, d_main, nhead, num_layers, ffn_multiplier, embd_dim, window_size):
        super().__init__()
        self.num_layers = num_layers
        self.layers_no_shift = nn.ModuleList([
            SwinTransformerBlock(d_main, (64, 64), nhead, embd_dim, window_size, 0, ffn_multiplier)
            for i in range(num_layers)
        ])
        self.layers_shift = nn.ModuleList([
            SwinTransformerBlock(d_main, (64, 64), nhead, embd_dim, window_size, window_size//2, ffn_multiplier)
            for i in range(num_layers)
        ])

    def forward(self, x, embd):
        for layer_n, layer_s in zip(self.layers_no_shift, self.layers_shift):
            x = layer_n(x, embd) if not self.training else checkpoint(run_block, layer_n, x, embd, use_reentrant=False)
            x = layer_s(x, embd) if not self.training else checkpoint(run_block, layer_s, x, embd, use_reentrant=False)
        return x


def run_block(module, *args):
    return module(*args)

class Network(nn.Module):
    def __init__(self, nhead=2):
        super().__init__()
        embd_dim = 64
        self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
        self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
        embd_dim *= 2
        self.embd_mlp = nn.Sequential(nn.Linear(embd_dim, embd_dim),
                                      nn.LayerNorm(embd_dim),
                                      nn.LeakyReLU(inplace=True))

        self.patch_size = 16
        self.patch_dim = (self.patch_size**2) * 3
        self.to_patches = nn.Unfold(kernel_size=(self.patch_size, self.patch_size),
                                    stride=(self.patch_size, self.patch_size))
        self.patch_up = nn.Linear(self.patch_dim, self.patch_dim)
        self.dit = SwinBlock(self.patch_dim, nhead, 4, 3, embd_dim, 16)

        self.patch_down = nn.Linear(self.patch_dim, self.patch_dim//3)

    def initialize(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name or 'emb_proj' in name):
                m.weight.data.zero_()
                m.bias.data.zero_()

    def forward(self, x, ridge_map, basin_map, water_level, t):
        t_embed = self.time_embd(t).to(x.dtype)
        waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
        embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
        embeds = self.embd_mlp(embeds)

        x = torch.cat([x, ridge_map, basin_map], dim=1)
        B,C,H,W = x.shape
        #num_patch = (H // self.patch_size) * (W // self.patch_size)
        x = self.to_patches(x).permute(0, 2, 1)
        x = self.patch_up(x)
        #x = x.view(B, num_patch, self.patch_dim)  # [B, num_patches, d]
        x = self.dit(x, embeds)# if not self.training else checkpoint(run_block, self.dit, x, embeds, use_reentrant=False)
        #x = x.view(B, C, H, W)
        x = self.patch_down(x)
        x = x.permute(0, 2, 1)  # [B, d, num_patches]
        x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
        return x