|
|
from bisect import bisect_left
|
|
|
|
|
|
import numpy as np
|
|
|
import pytorch_lightning as pl
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
from . import batch
|
|
|
|
|
|
|
|
|
def get_chunks(
|
|
|
primary_raw, valid_frac=0.1, test_frac=0.1,
|
|
|
chunk_seconds=2*24*60*60, random_seed=None
|
|
|
):
|
|
|
t0 = min(
|
|
|
primary_raw["patch_times"][0],
|
|
|
primary_raw["zero_patch_times"][0]
|
|
|
)
|
|
|
t1 = max(
|
|
|
primary_raw["patch_times"][-1],
|
|
|
primary_raw["zero_patch_times"][-1]
|
|
|
)+1
|
|
|
|
|
|
rng = np.random.RandomState(seed=random_seed)
|
|
|
chunk_limits = np.arange(t0,t1,chunk_seconds)
|
|
|
num_chunks = len(chunk_limits)-1
|
|
|
|
|
|
chunk_ind = np.arange(num_chunks)
|
|
|
rng.shuffle(chunk_ind)
|
|
|
i_valid = int(round(num_chunks * valid_frac))
|
|
|
i_test = i_valid + int(round(num_chunks * test_frac))
|
|
|
chunk_ind = {
|
|
|
"valid": chunk_ind[:i_valid],
|
|
|
"test": chunk_ind[i_valid:i_test],
|
|
|
"train": chunk_ind[i_test:]
|
|
|
}
|
|
|
def get_chunk_limits(chunk_ind_split):
|
|
|
return sorted(
|
|
|
(chunk_limits[i], chunk_limits[i+1])
|
|
|
for i in chunk_ind_split
|
|
|
)
|
|
|
chunks = {
|
|
|
split: get_chunk_limits(chunk_ind_split)
|
|
|
for (split, chunk_ind_split) in chunk_ind.items()
|
|
|
}
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
def train_valid_test_split(
|
|
|
raw_data, primary_raw_var, chunks=None, **kwargs
|
|
|
):
|
|
|
if chunks is None:
|
|
|
primary = raw_data[primary_raw_var]
|
|
|
chunks = get_chunks(primary, **kwargs)
|
|
|
|
|
|
def split_chunks_from_array(x, chunks_split, times):
|
|
|
n = 0
|
|
|
chunk_ind = []
|
|
|
for (t0,t1) in chunks_split:
|
|
|
k0 = bisect_left(times, t0)
|
|
|
k1 = bisect_left(times, t1)
|
|
|
n += k1 - k0
|
|
|
chunk_ind.append((k0,k1))
|
|
|
|
|
|
shape = (n,) + x.shape[1:]
|
|
|
x_chunk = np.empty_like(x, shape=shape)
|
|
|
|
|
|
j0 = 0
|
|
|
for (k0,k1) in chunk_ind:
|
|
|
j1 = j0 + (k1-k0)
|
|
|
x_chunk[j0:j1,...] = x[k0:k1,...]
|
|
|
j0 = j1
|
|
|
|
|
|
return x_chunk
|
|
|
|
|
|
split_raw_data = {
|
|
|
split: {var: {} for var in raw_data}
|
|
|
for split in chunks
|
|
|
}
|
|
|
|
|
|
for (var, raw_data_var) in raw_data.items():
|
|
|
for (split, chunks_split) in chunks.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
added_keys = set(split_raw_data[split][var].keys())
|
|
|
missing_keys = set(raw_data[var].keys()) - added_keys
|
|
|
for k in missing_keys:
|
|
|
split_raw_data[split][var][k] = raw_data[var][k]
|
|
|
|
|
|
return (split_raw_data, chunks)
|
|
|
|
|
|
|
|
|
class DataModule(pl.LightningDataModule):
|
|
|
def __init__(
|
|
|
self,
|
|
|
variables, raw, predictors, target, primary_var,
|
|
|
sampling_bins, sampler_file,
|
|
|
batch_size=8,
|
|
|
train_epoch_size=10, valid_epoch_size=2, test_epoch_size=10,
|
|
|
valid_seed=None, test_seed=None,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.batch_gen = {
|
|
|
split: batch.BatchGenerator(
|
|
|
variables, raw_var, predictors, target, primary_var,
|
|
|
sampling_bins=sampling_bins, batch_size=batch_size,
|
|
|
sampler_file=sampler_file.get(split),
|
|
|
augment=(split=="train"),
|
|
|
**kwargs
|
|
|
)
|
|
|
for (split,raw_var) in raw.items()
|
|
|
}
|
|
|
self.datasets = {}
|
|
|
if "train" in self.batch_gen:
|
|
|
self.datasets["train"] = batch.StreamBatchDataset(
|
|
|
self.batch_gen["train"], train_epoch_size
|
|
|
)
|
|
|
if "valid" in self.batch_gen:
|
|
|
self.datasets["valid"] = batch.DeterministicBatchDataset(
|
|
|
self.batch_gen["valid"], valid_epoch_size, random_seed=valid_seed
|
|
|
)
|
|
|
if "test" in self.batch_gen:
|
|
|
self.datasets["test"] = batch.DeterministicBatchDataset(
|
|
|
self.batch_gen["test"], test_epoch_size, random_seed=test_seed
|
|
|
)
|
|
|
|
|
|
def dataloader(self, split):
|
|
|
return DataLoader(
|
|
|
self.datasets[split], batch_size=None,
|
|
|
pin_memory=True, num_workers=0
|
|
|
)
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
return self.dataloader("train")
|
|
|
|
|
|
def val_dataloader(self):
|
|
|
return self.dataloader("valid")
|
|
|
|
|
|
def test_dataloader(self):
|
|
|
return self.dataloader("test")
|
|
|
|