ldcast_code / scripts /forecast_demo.py
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
from datetime import datetime, timedelta
import glob
import os
from fire import Fire
import h5py
from matplotlib import pyplot as plt
import numpy as np
from ldcast import forecast
from ldcast.visualization import plots
def read_data(
data_dir="../data/demo/20210622",
t0=datetime(2021,6,22,18,35),
interval=timedelta(minutes=5),
past_timesteps=4,
crop_box=((128,480), (160,608))
):
cb = crop_box
R_past = []
t = t0 - (past_timesteps-1) * interval
for i in range(past_timesteps):
timestamp = t.strftime("%y%j%H%M")
fn = f"RZC{timestamp}VL.[80]01.h5"
fn = os.path.join(data_dir, fn)
found_files = glob.glob(fn)
if found_files:
fn = found_files[0]
else:
raise FileNotFoundError(f"Unable to find data file {fn}.")
with h5py.File(fn, 'r') as f:
R = f["dataset1"]["data1"]["data"][:]
R = R[cb[0][0]:cb[0][1], cb[1][0]:cb[1][1]]
R_past.append(R)
t += interval
R_past = np.stack(R_past, axis=0)
return R_past
def plot_border(ax, crop_box=((128,480), (160,608))):
import shapefile
border = shapefile.Reader("../data/Border_CH.shp")
shapes = list(border.shapeRecords())
for shape in shapes:
x = np.array([i[0]/1000. for i in shape.shape.points[:]])
y = np.array([i[1]/1000. for i in shape.shape.points[:]])
ax.plot(
x-crop_box[1][0]-255, 480-y-crop_box[0][0],
'k', linewidth=1.0
)
def plot_frame(R, fn, draw_border=True, t=None, label=None):
fig = plt.figure(dpi=150)
ax = fig.add_subplot()
plots.plot_precip_image(ax, R)
if draw_border:
plot_border(ax)
if t is not None:
timestamp = "%Y-%m-%d %H:%M UTC"
if label is not None:
timestamp += f" ({label})"
ax.text(
0.02, 0.98, t.strftime(timestamp),
horizontalalignment='left', verticalalignment='top',
transform=ax.transAxes
)
fig.savefig(fn, bbox_inches='tight')
plt.close(fig)
def forecast_demo(
ldm_weights_fn="../models/genforecast/genforecast-radaronly-256x256-20step.pt",
autoenc_weights_fn="../models/autoenc/autoenc-32-0.01.pt",
num_diffusion_iters=50,
out_dir="../figures/demo/",
data_dir="../data/demo/20210622",
t0=datetime(2021,6,22,18,35),
interval=timedelta(minutes=5),
past_timesteps=4,
crop_box=((128,480), (160,608)),
draw_border=True,
ensemble_members=1,
):
R_past = read_data(
data_dir=data_dir, t0=t0, interval=interval,
past_timesteps=past_timesteps, crop_box=crop_box
)
if ensemble_members == 1:
fc = forecast.Forecast(
ldm_weights_fn=ldm_weights_fn,
autoenc_weights_fn=autoenc_weights_fn
)
R_pred = fc(
R_past,
num_diffusion_iters=num_diffusion_iters
)
elif ensemble_members > 1:
fc = forecast.ForecastDistributed(
ldm_weights_fn=ldm_weights_fn,
autoenc_weights_fn=autoenc_weights_fn,
)
R_past = R_past.reshape((1,) + R_past.shape)
R_pred = fc(
R_past,
num_diffusion_iters=num_diffusion_iters,
ensemble_members=ensemble_members
)
R_past = R_past[0,...]
R_pred = R_pred[0,...].mean(axis=-1) # compute ensemble mean
else:
raise ValueError("ensemble_members must be > 0")
os.makedirs(out_dir, exist_ok=True)
for k in range(R_past.shape[0]):
fn = os.path.join(out_dir, f"R_past-{k:02d}.png")
t = t0 - (R_past.shape[0]-k-1) * interval
plot_frame(R_past[k,:,:], fn, draw_border=draw_border,
t=t, label="Real")
for k in range(R_pred.shape[0]):
fn = os.path.join(out_dir, f"R_pred-{k:02d}.png")
t = t0 + (k+1)*interval
plot_frame(R_pred[k,:,:], fn, draw_border=draw_border,
t=t, label="Predicted")
if __name__ == "__main__":
Fire(forecast_demo)