import torch
import math

import data_utils
import torch.utils.data

import imageio.v3 as imageio
import lightning.pytorch as pl

#from models.network_diffusion_unet import ConditionalUNetManual
from models.network_swin import Network
#from models.network_hybrid import NetworkDeep
from pytorch_optimizer.optimizer import AdaMuon
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateMonitor, StochasticWeightAveraging
from lightning.pytorch.utilities import grad_norm


def convert_uniform_to_custom(u):
    return 0.5 - torch.cos((1/3) * torch.acos(1 - 2 * u) + math.pi / 3)
    #return 0.5 + 2 * torch.cos((2 * math.pi - torch.arccos((11/16)*(1-2*u)))/3)


def get_parameter_groups(model, weight_decay=0.0001):
    no_decay_keywords = ["bias", "bn", "batch_norm", "layer_norm", "norm"]

    decay_params = []
    no_decay_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if any(no_decay_keyword in name for no_decay_keyword in no_decay_keywords):
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    return [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': no_decay_params, 'weight_decay': 0.0}
    ]


def get_parameter_groups_withmuon(model, weight_decay=0.0001):
    no_decay_keywords = ["bias", "bn", "batch_norm", "layer_norm", "norm"]

    # Create the 4 groups
    decay_muon_params = []  # decay + muon
    decay_no_muon_params = []  # decay + no muon
    no_decay_muon_params = []  # no decay + muon
    no_decay_no_muon_params = []  # no decay + no muon

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        # Determine decay status
        requires_decay = not any(no_decay_keyword in name for no_decay_keyword in no_decay_keywords)

        # Determine muon status (hidden weights with ndim == 2)
        is_muon = param.ndim == 2

        # Assign to appropriate group
        if requires_decay and is_muon:
            decay_muon_params.append(param)
        elif requires_decay and not is_muon:
            decay_no_muon_params.append(param)
        elif not requires_decay and is_muon:
            no_decay_muon_params.append(param)
        else:  # not requires_decay and not is_muon
            no_decay_no_muon_params.append(param)

    return [
        # Group 1: decay + muon
        {'params': decay_muon_params, 'weight_decay': weight_decay, 'use_muon': True},

        # Group 2: decay + no muon
        {'params': decay_no_muon_params, 'betas': (0.9, 0.95), 'weight_decay': weight_decay, 'use_muon': False},

        # Group 3: no decay + muon
        {'params': no_decay_muon_params, 'weight_decay': 0.0, 'use_muon': True},

        # Group 4: no decay + no muon
        {'params': no_decay_no_muon_params, 'betas': (0.9, 0.95), 'weight_decay': 0.0, 'use_muon': False}
    ]

