import torch
import torch.nn as nn

class ImageLinearAttention(nn.Module):
    def __init__(self, chan, kernel_size=3, heads=4, norm_queries=True, embd_dim=None):
        super().__init__()
        self.chan = chan
        self.heads = heads
        self.key_dim = key_dim = chan // heads
        self.value_dim = value_dim = chan // heads
        self.norm_queries = norm_queries

        # Convolutional projections for Q, K, V
        self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, padding='same', padding_mode='replicate')
        self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, padding='same', padding_mode='replicate')
        self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, padding='same', padding_mode='replicate')
        self.to_out = nn.Conv2d(value_dim * heads, chan, kernel_size, padding='same', padding_mode='replicate')

        # Adaptive normalization: Project embedding to scale/shift for group norm
        if embd_dim is not None:
            self.norm = nn.GroupNorm(1, key_dim * heads, affine=False)  # Normalize without inherent affine params
            self.emb_proj = nn.Linear(embd_dim, 2 * key_dim * heads)  # Project emb to scale/shift
        else:
            self.norm = nn.GroupNorm(1, key_dim * heads, affine=True)
            self.emb_proj = None

    def forward(self, x, emb=None):
        b, c, h, w = x.shape
        heads = self.heads
        key_dim = self.key_dim

        # Project input to queries, keys, and values
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # Apply adaptive normalization if embedding is provided
        if emb is not None and self.emb_proj is not None:
            emb_params = self.emb_proj(emb).view(b, 2, -1)  # (b, 2, key_dim * heads)
            scale, shift = emb_params[:, 0], emb_params[:, 1]  # Split into scale and shift
            # Normalize and modulate Q, K, V
            q = self.norm(q)
            k = self.norm(k)
            v = self.norm(v)
            # Apply scale and shift across spatial dimensions
            q = q * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
            k = k * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
            v = v * (1 + scale[:, :, None, None]) + shift[:, :, None, None]

        # Reshape Q, K, V for multi-head attention
        q = q.view(b, heads, key_dim, h * w)
        k = k.view(b, heads, key_dim, h * w)
        v = v.view(b, heads, self.value_dim, h * w)

        # Scale queries and keys
        q = q * (key_dim ** -0.25)
        k = k * (key_dim ** -0.25)

        # Softmax on keys along the sequence dimension
        k = k.softmax(dim=-1)
        if self.norm_queries:
            q = q.softmax(dim=-2)

        # Compute context and output
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhdn,bhde->bhen', q, context)
        out = out.reshape(b, -1, h, w)
        out = self.to_out(out)
        return x + out