import math
import torch
import torch.nn as nn
import imageio.v3 as imageio


class RotaryEmbedding2D(nn.Module):
    def __init__(self, dim):
        super().__init__()
        assert dim % 2 == 0, "head dim must be even"
        self.dim = dim
        self.half_dim = dim // 2

        # Create frequency tensors for both x and y
        inv_freq = 1.0 / (50 ** (torch.arange(0, self.half_dim, 1).float() / self.half_dim))
        # at base 50, max-range = +=25pi = -75 to 75
        self.register_buffer("inv_freq", inv_freq)
        self.unfold = nn.Unfold(kernel_size=1, stride=1)

    def reshape(self, x):
        x = self.unfold(x.permute(2, 0, 1).unsqueeze(0))
        x = x.squeeze(0).permute(1, 0)
        return x

    def forward(self, grid_size):
        H, W = grid_size
        device = self.inv_freq.device

        # Create grid coordinates
        y_pos = torch.arange(H, device=device, dtype=self.inv_freq.dtype)
        x_pos = torch.arange(W, device=device, dtype=self.inv_freq.dtype)

        # Compute sinusoidal embeddings for height and width separately
        # Each will have shape: (position, half_dim)
        sin_h = torch.sin(y_pos[:, None] * self.inv_freq[None, :])
        cos_h = torch.cos(y_pos[:, None] * self.inv_freq[None, :])
        sin_w = torch.sin(x_pos[:, None] * self.inv_freq[None, :])
        cos_w = torch.cos(x_pos[:, None] * self.inv_freq[None, :])

        # Expand to full grid
        # sin_h, cos_h: (H, W, half_dim) - repeated along width
        # sin_w, cos_w: (H, W, half_dim) - repeated along height
        sin_h = sin_h.unsqueeze(1).repeat(1, W, 1)
        cos_h = cos_h.unsqueeze(1).repeat(1, W, 1)
        sin_w = sin_w.unsqueeze(0).repeat(H, 1, 1)
        cos_w = cos_w.unsqueeze(0).repeat(H, 1, 1)

        # Flatten spatial dimensions
        sin_h = self.reshape(sin_h)
        cos_h = self.reshape(cos_h)
        sin_w = self.reshape(sin_w)
        cos_w = self.reshape(cos_w)

        return (sin_h, cos_h), (sin_w, cos_w)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb_split(x, sin_h, cos_h, sin_w, cos_w):
    """
    Apply rotary embedding by splitting x into two halves:
    - First half rotated based on height (y position)
    - Second half rotated based on width (x position)
    """
    batch_size, head, seq_len, dim = x.shape
    half_dim = dim // 2

    # Split into two halves
    x_h = x[..., :half_dim]  # Part to be rotated based on height
    x_w = x[..., half_dim:]  # Part to be rotated based on width

    # Apply rotary embeddings separately
    x_h_rotated = (x_h * cos_h) + (rotate_half(x_h) * sin_h)
    x_w_rotated = (x_w * cos_w) + (rotate_half(x_w) * sin_w)

    # Concatenate back
    return torch.cat([x_h_rotated, x_w_rotated], dim=-1)


