File size: 4,190 Bytes
d2f661a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from datetime import datetime, timedelta
import glob
import os

from fire import Fire
import h5py
from matplotlib import pyplot as plt
import numpy as np

from ldcast import forecast
from ldcast.visualization import plots


def read_data(

    data_dir="../data/demo/20210622",

    t0=datetime(2021,6,22,18,35),

    interval=timedelta(minutes=5),

    past_timesteps=4,

    crop_box=((128,480), (160,608))

):
    cb = crop_box
    R_past = []
    t = t0 - (past_timesteps-1) * interval
    for i in range(past_timesteps):
        timestamp = t.strftime("%y%j%H%M")
        fn = f"RZC{timestamp}VL.[80]01.h5"
        fn = os.path.join(data_dir, fn)
        found_files = glob.glob(fn)
        if found_files:
            fn = found_files[0]
        else:
            raise FileNotFoundError(f"Unable to find data file {fn}.")
        with h5py.File(fn, 'r') as f:
            R = f["dataset1"]["data1"]["data"][:]
        R = R[cb[0][0]:cb[0][1], cb[1][0]:cb[1][1]]
        R_past.append(R)
        t += interval

    R_past = np.stack(R_past, axis=0)
    return R_past


def plot_border(ax, crop_box=((128,480), (160,608))):    
    import shapefile
    border = shapefile.Reader("../data/Border_CH.shp")
    shapes = list(border.shapeRecords())
    for shape in shapes:
        x = np.array([i[0]/1000. for i in shape.shape.points[:]])
        y = np.array([i[1]/1000. for i in shape.shape.points[:]])
        ax.plot(
            x-crop_box[1][0]-255, 480-y-crop_box[0][0],
            'k', linewidth=1.0
        )


def plot_frame(R, fn, draw_border=True, t=None, label=None):
    fig = plt.figure(dpi=150)
    ax = fig.add_subplot()
    plots.plot_precip_image(ax, R)
    if draw_border:
        plot_border(ax)
    if t is not None:
        timestamp = "%Y-%m-%d %H:%M UTC"
        if label is not None:
            timestamp += f" ({label})"
        ax.text(
            0.02, 0.98, t.strftime(timestamp),
            horizontalalignment='left', verticalalignment='top',
            transform=ax.transAxes            
        )
    
    fig.savefig(fn, bbox_inches='tight')
    plt.close(fig)


def forecast_demo(

    ldm_weights_fn="../models/genforecast/genforecast-radaronly-256x256-20step.pt",

    autoenc_weights_fn="../models/autoenc/autoenc-32-0.01.pt",

    num_diffusion_iters=50,

    out_dir="../figures/demo/",

    data_dir="../data/demo/20210622",

    t0=datetime(2021,6,22,18,35),

    interval=timedelta(minutes=5),

    past_timesteps=4,

    crop_box=((128,480), (160,608)),

    draw_border=True,

    ensemble_members=1,

):
    R_past = read_data(
        data_dir=data_dir, t0=t0, interval=interval,
        past_timesteps=past_timesteps, crop_box=crop_box
    )
    if ensemble_members == 1:
        fc = forecast.Forecast(
            ldm_weights_fn=ldm_weights_fn,
            autoenc_weights_fn=autoenc_weights_fn
        )
        R_pred = fc(
            R_past,
            num_diffusion_iters=num_diffusion_iters
        )
    elif ensemble_members > 1:
        fc = forecast.ForecastDistributed(
            ldm_weights_fn=ldm_weights_fn,
            autoenc_weights_fn=autoenc_weights_fn,
        )
        R_past = R_past.reshape((1,) + R_past.shape)
        R_pred = fc(
            R_past,
            num_diffusion_iters=num_diffusion_iters,
            ensemble_members=ensemble_members   
        )
        R_past = R_past[0,...]
        R_pred = R_pred[0,...].mean(axis=-1) # compute ensemble mean
    else:
        raise ValueError("ensemble_members must be > 0")

    os.makedirs(out_dir, exist_ok=True)
    for k in range(R_past.shape[0]):
        fn = os.path.join(out_dir, f"R_past-{k:02d}.png")
        t = t0 - (R_past.shape[0]-k-1) * interval
        plot_frame(R_past[k,:,:], fn, draw_border=draw_border,
            t=t, label="Real")
    for k in range(R_pred.shape[0]):
        fn = os.path.join(out_dir, f"R_pred-{k:02d}.png")
        t = t0 + (k+1)*interval
        plot_frame(R_pred[k,:,:], fn, draw_border=draw_border,
            t=t, label="Predicted")


if __name__ == "__main__":
    Fire(forecast_demo)