import torch
import torch.nn as nn
from .modules.res_blocks import ResConvBlock
from .modules.sinusoidal import SinusoidalEmbedding
from .modules.dit_blocks import DitBlock
from .network_pure_dit import ResConvBlockWCC
from torch.utils.checkpoint import checkpoint

class Network(nn.Module):
    def __init__(self, embd_dim=32, nhead=4):
        super().__init__()
        self.patch_size = 8
        patch_dim = (self.patch_size**2) * 6
        self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
        self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
        embd_dim *= 2
        self.embd_mlp = nn.Sequential(nn.Linear(embd_dim, embd_dim),
                                      nn.SiLU(inplace=True))
        padding_mode = 'replicate'

        #self.exp = nn.Conv2d(3, 4, 3, 1, 1, padding_mode=padding_mode)
        self.enc = ResConvBlock(3, embd_dim, padding_mode)
        self.down = nn.Conv2d(3, 6, 4, 2, 1, padding_mode=padding_mode)
        self.to_patches = nn.Unfold(kernel_size=(self.patch_size, self.patch_size),
                                    stride=(self.patch_size, self.patch_size))
        self.dit = DitBlock(patch_dim, nhead, 14, 2, embd_dim)
        self.up = nn.ConvTranspose2d(6, 3, 4, stride=2, padding=1)
        self.dec = ResConvBlock(3, embd_dim, padding_mode)
        self.final = nn.Conv2d(3, 1, 1)

    def initialize(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name or 'emb_proj' in name):
                m.weight.data.zero_()
                m.bias.data.zero_()
            if isinstance(m, nn.Conv2d) and 'second_conv' in name:
                m.weight.data.zero_()
                m.bias.data.zero_()

    def forward(self, x, ridge_map, basin_map, water_level, t):
        t_embed = self.time_embd(t).to(x.dtype)
        waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
        embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
        embeds = self.embd_mlp(embeds)
        x = torch.cat([x, ridge_map, basin_map], dim=1)

        #x = self.exp(x)
        x = self.enc(x, embeds)
        x = self.down(x)

        B,C,H,W = x.shape
        grid_size = (H//self.patch_size, W//self.patch_size)
        x = self.to_patches(x).permute(0, 2, 1)  # [B, num_patches, d]
        x = self.dit(x, embeds, grid_size)
        x = x.permute(0, 2, 1)  # [B, d, num_patches]
        x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))

        x = self.up(x)
        x = self.dec(x, embeds)
        x = self.final(x)
        return x


def run_block(module, *args):
    return module(*args)


class NetworkDeep(nn.Module):
    def __init__(self, embd_dim=64, nhead=4):
        super().__init__()
        self.patch_size = 1
        patch_dim = (self.patch_size**2) * 256
        self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
        self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
        embd_dim *= 2
        self.embd_mlp = nn.Sequential(nn.Linear(embd_dim, embd_dim),
                                      nn.SiLU(inplace=True))
        padding_mode = 'replicate'

        self.enc_0 = ResConvBlockWCC((3, 4, 4), embd_dim, padding_mode)
        self.down_0 = nn.Conv2d(4, 16, 4, 2, 1, padding_mode=padding_mode) # 2
        self.enc_1 = ResConvBlock(16, embd_dim, padding_mode)
        self.down_1 = nn.Conv2d(16, 64, 4, 2, 1, padding_mode=padding_mode) # 4
        self.enc_2 = ResConvBlock(64, embd_dim, padding_mode)
        self.down_2 = nn.Conv2d(64, 256, 4, 2, 1, padding_mode=padding_mode)

        self.to_patches = nn.Unfold(kernel_size=(self.patch_size, self.patch_size),
                                    stride=(self.patch_size, self.patch_size))
        self.dit = DitBlock(patch_dim, nhead, 16, 2, embd_dim)

        self.up_2 = nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1)
        self.dec_2 = ResConvBlock(64, embd_dim, padding_mode)
        self.up_1 = nn.ConvTranspose2d(64, 16, 4, stride=2, padding=1)
        self.dec_1 = ResConvBlock(16, embd_dim, padding_mode)
        self.up_0 = nn.ConvTranspose2d(16, 4, 4, stride=2, padding=1)
        self.dec_0 = ResConvBlockWCC((4, 4, 1), embd_dim, padding_mode)

    def initialize(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name or 'emb_proj' in name):
                m.weight.data.zero_()
                m.bias.data.zero_()
            if isinstance(m, nn.Conv2d) and 'second_conv' in name:
                m.weight.data.zero_()
                m.bias.data.zero_()

    def forward(self, x, ridge_map, basin_map, water_level, t):
        t_embed = self.time_embd(t).to(x.dtype)
        waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
        embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
        embeds = self.embd_mlp(embeds)
        x0 = torch.cat([x, ridge_map, basin_map], dim=1)

        x0 = self.enc_0(x0, embeds) if not self.training else checkpoint(run_block, self.enc_0, x0, embeds, use_reentrant=False)
        x1 = self.down_0(x0) if not self.training else checkpoint(run_block, self.down_0, x0, use_reentrant=False)
        x1 = self.enc_1(x1, embeds) if not self.training else checkpoint(run_block, self.enc_1, x1, embeds, use_reentrant=False)
        x2 = self.down_1(x1) if not self.training else checkpoint(run_block, self.down_1, x1, use_reentrant=False)
        x2 = self.enc_2(x2, embeds) if not self.training else checkpoint(run_block, self.enc_2, x2, embeds, use_reentrant=False)
        x3 = self.down_2(x2) if not self.training else checkpoint(run_block, self.down_2, x2, use_reentrant=False)

        B,C,H,W = x3.shape
        grid_size = (H//self.patch_size, W//self.patch_size)
        x3 = self.to_patches(x3).permute(0, 2, 1)  # [B, num_patches, d]
        x3 = self.dit(x3, embeds, grid_size) if not self.training else checkpoint(run_block, self.dit, x3, embeds, grid_size, use_reentrant=False)
        x3 = x3.permute(0, 2, 1)  # [B, d, num_patches]
        x3 = nn.functional.fold(x3, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))

        x = self.up_2(x3) if not self.training else checkpoint(run_block, self.up_2, x3, use_reentrant=False)
        x = self.dec_2(x, embeds) if not self.training else checkpoint(run_block, self.dec_2, x, embeds, use_reentrant=False)
        x = self.up_1(x) if not self.training else checkpoint(run_block, self.up_1, x, use_reentrant=False)
        x = self.dec_1(x, embeds) if not self.training else checkpoint(run_block, self.dec_1, x, embeds, use_reentrant=False)
        x = self.up_0(x) if not self.training else checkpoint(run_block, self.up_0, x, use_reentrant=False)
        x = self.dec_0(x, embeds) if not self.training else checkpoint(run_block, self.dec_0, x, embeds, use_reentrant=False)
        return x



if __name__ == "__main__":
    network = Network()
    for name, m in network.named_modules():
        pass
        '''if hasattr(m, 'weight'):
            if hasattr(m.weight, 'data'):
                print(name, m.weight.data.shape)
            else:
                print(name, m.weight.shape)'''