import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .modules.sinusoidal import SinusoidalEmbedding, SinusoidalPositionalEmbedding2D
from .modules.rope import MultiheadAttentionWithRoPE

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False, padding_mode='replicate')
        self.norm = nn.InstanceNorm2d(out_ch, affine=False)
        self.act = nn.SiLU()
        self.embd_affine = nn.Sequential(nn.Linear(128,in_ch*2),
                                         nn.SiLU(inplace=True),
                                         nn.Linear(in_ch*2, in_ch*2))

    def forward(self, x, embd):
        affine_params = self.embd_affine(embd)
        scale, shift = affine_params.chunk(2, dim=1)
        x = self.norm(x)
        x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
        x = self.act(x)
        x = self.conv(x)
        return x

class DitLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.attn = MultiheadAttentionWithRoPE(d_model, nhead)
        self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.SiLU(),
            nn.Linear(dim_feedforward, d_model),
        )
        self.embd_affine = nn.Sequential(nn.Linear(128,128),
                                         nn.SiLU(),
                                         nn.Linear(128, 6*d_model))

    def forward(self, x, embd, grid_size):
        # Self-attention block
        affine_params = self.embd_affine(embd)
        scale1, scale2, alpha1, shift1, shift2, alpha2 = affine_params.chunk(6, dim=1)

        h = self.norm1(x)
        h = h * (1 + scale1[:, None, :]) + shift1[:, None, :]
        attn_output = self.attn(h, grid_size)
        attn_output = attn_output * alpha1[:, None, :]
        x = x + attn_output

        # Feedforward block
        h = self.norm2(x)
        h = h * (1 + scale2[:, None, :]) + shift2[:, None, :]
        ffn_output = self.ffn(h)
        ffn_output = ffn_output * alpha2[:, None, :]
        x = x + ffn_output
        return x


class DitBlock(nn.Module):
    def __init__(self, d_main, nhead, num_layers, ffn_multiplier):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            DitLayer(d_main, nhead, int(d_main * ffn_multiplier))
            for _ in range(num_layers)
        ])

    def forward(self, x, embd, grid_size):
        for layer in self.layers:
            x = layer(x, embd, grid_size)
        return x

def run_block(module, *args):
    return module(*args)

class DitNetwork(nn.Module):
    def __init__(self, heads=4, layers=16):
        super().__init__()
        self.patch_size = 1
        self.to_patches = nn.Unfold(kernel_size=(self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
        self.time_embd = SinusoidalEmbedding(64, scaling=1000)
        self.water_level_embd = SinusoidalEmbedding(64, scaling=1)


        self.in_enc = nn.Conv2d(3, 4, 3, 1, 1, padding_mode='replicate', bias=False)
        self.down_0 = Down(4, 8) # 512
        self.down_1 = Down(8, 16) # 256
        self.down_2 = Down(16, 32) # 128
        self.down_3 = Down(32, 64) # 64
        self.down_4 = Down(64, 128) # 32

        self.blocks = DitBlock(128, heads, layers, 1)

        self.out_enc = nn.ConvTranspose2d(128, 1, 64, 32, 16)


    def initialize(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Linear) and 'affine' in name:
                m.weight.data.zero_()
                m.bias.data.zero_()


    def reconstruct(self, x, H, W):  # [B, num_patches, d_main]
        x = x.permute(0, 2, 1)  # [B, patch_size**2, num_patches]
        x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
        x = self.out_enc(x)
        return x

    def forward(self, x, ridge_map, basin_map, water_level, time):
        # Calculate patch grid dimensions
        B, C, H, W = x.shape
        H_p, W_p = H // self.patch_size, W // self.patch_size
        grid_size = (H_p, W_p)
        x = torch.cat([x, ridge_map, basin_map], dim=1)  # concat condition
        t_embd = self.time_embd(time)
        w_embd = self.water_level_embd(water_level)
        embds = torch.cat([t_embd, w_embd], dim=1)

        x = self.in_enc(x)
        x = self.down_0(x, embds)
        x = self.down_1(x, embds)
        x = self.down_2(x, embds)
        x = self.down_3(x, embds)
        x = self.down_4(x, embds)
        x = self.to_patches(x).permute(0, 2, 1)  # [B, num_patches, d_main]
        # Process through transformer
        #if self.training:
        #    x = checkpoint(run_block, self.block, x, embds, use_reentrant=False)
        #else:
        x = self.blocks(x, embds, grid_size)
        x = self.reconstruct(x, H//32, W//32)
        return x



# Example usage
if __name__ == "__main__":
    # Create model instance
    #model = DitNetwork().to('cuda')
    #model.eval()

    # Test with different input sizes
    '''for size in [(256, 256), (512, 512)]:#, (1024, 1024)]:
        H, W = size
        x = torch.randn(8, 3, H, W, device='cuda')  # Batch of 8, 1 channel
        time = torch.randn((8,), device='cuda')
        water_level = torch.randn((8,), device='cuda')
        y = model(x, time, water_level)
        print(y.shape)'''
    network = DitNetwork()
    for name, m in network.named_modules():
        if isinstance(m, nn.Linear) and ('time_affine' in name or 'water_level_affine' in name):
            m.weight.data.zero_()
            m.bias.data.zero_()