File size: 5,560 Bytes
d2f661a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from bisect import bisect_left
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from . import batch
def get_chunks(
primary_raw, valid_frac=0.1, test_frac=0.1,
chunk_seconds=2*24*60*60, random_seed=None
):
t0 = min(
primary_raw["patch_times"][0],
primary_raw["zero_patch_times"][0]
)
t1 = max(
primary_raw["patch_times"][-1],
primary_raw["zero_patch_times"][-1]
)+1
rng = np.random.RandomState(seed=random_seed)
chunk_limits = np.arange(t0,t1,chunk_seconds)
num_chunks = len(chunk_limits)-1
chunk_ind = np.arange(num_chunks)
rng.shuffle(chunk_ind)
i_valid = int(round(num_chunks * valid_frac))
i_test = i_valid + int(round(num_chunks * test_frac))
chunk_ind = {
"valid": chunk_ind[:i_valid],
"test": chunk_ind[i_valid:i_test],
"train": chunk_ind[i_test:]
}
def get_chunk_limits(chunk_ind_split):
return sorted(
(chunk_limits[i], chunk_limits[i+1])
for i in chunk_ind_split
)
chunks = {
split: get_chunk_limits(chunk_ind_split)
for (split, chunk_ind_split) in chunk_ind.items()
}
return chunks
def train_valid_test_split(
raw_data, primary_raw_var, chunks=None, **kwargs
):
if chunks is None:
primary = raw_data[primary_raw_var]
chunks = get_chunks(primary, **kwargs)
def split_chunks_from_array(x, chunks_split, times):
n = 0
chunk_ind = []
for (t0,t1) in chunks_split:
k0 = bisect_left(times, t0)
k1 = bisect_left(times, t1)
n += k1 - k0
chunk_ind.append((k0,k1))
shape = (n,) + x.shape[1:]
x_chunk = np.empty_like(x, shape=shape)
j0 = 0
for (k0,k1) in chunk_ind:
j1 = j0 + (k1-k0)
x_chunk[j0:j1,...] = x[k0:k1,...]
j0 = j1
return x_chunk
split_raw_data = {
split: {var: {} for var in raw_data}
for split in chunks
}
for (var, raw_data_var) in raw_data.items():
for (split, chunks_split) in chunks.items():
#split_raw_data[split][var]["patches"] = \
# split_chunks_from_array(
# raw_data_var["patches"], chunks_split,
# raw_data_var["patch_times"]
# )
#split_raw_data[split][var]["patch_coords"] = \
# split_chunks_from_array(
# raw_data_var["patch_coords"], chunks_split,
# raw_data_var["patch_times"]
# )
#split_raw_data[split][var]["patch_times"] = \
# split_chunks_from_array(
# raw_data_var["patch_times"], chunks_split,
# raw_data_var["patch_times"]
# )
#split_raw_data[split][var]["zero_patch_coords"] = \
# split_chunks_from_array(
# raw_data_var["zero_patch_coords"], chunks_split,
# raw_data_var["zero_patch_times"]
# )
#split_raw_data[split][var]["zero_patch_times"] = \
# split_chunks_from_array(
# raw_data_var["zero_patch_times"], chunks_split,
# raw_data_var["zero_patch_times"]
# )
added_keys = set(split_raw_data[split][var].keys())
missing_keys = set(raw_data[var].keys()) - added_keys
for k in missing_keys:
split_raw_data[split][var][k] = raw_data[var][k]
return (split_raw_data, chunks)
class DataModule(pl.LightningDataModule):
def __init__(
self,
variables, raw, predictors, target, primary_var,
sampling_bins, sampler_file,
batch_size=8,
train_epoch_size=10, valid_epoch_size=2, test_epoch_size=10,
valid_seed=None, test_seed=None,
**kwargs
):
super().__init__()
self.batch_gen = {
split: batch.BatchGenerator(
variables, raw_var, predictors, target, primary_var,
sampling_bins=sampling_bins, batch_size=batch_size,
sampler_file=sampler_file.get(split),
augment=(split=="train"),
**kwargs
)
for (split,raw_var) in raw.items()
}
self.datasets = {}
if "train" in self.batch_gen:
self.datasets["train"] = batch.StreamBatchDataset(
self.batch_gen["train"], train_epoch_size
)
if "valid" in self.batch_gen:
self.datasets["valid"] = batch.DeterministicBatchDataset(
self.batch_gen["valid"], valid_epoch_size, random_seed=valid_seed
)
if "test" in self.batch_gen:
self.datasets["test"] = batch.DeterministicBatchDataset(
self.batch_gen["test"], test_epoch_size, random_seed=test_seed
)
def dataloader(self, split):
return DataLoader(
self.datasets[split], batch_size=None,
pin_memory=True, num_workers=0
)
def train_dataloader(self):
return self.dataloader("train")
def val_dataloader(self):
return self.dataloader("valid")
def test_dataloader(self):
return self.dataloader("test")
|