File size: 4,855 Bytes
d2f661a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from datetime import timedelta
import gc
import gzip
import os
import pickle
import numpy as np
from ldcast.features import batch, patches, split, transform
#file_dir = os.path.dirname(os.path.abspath(__file__))
file_dir = os.path.dirname("/data/data_WF/ldcast_precipitation/ldcast/")
def setup_data(
use_obs=True,
use_nwp=False,
obs_vars=("RZC",),
nwp_vars=(
"cape", "cin", "rate-cp", "rate-tp", "t2m",
"tclw", "tcwv", "u", "v"
),
nwp_lags=(0,12),
target_var="RZC",
batch_size=8,
past_timesteps=4,
future_timesteps=20,
timestep_secs=300,
nwp_timestep_secs=3600,
sampler_file=None,
#chunks_file="./preprocess_data/split_chunks.pkl.gz",
chunks_file="./data/split_chunks.pkl.gz",
sample_shape=(4,4)
):
target = target_var + "-T"
predictors_obs = [v + "-O" for v in obs_vars]
predictors = []
if use_obs:
predictors += predictors_obs
if use_nwp:
predictors.append("nwp")
variables = {
target: {
"sources": [target_var],
"timesteps": np.arange(1,future_timesteps+1),
}
}
for (var, raw_var) in zip(predictors_obs, obs_vars):
variables[var] = {
"sources": [raw_var],
"timesteps": np.arange(-past_timesteps+1,1)
}
nwp_t1 = int(np.ceil(future_timesteps*timestep_secs/nwp_timestep_secs)) + 2
nwp_range = np.arange(nwp_t1)
variables["nwp"] = {
"sources": nwp_vars,
"timesteps": nwp_range,
"timestep_secs": nwp_timestep_secs
}
# determine which raw variables are needed, then load them
raw_vars = set.union(
*(set(variables[v]["sources"]) for v in predictors_obs+[target])
)
if use_nwp:
for raw_var_base in variables["nwp"]["sources"]:
raw_vars.update(f"{raw_var_base}-{lag}" for lag in nwp_lags)
raw = {
var: patches.load_all_patches(
os.path.join(file_dir, f"./data/{var}/"), var
#os.path.join(file_dir, f"./preprocess_data/{var}/"), var
)
for var in raw_vars
}
# Load pregenerated train/valid/test split data.
# These can be generated with features.split.get_chunks()
with gzip.open(os.path.join(file_dir, chunks_file), 'rb') as f:
chunks = pickle.load(f)
(raw, _) = split.train_valid_test_split(raw, var, chunks=chunks)
transform_rain = lambda: transform.default_rainrate_transform(
raw["train"]["RZC"]["scale"]
)
transform_cape = lambda: transform.normalize_threshold(
log=True,
threshold=1.0, fill_value=1.0,
mean=1.530, std=0.859
)
transform_rate_tp = lambda: transform.normalize_threshold(
log=True,
threshold=1e-5, fill_value=1e-5,
mean=-3.831, std=0.650
)
transform_wind = lambda: transform.normalize(std=9.44)
transforms = {
"RZC-T": transform_rain(),
"RZC-O": transform_rain(),
"cape": transform_cape(),
"cin": transform_cape(),
"rate-tp": transform_rate_tp(),
"rate-cp": transform_rate_tp(),
"t2m": transform.normalize(mean=286.069, std=7.323),
"tclw": transform.normalize_threshold(
log=True,
threshold=0.001, fill_value=0.001,
mean=-1.486, std=0.638
),
"tcwv": transform.normalize(std=17.307),
"u": transform_wind(),
"v": transform_wind()
}
transforms["nwp"] = transform.combine([transforms[v] for v in nwp_vars])
for (var_name, var_data) in variables.items():
var_data["transform"] = transforms[var_name]
if sampler_file is None:
sampler_file = {
#"train": "../cache/sampler_nowcaster_train.pkl",
#"valid": "../cache/sampler_nowcaster_valid.pkl",
#"test": "../cache/sampler_nowcaster_test.pkl",
"train": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_train.pkl",
"valid": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_valid.pkl",
"test": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_test.pkl",
}
bins = np.exp(np.linspace(np.log(0.2), np.log(50), 10))
datamodule = split.DataModule(
variables, raw, predictors, target, target,
forecast_raw_vars=nwp_vars,
interval=timedelta(seconds=timestep_secs),
batch_size=batch_size, sampling_bins=bins,
time_range_sampling=(-past_timesteps+1,future_timesteps+1),
sampler_file=sampler_file,
sample_shape=sample_shape,
valid_seed=1234, test_seed=2345,
)
gc.collect()
return datamodule
|