|
|
from datetime import datetime, timedelta
|
|
|
import glob
|
|
|
import os
|
|
|
|
|
|
from fire import Fire
|
|
|
import h5py
|
|
|
from matplotlib import pyplot as plt
|
|
|
import numpy as np
|
|
|
|
|
|
from ldcast import forecast
|
|
|
from ldcast.visualization import plots
|
|
|
|
|
|
|
|
|
def read_data(
|
|
|
data_dir="../data/demo/20210622",
|
|
|
t0=datetime(2021,6,22,18,35),
|
|
|
interval=timedelta(minutes=5),
|
|
|
past_timesteps=4,
|
|
|
crop_box=((128,480), (160,608))
|
|
|
):
|
|
|
cb = crop_box
|
|
|
R_past = []
|
|
|
t = t0 - (past_timesteps-1) * interval
|
|
|
for i in range(past_timesteps):
|
|
|
timestamp = t.strftime("%y%j%H%M")
|
|
|
fn = f"RZC{timestamp}VL.[80]01.h5"
|
|
|
fn = os.path.join(data_dir, fn)
|
|
|
found_files = glob.glob(fn)
|
|
|
if found_files:
|
|
|
fn = found_files[0]
|
|
|
else:
|
|
|
raise FileNotFoundError(f"Unable to find data file {fn}.")
|
|
|
with h5py.File(fn, 'r') as f:
|
|
|
R = f["dataset1"]["data1"]["data"][:]
|
|
|
R = R[cb[0][0]:cb[0][1], cb[1][0]:cb[1][1]]
|
|
|
R_past.append(R)
|
|
|
t += interval
|
|
|
|
|
|
R_past = np.stack(R_past, axis=0)
|
|
|
return R_past
|
|
|
|
|
|
|
|
|
def plot_border(ax, crop_box=((128,480), (160,608))):
|
|
|
import shapefile
|
|
|
border = shapefile.Reader("../data/Border_CH.shp")
|
|
|
shapes = list(border.shapeRecords())
|
|
|
for shape in shapes:
|
|
|
x = np.array([i[0]/1000. for i in shape.shape.points[:]])
|
|
|
y = np.array([i[1]/1000. for i in shape.shape.points[:]])
|
|
|
ax.plot(
|
|
|
x-crop_box[1][0]-255, 480-y-crop_box[0][0],
|
|
|
'k', linewidth=1.0
|
|
|
)
|
|
|
|
|
|
|
|
|
def plot_frame(R, fn, draw_border=True, t=None, label=None):
|
|
|
fig = plt.figure(dpi=150)
|
|
|
ax = fig.add_subplot()
|
|
|
plots.plot_precip_image(ax, R)
|
|
|
if draw_border:
|
|
|
plot_border(ax)
|
|
|
if t is not None:
|
|
|
timestamp = "%Y-%m-%d %H:%M UTC"
|
|
|
if label is not None:
|
|
|
timestamp += f" ({label})"
|
|
|
ax.text(
|
|
|
0.02, 0.98, t.strftime(timestamp),
|
|
|
horizontalalignment='left', verticalalignment='top',
|
|
|
transform=ax.transAxes
|
|
|
)
|
|
|
|
|
|
fig.savefig(fn, bbox_inches='tight')
|
|
|
plt.close(fig)
|
|
|
|
|
|
|
|
|
def forecast_demo(
|
|
|
ldm_weights_fn="../models/genforecast/genforecast-radaronly-256x256-20step.pt",
|
|
|
autoenc_weights_fn="../models/autoenc/autoenc-32-0.01.pt",
|
|
|
num_diffusion_iters=50,
|
|
|
out_dir="../figures/demo/",
|
|
|
data_dir="../data/demo/20210622",
|
|
|
t0=datetime(2021,6,22,18,35),
|
|
|
interval=timedelta(minutes=5),
|
|
|
past_timesteps=4,
|
|
|
crop_box=((128,480), (160,608)),
|
|
|
draw_border=True,
|
|
|
ensemble_members=1,
|
|
|
):
|
|
|
R_past = read_data(
|
|
|
data_dir=data_dir, t0=t0, interval=interval,
|
|
|
past_timesteps=past_timesteps, crop_box=crop_box
|
|
|
)
|
|
|
if ensemble_members == 1:
|
|
|
fc = forecast.Forecast(
|
|
|
ldm_weights_fn=ldm_weights_fn,
|
|
|
autoenc_weights_fn=autoenc_weights_fn
|
|
|
)
|
|
|
R_pred = fc(
|
|
|
R_past,
|
|
|
num_diffusion_iters=num_diffusion_iters
|
|
|
)
|
|
|
elif ensemble_members > 1:
|
|
|
fc = forecast.ForecastDistributed(
|
|
|
ldm_weights_fn=ldm_weights_fn,
|
|
|
autoenc_weights_fn=autoenc_weights_fn,
|
|
|
)
|
|
|
R_past = R_past.reshape((1,) + R_past.shape)
|
|
|
R_pred = fc(
|
|
|
R_past,
|
|
|
num_diffusion_iters=num_diffusion_iters,
|
|
|
ensemble_members=ensemble_members
|
|
|
)
|
|
|
R_past = R_past[0,...]
|
|
|
R_pred = R_pred[0,...].mean(axis=-1)
|
|
|
else:
|
|
|
raise ValueError("ensemble_members must be > 0")
|
|
|
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
for k in range(R_past.shape[0]):
|
|
|
fn = os.path.join(out_dir, f"R_past-{k:02d}.png")
|
|
|
t = t0 - (R_past.shape[0]-k-1) * interval
|
|
|
plot_frame(R_past[k,:,:], fn, draw_border=draw_border,
|
|
|
t=t, label="Real")
|
|
|
for k in range(R_pred.shape[0]):
|
|
|
fn = os.path.join(out_dir, f"R_pred-{k:02d}.png")
|
|
|
t = t0 + (k+1)*interval
|
|
|
plot_frame(R_pred[k,:,:], fn, draw_border=draw_border,
|
|
|
t=t, label="Predicted")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
Fire(forecast_demo)
|
|
|
|