import torch
import torch.nn as nn
from .modules.rope import MultiheadAttentionWithRoPE, MultiheadCrossAttentionWithRoPE
from .modules.sinusoidal import SinusoidalEmbedding


class DitLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        self.self_attn = MultiheadAttentionWithRoPE(d_model, nhead)
        self.cross_attn = MultiheadCrossAttentionWithRoPE(d_model, nhead)
        self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.norm2_x = nn.LayerNorm(d_model, elementwise_affine=True)
        self.norm2_c = nn.LayerNorm(d_model, elementwise_affine=True)
        self.norm3 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.LeakyReLU(inplace=True),
            nn.Linear(dim_feedforward, d_model),
        )
        self.embd_affine = nn.Sequential(
            nn.Linear(128, 128),
            nn.LeakyReLU(inplace=True),
            nn.Linear(128, 6 * d_model)
        )

    def forward(self, x, condition, t_embd, grid_size):
        affine_params = self.embd_affine(t_embd)
        scale1, scale2, alpha1, shift1, shift2, alpha2 = affine_params.chunk(6, dim=1)

        # Self-attention block
        h = self.norm1(x)
        h = h * (1 + scale1[:, None, :]) + shift1[:, None, :]
        attn_output = self.self_attn(h, grid_size)
        attn_output = attn_output * alpha1[:, None, :]
        x = x + attn_output

        # Cross-attention block
        h = self.norm2_x(x)
        condition = self.norm2_c(condition)
        attn_output = self.cross_attn(h, condition, grid_size)
        x = x + attn_output

        # Feedforward block
        h = self.norm3(x)
        h = h * (1 + scale2[:, None, :]) + shift2[:, None, :]
        ffn_output = self.ffn(h)
        ffn_output = ffn_output * alpha2[:, None, :]
        x = x + ffn_output
        return x


class DitBlock(nn.Module):
    def __init__(self, d_main, nhead, num_layers, ffn_multiplier):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            DitLayer(d_main, nhead, int(d_main * ffn_multiplier))
            for _ in range(num_layers)
        ])

    def forward(self, x, condition, embd, grid_size):
        for layer in self.layers:
            x = layer(x, condition, embd, grid_size)
        return x


class DitNetwork(nn.Module):
    def __init__(self, heads=4, layers=12):
        super().__init__()
        self.patch_size = 64
        self.to_patches = nn.Unfold(
            kernel_size=(self.patch_size, self.patch_size),
            stride=(self.patch_size, self.patch_size)
        )

        d_x = (self.patch_size ** 2)
        d_c = d_x * 2
        d_block = 1024
        self.x_enc = nn.Linear(d_x, d_block)
        self.condition_enc = nn.Linear(d_c, d_block)
        self.time_embd = SinusoidalEmbedding(64, scaling=1000)
        self.water_level_embd = SinusoidalEmbedding(64, scaling=1)
        self.out_enc = nn.Linear(d_block, d_x)

        self.blocks = DitBlock(d_block, heads, layers, 1)


    def reconstruct(self, x, H, W):
        x = self.out_enc(x)  # [B, num_patches, d_x]
        x = x.permute(0, 2, 1)  # [B, d_xx, num_patches]
        x = nn.functional.fold(
            x, (H, W),
            (self.patch_size, self.patch_size),
            stride=(self.patch_size, self.patch_size)
        )
        return x

    def forward(self, x, ridge_map, basin_map, water_level, time):
        B, C, H, W = x.shape
        H_p, W_p = H // self.patch_size, W // self.patch_size
        grid_size = (H_p, W_p)

        condition = torch.cat([ridge_map, basin_map], dim=1)
        t_embd = self.time_embd(time)
        w_embd = self.water_level_embd(water_level)
        embds = torch.cat([t_embd, w_embd], dim=1)

        x = self.to_patches(x).permute(0, 2, 1)
        x = self.x_enc(x)  # [B, num_patches, d_block]
        condition = self.to_patches(condition).permute(0, 2, 1)
        condition = self.condition_enc(condition)

        x = self.blocks(x, condition, embds, grid_size)

        x = self.reconstruct(x, H, W)
        return x

if __name__ == '__main__':

    t_x = torch.randn((4, 1, 1024, 1024))
    t_r = torch.randn((4, 1, 1024, 1024))
    t_b = torch.randn((4, 1, 1024, 1024))
    water_level = torch.tensor((0,1,2,3,))
    time = torch.tensor((0.0, 0.1, 0.2, 0.3,))
    model = DitNetwork()
    output = model(t_x, t_r, t_b, water_level, time)