import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .modules.res_blocks import ResConvBlock
from .modules.sinusoidal import SinusoidalEmbedding
from .modules.rope import MultiheadAttentionWithRoPE
from .network_pure_dit import ResConvBlockWCC

class DiTLayer(nn.Module):
    def __init__(self, d_model, embd_dim, nhead, dim_feedforward=1024):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.attn = MultiheadAttentionWithRoPE(d_model, nhead)
        self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.embd_affine = nn.Linear(embd_dim, 6*d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(dim_feedforward, d_model),
        )

    def forward(self, x, embd, grid_size):
        affine_params = self.embd_affine(embd)
        scale1, scale2, shift1, shift2, alpha1, alpha2 = affine_params.chunk(6, dim=1)

        # Self-attention block
        h = self.norm1(x)
        h = h * (1 + scale1[:, None, :]) + shift1[:, None, :]
        attn_output = self.attn(h, grid_size)
        x = x + attn_output * alpha1[:, None, :]

        # Feedforward block
        h = self.norm2(x)
        h = h * (1 + scale2[:, None, :]) + shift2[:, None, :]
        ffn_output = self.ffn(h)
        x = x + ffn_output * alpha2[:, None, :]
        return x


class DiTBlock(nn.Module):
    def __init__(self, channels, embd_dim, patch_size, nhead, num_layers):
        super().__init__()
        self.patch_size = patch_size
        self.patchify = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        hidden_size = channels * patch_size**2
        self.dit_layers = nn.ModuleList([
            DiTLayer(hidden_size, embd_dim, nhead, 2*hidden_size)
            for _ in range(num_layers)
        ])

    def forward(self, x, embd):
        B, C, H, W = x.shape
        H_p, W_p = H // self.patch_size, W // self.patch_size
        grid_size = (H_p, W_p)
        x = self.patchify(x).permute(0, 2, 1)  # [B, num_patches, d_main]
        for dit_layer in self.dit_layers:
            x = dit_layer(x, embd, grid_size)
        x = x.permute(0, 2, 1)  # [B, d_main, num_patches]
        x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
        return x


class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, cat):
        super().__init__()
        self.res = ResConvBlock(in_ch, time_dim)
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.cat = cat

    def forward(self, x, t_emb, skip=None):
        x = self.res(x, t_emb)
        x = self.up(x)
        if self.cat:
            x = torch.cat([x, skip], dim=1)
        else:
            x = x + skip
        return x

class UpBlockWithDit(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch, patch_size, nhead, time_dim, layers, cat):
        super().__init__()
        self.res = ResConvBlock(in_ch, time_dim)
        self.down_map = nn.Conv2d(in_ch, mid_ch, kernel_size=1, bias=False)
        self.dit = DiTBlock(mid_ch, time_dim, patch_size, nhead, layers)
        self.up_map = nn.Conv2d(mid_ch, in_ch, kernel_size=1)
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.cat = cat

    def forward(self, x, embd, skip=None):
        x = self.res(x, embd)
        h = self.down_map(x)
        h = self.dit(h, embd)
        h = self.up_map(h)
        x = x + h
        x = self.up(x)
        if self.cat:
            x = torch.cat([x, skip], dim=1)
        else:
            x = x + skip
        return x


def run_block(module, *args):
    return module(*args)