class PLModule(pl.LightningModule):
    def __init__(self, mid_visual_ridge, mid_visual_basins, mid_visual_gt, lr=1e-2):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        # FlashScape Conv reference 0.002
        self.wd = 0.0001
        self.model = Network()
        self.model.initialize()
        #self.map_average = torch.from_numpy(imageio.imread(map_average)).unsqueeze(0)
        #self.map_average = (self.map_average - self.map_average.mean()) / self.map_average.std()
        self.loss_fn = torch.nn.L1Loss()
        self.val_metrics = []
        self.mid_visual_ridge, self.mid_visual_basins = mid_visual_ridge, mid_visual_basins
        self.mid_visual_gt = mid_visual_gt

    def configure_optimizers(self):
        param_groups = get_parameter_groups_withmuon(self, weight_decay=self.wd)
        opt = AdaMuon(param_groups, lr=self.lr, weight_decay=self.wd, adamw_lr=3e-4, adamw_wd=self.wd)
        scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(opt, 100, eta_min=1e-4)
        #scheduler2 = torch.optim.lr_scheduler.LinearLR(opt, 0.1, 1, 5)
        #scheduler = torch.optim.lr_scheduler.ChainedScheduler([scheduler1, scheduler2], opt)
        return {
            "optimizer": opt,
            "lr_scheduler": {"scheduler": scheduler1, "interval": "epoch", "frequency": 1},
        }

    def _step(self, batch, t):
        x0, ridge_map, basin_map, water_level = batch
        b = water_level.shape[0]
        #map_average = self.map_average.expand((b, -1, -1, -1)).to(self.device)

        noise = torch.randn_like(x0, device=self.device, dtype=x0.dtype)
        if t is None:
            t = torch.rand((b,), device=self.device)
            t = convert_uniform_to_custom(t).to(x0.dtype)

        xt = t.view(-1, 1, 1, 1) * x0 + (1 - t.view(-1, 1, 1, 1)) * noise
        v = x0 - noise

        predicted_v = self.model(xt, ridge_map, basin_map, water_level, t)  # Predict velocity v
        loss = self.loss_fn(predicted_v, v)  # Loss between predicted and target v
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch, None)
        self.logger.experiment.add_scalar(f"Train/Loss", loss.detach(), self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        t = (batch_idx % 10) * 0.1 + 0.001
        t = torch.tensor(t, device=self.device).expand(batch[0].shape[0])
        loss = self._step(batch, t)
        self.val_metrics.append(loss.detach())
        return loss

    @torch.no_grad()
    def inference_step(self, ridge_map, basin_map, water_level, num_steps=50):
        device = self.device
        b = ridge_map.shape[0]
        x = torch.randn_like(ridge_map, device=device)
        water_level = torch.tensor((water_level,), device=device).expand(b,)
        time = torch.linspace(0, 1, num_steps + 1, device=device)

        for i in range(num_steps):
            t = torch.full((b,), time[i], device=device)
            dt = torch.full((b, 1, 1, 1), time[i+1] - time[i], device=device)

            v = self.model(x, ridge_map, basin_map, water_level, t)

            x = x + dt * v

        return x

    def on_train_epoch_end(self):
        sea_level = 0.0
        ridge_map = torch.from_numpy(imageio.imread(self.mid_visual_ridge))[None,None,:].to(device=self.device, dtype=torch.float32)

        basin_map = torch.from_numpy(imageio.imread(self.mid_visual_basins))[None,None,:].to(device=self.device)
        basin_map = (basin_map>=sea_level).to(torch.float32)
        output = self.inference_step(ridge_map, basin_map, sea_level)
        mid_visual_result = output.squeeze([1])
        self.logger.experiment.add_scalar("Visualize/Min", mid_visual_result.min(), self.current_epoch)
        self.logger.experiment.add_scalar("Visualize/Max", mid_visual_result.max(), self.current_epoch)
        self.logger.experiment.add_scalar("Visualize/Mean", mid_visual_result.mean(), self.current_epoch)
        mid_visual_result = (mid_visual_result - mid_visual_result.min()) / (mid_visual_result.max() - mid_visual_result.min())
        self.logger.experiment.add_image(f'Visualize/Model Output', mid_visual_result, self.current_epoch)

        vram_data = torch.cuda.mem_get_info()
        vram_usage = (vram_data[1] - vram_data[0]) / (1024 ** 2)
        self.logger.experiment.add_scalar(f"Other/VRAM Usage", vram_usage, self.current_epoch)
        torch.cuda.reset_peak_memory_stats()
        if self.current_epoch == 0:
            mid_visual_gt = torch.from_numpy(imageio.imread(self.mid_visual_gt))[None,:]
            mid_visual_gt = (mid_visual_gt - mid_visual_gt.min()) / (mid_visual_gt.max() - mid_visual_gt.min())
            self.logger.experiment.add_image(f'Visualize/Ridge', ridge_map.squeeze([1]), self.current_epoch)
            self.logger.experiment.add_image(f'Visualize/Basin', basin_map.squeeze([1]), self.current_epoch)
            self.logger.experiment.add_image(f'Visualize/GT', mid_visual_gt, self.current_epoch)

    def on_validation_epoch_end(self):
        epoch_averages = torch.stack(self.val_metrics).nanmean(dim=0)
        self.logger.experiment.add_scalar("Val/Loss", epoch_averages, self.current_epoch)
        self.val_metrics.clear()

    #def on_before_optimizer_step(self, optimizer):
    #    norms = grad_norm(self.model, norm_type=2)
    #    self.log_dict(norms, logger=True)



# Example usage
if __name__ == "__main__":
    torch.set_float32_matmul_precision('medium')
    torch._dynamo.config.recompile_limit = 2
    if torch.cuda.is_available() and torch.version.cuda:
        print('Optimising computing and memory use via cuDNN! (NVIDIA GPU only).')
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.allow_tf32 = True
    elif torch.cuda.is_available() and torch.version.hip:
        print('Optimising computing using TunableOp! (AMD GPU only).')
        torch.cuda.tunable.enable()
        torch.cuda.tunable.set_filename('TunableOp_results')

    train_split, val_split = data_utils.make_dataset_t_v('dataset_large')

    callbacks = []
    callbacks.append(LearningRateMonitor(logging_interval='epoch'))
    model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath="", filename="FlashScape",
                                                    save_weights_only=False,
                                                    enable_version_counter=False, save_last=False)
    callbacks.append(model_checkpoint)
    #swa_callback = StochasticWeightAveraging(1e-5, 0.8, int(0.2 * 100 - 1))
    #callbacks.append(swa_callback)
    #lr_finder = LearningRateFinder(1e-6, 0.1)
    #callbacks.append(lr_finder)
    #model = PLModule.load_from_checkpoint('FlashScape Swin Reference.ckpt')
    trainer = pl.Trainer(max_epochs=100, log_every_n_steps=1, logger=TensorBoardLogger(f'lightning_logs', name=f'FlashScape Pure DiT Swin p16 16ws 3ffn'),
                         accelerator="gpu", enable_checkpointing=True,
                         precision='16-mixed', enable_progress_bar=True, num_sanity_val_steps=0, callbacks=callbacks)

    with trainer.init_module():
        model = PLModule('dataset_large/Ridge_11417648.tiff',
                         'dataset_large/Basins_11417648.tiff',
                         'dataset_large/11417648.tiff')
    model = torch.compile(model)


    train_dataset = data_utils.TrainDataset(train_split)
    val_dataset = data_utils.ValDataset(val_split)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16,
                                               num_workers=8, pin_memory=False, persistent_workers=True, shuffle=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=16,
                                             num_workers=8, pin_memory=False, persistent_workers=True)

    trainer.fit(model,
                val_dataloaders=val_loader,
                train_dataloaders=train_loader,
                ckpt_path='FlashScape Swin Reference.ckpt')