import torch
import torch.nn as nn
from .res_blocks import ResConvBlock
from .sinusoidal import SinusoidalEmbedding
from .linear_attention import ImageLinearAttention
from .rope import MultiheadAttentionWithRoPE, EfficientMultiheadAttentionWithRoPE
from .fnet import FNetBlock


class DitLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, embd_dim, type='eff'):
        super().__init__()
        self.type = type
        self.norm1 = nn.RMSNorm(d_model, elementwise_affine=False)
        if type == 'eff':
            self.attn = EfficientMultiheadAttentionWithRoPE(d_model, nhead)
        elif type == 'vanilla':
            self.attn = MultiheadAttentionWithRoPE(d_model, nhead)
        elif type == 'fnet':
            self.attn = FNetBlock()
        self.norm2 = nn.RMSNorm(d_model, elementwise_affine=False)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.SiLU(inplace=True),
            nn.Linear(dim_feedforward, d_model),
        )
        self.embd_affine = nn.Linear(embd_dim, 6*d_model)

    def forward(self, x, embd, grid_size):
        # Self-attention block
        affine_params = self.embd_affine(embd)
        scale1, scale2, alpha1, shift1, shift2, alpha2 = affine_params.chunk(6, dim=1)

        h = self.norm1(x)
        h = h * (1 + scale1[:, None, :]) + shift1[:, None, :]
        if self.type != 'fnet':
            attn_output = self.attn(h, grid_size)
        else:
            attn_output = self.attn(h)
        attn_output = attn_output * alpha1[:, None, :]
        x = x + attn_output

        # Feedforward block
        h = self.norm2(x)
        h = h * (1 + scale2[:, None, :]) + shift2[:, None, :]
        ffn_output = self.ffn(h)
        ffn_output = ffn_output * alpha2[:, None, :]
        x = x + ffn_output
        return x


class DitBlock(nn.Module):
    def __init__(self, d_main, nhead, num_layers, ffn_multiplier, embd_dim):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            DitLayer(d_main, nhead, int(d_main * ffn_multiplier), embd_dim, 'vanilla' if i<2 else 'eff')
            for i in range(num_layers)
        ])

    def forward(self, x, embd, grid_size):
        for layer in self.layers:
            x = layer(x, embd, grid_size)
        return x



class LinearAttnBlock(nn.Module):
    def __init__(self, channel, kernel_size, heads, embd_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            ImageLinearAttention(channel, kernel_size, heads, embd_dim)
            for _ in range(num_layers)
        ])

    def forward(self, x, embd):
        for layer in self.layers:
            x = layer(x, embd)
        return x