import torch
import math

import torch.utils.data

import imageio.v3 as imageio
import lightning.pytorch as pl
import matplotlib.pyplot as plt

from models.network_swin import Network
from safetensors.torch import load_file

class PLModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = Network()

    @torch.no_grad()
    def inference_step(self, ridge_map, basin_map, water_level, num_steps=50):
        device = self.device
        b = ridge_map.shape[0]
        x = torch.randn_like(ridge_map, device=device, dtype=torch.float16)
        water_level = torch.tensor((water_level,), device=device, dtype=torch.float16).expand(b, )
        time = torch.linspace(0, 1, num_steps + 1, device=device, dtype=torch.float16)

        for i in range(num_steps):
            t = torch.full((b,), time[i], device=device, dtype=torch.float16)
            dt = torch.full((b, 1, 1, 1), time[i + 1] - time[i], device=device, dtype=torch.float16)

            v = self.model(x, ridge_map, basin_map, water_level, t)

            x = x + dt * v

        return x

    @torch.no_grad()
    def inference_improved(self, ridge_map, basin_map, water_level, num_steps=50):
        device = self.device
        b = ridge_map.shape[0]
        x = torch.randn_like(ridge_map, device=device, dtype=torch.float16)
        water_level = torch.tensor((water_level,), device=device, dtype=torch.float16).expand(b, )

        # Cosine schedule - more steps where flow changes rapidly
        time = 1 - torch.cos(torch.linspace(0, torch.pi / 2, num_steps + 1, device=device)).to(dtype=torch.float16)

        for i in range(num_steps):
            t = torch.full((b,), time[i], device=device, dtype=torch.float16)
            t_next = torch.full((b,), time[i + 1], device=device, dtype=torch.float16)
            dt = time[i + 1] - time[i]
            dt_tensor = dt.view(1, 1, 1, 1).expand(b, 1, 1, 1)

            # Heun's method (2nd order)
            v_current = self.model(x, ridge_map, basin_map, water_level, t)
            x_pred = x + dt_tensor * v_current
            v_next = self.model(x_pred, ridge_map, basin_map, water_level, t_next)
            x = x + dt_tensor * (v_current + v_next) / 2

        return x

if __name__ == "__main__":
    torch.manual_seed(0)
    #model = PLModule.load_from_checkpoint('FlashScape Swin Reference.ckpt').to(device='cuda', dtype=torch.float16)
    model = PLModule()
    model.model.load_state_dict(load_file('FlashScape.safetensors'))
    model.to(device='cuda', dtype=torch.float16)
    model.eval()

    test_ridge = torch.from_numpy(imageio.imread('dataset_large/Ridge_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
    test_basin = torch.from_numpy(imageio.imread('dataset_large/Basins_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
    gt = torch.from_numpy(imageio.imread('dataset_large/11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
    water_level = 0.0
    num_steps = 50
    num_images = 2

    test_basin = (test_basin >= water_level).to(torch.float16)
    test_ridge = test_ridge.expand(num_images, -1, -1, -1)
    test_basin = test_basin.expand(num_images, -1, -1, -1)
    generated = model.inference_improved(test_ridge, test_basin, water_level, num_steps)
    # Back to original range
    generated = generated * 330.8314960521203 + 149.95293407563648

    # Prepare images for visualization
    ridge_display = test_ridge[0, 0].cpu().float()
    basin_display = test_basin[0, 0].cpu().float()
    gt_display = gt[0, 0].cpu().float()
    generated_display = generated[:, 0].cpu()  # Remove channel dim

    # Calculate optimal grid layout
    total_images = num_images + 3  # condition1+ condition2 + gt + generated images
    image_size = ridge_display.shape[0]  # assuming square images

    # Determine optimal number of columns (aim for roughly 4:3 aspect ratio)
    max_cols = min(6, total_images)  # Maximum 6 columns for readability
    cols = min(max_cols, total_images)
    rows = math.ceil(total_images / cols)

    # Calculate figure size based on image dimensions and grid layout
    base_height_per_image = 10  # inches per image height
    base_width_per_image = 10  # inches per image width

    fig_width = cols * base_width_per_image + 0.1  # +1 for colorbar space
    fig_height = rows * base_height_per_image

    # Create figure with subplots
    fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))

    # Flatten axes array for easier indexing
    if rows > 1 and cols > 1:
        axes = axes.flatten()
    elif rows == 1 and cols > 1:
        axes = axes
    elif rows > 1 and cols == 1:
        axes = axes[:, 0]
    else:
        axes = [axes]

    # Hide unused subplots
    for i in range(total_images, len(axes)):
        axes[i].set_visible(False)

    # Plot condition image
    im0 = axes[0].imshow(ridge_display, cmap='gray')
    axes[0].set_title('Ridge Condition', fontsize=12, pad=2)
    axes[0].set_axis_off()

    # Plot condition image
    im1 = axes[1].imshow(basin_display, cmap='gray')
    axes[1].set_title(f'Basin Condition at level {water_level}', fontsize=12, pad=2)
    axes[1].set_axis_off()

    # Plot ground truth image
    im2 = axes[2].imshow(gt_display, cmap='gray')
    axes[2].set_title('Ground Truth', fontsize=12, pad=2)
    axes[2].set_axis_off()

    # Plot generated images
    for i in range(num_images):
        im = axes[i + 3].imshow(generated_display[i], cmap='gray')
        axes[i + 3].set_title(f'Generated {i + 1}', fontsize=10, pad=2)
        axes[i + 3].set_axis_off()

    # Add colorbar
    cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8, location='right')
    cbar.set_label('Elevation', fontsize=14)

    plt.savefig('result_grid.png', bbox_inches='tight', dpi=300)
    plt.show()