class ConditionalUNet(nn.Module):
    def __init__(self, base_ch=16, embd_dim=64, depth=5):
        super().__init__()
        self.depth = depth
        self.time_embd = SinusoidalEmbedding(embd_dim)
        self.waterlevel_embd = SinusoidalEmbedding(embd_dim, 10)
        embd_dim *= 2

        # Input channels = noisy height (1) + ridge map (1) + lake map (1)
        self.expand = nn.Conv2d(4, base_ch, 3, padding=1, padding_mode='replicate')

        # Encoder layers
        self.enc_blocks = nn.ModuleList()
        self.enc_dit_blocks = nn.ModuleList()
        self.down_convs = nn.ModuleList()
        current_ch = base_ch

        for i in range(depth):
            self.enc_blocks.append(ResConvBlock(current_ch, embd_dim))
            if i < depth - 1:
                self.down_convs.append(
                    nn.Conv2d(current_ch, current_ch * 2, 4, stride=2, padding=1, padding_mode='replicate')
                )
                current_ch *= 2

        # Bottleneck
        self.bottleneck = nn.Conv2d(current_ch, current_ch * 2, 4, stride=2, padding=1, padding_mode='replicate')
        current_ch *= 2

        # Decoder layers
        self.up_blocks = nn.ModuleList()
        for i in range(depth):
            cat = (i == depth - 1)  # Only concatenate in the final up block
            self.up_blocks.append(UpBlock(current_ch, current_ch // 2, embd_dim, cat))
            current_ch = current_ch // 2 * (2 if cat else 1)

        self.out = ResConvBlock(current_ch, embd_dim)
        self.final = nn.Conv2d(current_ch, 1, 1)


    def forward(self, x, map_average, ridge_map, basin_map, water_level, t):
        t_embed = self.time_embd(t).to(x.dtype)
        waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
        embeds = torch.cat([t_embed, waterlevel_embd], dim=1)

        h = torch.cat([x, ridge_map, basin_map, map_average], dim=1)
        h = checkpoint(run_block, self.expand, h, use_reentrant=False) if self.training else self.expand(h)

        # Encoder
        skips = []
        for i in range(self.depth):
            h = checkpoint(run_block, self.enc_blocks[i], h, embeds, use_reentrant=False) if self.training else self.enc_blocks[i](h, embeds)
            skips.append(h)
            if i < self.depth - 1:
                h = checkpoint(run_block, self.down_convs[i], h, use_reentrant=False) if self.training else self.down_convs[i](h)

        # Bottleneck
        h = checkpoint(run_block, self.bottleneck, h, use_reentrant=False) if self.training else self.bottleneck(h)

        # Decoder
        for i in range(self.depth):
            h = checkpoint(run_block, self.up_blocks[i], h, embeds, skips[-(i + 1)], use_reentrant=False) if self.training else self.up_blocks[i](h, embeds, skips[-(i + 1)])

        h = checkpoint(run_block, self.out, h, embeds, use_reentrant=False) if self.training else self.out(h, embeds)
        h = checkpoint(run_block, self.final, h, use_reentrant=False) if self.training else self.final(h)
        return h


class ConditionalUNetManual(nn.Module):
    def __init__(self, base_ch=8, embd_dim=16):
        super().__init__()
        self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
        self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
        embd_dim *= 2
        self.embd_mlp = nn.Sequential(nn.Linear(embd_dim, embd_dim),
                                      nn.SiLU(inplace=True))
        padding_mode = 'replicate'

        # Input channels = noisy height (1) + ridge map (1) + lake map (1)
        self.expand = nn.Conv2d(3, base_ch, 3, padding=1, padding_mode=padding_mode)
        self.enc_0 = ResConvBlockWCC((3, 8, base_ch), embd_dim, padding_mode)

        self.down0 = nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1, padding_mode=padding_mode) # 1024->512
        self.enc_1 = ResConvBlock(base_ch * 2, embd_dim, padding_mode)

        self.down1 = nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1, padding_mode=padding_mode) # 512->256

        self.up1 = UpBlock(base_ch * 4, base_ch * 2, embd_dim, False) # 256->512
        self.up0 = UpBlock(base_ch * 2, base_ch, embd_dim, True) # 512->1024
        self.out = ResConvBlockWCC((base_ch * 2, 8, 1), embd_dim, padding_mode)

    def initialize(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name):
                m.weight.data.zero_()
                m.bias.data.zero_()
            if isinstance(m, nn.Conv2d) and 'second_conv' in name:
                m.weight.data.zero_()
                m.bias.data.zero_()

    def forward(self, x, ridge_map, basin_map, water_level, t):
        t_embed = self.time_embd(t).to(x.dtype)
        waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
        embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
        embeds = self.embd_mlp(embeds)
        # x: noisy height map, ridge_map: binary edges, basin_map: binary basins, water_level: the estimate sea level
        h0 = torch.cat([x, ridge_map, basin_map], dim=1)  # concat condition
        # encode
        h0 = self.enc_0(h0, embeds) if not self.training else checkpoint(run_block, self.enc_0, h0, embeds, use_reentrant=False)
        h1 = self.down0(h0)
        h1 = self.enc_1(h1, embeds) # 512x512
        #h1 = checkpoint(run_block, self.enc_1_dit, h1, water_level, use_reentrant=False) if self.training else self.enc_1_dit(h1, water_level)
        h2 = self.down1(h1)  # 256x256
        # decode with skip connections
        out = self.up1(h2, embeds, h1)  # 512x512
        out = self.up0(out, embeds, h0)  # 1024x1024
        out = self.out(out, embeds) if not self.training else checkpoint(run_block, self.out, out, embeds, use_reentrant=False)
        return out  # predicted noise for diffusion loss



if __name__ == "__main__":
    #a = ConditionalUNet()
    #t = SinusoidalEmbedding(256)
    #t_embd = t(torch.randint(0, 100, (1,)))
    #x = torch.randn(1, 1, 256, 256)
    #r = torch.randn(1, 1, 256, 256)
    #c = a(x, r, t_embd)
    #print(c)
    #print(c.shape)
    network = ConditionalUNetDiT()
    for name, m in network.named_modules():
        if isinstance(m, nn.Linear) and 'time_affine':
            m.weight.data.zero_()
            m.bias.data.zero_()
