|
|
from datetime import datetime, timedelta
|
|
|
import os
|
|
|
import pickle
|
|
|
|
|
|
from numba import njit, prange, types
|
|
|
from numba.typed import Dict
|
|
|
import numpy as np
|
|
|
from torch.utils.data import Dataset, IterableDataset
|
|
|
|
|
|
from .patches import unpack_patches
|
|
|
from .sampling import EqualFrequencySampler
|
|
|
|
|
|
|
|
|
class BatchGenerator:
|
|
|
def __init__(self,
|
|
|
variables,
|
|
|
raw,
|
|
|
predictors,
|
|
|
target,
|
|
|
primary_var,
|
|
|
time_range_sampling=(-1,2),
|
|
|
forecast_raw_vars=(),
|
|
|
sampling_bins=None,
|
|
|
sampler_file=None,
|
|
|
sample_shape=(4,4),
|
|
|
batch_size=32,
|
|
|
interval=timedelta(minutes=5),
|
|
|
random_seed=None,
|
|
|
augment=False
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.batch_size = batch_size
|
|
|
self.interval = interval
|
|
|
self.interval_secs = np.int64(self.interval.total_seconds())
|
|
|
self.variables = variables
|
|
|
self.predictors = predictors
|
|
|
self.target = target
|
|
|
self.used_variables = predictors + [target]
|
|
|
self.rng = np.random.RandomState(seed=random_seed)
|
|
|
self.augment = augment
|
|
|
|
|
|
|
|
|
self.sources = set.union(
|
|
|
*(set(variables[v]["sources"]) for v in self.used_variables)
|
|
|
)
|
|
|
self.forecast_raw_vars = set(forecast_raw_vars) & self.sources
|
|
|
self.patch_index = {}
|
|
|
for raw_name_base in self.sources:
|
|
|
if raw_name_base in forecast_raw_vars:
|
|
|
raw_names = (
|
|
|
rn for rn in raw if rn.startswith(raw_name_base+"-")
|
|
|
)
|
|
|
else:
|
|
|
raw_names = (raw_name_base,)
|
|
|
for raw_name in raw_names:
|
|
|
raw_data = raw[raw_name]
|
|
|
self.setup_index(raw_name, raw_data, sample_shape)
|
|
|
|
|
|
for raw_name in self.forecast_raw_vars:
|
|
|
patch_index_var = {
|
|
|
k: v for (k,v) in self.patch_index.items()
|
|
|
if k.startswith(raw_name+"-")
|
|
|
}
|
|
|
self.patch_index[raw_name] = \
|
|
|
ForecastPatchIndexWrapper(patch_index_var)
|
|
|
|
|
|
|
|
|
if (sampler_file is None) or not os.path.isfile(sampler_file):
|
|
|
print("No cached sampler found, creating a new one...")
|
|
|
primary_raw_var = variables[primary_var]["sources"][0]
|
|
|
t0 = t1 = None
|
|
|
for (var_name, var_data) in variables.items():
|
|
|
timesteps = var_data["timesteps"][[0,-1]].copy()
|
|
|
timesteps[0] -= 1
|
|
|
ts_secs = timesteps * \
|
|
|
var_data.get("timestep_secs", self.interval_secs)
|
|
|
timesteps = ts_secs // self.interval_secs
|
|
|
t0 = timesteps[0] if t0 is None else min(t0,timesteps[0])
|
|
|
t1 = timesteps[-1] if t1 is None else max(t1,timesteps[-1])
|
|
|
time_range_valid = (t0,t1+1)
|
|
|
self.sampler = EqualFrequencySampler(
|
|
|
sampling_bins, raw[primary_raw_var],
|
|
|
self.patch_index[primary_raw_var], sample_shape,
|
|
|
time_range_valid, time_range_sampling=time_range_sampling,
|
|
|
timestep_secs=self.interval_secs
|
|
|
)
|
|
|
if sampler_file is not None:
|
|
|
print(f"Caching sampler to {sampler_file}.")
|
|
|
with open(sampler_file, 'wb') as f:
|
|
|
pickle.dump(self.sampler, f)
|
|
|
else:
|
|
|
print(f"Loading cached sampler from {sampler_file}.")
|
|
|
with open(sampler_file, 'rb') as f:
|
|
|
self.sampler = pickle.load(f)
|
|
|
|
|
|
def setup_index(self, raw_name, raw_data, box_size):
|
|
|
zero_value = raw_data.get("zero_value", 0)
|
|
|
missing_value = raw_data.get("missing_value", zero_value)
|
|
|
|
|
|
self.patch_index[raw_name] = PatchIndex(
|
|
|
*unpack_patches(raw_data),
|
|
|
zero_value=zero_value,
|
|
|
missing_value=missing_value,
|
|
|
interval=self.interval,
|
|
|
box_size=box_size
|
|
|
)
|
|
|
|
|
|
def augmentations(self):
|
|
|
return tuple(self.rng.randint(2, size=3))
|
|
|
|
|
|
def augment_batch(self, batch, transpose, flipud, fliplr):
|
|
|
if self.augment:
|
|
|
if transpose:
|
|
|
axes = list(range(batch.ndim))
|
|
|
axes = axes[:-2] + [axes[-1], axes[-2]]
|
|
|
batch = batch.transpose(axes)
|
|
|
flips = []
|
|
|
if flipud:
|
|
|
flips.append(-2)
|
|
|
if fliplr:
|
|
|
flips.append(-1)
|
|
|
if flips:
|
|
|
batch = np.flip(batch, axis=flips)
|
|
|
return batch.copy()
|
|
|
|
|
|
def batch(self, samples=None, batch_size=None):
|
|
|
if batch_size is None:
|
|
|
batch_size = self.batch_size
|
|
|
|
|
|
if samples is None:
|
|
|
|
|
|
samples = self.sampler(batch_size)
|
|
|
|
|
|
print(samples)
|
|
|
(t0,i0,j0) = samples.T
|
|
|
|
|
|
if self.augment:
|
|
|
augmentations = self.augmentations()
|
|
|
|
|
|
batch = {}
|
|
|
|
|
|
|
|
|
|
|
|
for var_name in self.used_variables:
|
|
|
var_data = self.variables[var_name]
|
|
|
|
|
|
|
|
|
|
|
|
ts_secs = var_data.get("timestep_secs", self.interval_secs)
|
|
|
t_shift = -(t0 % ts_secs)
|
|
|
t0_shifted = t0 + t_shift
|
|
|
t = t0_shifted[:,None] + ts_secs*var_data["timesteps"][None,:]
|
|
|
t_relative = (t - t0[:,None]) / self.interval_secs
|
|
|
|
|
|
|
|
|
raw_data = (
|
|
|
self.patch_index[raw_name](t,i0,j0)
|
|
|
for raw_name in var_data["sources"]
|
|
|
)
|
|
|
|
|
|
|
|
|
batch_var = var_data["transform"](*raw_data)
|
|
|
|
|
|
|
|
|
add_dims = (1,) if batch_var.ndim == 4 else ()
|
|
|
batch_var = np.expand_dims(batch_var, add_dims)
|
|
|
|
|
|
|
|
|
if self.augment:
|
|
|
batch_var = self.augment_batch(batch_var, *augmentations)
|
|
|
|
|
|
|
|
|
batch[var_name] = (batch_var, t_relative.astype(np.float32))
|
|
|
|
|
|
pred_batch = [batch[v] for v in self.predictors]
|
|
|
target_batch = batch[self.target][0]
|
|
|
return (pred_batch, target_batch)
|
|
|
|
|
|
def batches(self, *args, num=None, **kwargs):
|
|
|
if num is not None:
|
|
|
for i in range(num):
|
|
|
yield self.batch(*args, **kwargs)
|
|
|
else:
|
|
|
while True:
|
|
|
yield self.batch(*args, **kwargs)
|
|
|
|
|
|
|
|
|
class StreamBatchDataset(IterableDataset):
|
|
|
def __init__(self, batch_gen, batches_per_epoch):
|
|
|
super().__init__()
|
|
|
self.batch_gen = batch_gen
|
|
|
self.batches_per_epoch = batches_per_epoch
|
|
|
|
|
|
def __iter__(self):
|
|
|
batches = self.batch_gen.batches(num=self.batches_per_epoch)
|
|
|
yield from batches
|
|
|
|
|
|
|
|
|
class DeterministicBatchDataset(Dataset):
|
|
|
def __init__(self, batch_gen, batches_per_epoch, random_seed=None):
|
|
|
super().__init__()
|
|
|
self.batch_gen = batch_gen
|
|
|
self.batches_per_epoch = batches_per_epoch
|
|
|
self.batch_gen.sampler.rng = np.random.RandomState(seed=random_seed)
|
|
|
self.samples = [
|
|
|
self.batch_gen.sampler(self.batch_gen.batch_size)
|
|
|
for i in range(self.batches_per_epoch)
|
|
|
]
|
|
|
|
|
|
def __len__(self):
|
|
|
return self.batches_per_epoch
|
|
|
|
|
|
def __getitem__(self, ind):
|
|
|
print(self.samples[ind])
|
|
|
return self.batch_gen.batch(samples=self.samples[ind])
|
|
|
|
|
|
|
|
|
class PatchIndex:
|
|
|
IDX_ZERO = -1
|
|
|
IDX_MISSING = -2
|
|
|
|
|
|
def __init__(
|
|
|
self, patch_data, patch_coords, patch_times,
|
|
|
zero_patch_coords, zero_patch_times,
|
|
|
interval=timedelta(minutes=5),
|
|
|
box_size=(4,4), zero_value=0,
|
|
|
missing_value=0
|
|
|
):
|
|
|
self.dt = int(round(interval.total_seconds()))
|
|
|
self.box_size = box_size
|
|
|
self.zero_value = zero_value
|
|
|
self.missing_value = missing_value
|
|
|
self.patch_data = patch_data
|
|
|
self.sample_shape = (
|
|
|
self.patch_data.shape[1]*box_size[0],
|
|
|
self.patch_data.shape[2]*box_size[1]
|
|
|
)
|
|
|
|
|
|
self.patch_index = Dict.empty(
|
|
|
key_type=types.UniTuple(types.int64, 3),
|
|
|
value_type=types.int64
|
|
|
)
|
|
|
init_patch_index(self.patch_index, patch_coords, patch_times)
|
|
|
init_patch_index_zero(self.patch_index, zero_patch_coords,
|
|
|
zero_patch_times, PatchIndex.IDX_ZERO)
|
|
|
|
|
|
self._batch = None
|
|
|
|
|
|
def _alloc_batch(self, batch_size, num_timesteps):
|
|
|
needs_rebuild = (self._batch is None) or \
|
|
|
(self._batch.shape[0] < batch_size) or \
|
|
|
(self._batch.shape[1] < num_timesteps)
|
|
|
if needs_rebuild:
|
|
|
del self._batch
|
|
|
self._batch = np.zeros(
|
|
|
(batch_size,num_timesteps)+self.sample_shape,
|
|
|
self.patch_data.dtype
|
|
|
)
|
|
|
return self._batch
|
|
|
|
|
|
def __call__(self, t, i0_all, j0_all):
|
|
|
batch = self._alloc_batch(*t.shape)
|
|
|
|
|
|
i1_all = i0_all + self.box_size[0]
|
|
|
j1_all = j0_all + self.box_size[1]
|
|
|
bi_size = self.patch_data.shape[1]
|
|
|
bj_size = self.patch_data.shape[2]
|
|
|
|
|
|
build_batch(batch, self.patch_data, self.patch_index,
|
|
|
t, i0_all, i1_all, j0_all, j1_all,
|
|
|
bi_size, bj_size, self.zero_value,
|
|
|
self.missing_value)
|
|
|
|
|
|
return batch[:,:t.shape[1],...]
|
|
|
|
|
|
|
|
|
@njit
|
|
|
def init_patch_index(patch_index, patch_coords, patch_times):
|
|
|
for k in range(patch_coords.shape[0]):
|
|
|
t = patch_times[k]
|
|
|
i = np.int64(patch_coords[k,0])
|
|
|
j = np.int64(patch_coords[k,1])
|
|
|
patch_index[(t,i,j)] = k
|
|
|
|
|
|
|
|
|
@njit
|
|
|
def init_patch_index_zero(patch_index, zero_patch_coords,
|
|
|
zero_patch_times, idx_zero):
|
|
|
|
|
|
for k in range(zero_patch_coords.shape[0]):
|
|
|
t = zero_patch_times[k]
|
|
|
i = np.int64(zero_patch_coords[k,0])
|
|
|
j = np.int64(zero_patch_coords[k,1])
|
|
|
patch_index[(t,i,j)] = idx_zero
|
|
|
|
|
|
|
|
|
|
|
|
IDX_ZERO = PatchIndex.IDX_ZERO
|
|
|
IDX_MISSING = PatchIndex.IDX_MISSING
|
|
|
@njit(parallel=True)
|
|
|
def build_batch(
|
|
|
batch, patch_data, patch_index,
|
|
|
t_all, i0_all, i1_all, j0_all, j1_all,
|
|
|
bi_size, bj_size, zero_value, missing_value
|
|
|
):
|
|
|
for k in prange(t_all.shape[0]):
|
|
|
i0 = i0_all[k]
|
|
|
i1 = i1_all[k]
|
|
|
j0 = j0_all[k]
|
|
|
j1 = j1_all[k]
|
|
|
|
|
|
for (bt,t) in enumerate(t_all[k,:]):
|
|
|
for i in range(i0, i1):
|
|
|
bi0 = (i-i0) * bi_size
|
|
|
bi1 = bi0 + bi_size
|
|
|
for j in range(j0, j1):
|
|
|
ind = int(patch_index.get((t,i,j), IDX_MISSING))
|
|
|
bj0 = (j-j0) * bj_size
|
|
|
bj1 = bj0 + bj_size
|
|
|
if ind >= 0:
|
|
|
batch[k,bt,bi0:bi1,bj0:bj1] = patch_data[ind]
|
|
|
elif ind == IDX_ZERO:
|
|
|
batch[k,bt,bi0:bi1,bj0:bj1] = zero_value
|
|
|
elif ind == IDX_MISSING:
|
|
|
batch[k,bt,bi0:bi1,bj0:bj1] = missing_value
|
|
|
|
|
|
|
|
|
class ForecastPatchIndexWrapper(PatchIndex):
|
|
|
def __init__(self, patch_index):
|
|
|
self.patch_index = patch_index
|
|
|
raw_names = {"-".join(v.split("-")[:-1]) for v in patch_index}
|
|
|
if len(raw_names) != 1:
|
|
|
raise ValueError(
|
|
|
"Can only wrap variables with the same base name")
|
|
|
self.raw_name = list(raw_names)[0]
|
|
|
lags_hour = [int(v.split("-")[-1]) for v in patch_index]
|
|
|
self.lags_hour = set(lags_hour)
|
|
|
forecast_interval_hour = np.diff(sorted(lags_hour))
|
|
|
if len(set(forecast_interval_hour)) != 1:
|
|
|
raise ValueError("Lags must be evenly spaced")
|
|
|
forecast_interval_hour = forecast_interval_hour[0]
|
|
|
if (24 % forecast_interval_hour):
|
|
|
raise ValueError(
|
|
|
"24 hours must be a multiple of the forecast interval")
|
|
|
self.forecast_interval_hour = forecast_interval_hour
|
|
|
self.forecast_interval = 3600 * forecast_interval_hour
|
|
|
|
|
|
|
|
|
self._batch = None
|
|
|
v = list(self.patch_index.keys())[0]
|
|
|
self.sample_shape = self.patch_index[v].sample_shape
|
|
|
self.patch_data = self.patch_index[v].patch_data
|
|
|
|
|
|
def __call__(self, t, i0, j0):
|
|
|
batch = self._alloc_batch(*t.shape)
|
|
|
|
|
|
|
|
|
t0 = t[:,:1]
|
|
|
start_time_from_fc = t0 % self.forecast_interval
|
|
|
time_from_fc = start_time_from_fc + (t - t0)
|
|
|
lags_hour = (time_from_fc // self.forecast_interval) * \
|
|
|
self.forecast_interval_hour
|
|
|
|
|
|
for lag in self.lags_hour:
|
|
|
raw_name_lag = f"{self.raw_name}-{lag}"
|
|
|
batch_lag = self.patch_index[raw_name_lag](t,i0,j0)
|
|
|
lag_mask = (lags_hour == lag)
|
|
|
copy_masked_times(batch_lag, batch, lag_mask)
|
|
|
|
|
|
return batch[:,:t.shape[1],...]
|
|
|
|
|
|
|
|
|
@njit(parallel=True)
|
|
|
def copy_masked_times(from_batch, to_batch, mask):
|
|
|
for k in prange(from_batch.shape[0]):
|
|
|
for bt in range(from_batch.shape[1]):
|
|
|
if mask[k,bt]:
|
|
|
to_batch[k,bt,:,:] = from_batch[k,bt,:,:]
|
|
|
|