ldcast_code / scripts /train_nowcaster.py
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
from datetime import timedelta
import gc
import gzip
import os
import pickle
import numpy as np
from ldcast.features import batch, patches, split, transform
#file_dir = os.path.dirname(os.path.abspath(__file__))
file_dir = os.path.dirname("/data/data_WF/ldcast_precipitation/ldcast/")
def setup_data(
use_obs=True,
use_nwp=False,
obs_vars=("RZC",),
nwp_vars=(
"cape", "cin", "rate-cp", "rate-tp", "t2m",
"tclw", "tcwv", "u", "v"
),
nwp_lags=(0,12),
target_var="RZC",
batch_size=8,
past_timesteps=4,
future_timesteps=20,
timestep_secs=300,
nwp_timestep_secs=3600,
sampler_file=None,
#chunks_file="./preprocess_data/split_chunks.pkl.gz",
chunks_file="./data/split_chunks.pkl.gz",
sample_shape=(4,4)
):
target = target_var + "-T"
predictors_obs = [v + "-O" for v in obs_vars]
predictors = []
if use_obs:
predictors += predictors_obs
if use_nwp:
predictors.append("nwp")
variables = {
target: {
"sources": [target_var],
"timesteps": np.arange(1,future_timesteps+1),
}
}
for (var, raw_var) in zip(predictors_obs, obs_vars):
variables[var] = {
"sources": [raw_var],
"timesteps": np.arange(-past_timesteps+1,1)
}
nwp_t1 = int(np.ceil(future_timesteps*timestep_secs/nwp_timestep_secs)) + 2
nwp_range = np.arange(nwp_t1)
variables["nwp"] = {
"sources": nwp_vars,
"timesteps": nwp_range,
"timestep_secs": nwp_timestep_secs
}
# determine which raw variables are needed, then load them
raw_vars = set.union(
*(set(variables[v]["sources"]) for v in predictors_obs+[target])
)
if use_nwp:
for raw_var_base in variables["nwp"]["sources"]:
raw_vars.update(f"{raw_var_base}-{lag}" for lag in nwp_lags)
raw = {
var: patches.load_all_patches(
os.path.join(file_dir, f"./data/{var}/"), var
#os.path.join(file_dir, f"./preprocess_data/{var}/"), var
)
for var in raw_vars
}
# Load pregenerated train/valid/test split data.
# These can be generated with features.split.get_chunks()
with gzip.open(os.path.join(file_dir, chunks_file), 'rb') as f:
chunks = pickle.load(f)
(raw, _) = split.train_valid_test_split(raw, var, chunks=chunks)
transform_rain = lambda: transform.default_rainrate_transform(
raw["train"]["RZC"]["scale"]
)
transform_cape = lambda: transform.normalize_threshold(
log=True,
threshold=1.0, fill_value=1.0,
mean=1.530, std=0.859
)
transform_rate_tp = lambda: transform.normalize_threshold(
log=True,
threshold=1e-5, fill_value=1e-5,
mean=-3.831, std=0.650
)
transform_wind = lambda: transform.normalize(std=9.44)
transforms = {
"RZC-T": transform_rain(),
"RZC-O": transform_rain(),
"cape": transform_cape(),
"cin": transform_cape(),
"rate-tp": transform_rate_tp(),
"rate-cp": transform_rate_tp(),
"t2m": transform.normalize(mean=286.069, std=7.323),
"tclw": transform.normalize_threshold(
log=True,
threshold=0.001, fill_value=0.001,
mean=-1.486, std=0.638
),
"tcwv": transform.normalize(std=17.307),
"u": transform_wind(),
"v": transform_wind()
}
transforms["nwp"] = transform.combine([transforms[v] for v in nwp_vars])
for (var_name, var_data) in variables.items():
var_data["transform"] = transforms[var_name]
if sampler_file is None:
sampler_file = {
#"train": "../cache/sampler_nowcaster_train.pkl",
#"valid": "../cache/sampler_nowcaster_valid.pkl",
#"test": "../cache/sampler_nowcaster_test.pkl",
"train": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_train.pkl",
"valid": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_valid.pkl",
"test": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_test.pkl",
}
bins = np.exp(np.linspace(np.log(0.2), np.log(50), 10))
datamodule = split.DataModule(
variables, raw, predictors, target, target,
forecast_raw_vars=nwp_vars,
interval=timedelta(seconds=timestep_secs),
batch_size=batch_size, sampling_bins=bins,
time_range_sampling=(-past_timesteps+1,future_timesteps+1),
sampler_file=sampler_file,
sample_shape=sample_shape,
valid_seed=1234, test_seed=2345,
)
gc.collect()
return datamodule