|
|
import time |
|
|
import argparse |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import List, Sequence |
|
|
import sys |
|
|
from datetime import datetime, timedelta |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from model_architect.UNet_DDPM import UNet_with_time, DDPM |
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
input_frame: int = 12 |
|
|
output_frame: int = 6 |
|
|
cond_nc: int = 5 |
|
|
time_emb_dim: int = 128 |
|
|
base_chs: int = 32 |
|
|
chs_mult: tuple = (1, 2, 4, 8, 8) |
|
|
use_attn_list: tuple = (0, 0, 1, 1, 1) |
|
|
n_res_blocks: int = 2 |
|
|
n_steps: int = 1000 |
|
|
dropout: float = 0.1 |
|
|
|
|
|
def data_loading(BASETIME, device): |
|
|
data_npz = np.load(f'./sample_data/sample_{BASETIME}.npz') |
|
|
|
|
|
inputs = {} |
|
|
for key in data_npz: |
|
|
inputs[key] = torch.from_numpy(data_npz[key]).to(device) |
|
|
|
|
|
return inputs |
|
|
|
|
|
def arg_parse(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
'--pred-hr', |
|
|
type=str, |
|
|
default='1hr', |
|
|
choices=[ |
|
|
'1hr', |
|
|
'6hr' |
|
|
] |
|
|
) |
|
|
parser.add_argument( |
|
|
'--pred-mode', |
|
|
type=str, |
|
|
default='DDPM', |
|
|
choices=[ |
|
|
'DDPM', |
|
|
'DDIM' |
|
|
] |
|
|
) |
|
|
parser.add_argument('--basetime', type=str, default='202504131100') |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
|
config = Config() |
|
|
args = arg_parse() |
|
|
pred_hr = args.pred_hr |
|
|
pred_mode = args.pred_mode |
|
|
|
|
|
BASETIME = args.basetime |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
inputs = data_loading(BASETIME, device) |
|
|
model_config = Config() |
|
|
if pred_hr == '6hr': |
|
|
model_config.input_frame = 72 |
|
|
model_config.output_frame = 36 |
|
|
print("Prediction mode:", pred_mode) |
|
|
print("Prediction horizon:", pred_hr) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_himawari = inputs['Himawari'].squeeze(2) |
|
|
topo = inputs['topo'] |
|
|
input_ = torch.cat([prev_himawari, topo], dim=1) |
|
|
WRF = F.interpolate(inputs['WRF'].squeeze(2), scale_factor=4, mode='bilinear') |
|
|
|
|
|
clearsky = inputs['clearsky'] |
|
|
if pred_hr == '1hr': |
|
|
WRF = WRF[:, :6] |
|
|
clearsky = clearsky[:, :6] |
|
|
|
|
|
backbone = UNet_with_time(model_config) |
|
|
model = DDPM(backbone, output_shape=(model_config.output_frame, 512, 512)) |
|
|
|
|
|
|
|
|
if pred_hr == '1hr': |
|
|
ckpt_path = './model_weights/ft06_01hr/weights.ckpt' |
|
|
elif pred_hr == '6hr': |
|
|
ckpt_path = './model_weights/ft36_06hr/weights.ckpt' |
|
|
|
|
|
ckpt = torch.load(ckpt_path, weights_only=True) |
|
|
model.load_state_dict(ckpt['state_dict']) |
|
|
model.eval() |
|
|
model = model.to(device) |
|
|
|
|
|
if pred_mode == 'DDPM': |
|
|
pred_clr_idx = model.sample_ddpm( |
|
|
input_, |
|
|
input_cond=WRF, |
|
|
verbose="text" |
|
|
) |
|
|
elif pred_mode == 'DDIM': |
|
|
pred_clr_idx = model.sample_ddim( |
|
|
input_, |
|
|
input_cond=WRF, |
|
|
ddim_steps=100, |
|
|
verbose="text" |
|
|
) |
|
|
|
|
|
pred_clr_idx = (pred_clr_idx + 1.0) / 2.0 |
|
|
pred_clr_idx = pred_clr_idx.clamp(0.0, 1.0) |
|
|
|
|
|
|
|
|
pred_srad = pred_clr_idx * clearsky |
|
|
|
|
|
|
|
|
np.save(f'./pred_{BASETIME}_{pred_hr}_{pred_mode}.npy', pred_srad.cpu().numpy()) |
|
|
print('Done') |
|
|
|