import torch
import torch.nn as nn


class LearnedPositionalEncoding(nn.Module):
    def __init__(self, max_len, embedding_dim):
        """
        Initialize the learned positional encoding module.

        Args:
            max_len (int): Maximum sequence length.
            embedding_dim (int): Dimensionality of the embeddings.
        """
        super().__init__()
        self.positional_embeddings = nn.Embedding(max_len, embedding_dim)
        nn.init.uniform_(self.positional_embeddings.weight, -0.1, 0.1)  # Initialize weights

    def forward(self, x):
        """
        Forward pass to add positional embeddings to input embeddings.

        Args:
            x (torch.Tensor): Input embeddings of shape (batch_size, seq_len, embedding_dim).

        Returns:
            torch.Tensor: Embeddings with positional encodings added.
        """
        seq_len = x.size(1)  # Get the sequence length
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)  # Shape: (1, seq_len)
        positional_encoding = self.positional_embeddings(positions)  # Shape: (1, seq_len, embedding_dim)
        #positional_encoding = positional_encoding.unsqueeze(1)
        return x + positional_encoding



class MultiheadAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead
        self.pe = LearnedPositionalEncoding(4096, d_model)

        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.q_norm = nn.RMSNorm(self.head_dim)
        self.k_norm = nn.RMSNorm(self.head_dim)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, grid_size):
        batch_size, seq_len, _ = x.shape
        x = self.pe(x)

        # Project and split queries, keys, values
        qkv = self.qkv_proj(x)
        qkv = qkv.view(batch_size, seq_len, 3, self.nhead, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, nhead, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # unpack along first dimension

        q = self.q_norm(q)
        k = self.k_norm(k)

        # Compute attention
        attn_output = nn.functional.scaled_dot_product_attention(
            q, k, v, is_causal=False
        )

        # Combine heads and project
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(batch_size, seq_len, self.d_model)
        return self.out_proj(attn_output)


class EfficientMultiheadAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead

        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.q_norm = nn.RMSNorm(self.head_dim)
        self.k_norm = nn.RMSNorm(self.head_dim)

        self.softmax_q = nn.Softmax(dim=-1) # ρ_q: softmax along head_dim
        self.softmax_k = nn.Softmax(dim=-2) # ρ_k: softmax along sequence length
        self.out_proj = nn.Linear(d_model, d_model)
        self.pe = LearnedPositionalEncoding(4096, d_model)

    def forward(self, x, grid_size):
        batch_size, seq_len, _ = x.shape

        x = self.pe(x)
        qkv = self.qkv_proj(x)
        # q, k, v = rearrange(qkv, 'b s (three h d) -> three b h s d', three=3, h=self.nhead)
        # Replace einops rearrange with native PyTorch operations
        qkv = qkv.view(batch_size, seq_len, 3, self.nhead, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, nhead, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # unpack along first dimension

        # Apply QK-Normalization before softmax as per SD3 paper
        q = self.q_norm(q)  # Normalize Q stream
        k = self.k_norm(k)

        # Apply softmax
        q = self.softmax_q(q)  # ρ_q: softmax along head_dim
        k = self.softmax_k(k)  # ρ_k: softmax along sequence length

        # Compute efficient attention: ρ_q(Q) (ρ_k(K)^T V)
        context = torch.matmul(k.transpose(-2, -1), v)  # (K^T V)
        attn_output = torch.matmul(q, context)  # Q (K^T V)

        # attn_output = rearrange(attn_output, 'b h s d -> b s (h d)')
        # Replace einops rearrange with native PyTorch operations
        attn_output = attn_output.transpose(1, 2)  # (batch_size, seq_len, nhead, head_dim)
        attn_output = attn_output.reshape(batch_size, seq_len, self.d_model)  # (batch_size, seq_len, d_model)
        return self.out_proj(attn_output)

if __name__ == '__main__':
    t_tensor = torch.randn((8, 4096, 512)).to('cpu')
    grid_size = (64, 64)
    attention = MultiheadAttention(512, 4).to('cpu')
    print(attention(t_tensor, grid_size))