import torch
import torch.nn as nn
from .modules.res_blocks import ResConvBlock
from .modules.sinusoidal import SinusoidalEmbedding
from .modules.dit_blocks import DitLayer
from torch.utils.checkpoint import checkpoint


class DitBlock(nn.Module):
    def __init__(self, d_main, nhead, num_layers, ffn_multiplier, embd_dim):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            DitLayer(d_main, nhead, int(d_main * ffn_multiplier), embd_dim, 'vanilla')
            for i in range(num_layers)
        ])

    def forward(self, x, embd, grid_size):
        for layer in self.layers:
            x = layer(x, embd, grid_size)# if not self.training else checkpoint(run_block, layer, x, embd, grid_size, use_reentrant=False)
        return x

def run_block(module, *args):
    return module(*args)

class Network(nn.Module):
    def __init__(self, nhead=2):
        super().__init__()
        embd_dim = 64
        self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
        self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
        embd_dim *= 2
        self.embd_mlp = nn.Sequential(nn.Linear(embd_dim, embd_dim),
                                      nn.LeakyReLU(inplace=True))

        self.patch_size = 16
        patch_dim = (self.patch_size**2) * 3
        self.to_patches = nn.Unfold(kernel_size=(self.patch_size, self.patch_size),
                                    stride=(self.patch_size, self.patch_size))
        self.dit = DitBlock(patch_dim, nhead, 2, 2, embd_dim)

        self.dec = nn.Conv2d(3, 1, 1)

    def initialize(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name or 'emb_proj' in name):
                m.weight.data.zero_()
                m.bias.data.zero_()
            if isinstance(m, nn.Conv2d) and 'second_conv' in name:
                m.weight.data.zero_()
                m.bias.data.zero_()

    def forward(self, x, ridge_map, basin_map, water_level, t):
        t_embed = self.time_embd(t).to(x.dtype)
        waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
        embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
        embeds = self.embd_mlp(embeds)

        x = torch.cat([x, ridge_map, basin_map], dim=1)
        B,C,H,W = x.shape
        grid_size = (H//self.patch_size, W//self.patch_size)
        x = self.to_patches(x).permute(0, 2, 1)  # [B, num_patches, d]
        x = self.dit(x, embeds, grid_size)# if not self.training else checkpoint(run_block, self.dit, x, embeds, grid_size, use_reentrant=False)
        x = x.permute(0, 2, 1)  # [B, d, num_patches]
        x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))

        x = self.dec(x)
        return x