import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def make_dct_pe(patch_size, L):
    """
    Minimal DCT-like 2D basis: features = cos(k1*i) * cos(k2*j) for k1,k2∈(0,L]
    returns (patch_size ** 2, L ** 2)
    """
    xs = torch.arange(patch_size)
    ys = torch.arange(patch_size)
    i_grid, j_grid = torch.meshgrid(xs, ys, indexing="ij")

    # Create frequency indices k1, k2 from 1 to L
    k1 = torch.arange(1, L + 1)
    k2 = torch.arange(1, L + 1)

    # Reshape grids for broadcasting: (patch_size, patch_size, 1, 1)
    i_grid = i_grid.unsqueeze(-1).unsqueeze(-1)
    j_grid = j_grid.unsqueeze(-1).unsqueeze(-1)

    # Reshape frequencies for broadcasting: (1, 1, L, L)
    k1 = k1.view(1, 1, L, 1)
    k2 = k2.view(1, 1, 1, L)

    # Compute DCT basis: cos(k1*i) * cos(k2*j)
    # Broadcasting: (patch_size, patch_size, L, L)
    dct_basis = torch.cos(k1 * i_grid) * torch.cos(k2 * j_grid)
    # times by (math.pi / patch_size) seems better? But stick to paper implementation for now.

    # Reshape to (patch_size ** 2, L*L)
    dct_basis = dct_basis.reshape(patch_size ** 2, L ** 2)

    return dct_basis


class NerfLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, dtc_size, nf_hidden_size):
        """
        hidden_size: dimension of transformer token used to produce parameters (per patch)
        patch_size: spatial size of patch (e.g. 16)
        dtc_size: number of DCT frequencies per axis (L)
        nf_hidden_size: hidden dim of neural field MLP
        """
        super().__init__()
        self.patch_size = patch_size
        self.dtc_size = dtc_size
        self.nf_hidden_size = nf_hidden_size
        dtc_pe = make_dct_pe(patch_size, dtc_size)
        self.register_buffer('dtc_pe', dtc_pe)

        self.s_W1 = (dtc_size ** 2 + 1) * nf_hidden_size
        self.s_b1 = nf_hidden_size
        self.s_W2 = nf_hidden_size * 1  # grey-scale
        self.s_b2 = 1

        parameters_per_patch = self.s_W1 + self.s_b1 + self.s_W2 + self.s_b2

        self.param_generator = nn.Linear(hidden_size, parameters_per_patch)
        self.act = nn.SiLU(inplace=True)

    def forward(self, x, last_hidden, target_hw):
        """
        :param x: (B, L, D)
        :param last_hidden: (B, L, D)
        :param target_hw: tuple(H, W)
        :return: (B, H, W)
        """
        batch_size, L, D = x.shape
        dct_pe = self.dtc_pe.view(1, 1, -1, self.dtc_size * self.dtc_size)  # (1, 1, patch_size**2, dtc_size**2)
        x = x.unsqueeze(-1)
        x = torch.cat((x, dct_pe.expand(batch_size, L, -1, -1)), dim=3) # (B, L, patch_size**2, dct_size**2+1)


        last_hidden = self.param_generator(last_hidden)
        W1 = last_hidden[..., :self.s_W1]
        b1 = last_hidden[..., self.s_W1:self.s_W1 + self.s_b1]
        W2 = last_hidden[..., self.s_W1 + self.s_b1:self.s_W1 + self.s_b1 + self.s_W2]
        b2 = last_hidden[..., -self.s_b2:]

        W1 = W1.view(batch_size, -1, (self.dtc_size ** 2 + 1), self.nf_hidden_size)
        b1 = b1.view(batch_size, -1, 1, self.nf_hidden_size)
        W2 = W2.view(batch_size, -1, self.nf_hidden_size, 1)
        b2 = b2.view(batch_size, -1, 1, 1)

        # row-wise normalization (normalize rows of W1 and W2)
        W1 = F.normalize(W1, dim=-1, eps=1e-6)
        W2 = F.normalize(W2, dim=-1, eps=1e-6)

        x = torch.matmul(x, W1)
        x = x + b1
        x = self.act(x)
        x = torch.matmul(x, W2)
        x = x + b2
        x = x.transpose(1, 2) # (B, D, L, 1)
        x = F.fold(x.squeeze(-1), target_hw,
                   (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))

        return x


if __name__ == '__main__':
    nerf = NerfLayer(128, 16, 4, 64)
    x = torch.randn(2, 4096, 256)
    last_hidden = torch.randn(2, 4096, 128)
    target_hw = (1024, 1024)
    output = nerf(x, last_hidden, target_hw)
    print(output.shape)