import math
import torch
import torch.nn as nn

class SinusoidalEmbedding(nn.Module):
    def __init__(self, embedding_dim=128, base=1000, scaling=1000):
        super().__init__()
        self.embedding_dim = embedding_dim
        half_dim = embedding_dim // 2
        freqs = torch.exp(-math.log(base) * torch.arange(0, half_dim) / half_dim)
        # at base 1000, max-range = +=500pi = -1571 to 1571
        self.scaling = nn.parameter.Buffer(torch.tensor(scaling))
        self.freqs = nn.parameter.Buffer(freqs)

    def forward(self, scaler):
        scaler = scaler * self.scaling
        args = scaler[:, None] * self.freqs[None]
        embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        return embedding


class SinusoidalPositionalEmbedding2D(nn.Module):

    def __init__(self, embedding_dim):
        super().__init__()
        assert embedding_dim % 2 == 0, "embedding_dim must be even"
        self.embedding_dim = embedding_dim
        half_dim = self.embedding_dim // 2
        div_term = torch.exp(torch.arange(0, half_dim, 2) * (-math.log(100.0) / half_dim))
        # Since our grid size is small, 100 should be enough
        self.div_term = nn.parameter.Buffer(div_term.to(torch.float32))

    def forward(self, height, width):
        """Generate embeddings for a grid of size (height, width)."""

        # Generate grid coordinates
        y_pos = torch.arange(height, dtype=torch.float32, device=self.div_term.device)
        x_pos = torch.arange(width, dtype=torch.float32, device=self.div_term.device)

        # Compute sinusoidal components for height and width
        y_sin = torch.sin(y_pos[:, None] * self.div_term[None, :])
        y_cos = torch.cos(y_pos[:, None] * self.div_term[None, :])
        x_sin = torch.sin(x_pos[:, None] * self.div_term[None, :])
        x_cos = torch.cos(x_pos[:, None] * self.div_term[None, :])

        # Interleave sin and cos components
        y_embed = torch.stack([y_sin, y_cos], dim=-1).view(height, -1)
        x_embed = torch.stack([x_sin, x_cos], dim=-1).view(width, -1)

        # Combine height and width embeddings
        pos_embed = torch.cat([y_embed[:, None, :].expand(-1, width, -1),
                               x_embed[None, :, :].expand(height, -1, -1)], dim=-1)
        return pos_embed.view(height * width, self.embedding_dim)