ldcast_code / ldcast /features /batch.py.save
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
from datetime import datetime, timedelta
import os
import pickle
from numba import njit, prange, types
from numba.typed import Dict
import numpy as np
from torch.utils.data import Dataset, IterableDataset
from .patches import unpack_patches
from .sampling import EqualFrequencySampler
class BatchGenerator:
def __init__(self,
variables,
raw,
predictors,
target,
primary_var,
time_range_sampling=(-1,2),
forecast_raw_vars=(),
sampling_bins=None,
sampler_file=None,
sample_shape=(4,4),
batch_size=32,
interval=timedelta(minutes=5),
random_seed=None,
augment=False
):
super().__init__()
self.batch_size = batch_size
self.interval = interval
self.interval_secs = np.int64(self.interval.total_seconds())
self.variables = variables
self.predictors = predictors
self.target = target
self.used_variables = predictors + [target]
self.rng = np.random.RandomState(seed=random_seed)
self.augment = augment
# setup indices for retrieving source raw data
self.sources = set.union(
*(set(variables[v]["sources"]) for v in self.used_variables)
)
self.forecast_raw_vars = set(forecast_raw_vars) & self.sources
self.patch_index = {}
for raw_name_base in self.sources:
if raw_name_base in forecast_raw_vars:
raw_names = (
rn for rn in raw if rn.startswith(raw_name_base+"-")
)
else:
raw_names = (raw_name_base,)
for raw_name in raw_names:
raw_data = raw[raw_name]
self.setup_index(raw_name, raw_data, sample_shape)
for raw_name in self.forecast_raw_vars:
patch_index_var = {
k: v for (k,v) in self.patch_index.items()
if k.startswith(raw_name+"-")
}
self.patch_index[raw_name] = \
ForecastPatchIndexWrapper(patch_index_var)
# setup samplers
if (sampler_file is None) or not os.path.isfile(sampler_file):
print("No cached sampler found, creating a new one...")
primary_raw_var = variables[primary_var]["sources"][0]
t0 = t1 = None
for (var_name, var_data) in variables.items():
timesteps = var_data["timesteps"][[0,-1]].copy()
timesteps[0] -= 1
ts_secs = timesteps * \
var_data.get("timestep_secs", self.interval_secs)
timesteps = ts_secs // self.interval_secs
t0 = timesteps[0] if t0 is None else min(t0,timesteps[0])
t1 = timesteps[-1] if t1 is None else max(t1,timesteps[-1])
time_range_valid = (t0,t1+1)
self.sampler = EqualFrequencySampler(
sampling_bins, raw[primary_raw_var],
self.patch_index[primary_raw_var], sample_shape,
time_range_valid, time_range_sampling=time_range_sampling,
timestep_secs=self.interval_secs
)
if sampler_file is not None:
print(f"Caching sampler to {sampler_file}.")
with open(sampler_file, 'wb') as f:
pickle.dump(self.sampler, f)
else:
print(f"Loading cached sampler from {sampler_file}.")
with open(sampler_file, 'rb') as f:
self.sampler = pickle.load(f)
def setup_index(self, raw_name, raw_data, box_size):
zero_value = raw_data.get("zero_value", 0)
missing_value = raw_data.get("missing_value", zero_value)
self.patch_index[raw_name] = PatchIndex(
*unpack_patches(raw_data),
zero_value=zero_value,
missing_value=missing_value,
interval=self.interval,
box_size=box_size
)
def augmentations(self):
return tuple(self.rng.randint(2, size=3))
def augment_batch(self, batch, transpose, flipud, fliplr):
if self.augment:
if transpose:
axes = list(range(batch.ndim))
axes = axes[:-2] + [axes[-1], axes[-2]]
batch = batch.transpose(axes)
flips = []
if flipud:
flips.append(-2)
if fliplr:
flips.append(-1)
if flips:
batch = np.flip(batch, axis=flips)
return batch.copy()
def batch(self, samples=None, batch_size=None):
if batch_size is None:
batch_size = self.batch_size
if samples is None:
# get the sample coordinates from the sampler
samples = self.sampler(batch_size)
print(samples)
(t0,i0,j0) = samples.T
if self.augment:
augmentations = self.augmentations()
batch = {}
for var_name in self.used_variables:
var_data = self.variables[var_name]
# different timestep from standard (e.g. forecast); round down
# to times where we have data available
ts_secs = var_data.get("timestep_secs", self.interval_secs)
t_shift = -(t0 % ts_secs)
t0_shifted = t0 + t_shift
t = t0_shifted[:,None] + ts_secs*var_data["timesteps"][None,:]
t_relative = (t - t0[:,None]) / self.interval_secs
# read raw data from index
raw_data = (
self.patch_index[raw_name](t,i0,j0)
for raw_name in var_data["sources"]
)
# transform to model variable
batch_var = var_data["transform"](*raw_data)
# add channel dimension if not already present
add_dims = (1,) if batch_var.ndim == 4 else ()
batch_var = np.expand_dims(batch_var, add_dims)
# data augmentation
if self.augment:
batch_var = self.augment_batch(batch_var, *augmentations)
# bundle with time coordinates
batch[var_name] = (batch_var, t_relative.astype(np.float32))
pred_batch = [batch[v] for v in self.predictors]
target_batch = batch[self.target][0] # no time coordinates for target
return (pred_batch, target_batch)
def batches(self, *args, num=None, **kwargs):
if num is not None:
for i in range(num):
yield self.batch(*args, **kwargs)
else:
while True:
yield self.batch(*args, **kwargs)
class StreamBatchDataset(IterableDataset):
def __init__(self, batch_gen, batches_per_epoch):
super().__init__()
self.batch_gen = batch_gen
self.batches_per_epoch = batches_per_epoch
def __iter__(self):
batches = self.batch_gen.batches(num=self.batches_per_epoch)
yield from batches
class DeterministicBatchDataset(Dataset):
def __init__(self, batch_gen, batches_per_epoch, random_seed=None):
super().__init__()
self.batch_gen = batch_gen
self.batches_per_epoch = batches_per_epoch
self.batch_gen.sampler.rng = np.random.RandomState(seed=random_seed)
self.samples = [
self.batch_gen.sampler(self.batch_gen.batch_size)
for i in range(self.batches_per_epoch)
]
def __len__(self):
return self.batches_per_epoch
def __getitem__(self, ind):
print(self.samples[ind])
return self.batch_gen.batch(samples=self.samples[ind])
class PatchIndex:
IDX_ZERO = -1
IDX_MISSING = -2
def __init__(
self, patch_data, patch_coords, patch_times,
zero_patch_coords, zero_patch_times,
interval=timedelta(minutes=5),
box_size=(4,4), zero_value=0,
missing_value=0
):
self.dt = int(round(interval.total_seconds()))
self.box_size = box_size
self.zero_value = zero_value
self.missing_value = missing_value
self.patch_data = patch_data
self.sample_shape = (
self.patch_data.shape[1]*box_size[0],
self.patch_data.shape[2]*box_size[1]
)
self.patch_index = Dict.empty(
key_type=types.UniTuple(types.int64, 3),
value_type=types.int64
)
init_patch_index(self.patch_index, patch_coords, patch_times)
init_patch_index_zero(self.patch_index, zero_patch_coords,
zero_patch_times, PatchIndex.IDX_ZERO)
self._batch = None
def _alloc_batch(self, batch_size, num_timesteps):
needs_rebuild = (self._batch is None) or \
(self._batch.shape[0] < batch_size) or \
(self._batch.shape[1] < num_timesteps)
if needs_rebuild:
del self._batch
self._batch = np.zeros(
(batch_size,num_timesteps)+self.sample_shape,
self.patch_data.dtype
)
return self._batch
def __call__(self, t, i0_all, j0_all):
batch = self._alloc_batch(*t.shape)
i1_all = i0_all + self.box_size[0]
j1_all = j0_all + self.box_size[1]
bi_size = self.patch_data.shape[1]
bj_size = self.patch_data.shape[2]
build_batch(batch, self.patch_data, self.patch_index,
t, i0_all, i1_all, j0_all, j1_all,
bi_size, bj_size, self.zero_value,
self.missing_value)
return batch[:,:t.shape[1],...]
@njit
def init_patch_index(patch_index, patch_coords, patch_times):
for k in range(patch_coords.shape[0]):
t = patch_times[k]
i = np.int64(patch_coords[k,0])
j = np.int64(patch_coords[k,1])
patch_index[(t,i,j)] = k
@njit
def init_patch_index_zero(patch_index, zero_patch_coords,
zero_patch_times, idx_zero):
for k in range(zero_patch_coords.shape[0]):
t = zero_patch_times[k]
i = np.int64(zero_patch_coords[k,0])
j = np.int64(zero_patch_coords[k,1])
patch_index[(t,i,j)] = idx_zero
# numba can't find these values from PatchIndex
IDX_ZERO = PatchIndex.IDX_ZERO
IDX_MISSING = PatchIndex.IDX_MISSING
@njit(parallel=True)
def build_batch(
batch, patch_data, patch_index,
t_all, i0_all, i1_all, j0_all, j1_all,
bi_size, bj_size, zero_value, missing_value
):
for k in prange(t_all.shape[0]):
i0 = i0_all[k]
i1 = i1_all[k]
j0 = j0_all[k]
j1 = j1_all[k]
for (bt,t) in enumerate(t_all[k,:]):
for i in range(i0, i1):
bi0 = (i-i0) * bi_size
bi1 = bi0 + bi_size
for j in range(j0, j1):
ind = int(patch_index.get((t,i,j), IDX_MISSING))
bj0 = (j-j0) * bj_size
bj1 = bj0 + bj_size
if ind >= 0:
batch[k,bt,bi0:bi1,bj0:bj1] = patch_data[ind]
elif ind == IDX_ZERO:
batch[k,bt,bi0:bi1,bj0:bj1] = zero_value
elif ind == IDX_MISSING:
batch[k,bt,bi0:bi1,bj0:bj1] = missing_value
class ForecastPatchIndexWrapper(PatchIndex):
def __init__(self, patch_index):
self.patch_index = patch_index
raw_names = {"-".join(v.split("-")[:-1]) for v in patch_index}
if len(raw_names) != 1:
raise ValueError(
"Can only wrap variables with the same base name")
self.raw_name = list(raw_names)[0]
lags_hour = [int(v.split("-")[-1]) for v in patch_index]
self.lags_hour = set(lags_hour)
forecast_interval_hour = np.diff(sorted(lags_hour))
if len(set(forecast_interval_hour)) != 1:
raise ValueError("Lags must be evenly spaced")
forecast_interval_hour = forecast_interval_hour[0]
if (24 % forecast_interval_hour):
raise ValueError(
"24 hours must be a multiple of the forecast interval")
self.forecast_interval_hour = forecast_interval_hour
self.forecast_interval = 3600 * forecast_interval_hour
# need to set these for _alloc_batch to work
self._batch = None
v = list(self.patch_index.keys())[0]
self.sample_shape = self.patch_index[v].sample_shape
self.patch_data = self.patch_index[v].patch_data
def __call__(self, t, i0, j0):
batch = self._alloc_batch(*t.shape)
# ensure that all data come from the same forecast
t0 = t[:,:1]
start_time_from_fc = t0 % self.forecast_interval
time_from_fc = start_time_from_fc + (t - t0)
lags_hour = (time_from_fc // self.forecast_interval) * \
self.forecast_interval_hour
for lag in self.lags_hour:
raw_name_lag = f"{self.raw_name}-{lag}"
batch_lag = self.patch_index[raw_name_lag](t,i0,j0)
lag_mask = (lags_hour == lag)
copy_masked_times(batch_lag, batch, lag_mask)
return batch[:,:t.shape[1],...]
@njit(parallel=True)
def copy_masked_times(from_batch, to_batch, mask):
for k in prange(from_batch.shape[0]):
for bt in range(from_batch.shape[1]):
if mask[k,bt]:
to_batch[k,bt,:,:] = from_batch[k,bt,:,:]