weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
from bisect import bisect_left
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from . import batch
def get_chunks(
primary_raw, valid_frac=0.1, test_frac=0.1,
chunk_seconds=2*24*60*60, random_seed=None
):
t0 = min(
primary_raw["patch_times"][0],
primary_raw["zero_patch_times"][0]
)
t1 = max(
primary_raw["patch_times"][-1],
primary_raw["zero_patch_times"][-1]
)+1
rng = np.random.RandomState(seed=random_seed)
chunk_limits = np.arange(t0,t1,chunk_seconds)
num_chunks = len(chunk_limits)-1
chunk_ind = np.arange(num_chunks)
rng.shuffle(chunk_ind)
i_valid = int(round(num_chunks * valid_frac))
i_test = i_valid + int(round(num_chunks * test_frac))
chunk_ind = {
"valid": chunk_ind[:i_valid],
"test": chunk_ind[i_valid:i_test],
"train": chunk_ind[i_test:]
}
def get_chunk_limits(chunk_ind_split):
return sorted(
(chunk_limits[i], chunk_limits[i+1])
for i in chunk_ind_split
)
chunks = {
split: get_chunk_limits(chunk_ind_split)
for (split, chunk_ind_split) in chunk_ind.items()
}
return chunks
def train_valid_test_split(
raw_data, primary_raw_var, chunks=None, **kwargs
):
if chunks is None:
primary = raw_data[primary_raw_var]
chunks = get_chunks(primary, **kwargs)
def split_chunks_from_array(x, chunks_split, times):
n = 0
chunk_ind = []
for (t0,t1) in chunks_split:
k0 = bisect_left(times, t0)
k1 = bisect_left(times, t1)
n += k1 - k0
chunk_ind.append((k0,k1))
shape = (n,) + x.shape[1:]
x_chunk = np.empty_like(x, shape=shape)
j0 = 0
for (k0,k1) in chunk_ind:
j1 = j0 + (k1-k0)
x_chunk[j0:j1,...] = x[k0:k1,...]
j0 = j1
return x_chunk
split_raw_data = {
split: {var: {} for var in raw_data}
for split in chunks
}
for (var, raw_data_var) in raw_data.items():
for (split, chunks_split) in chunks.items():
#split_raw_data[split][var]["patches"] = \
# split_chunks_from_array(
# raw_data_var["patches"], chunks_split,
# raw_data_var["patch_times"]
# )
#split_raw_data[split][var]["patch_coords"] = \
# split_chunks_from_array(
# raw_data_var["patch_coords"], chunks_split,
# raw_data_var["patch_times"]
# )
#split_raw_data[split][var]["patch_times"] = \
# split_chunks_from_array(
# raw_data_var["patch_times"], chunks_split,
# raw_data_var["patch_times"]
# )
#split_raw_data[split][var]["zero_patch_coords"] = \
# split_chunks_from_array(
# raw_data_var["zero_patch_coords"], chunks_split,
# raw_data_var["zero_patch_times"]
# )
#split_raw_data[split][var]["zero_patch_times"] = \
# split_chunks_from_array(
# raw_data_var["zero_patch_times"], chunks_split,
# raw_data_var["zero_patch_times"]
# )
added_keys = set(split_raw_data[split][var].keys())
missing_keys = set(raw_data[var].keys()) - added_keys
for k in missing_keys:
split_raw_data[split][var][k] = raw_data[var][k]
return (split_raw_data, chunks)
class DataModule(pl.LightningDataModule):
def __init__(
self,
variables, raw, predictors, target, primary_var,
sampling_bins, sampler_file,
batch_size=8,
train_epoch_size=10, valid_epoch_size=2, test_epoch_size=10,
valid_seed=None, test_seed=None,
**kwargs
):
super().__init__()
self.batch_gen = {
split: batch.BatchGenerator(
variables, raw_var, predictors, target, primary_var,
sampling_bins=sampling_bins, batch_size=batch_size,
sampler_file=sampler_file.get(split),
augment=(split=="train"),
**kwargs
)
for (split,raw_var) in raw.items()
}
self.datasets = {}
if "train" in self.batch_gen:
self.datasets["train"] = batch.StreamBatchDataset(
self.batch_gen["train"], train_epoch_size
)
if "valid" in self.batch_gen:
self.datasets["valid"] = batch.DeterministicBatchDataset(
self.batch_gen["valid"], valid_epoch_size, random_seed=valid_seed
)
if "test" in self.batch_gen:
self.datasets["test"] = batch.DeterministicBatchDataset(
self.batch_gen["test"], test_epoch_size, random_seed=test_seed
)
def dataloader(self, split):
return DataLoader(
self.datasets[split], batch_size=None,
pin_memory=True, num_workers=0
)
def train_dataloader(self):
return self.dataloader("train")
def val_dataloader(self):
return self.dataloader("valid")
def test_dataloader(self):
return self.dataloader("test")