import torch
import torch.nn as nn
from .modules.sinusoidal import SinusoidalEmbedding
from .modules.dit_blocks import DitBlock
from .modules.nerf import NerfLayer


class Network(nn.Module):
    def __init__(self, base_ch=3, embd_dim=16):
        super().__init__()
        self.patch_size = 16
        patch_dim = (self.patch_size**2) * base_ch
        self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
        self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
        embd_dim *= 2

        #self.expand = nn.Conv2d(3, base_ch, 1)
        self.to_patches = nn.Unfold(kernel_size=(self.patch_size, self.patch_size),
                                    stride=(self.patch_size, self.patch_size))

        self.dit0 = DitBlock(patch_dim, 8, 2, 2)
        self.dit0t1 = nn.Linear(patch_dim, patch_dim//2)
        self.dit1 = DitBlock(patch_dim//2, 4, 3, 2)
        self.dit1t2 = nn.Linear(patch_dim//2, patch_dim//4)
        self.dit2 = DitBlock(patch_dim//4, 2, 4, 2)
        self.dit2t3 = nn.Linear(patch_dim//4, patch_dim//8)
        self.dit3 = DitBlock(patch_dim//8, 1, 5, 2)
        self.dit3t4 = nn.Linear(patch_dim//8, patch_dim//16)
        self.dit4 = DitBlock(patch_dim//16, 1, 6, 2)

        self.nerf = NerfLayer(patch_dim//16, self.patch_size, 4, 64)

    def initialize(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name or 'emb_proj' in name):
                m.weight.data.zero_()
                m.bias.data.zero_()
            if isinstance(m, nn.Conv2d) and 'second_conv' in name:
                m.weight.data.zero_()
                m.bias.data.zero_()

    def forward(self, x, ridge_map, basin_map, water_level, t):
        t_embed = self.time_embd(t).to(x.dtype)
        waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
        embeds = torch.cat([t_embed, waterlevel_embd], dim=1)

        h = torch.cat([x, ridge_map, basin_map], dim=1)

        #h = self.expand(h)
        B,C,H,W = h.shape
        grid_size = (H//self.patch_size, W//self.patch_size)
        h = self.to_patches(h).permute(0, 2, 1)  # [B, num_patches, d]
        x = self.to_patches(x).permute(0, 2, 1)
        h = self.dit0(h, embeds, grid_size)
        h = self.dit0t1(h)
        h = self.dit1(h, embeds, grid_size)
        h = self.dit1t2(h)
        h = self.dit2(h, embeds, grid_size)
        h = self.dit2t3(h)
        h = self.dit3(h, embeds, grid_size)
        h = self.dit3t4(h)
        h = self.dit4(h, embeds, grid_size)

        out = self.nerf(x, h, (H, W))
        return out

if __name__ == '__main__':
    network = Network()
    x = torch.randn(1, 1, 1024, 1024)
    ridge_map = torch.randn(1, 1, 1024, 1024)
    basin_map = torch.randn(1, 1, 1024, 1024)
    water_level = torch.tensor((0.0,))
    t = torch.tensor((0.5,))
    out = network(x, ridge_map, basin_map, water_level, t)
    print(out.shape)