class MultiheadAttentionWithRoPE(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead

        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.q_norm = nn.RMSNorm(self.head_dim)
        self.k_norm = nn.RMSNorm(self.head_dim)
        self.out_proj = nn.Linear(d_model, d_model)
        self.rotary_embedding = RotaryEmbedding2D(self.head_dim)

    def forward(self, x, grid_size):
        batch_size, seq_len, _ = x.shape

        # Generate rotary embeddings
        (sin_h, cos_h), (sin_w, cos_w) = self.rotary_embedding(grid_size)

        # Project and split queries, keys, values
        qkv = self.qkv_proj(x)
        qkv = qkv.view(batch_size, seq_len, 3, self.nhead, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, nhead, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # unpack along first dimension

        # Apply rotary positional embeddings to queries and keys
        q = apply_rotary_pos_emb_split(q, sin_h, cos_h, sin_w, cos_w)
        k = apply_rotary_pos_emb_split(k, sin_h, cos_h, sin_w, cos_w)

        q = self.q_norm(q)
        k = self.k_norm(k)

        # Compute attention
        attn_output = nn.functional.scaled_dot_product_attention(
            q, k, v, is_causal=False
        )

        # Combine heads and project
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(batch_size, seq_len, self.d_model)
        return self.out_proj(attn_output)


class EfficientMultiheadAttentionWithRoPE(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead

        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.q_norm = nn.RMSNorm(self.head_dim)
        self.k_norm = nn.RMSNorm(self.head_dim)

        self.softmax_q = nn.Softmax(dim=-1) # ρ_q: softmax along head_dim
        self.softmax_k = nn.Softmax(dim=-2) # ρ_k: softmax along sequence length
        self.out_proj = nn.Linear(d_model, d_model)
        self.rotary_embedding = RotaryEmbedding2D(self.head_dim)

    def forward(self, x, grid_size):
        batch_size, seq_len, _ = x.shape

        (sin_h, cos_h), (sin_w, cos_w) = self.rotary_embedding(grid_size)

        qkv = self.qkv_proj(x)
        # q, k, v = rearrange(qkv, 'b s (three h d) -> three b h s d', three=3, h=self.nhead)
        # Replace einops rearrange with native PyTorch operations
        qkv = qkv.view(batch_size, seq_len, 3, self.nhead, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, nhead, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # unpack along first dimension

        q = apply_rotary_pos_emb_split(q, sin_h, cos_h, sin_w, cos_w)
        k = apply_rotary_pos_emb_split(k, sin_h, cos_h, sin_w, cos_w)

        # Apply QK-Normalization before softmax as per SD3 paper
        q = self.q_norm(q)  # Normalize Q stream
        k = self.k_norm(k)

        # Apply softmax
        q = self.softmax_q(q)  # ρ_q: softmax along head_dim
        k = self.softmax_k(k)  # ρ_k: softmax along sequence length

        # Compute efficient attention: ρ_q(Q) (ρ_k(K)^T V)
        context = torch.matmul(k.transpose(-2, -1), v)  # (K^T V)
        attn_output = torch.matmul(q, context)  # Q (K^T V)

        # attn_output = rearrange(attn_output, 'b h s d -> b s (h d)')
        # Replace einops rearrange with native PyTorch operations
        attn_output = attn_output.transpose(1, 2)  # (batch_size, seq_len, nhead, head_dim)
        attn_output = attn_output.reshape(batch_size, seq_len, self.d_model)  # (batch_size, seq_len, d_model)
        return self.out_proj(attn_output)


if __name__ == '__main__':
    rotation_embedding = RotaryEmbedding2D(64, interleave=False)
    (sin_h, cos_h), (sin_w, cos_w) = rotation_embedding((32, 32))

    '''tensor = torch.arange(0, 1024, 1)
    tensor = tensor.view(1, 4, 16, 16).to(torch.float32)
    print(tensor[0,0,0,0])
    print(tensor[0,1,0,0])
    tensor = torch.nn.functional.unfold(tensor, (1, 1), stride=(1, 1)).permute(0, 2, 1)
    tensor = apply_rotary_pos_emb_split(tensor.unsqueeze(1), sin_h, cos_h, sin_w, cos_w).squeeze(1)
    tensor = torch.nn.functional.fold(tensor.permute(0, 2, 1), (16, 16), kernel_size=(1, 1), stride=(1, 1))
    ch0 = tensor[0,0,:,:]
    ch1 = tensor[0,1,:,:]
    ch2 = tensor[0,2,:,:]
    ch3 = tensor[0,3,:,:]
    print(tensor[0, 0, 0, 0])
    print(tensor[0, 1, 0, 0])'''
    #print(attention(t_tensor, grid_size).shape)

    tensor = torch.full((1, 1, 1024, 64), 5.0)
    tensor = apply_rotary_pos_emb_split(tensor, sin_h, cos_h, sin_w, cos_w)
    tensor = tensor.squeeze(1)
    tensor = torch.nn.functional.fold(tensor.permute(0, 2, 1), (256, 256), kernel_size=(8, 8), stride=(8, 8)).squeeze()
    imageio.imwrite('filtered.tiff', tensor.numpy())
