Diffusion_SolRad / inference.py
Jason-thingnario's picture
upload DDPM inference script
be89dda
import time
import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import List, Sequence
import sys
from datetime import datetime, timedelta
import numpy as np
import torch
import torch.nn.functional as F
from model_architect.UNet_DDPM import UNet_with_time, DDPM
@dataclass
class Config:
input_frame: int = 12
output_frame: int = 6
cond_nc: int = 5
time_emb_dim: int = 128
base_chs: int = 32
chs_mult: tuple = (1, 2, 4, 8, 8) ## different resolution
use_attn_list: tuple = (0, 0, 1, 1, 1) # 0 means no attention, 1 means use attention
n_res_blocks: int = 2
n_steps: int = 1000
dropout: float = 0.1
def data_loading(BASETIME, device):
data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz')
inputs = {}
for key in data_npz:
inputs[key] = torch.from_numpy(data_npz[key]).to(device)
return inputs
def arg_parse():
parser = argparse.ArgumentParser()
parser.add_argument(
'--pred-hr',
type=str,
default='1hr',
choices=[
'1hr',
'6hr'
]
)
parser.add_argument(
'--pred-mode',
type=str,
default='DDPM',
choices=[
'DDPM',
'DDIM'
]
)
parser.add_argument('--basetime', type=str, default='202504131100')
args = parser.parse_args()
return args
if __name__ == "__main__":
config = Config()
args = arg_parse()
pred_hr = args.pred_hr
pred_mode = args.pred_mode
BASETIME = args.basetime
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = data_loading(BASETIME, device)
model_config = Config()
if pred_hr == '6hr':
model_config.input_frame = 72
model_config.output_frame = 36
print("Prediction mode:", pred_mode)
print("Prediction horizon:", pred_hr)
## preporcess inputs for DDPM model
## concat previous Himawari and topo as conditional input (B, 5, 512, 512)
## WRF dim: (B, 36, 512, 512). 1hr: (B, 6, 512, 512), 6hr: (B, 36, 512, 512)
prev_himawari = inputs['Himawari'].squeeze(2)
topo = inputs['topo']
input_ = torch.cat([prev_himawari, topo], dim=1)
WRF = F.interpolate(inputs['WRF'].squeeze(2), scale_factor=4, mode='bilinear')
clearsky = inputs['clearsky']
if pred_hr == '1hr':
WRF = WRF[:, :6]
clearsky = clearsky[:, :6]
backbone = UNet_with_time(model_config)
model = DDPM(backbone, output_shape=(model_config.output_frame, 512, 512))
## load model weights
if pred_hr == '1hr':
ckpt_path = './model_weights/ft06_01hr/weights.ckpt'
elif pred_hr == '6hr':
ckpt_path = './model_weights/ft36_06hr/weights.ckpt'
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt['state_dict'])
model.eval()
model = model.to(device)
if pred_mode == 'DDPM':
pred_clr_idx = model.sample_ddpm(
input_,
input_cond=WRF,
verbose="text"
)
elif pred_mode == 'DDIM':
pred_clr_idx = model.sample_ddim(
input_,
input_cond=WRF,
ddim_steps=100,
verbose="text"
)
pred_clr_idx = (pred_clr_idx + 1.0) / 2.0
pred_clr_idx = pred_clr_idx.clamp(0.0, 1.0)
## transform clearsky index to solar radiation
pred_srad = pred_clr_idx * clearsky
## save prediction
np.save(f'./pred_{BASETIME}_{pred_hr}_{pred_mode}.npy', pred_srad.cpu().numpy())
print('Done')