| | import logging |
| | import math |
| | from collections import Counter |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from joblib import Parallel, delayed |
| | from torch.quasirandom import SobolEngine |
| |
|
| | from src.gift_eval.data import Dataset |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def find_consecutive_nan_lengths(series: np.ndarray) -> list[int]: |
| | """Finds the lengths of all consecutive NaN blocks in a 1D array.""" |
| | if series.ndim > 1: |
| | |
| | series = series.flatten() |
| |
|
| | is_nan = np.isnan(series) |
| | padded_is_nan = np.concatenate(([False], is_nan, [False])) |
| | diffs = np.diff(padded_is_nan.astype(int)) |
| |
|
| | start_indices = np.where(diffs == 1)[0] |
| | end_indices = np.where(diffs == -1)[0] |
| |
|
| | return (end_indices - start_indices).tolist() |
| |
|
| |
|
| | def analyze_datasets_for_augmentation(gift_eval_path_str: str) -> dict: |
| | """ |
| | Analyzes all datasets to derive statistics needed for NaN augmentation. |
| | This version collects the full distribution of NaN ratios. |
| | """ |
| | logger.info("--- Starting Dataset Analysis for Augmentation (Full Distribution) ---") |
| | path = Path(gift_eval_path_str) |
| | if not path.exists(): |
| | raise FileNotFoundError( |
| | f"Provided raw data path for augmentation analysis does not exist: {gift_eval_path_str}" |
| | ) |
| |
|
| | dataset_names = [] |
| | for dataset_dir in path.iterdir(): |
| | if dataset_dir.name.startswith(".") or not dataset_dir.is_dir(): |
| | continue |
| | freq_dirs = [d for d in dataset_dir.iterdir() if d.is_dir()] |
| | if freq_dirs: |
| | for freq_dir in freq_dirs: |
| | dataset_names.append(f"{dataset_dir.name}/{freq_dir.name}") |
| | else: |
| | dataset_names.append(dataset_dir.name) |
| |
|
| | total_series_count = 0 |
| | series_with_nans_count = 0 |
| | nan_ratio_distribution = [] |
| | all_consecutive_nan_lengths = Counter() |
| |
|
| | for ds_name in sorted(dataset_names): |
| | try: |
| | ds = Dataset(name=ds_name, term="short", to_univariate=False) |
| | for series_data in ds.training_dataset: |
| | total_series_count += 1 |
| | target = np.atleast_1d(series_data["target"]) |
| | num_nans = np.isnan(target).sum() |
| |
|
| | if num_nans > 0: |
| | series_with_nans_count += 1 |
| | nan_ratio = num_nans / target.size |
| | nan_ratio_distribution.append(float(nan_ratio)) |
| |
|
| | nan_lengths = find_consecutive_nan_lengths(target) |
| | all_consecutive_nan_lengths.update(nan_lengths) |
| | except Exception as e: |
| | logger.warning(f"Could not process {ds_name} for augmentation analysis: {e}") |
| |
|
| | if total_series_count == 0: |
| | raise ValueError("No series were found during augmentation analysis. Check dataset path.") |
| |
|
| | p_series_has_nan = series_with_nans_count / total_series_count if total_series_count > 0 else 0 |
| |
|
| | logger.info("--- Augmentation Analysis Complete ---") |
| | |
| | logger.info(f"Total series analyzed: {total_series_count}") |
| | logger.info(f"Series with NaNs: {series_with_nans_count} ({p_series_has_nan:.4f})") |
| | logger.info(f"NaN ratio distribution: {Counter(nan_ratio_distribution)}") |
| | logger.info(f"Consecutive NaN lengths distribution: {all_consecutive_nan_lengths}") |
| | logger.info("--- End of Dataset Analysis for Augmentation ---") |
| | return { |
| | "p_series_has_nan": p_series_has_nan, |
| | "nan_ratio_distribution": nan_ratio_distribution, |
| | "nan_length_distribution": all_consecutive_nan_lengths, |
| | } |
| |
|
| |
|
| | class NanAugmenter: |
| | """ |
| | Applies realistic NaN augmentation by generating and caching NaN patterns on-demand |
| | during the first transform call for a given data shape. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | p_series_has_nan: float, |
| | nan_ratio_distribution: list[float], |
| | nan_length_distribution: Counter, |
| | num_patterns: int = 100000, |
| | n_jobs: int = -1, |
| | nan_patterns_path: str | None = None, |
| | ): |
| | """ |
| | Initializes the augmenter. NaN patterns are not generated at this stage. |
| | |
| | Args: |
| | p_series_has_nan (float): Probability that a series in a batch will be augmented. |
| | nan_ratio_distribution (List[float]): A list of NaN ratios observed in the dataset. |
| | nan_length_distribution (Counter): A Counter of consecutive NaN block lengths. |
| | num_patterns (int): The number of unique NaN patterns to generate per data shape. |
| | n_jobs (int): The number of CPU cores to use for parallel pattern generation (-1 for all cores). |
| | """ |
| | self.p_series_has_nan = p_series_has_nan |
| | self.nan_ratio_distribution = nan_ratio_distribution |
| | self.num_patterns = num_patterns |
| | self.n_jobs = n_jobs |
| | self.max_length = 2048 |
| | self.nan_patterns_path = nan_patterns_path |
| | |
| | self.pattern_cache: dict[tuple[int, ...], torch.BoolTensor] = {} |
| |
|
| | if not nan_length_distribution or sum(nan_length_distribution.values()) == 0: |
| | self._has_block_distribution = False |
| | logger.warning("NaN length distribution is empty. Augmentation disabled.") |
| | else: |
| | self._has_block_distribution = True |
| | total_blocks = sum(nan_length_distribution.values()) |
| | self.dist_lengths = [int(i) for i in nan_length_distribution.keys()] |
| | self.dist_probs = [count / total_blocks for count in nan_length_distribution.values()] |
| |
|
| | if not self.nan_ratio_distribution: |
| | logger.warning("NaN ratio distribution is empty. Augmentation disabled.") |
| |
|
| | |
| | self._load_existing_patterns() |
| |
|
| | def _load_existing_patterns(self): |
| | """Load existing NaN patterns from disk if they exist.""" |
| | |
| | explicit_path: Path | None = ( |
| | Path(self.nan_patterns_path).resolve() if self.nan_patterns_path is not None else None |
| | ) |
| |
|
| | candidate_files: list[Path] = [] |
| | if explicit_path is not None: |
| | |
| | if explicit_path.is_file(): |
| | candidate_files.append(explicit_path) |
| | |
| | explicit_dir = explicit_path.parent |
| | explicit_dir.mkdir(exist_ok=True, parents=True) |
| | candidate_files.extend(list(explicit_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))) |
| | else: |
| | |
| | data_dir = Path("data") |
| | data_dir.mkdir(exist_ok=True) |
| | candidate_files.extend(list(data_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))) |
| |
|
| | |
| | seen: set[str] = set() |
| | unique_candidates: list[Path] = [] |
| | for f in candidate_files: |
| | key = str(f.resolve()) |
| | if key not in seen: |
| | seen.add(key) |
| | unique_candidates.append(f) |
| |
|
| | for pattern_file in unique_candidates: |
| | try: |
| | |
| | filename = pattern_file.stem |
| | parts = filename.split("_") |
| | if len(parts) >= 4: |
| | num_channels = int(parts[-1]) |
| |
|
| | |
| | patterns = torch.load(pattern_file, map_location="cpu") |
| | cache_key = (self.max_length, num_channels) |
| | self.pattern_cache[cache_key] = patterns |
| |
|
| | logger.info(f"Loaded {patterns.shape[0]} patterns for shape {cache_key} from {pattern_file}") |
| | except (ValueError, RuntimeError, FileNotFoundError) as e: |
| | logger.warning(f"Failed to load patterns from {pattern_file}: {e}") |
| |
|
| | def _get_pattern_file_path(self, num_channels: int) -> Path: |
| | """Resolve the target file path for storing/loading patterns for a given channel count.""" |
| | |
| | if self.nan_patterns_path is not None: |
| | base_dir = Path(self.nan_patterns_path).resolve().parent |
| | base_dir.mkdir(exist_ok=True, parents=True) |
| | else: |
| | base_dir = Path("data").resolve() |
| | base_dir.mkdir(exist_ok=True, parents=True) |
| |
|
| | return base_dir / f"nan_patterns_{self.max_length}_{num_channels}.pt" |
| |
|
| | def _generate_nan_mask(self, series_shape: tuple[int, ...]) -> np.ndarray: |
| | """Generates a single boolean NaN mask for a given series shape.""" |
| | series_size = int(np.prod(series_shape)) |
| | sampled_ratio = np.random.choice(self.nan_ratio_distribution) |
| | n_nans_to_add = int(round(series_size * sampled_ratio)) |
| |
|
| | if n_nans_to_add == 0: |
| | return np.zeros(series_shape, dtype=bool) |
| |
|
| | mask_flat = np.zeros(series_size, dtype=bool) |
| | nans_added = 0 |
| | max_attempts = n_nans_to_add * 2 |
| | attempts = 0 |
| | while nans_added < n_nans_to_add and attempts < max_attempts: |
| | attempts += 1 |
| | block_length = np.random.choice(self.dist_lengths, p=self.dist_probs) |
| |
|
| | if nans_added + block_length > n_nans_to_add: |
| | block_length = n_nans_to_add - nans_added |
| | if block_length <= 0: |
| | break |
| |
|
| | nan_counts_in_window = np.convolve(mask_flat, np.ones(block_length), mode="valid") |
| | valid_starts = np.where(nan_counts_in_window == 0)[0] |
| |
|
| | if valid_starts.size == 0: |
| | continue |
| |
|
| | start_pos = np.random.choice(valid_starts) |
| | mask_flat[start_pos : start_pos + block_length] = True |
| | nans_added += block_length |
| |
|
| | return mask_flat.reshape(series_shape) |
| |
|
| | def _pregenerate_patterns(self, series_shape: tuple[int, ...]) -> torch.BoolTensor: |
| | """Uses joblib to parallelize the generation of NaN masks for a given shape.""" |
| | if not self._has_block_distribution or not self.nan_ratio_distribution: |
| | return torch.empty(0, *series_shape, dtype=torch.bool) |
| |
|
| | logger.info(f"Generating {self.num_patterns} NaN patterns for shape {series_shape}...") |
| |
|
| | with Parallel(n_jobs=self.n_jobs, backend="loky") as parallel: |
| | masks_list = parallel(delayed(self._generate_nan_mask)(series_shape) for _ in range(self.num_patterns)) |
| |
|
| | logger.info(f"Pattern generation complete for shape {series_shape}.") |
| | return torch.from_numpy(np.stack(masks_list)).bool() |
| |
|
| | def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Applies NaN patterns to a batch, generating them on-demand if the shape is new. |
| | """ |
| | if self.p_series_has_nan == 0: |
| | return time_series_batch |
| |
|
| | history_length, num_channels = time_series_batch.shape[1:] |
| | assert history_length <= self.max_length, ( |
| | f"History length {history_length} exceeds maximum allowed {self.max_length}." |
| | ) |
| |
|
| | |
| | if ( |
| | self.max_length, |
| | num_channels, |
| | ) not in self.pattern_cache: |
| | |
| | target_file = self._get_pattern_file_path(num_channels) |
| | if target_file.exists(): |
| | try: |
| | patterns = torch.load(target_file, map_location="cpu") |
| | self.pattern_cache[(self.max_length, num_channels)] = patterns |
| | logger.info(f"Loaded NaN patterns from {target_file} for shape {(self.max_length, num_channels)}") |
| | except (RuntimeError, FileNotFoundError): |
| | |
| | patterns = self._pregenerate_patterns((self.max_length, num_channels)) |
| | torch.save(patterns, target_file) |
| | self.pattern_cache[(self.max_length, num_channels)] = patterns |
| | logger.info(f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}") |
| | else: |
| | patterns = self._pregenerate_patterns((self.max_length, num_channels)) |
| | torch.save(patterns, target_file) |
| | self.pattern_cache[(self.max_length, num_channels)] = patterns |
| | logger.info(f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}") |
| | patterns = self.pattern_cache[(self.max_length, num_channels)][:, :history_length, :] |
| |
|
| | |
| | if patterns.numel() == 0: |
| | return time_series_batch |
| |
|
| | batch_size = time_series_batch.shape[0] |
| | device = time_series_batch.device |
| |
|
| | |
| | augment_mask = torch.rand(batch_size, device=device) < self.p_series_has_nan |
| | indices_to_augment = torch.where(augment_mask)[0] |
| | num_to_augment = indices_to_augment.numel() |
| |
|
| | if num_to_augment == 0: |
| | return time_series_batch |
| |
|
| | |
| | pattern_indices = torch.randint(0, patterns.shape[0], (num_to_augment,), device=device) |
| | |
| | selected_patterns = patterns[pattern_indices].to(device) |
| |
|
| | time_series_batch[indices_to_augment] = time_series_batch[indices_to_augment].masked_fill( |
| | selected_patterns, float("nan") |
| | ) |
| |
|
| | return time_series_batch |
| |
|
| |
|
| | class CensorAugmenter: |
| | """ |
| | Applies censor augmentation by clipping values from above, below, or both. |
| | """ |
| |
|
| | def __init__(self): |
| | """Initializes the CensorAugmenter.""" |
| | pass |
| |
|
| | def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Applies a vectorized censor augmentation to a batch of time series. |
| | """ |
| | batch_size, seq_len, num_channels = time_series_batch.shape |
| | assert num_channels == 1 |
| | time_series_batch = time_series_batch.squeeze(-1) |
| | with torch.no_grad(): |
| | batch_size, seq_len = time_series_batch.shape |
| | device = time_series_batch.device |
| |
|
| | |
| | op_mode = torch.randint(0, 3, (batch_size, 1), device=device) |
| |
|
| | |
| | q1 = torch.rand(batch_size, device=device) |
| | q2 = torch.rand(batch_size, device=device) |
| | q_low = torch.minimum(q1, q2) |
| | q_high = torch.maximum(q1, q2) |
| |
|
| | sorted_series = torch.sort(time_series_batch, dim=1).values |
| | indices_low = (q_low * (seq_len - 1)).long() |
| | indices_high = (q_high * (seq_len - 1)).long() |
| |
|
| | c_low = torch.gather(sorted_series, 1, indices_low.unsqueeze(1)) |
| | c_high = torch.gather(sorted_series, 1, indices_high.unsqueeze(1)) |
| |
|
| | |
| | clip_above = torch.minimum(time_series_batch, c_high) |
| | clip_below = torch.maximum(time_series_batch, c_low) |
| |
|
| | |
| | result = torch.where( |
| | op_mode == 1, |
| | clip_above, |
| | torch.where(op_mode == 2, clip_below, time_series_batch), |
| | ) |
| | augmented_batch = torch.where( |
| | op_mode == 0, |
| | time_series_batch, |
| | result, |
| | ) |
| |
|
| | return augmented_batch.unsqueeze(-1) |
| |
|
| |
|
| | class QuantizationAugmenter: |
| | """ |
| | Applies non-equidistant quantization using a Sobol sequence to generate |
| | uniformly distributed levels. This implementation is fully vectorized. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | p_quantize: float, |
| | level_range: tuple[int, int], |
| | seed: int | None = None, |
| | ): |
| | """ |
| | Initializes the augmenter. |
| | |
| | Args: |
| | p_quantize (float): Probability of applying quantization to a series. |
| | level_range (Tuple[int, int]): Inclusive range [min, max] to sample the |
| | number of quantization levels from. |
| | seed (Optional[int]): Seed for the Sobol sequence generator for reproducibility. |
| | """ |
| | assert 0.0 <= p_quantize <= 1.0, "Probability must be between 0 and 1." |
| | assert level_range[0] >= 2, "Minimum number of levels must be at least 2." |
| | assert level_range[0] <= level_range[1], "Min levels cannot be greater than max." |
| |
|
| | self.p_quantize = p_quantize |
| | self.level_range = level_range |
| |
|
| | |
| | |
| | max_intermediate_levels = self.level_range[1] - 2 |
| | if max_intermediate_levels > 0: |
| | |
| | self.sobol_engine = SobolEngine(dimension=max_intermediate_levels, scramble=True, seed=seed) |
| | else: |
| | self.sobol_engine = None |
| |
|
| | def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Applies augmentation in a fully vectorized way on the batch's device. |
| | Handles input shape (batch, length, 1). |
| | """ |
| | |
| | if time_series_batch.dim() == 3 and time_series_batch.shape[2] == 1: |
| | is_3d = True |
| | time_series_squeezed = time_series_batch.squeeze(-1) |
| | else: |
| | is_3d = False |
| | time_series_squeezed = time_series_batch |
| |
|
| | if self.p_quantize == 0 or self.sobol_engine is None: |
| | return time_series_batch |
| |
|
| | n_series, _ = time_series_squeezed.shape |
| | device = time_series_squeezed.device |
| |
|
| | |
| | augment_mask = torch.rand(n_series, device=device) < self.p_quantize |
| | n_augment = torch.sum(augment_mask) |
| | if n_augment == 0: |
| | return time_series_batch |
| |
|
| | series_to_augment = time_series_squeezed[augment_mask] |
| |
|
| | |
| | min_l, max_l = self.level_range |
| | n_levels_per_series = torch.randint(min_l, max_l + 1, size=(n_augment,), device=device) |
| | max_levels_in_batch = n_levels_per_series.max().item() |
| |
|
| | |
| | min_vals = torch.amin(series_to_augment, dim=1, keepdim=True) |
| | max_vals = torch.amax(series_to_augment, dim=1, keepdim=True) |
| | value_range = max_vals - min_vals |
| | is_flat = value_range == 0 |
| |
|
| | |
| | num_intermediate_levels = max_levels_in_batch - 2 |
| | if num_intermediate_levels > 0: |
| | |
| | sobol_points = self.sobol_engine.draw(n_augment).to(device) |
| | |
| | quasi_rand_points = sobol_points[:, :num_intermediate_levels] |
| | else: |
| | |
| | quasi_rand_points = torch.empty(n_augment, 0, device=device) |
| |
|
| | scaled_quasi_rand_levels = min_vals + value_range * quasi_rand_points |
| | level_values = torch.cat([min_vals, max_vals, scaled_quasi_rand_levels], dim=1) |
| | level_values, _ = torch.sort(level_values, dim=1) |
| |
|
| | |
| | series_expanded = series_to_augment.unsqueeze(2) |
| | levels_expanded = level_values.unsqueeze(1) |
| | diff = torch.abs(series_expanded - levels_expanded) |
| |
|
| | arange_mask = torch.arange(max_levels_in_batch, device=device).unsqueeze(0) |
| | valid_levels_mask = arange_mask < n_levels_per_series.unsqueeze(1) |
| | masked_diff = torch.where(valid_levels_mask.unsqueeze(1), diff, float("inf")) |
| | closest_level_indices = torch.argmin(masked_diff, dim=2) |
| |
|
| | |
| | quantized_subset = torch.gather(level_values, 1, closest_level_indices) |
| |
|
| | |
| | final_subset = torch.where(is_flat, series_to_augment, quantized_subset) |
| |
|
| | |
| | augmented_batch_squeezed = time_series_squeezed.clone() |
| | augmented_batch_squeezed[augment_mask] = final_subset |
| |
|
| | |
| | if is_3d: |
| | return augmented_batch_squeezed.unsqueeze(-1) |
| | else: |
| | return augmented_batch_squeezed |
| |
|
| |
|
| | class MixUpAugmenter: |
| | """ |
| | Applies mixup augmentation by creating a weighted average of multiple time series. |
| | |
| | This version includes an option for time-dependent mixup using Simplex Path |
| | Interpolation, creating a smooth transition between different mixing weights. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | max_n_series_to_combine: int = 10, |
| | p_combine: float = 0.4, |
| | p_time_dependent: float = 0.5, |
| | randomize_k_per_series: bool = True, |
| | dirichlet_alpha_range: tuple[float, float] = (0.1, 5.0), |
| | ): |
| | """ |
| | Initializes the augmenter. |
| | |
| | Args: |
| | max_n_series_to_combine (int): The maximum number of series to combine. |
| | The actual number k will be sampled from [2, max]. |
| | p_combine (float): The probability of replacing a series with a combination. |
| | p_time_dependent (float): The probability of using the time-dependent |
| | simplex path method for a given mixup operation. Defaults to 0.5. |
| | randomize_k_per_series (bool): If True, each augmented series will be a |
| | combination of a different number of series (k). |
| | If False, one k is chosen for the whole batch. |
| | dirichlet_alpha_range (Tuple[float, float]): The [min, max] range to sample the |
| | Dirichlet 'alpha' from. A smaller alpha (e.g., 0.2) creates mixes |
| | dominated by one series. A larger alpha (e.g., 5.0) creates |
| | more uniform weights. |
| | """ |
| | assert max_n_series_to_combine >= 2, "Must combine at least 2 series." |
| | assert 0.0 <= p_combine <= 1.0, "p_combine must be between 0 and 1." |
| | assert 0.0 <= p_time_dependent <= 1.0, "p_time_dependent must be between 0 and 1." |
| | assert dirichlet_alpha_range[0] > 0 and dirichlet_alpha_range[0] <= dirichlet_alpha_range[1] |
| | self.max_k = max_n_series_to_combine |
| | self.p_combine = p_combine |
| | self.p_time_dependent = p_time_dependent |
| | self.randomize_k = randomize_k_per_series |
| | self.alpha_range = dirichlet_alpha_range |
| |
|
| | def _sample_alpha(self) -> float: |
| | log_alpha_min = math.log10(self.alpha_range[0]) |
| | log_alpha_max = math.log10(self.alpha_range[1]) |
| | log_alpha = log_alpha_min + np.random.rand() * (log_alpha_max - log_alpha_min) |
| | return float(10**log_alpha) |
| |
|
| | def _sample_k(self) -> int: |
| | return int(torch.randint(2, self.max_k + 1, (1,)).item()) |
| |
|
| | def _static_mix( |
| | self, |
| | source_series: torch.Tensor, |
| | alpha: float, |
| | return_weights: bool = False, |
| | ): |
| | """Mixes k source series using a single, static set of Dirichlet weights.""" |
| | k = int(source_series.shape[0]) |
| | device = source_series.device |
| | concentration = torch.full((k,), float(alpha), device=device) |
| | weights = torch.distributions.Dirichlet(concentration).sample() |
| | weights_view = weights.view(k, 1, 1) |
| | mixed_series = (source_series * weights_view).sum(dim=0, keepdim=True) |
| | if return_weights: |
| | return mixed_series, weights |
| | return mixed_series |
| |
|
| | def _simplex_path_mix( |
| | self, |
| | source_series: torch.Tensor, |
| | alpha: float, |
| | return_weights: bool = False, |
| | ): |
| | """Mixes k series using time-varying weights interpolated along a simplex path.""" |
| | k, length, _ = source_series.shape |
| | device = source_series.device |
| |
|
| | |
| | concentration = torch.full((k,), float(alpha), device=device) |
| | dirichlet_dist = torch.distributions.Dirichlet(concentration) |
| | w_start = dirichlet_dist.sample() |
| | w_end = dirichlet_dist.sample() |
| |
|
| | |
| | alpha_ramp = torch.linspace(0, 1, length, device=device) |
| |
|
| | |
| | |
| | time_varying_weights = w_start.unsqueeze(1) * (1 - alpha_ramp.unsqueeze(0)) + w_end.unsqueeze( |
| | 1 |
| | ) * alpha_ramp.unsqueeze(0) |
| | |
| |
|
| | |
| | weights_view = time_varying_weights.unsqueeze(-1) |
| | mixed_series = (source_series * weights_view).sum(dim=0, keepdim=True) |
| |
|
| | if return_weights: |
| | return mixed_series, time_varying_weights |
| | return mixed_series |
| |
|
| | def transform(self, time_series_batch: torch.Tensor, return_debug_info: bool = False): |
| | """ |
| | Applies the mixup augmentation, randomly choosing between static and |
| | time-dependent mixing methods. |
| | """ |
| | with torch.no_grad(): |
| | if self.p_combine == 0: |
| | return (time_series_batch, {}) if return_debug_info else time_series_batch |
| |
|
| | batch_size, _, _ = time_series_batch.shape |
| | device = time_series_batch.device |
| |
|
| | if batch_size <= self.max_k: |
| | return (time_series_batch, {}) if return_debug_info else time_series_batch |
| |
|
| | |
| | augment_mask = torch.rand(batch_size, device=device) < self.p_combine |
| | indices_to_replace = torch.where(augment_mask)[0] |
| | n_augment = indices_to_replace.numel() |
| |
|
| | if n_augment == 0: |
| | return (time_series_batch, {}) if return_debug_info else time_series_batch |
| |
|
| | |
| | if self.randomize_k: |
| | k_values = torch.randint(2, self.max_k + 1, (n_augment,), device=device) |
| | else: |
| | k = self._sample_k() |
| | k_values = torch.full((n_augment,), k, device=device) |
| |
|
| | |
| | new_series_list = [] |
| | all_batch_indices = torch.arange(batch_size, device=device) |
| | debug_info = {} |
| |
|
| | for i, target_idx in enumerate(indices_to_replace): |
| | current_k = k_values[i].item() |
| |
|
| | |
| | candidate_mask = all_batch_indices != target_idx |
| | candidates = all_batch_indices[candidate_mask] |
| | perm = torch.randperm(candidates.shape[0], device=device) |
| | source_indices = candidates[perm[:current_k]] |
| | source_series = time_series_batch[source_indices] |
| |
|
| | alpha = self._sample_alpha() |
| | mix_type = "static" |
| |
|
| | |
| | if torch.rand(1).item() < self.p_time_dependent: |
| | mixed_series, weights = self._simplex_path_mix(source_series, alpha=alpha, return_weights=True) |
| | mix_type = "simplex" |
| | else: |
| | mixed_series, weights = self._static_mix(source_series, alpha=alpha, return_weights=True) |
| |
|
| | new_series_list.append(mixed_series) |
| |
|
| | if return_debug_info: |
| | debug_info[target_idx.item()] = { |
| | "source_indices": source_indices.cpu().numpy(), |
| | "weights": weights.cpu().numpy(), |
| | "alpha": alpha, |
| | "k": current_k, |
| | "mix_type": mix_type, |
| | } |
| |
|
| | |
| | augmented_batch = time_series_batch.clone() |
| | if new_series_list: |
| | new_series_tensor = torch.cat(new_series_list, dim=0) |
| | augmented_batch[indices_to_replace] = new_series_tensor |
| |
|
| | if return_debug_info: |
| | return augmented_batch.detach(), debug_info |
| | return augmented_batch.detach() |
| |
|
| |
|
| | class TimeFlipAugmenter: |
| | """ |
| | Applies time-reversal augmentation to a random subset of time series in a batch. |
| | """ |
| |
|
| | def __init__(self, p_flip: float = 0.5): |
| | """ |
| | Initializes the TimeFlipAugmenter. |
| | |
| | Args: |
| | p_flip (float): The probability of flipping a single time series in the batch. |
| | Defaults to 0.5. |
| | """ |
| | assert 0.0 <= p_flip <= 1.0, "Probability must be between 0 and 1." |
| | self.p_flip = p_flip |
| |
|
| | def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Applies time-reversal augmentation to a batch of time series. |
| | |
| | Args: |
| | time_series_batch (torch.Tensor): The input batch of time series with |
| | shape (batch_size, seq_len, num_channels). |
| | |
| | Returns: |
| | torch.Tensor: The batch with some series potentially flipped. |
| | """ |
| | with torch.no_grad(): |
| | if self.p_flip == 0: |
| | return time_series_batch |
| |
|
| | batch_size = time_series_batch.shape[0] |
| | device = time_series_batch.device |
| |
|
| | |
| | flip_mask = torch.rand(batch_size, device=device) < self.p_flip |
| | indices_to_flip = torch.where(flip_mask)[0] |
| |
|
| | if indices_to_flip.numel() == 0: |
| | return time_series_batch |
| |
|
| | |
| | series_to_flip = time_series_batch[indices_to_flip] |
| |
|
| | |
| | flipped_series = torch.flip(series_to_flip, dims=[1]) |
| |
|
| | |
| | augmented_batch = time_series_batch.clone() |
| | augmented_batch[indices_to_flip] = flipped_series |
| |
|
| | return augmented_batch |
| |
|
| |
|
| | class YFlipAugmenter: |
| | """ |
| | Applies y-reversal augmentation to a random subset of time series in a batch. |
| | """ |
| |
|
| | def __init__(self, p_flip: float = 0.5): |
| | """ |
| | Initializes the TimeFlipAugmenter. |
| | |
| | Args: |
| | p_flip (float): The probability of flipping a single time series in the batch. |
| | Defaults to 0.5. |
| | """ |
| | assert 0.0 <= p_flip <= 1.0, "Probability must be between 0 and 1." |
| | self.p_flip = p_flip |
| |
|
| | def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Applies time-reversal augmentation to a batch of time series. |
| | |
| | Args: |
| | time_series_batch (torch.Tensor): The input batch of time series with |
| | shape (batch_size, seq_len, num_channels). |
| | |
| | Returns: |
| | torch.Tensor: The batch with some series potentially flipped. |
| | """ |
| | with torch.no_grad(): |
| | if self.p_flip == 0: |
| | return time_series_batch |
| |
|
| | batch_size = time_series_batch.shape[0] |
| | device = time_series_batch.device |
| |
|
| | |
| | flip_mask = torch.rand(batch_size, device=device) < self.p_flip |
| | indices_to_flip = torch.where(flip_mask)[0] |
| |
|
| | if indices_to_flip.numel() == 0: |
| | return time_series_batch |
| |
|
| | |
| | series_to_flip = time_series_batch[indices_to_flip] |
| |
|
| | |
| | flipped_series = -series_to_flip |
| |
|
| | |
| | augmented_batch = time_series_batch.clone() |
| | augmented_batch[indices_to_flip] = flipped_series |
| |
|
| | return augmented_batch |
| |
|
| |
|
| | class DifferentialAugmenter: |
| | """ |
| | Applies calculus-inspired augmentations. This version includes up to the |
| | fourth derivative and uses nn.Conv1d with built-in 'reflect' padding for |
| | cleaner and more efficient convolutions. |
| | |
| | The Gaussian kernel size and sigma for the initial smoothing are randomly |
| | sampled at every transform() call from user-defined ranges. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | p_transform: float, |
| | gaussian_kernel_size_range: tuple[int, int] = (5, 51), |
| | gaussian_sigma_range: tuple[float, float] = (2.0, 20.0), |
| | ): |
| | """ |
| | Initializes the augmenter. |
| | |
| | Args: |
| | p_transform (float): The probability of applying an augmentation to any given |
| | time series in a batch. |
| | gaussian_kernel_size_range (Tuple[int, int]): The [min, max] inclusive range |
| | for the Gaussian kernel size. |
| | Sizes will be forced to be odd. |
| | gaussian_sigma_range (Tuple[float, float]): The [min, max] inclusive range |
| | for the Gaussian sigma. |
| | """ |
| | self.p_transform = p_transform |
| | self.kernel_size_range = gaussian_kernel_size_range |
| | self.sigma_range = gaussian_sigma_range |
| |
|
| | |
| | if not (self.kernel_size_range[0] <= self.kernel_size_range[1] and self.kernel_size_range[0] >= 3): |
| | raise ValueError("Invalid kernel size range. Ensure min <= max and min >= 3.") |
| | if not (self.sigma_range[0] <= self.sigma_range[1] and self.sigma_range[0] > 0): |
| | raise ValueError("Invalid sigma range. Ensure min <= max and min > 0.") |
| |
|
| | |
| | self.conv_cache: dict[tuple[int, torch.device], dict[str, nn.Module]] = {} |
| |
|
| | def _create_fixed_kernel_layers(self, num_channels: int, device: torch.device) -> dict: |
| | """ |
| | Creates and configures nn.Conv1d layers for fixed-kernel derivative operations. |
| | These layers are cached to improve performance. |
| | """ |
| | sobel_conv = nn.Conv1d( |
| | in_channels=num_channels, |
| | out_channels=num_channels, |
| | kernel_size=3, |
| | padding="same", |
| | padding_mode="reflect", |
| | groups=num_channels, |
| | bias=False, |
| | device=device, |
| | ) |
| | laplace_conv = nn.Conv1d( |
| | in_channels=num_channels, |
| | out_channels=num_channels, |
| | kernel_size=3, |
| | padding="same", |
| | padding_mode="reflect", |
| | groups=num_channels, |
| | bias=False, |
| | device=device, |
| | ) |
| | d3_conv = nn.Conv1d( |
| | in_channels=num_channels, |
| | out_channels=num_channels, |
| | kernel_size=5, |
| | padding="same", |
| | padding_mode="reflect", |
| | groups=num_channels, |
| | bias=False, |
| | device=device, |
| | ) |
| | d4_conv = nn.Conv1d( |
| | in_channels=num_channels, |
| | out_channels=num_channels, |
| | kernel_size=5, |
| | padding="same", |
| | padding_mode="reflect", |
| | groups=num_channels, |
| | bias=False, |
| | device=device, |
| | ) |
| |
|
| | sobel_kernel = ( |
| | torch.tensor([-1, 0, 1], device=device, dtype=torch.float32).view(1, 1, -1).repeat(num_channels, 1, 1) |
| | ) |
| | laplace_kernel = ( |
| | torch.tensor([1, -2, 1], device=device, dtype=torch.float32).view(1, 1, -1).repeat(num_channels, 1, 1) |
| | ) |
| | d3_kernel = ( |
| | torch.tensor([-1, 2, 0, -2, 1], device=device, dtype=torch.float32) |
| | .view(1, 1, -1) |
| | .repeat(num_channels, 1, 1) |
| | ) |
| | d4_kernel = ( |
| | torch.tensor([1, -4, 6, -4, 1], device=device, dtype=torch.float32) |
| | .view(1, 1, -1) |
| | .repeat(num_channels, 1, 1) |
| | ) |
| |
|
| | sobel_conv.weight.data = sobel_kernel |
| | laplace_conv.weight.data = laplace_kernel |
| | d3_conv.weight.data = d3_kernel |
| | d4_conv.weight.data = d4_kernel |
| |
|
| | for layer in [sobel_conv, laplace_conv, d3_conv, d4_conv]: |
| | layer.weight.requires_grad = False |
| |
|
| | return { |
| | "sobel": sobel_conv, |
| | "laplace": laplace_conv, |
| | "d3": d3_conv, |
| | "d4": d4_conv, |
| | } |
| |
|
| | def _create_gaussian_layer( |
| | self, kernel_size: int, sigma: float, num_channels: int, device: torch.device |
| | ) -> nn.Module: |
| | """Creates a single Gaussian convolution layer with the given dynamic parameters.""" |
| | gauss_conv = nn.Conv1d( |
| | in_channels=num_channels, |
| | out_channels=num_channels, |
| | kernel_size=kernel_size, |
| | padding="same", |
| | padding_mode="reflect", |
| | groups=num_channels, |
| | bias=False, |
| | device=device, |
| | ) |
| | ax = torch.arange( |
| | -(kernel_size // 2), |
| | kernel_size // 2 + 1, |
| | device=device, |
| | dtype=torch.float32, |
| | ) |
| | gauss_kernel = torch.exp(-0.5 * (ax / sigma) ** 2) |
| | gauss_kernel /= gauss_kernel.sum() |
| | gauss_kernel = gauss_kernel.view(1, 1, -1).repeat(num_channels, 1, 1) |
| | gauss_conv.weight.data = gauss_kernel |
| | gauss_conv.weight.requires_grad = False |
| | return gauss_conv |
| |
|
| | def _rescale_signal(self, processed_signal: torch.Tensor, original_signal: torch.Tensor) -> torch.Tensor: |
| | """Rescales the processed signal to match the min/max range of the original.""" |
| | original_min = torch.amin(original_signal, dim=2, keepdim=True) |
| | original_max = torch.amax(original_signal, dim=2, keepdim=True) |
| | processed_min = torch.amin(processed_signal, dim=2, keepdim=True) |
| | processed_max = torch.amax(processed_signal, dim=2, keepdim=True) |
| |
|
| | original_range = original_max - original_min |
| | processed_range = processed_max - processed_min |
| | epsilon = 1e-8 |
| | rescaled_signal = ( |
| | (processed_signal - processed_min) / (processed_range + epsilon) |
| | ) * original_range + original_min |
| | return torch.where(original_range < epsilon, original_signal, rescaled_signal) |
| |
|
| | def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
| | """Applies a random augmentation to a subset of the batch.""" |
| | with torch.no_grad(): |
| | if self.p_transform == 0: |
| | return time_series_batch |
| |
|
| | batch_size, seq_len, num_channels = time_series_batch.shape |
| | device = time_series_batch.device |
| |
|
| | augment_mask = torch.rand(batch_size, device=device) < self.p_transform |
| | indices_to_augment = torch.where(augment_mask)[0] |
| | num_to_augment = indices_to_augment.numel() |
| |
|
| | if num_to_augment == 0: |
| | return time_series_batch |
| |
|
| | |
| | min_k, max_k = self.kernel_size_range |
| | kernel_size = torch.randint(min_k, max_k + 1, (1,)).item() |
| | kernel_size = kernel_size // 2 * 2 + 1 |
| |
|
| | min_s, max_s = self.sigma_range |
| | sigma = (min_s + (max_s - min_s) * torch.rand(1)).item() |
| |
|
| | |
| | gauss_conv = self._create_gaussian_layer(kernel_size, sigma, num_channels, device) |
| |
|
| | cache_key = (num_channels, device) |
| | if cache_key not in self.conv_cache: |
| | self.conv_cache[cache_key] = self._create_fixed_kernel_layers(num_channels, device) |
| | fixed_layers = self.conv_cache[cache_key] |
| |
|
| | |
| | subset_to_augment = time_series_batch[indices_to_augment] |
| | subset_permuted = subset_to_augment.permute(0, 2, 1) |
| |
|
| | op_choices = torch.randint(0, 6, (num_to_augment,), device=device) |
| |
|
| | smoothed_subset = gauss_conv(subset_permuted) |
| | sobel_on_smoothed = fixed_layers["sobel"](smoothed_subset) |
| | laplace_on_smoothed = fixed_layers["laplace"](smoothed_subset) |
| | d3_on_smoothed = fixed_layers["d3"](smoothed_subset) |
| | d4_on_smoothed = fixed_layers["d4"](smoothed_subset) |
| |
|
| | gauss_result = self._rescale_signal(smoothed_subset, subset_permuted) |
| | sobel_result = self._rescale_signal(sobel_on_smoothed, subset_permuted) |
| | laplace_result = self._rescale_signal(laplace_on_smoothed, subset_permuted) |
| | d3_result = self._rescale_signal(d3_on_smoothed, subset_permuted) |
| | d4_result = self._rescale_signal(d4_on_smoothed, subset_permuted) |
| |
|
| | use_right_integral = torch.rand(num_to_augment, 1, 1, device=device) > 0.5 |
| | flipped_subset = torch.flip(subset_permuted, dims=[2]) |
| | right_integral = torch.flip(torch.cumsum(flipped_subset, dim=2), dims=[2]) |
| | left_integral = torch.cumsum(subset_permuted, dim=2) |
| | integral_result = torch.where(use_right_integral, right_integral, left_integral) |
| | integral_result_normalized = self._rescale_signal(integral_result, subset_permuted) |
| |
|
| | |
| | op_choices_view = op_choices.view(-1, 1, 1) |
| | augmented_subset = torch.where(op_choices_view == 0, gauss_result, subset_permuted) |
| | augmented_subset = torch.where(op_choices_view == 1, sobel_result, augmented_subset) |
| | augmented_subset = torch.where(op_choices_view == 2, laplace_result, augmented_subset) |
| | augmented_subset = torch.where(op_choices_view == 3, integral_result_normalized, augmented_subset) |
| | augmented_subset = torch.where(op_choices_view == 4, d3_result, augmented_subset) |
| | augmented_subset = torch.where(op_choices_view == 5, d4_result, augmented_subset) |
| |
|
| | augmented_subset_final = augmented_subset.permute(0, 2, 1) |
| | augmented_batch = time_series_batch.clone() |
| | augmented_batch[indices_to_augment] = augmented_subset_final |
| |
|
| | return augmented_batch |
| |
|
| |
|
| | class RandomConvAugmenter: |
| | """ |
| | Applies a stack of 1-to-N random 1D convolutions to a time series batch. |
| | |
| | This augmenter is inspired by the principles of ROCKET and RandConv, |
| | randomizing nearly every aspect of the convolution process to create a |
| | highly diverse set of transformations. This version includes multiple |
| | kernel generation strategies, random padding modes, and optional non-linearities. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | p_transform: float = 0.5, |
| | kernel_size_range: tuple[int, int] = (3, 31), |
| | dilation_range: tuple[int, int] = (1, 8), |
| | layer_range: tuple[int, int] = (1, 3), |
| | sigma_range: tuple[float, float] = (0.5, 5.0), |
| | bias_range: tuple[float, float] = (-0.5, 0.5), |
| | ): |
| | """ |
| | Initializes the augmenter. |
| | |
| | Args: |
| | p_transform (float): Probability of applying the augmentation to a series. |
| | kernel_size_range (Tuple[int, int]): [min, max] range for kernel sizes. |
| | Must be odd numbers. |
| | dilation_range (Tuple[int, int]): [min, max] range for dilation factors. |
| | layer_range (Tuple[int, int]): [min, max] range for the number of |
| | stacked convolution layers. |
| | sigma_range (Tuple[float, float]): [min, max] range for the sigma of |
| | Gaussian kernels. |
| | bias_range (Tuple[float, float]): [min, max] range for the bias term. |
| | """ |
| | assert kernel_size_range[0] % 2 == 1 and kernel_size_range[1] % 2 == 1, "Kernel sizes must be odd." |
| |
|
| | self.p_transform = p_transform |
| | self.kernel_size_range = kernel_size_range |
| | self.dilation_range = dilation_range |
| | self.layer_range = layer_range |
| | self.sigma_range = sigma_range |
| | self.bias_range = bias_range |
| | self.padding_modes = ["reflect", "replicate", "circular"] |
| |
|
| | def _rescale_signal(self, processed_signal: torch.Tensor, original_signal: torch.Tensor) -> torch.Tensor: |
| | """Rescales the processed signal to match the min/max range of the original.""" |
| | original_min = torch.amin(original_signal, dim=-1, keepdim=True) |
| | original_max = torch.amax(original_signal, dim=-1, keepdim=True) |
| | processed_min = torch.amin(processed_signal, dim=-1, keepdim=True) |
| | processed_max = torch.amax(processed_signal, dim=-1, keepdim=True) |
| |
|
| | original_range = original_max - original_min |
| | processed_range = processed_max - processed_min |
| | epsilon = 1e-8 |
| |
|
| | is_flat = processed_range < epsilon |
| |
|
| | rescaled_signal = ( |
| | (processed_signal - processed_min) / (processed_range + epsilon) |
| | ) * original_range + original_min |
| |
|
| | original_mean = torch.mean(original_signal, dim=-1, keepdim=True) |
| | flat_rescaled = original_mean.expand_as(original_signal) |
| |
|
| | return torch.where(is_flat, flat_rescaled, rescaled_signal) |
| |
|
| | def _apply_random_conv_stack(self, series: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Applies a randomly configured stack of convolutions to a single time series. |
| | |
| | Args: |
| | series (torch.Tensor): A single time series of shape (1, num_channels, seq_len). |
| | |
| | Returns: |
| | torch.Tensor: The augmented time series. |
| | """ |
| | num_channels = series.shape[1] |
| | device = series.device |
| |
|
| | num_layers = torch.randint(self.layer_range[0], self.layer_range[1] + 1, (1,)).item() |
| |
|
| | processed_series = series |
| | for i in range(num_layers): |
| | |
| | k_min, k_max = self.kernel_size_range |
| | kernel_size = torch.randint(k_min // 2, k_max // 2 + 1, (1,)).item() * 2 + 1 |
| |
|
| | |
| | d_min, d_max = self.dilation_range |
| | dilation = torch.randint(d_min, d_max + 1, (1,)).item() |
| |
|
| | |
| | b_min, b_max = self.bias_range |
| | bias_val = (b_min + (b_max - b_min) * torch.rand(1)).item() |
| |
|
| | |
| | padding_mode = np.random.choice(self.padding_modes) |
| |
|
| | conv_layer = nn.Conv1d( |
| | in_channels=num_channels, |
| | out_channels=num_channels, |
| | kernel_size=kernel_size, |
| | dilation=dilation, |
| | padding="same", |
| | padding_mode=padding_mode, |
| | groups=num_channels, |
| | bias=True, |
| | device=device, |
| | ) |
| |
|
| | |
| | weight_type = torch.randint(0, 4, (1,)).item() |
| | if weight_type == 0: |
| | s_min, s_max = self.sigma_range |
| | sigma = (s_min + (s_max - s_min) * torch.rand(1)).item() |
| | ax = torch.arange( |
| | -(kernel_size // 2), |
| | kernel_size // 2 + 1, |
| | device=device, |
| | dtype=torch.float32, |
| | ) |
| | kernel = torch.exp(-0.5 * (ax / sigma) ** 2) |
| | elif weight_type == 1: |
| | kernel = torch.randn(kernel_size, device=device) |
| | elif weight_type == 2: |
| | coeffs = torch.randn(3, device=device) |
| | x_vals = torch.linspace(-1, 1, kernel_size, device=device) |
| | kernel = coeffs[0] * x_vals**2 + coeffs[1] * x_vals + coeffs[2] |
| | else: |
| | |
| | actual_kernel_size = 3 if kernel_size < 3 else kernel_size |
| | sobel_base = torch.tensor([-1, 0, 1], dtype=torch.float32, device=device) |
| | noise = torch.randn(3, device=device) * 0.1 |
| | noisy_sobel = sobel_base + noise |
| | |
| | pad_total = actual_kernel_size - 3 |
| | pad_left = pad_total // 2 |
| | pad_right = pad_total - pad_left |
| | kernel = F.pad(noisy_sobel, (pad_left, pad_right), "constant", 0) |
| |
|
| | |
| | if torch.rand(1).item() < 0.8: |
| | kernel /= torch.sum(torch.abs(kernel)) + 1e-8 |
| |
|
| | kernel = kernel.view(1, 1, -1).repeat(num_channels, 1, 1) |
| |
|
| | conv_layer.weight.data = kernel |
| | conv_layer.bias.data.fill_(bias_val) |
| | conv_layer.weight.requires_grad = False |
| | conv_layer.bias.requires_grad = False |
| |
|
| | |
| | processed_series = conv_layer(processed_series) |
| |
|
| | |
| | if i < num_layers - 1: |
| | activation_type = torch.randint(0, 3, (1,)).item() |
| | if activation_type == 1: |
| | processed_series = F.relu(processed_series) |
| | elif activation_type == 2: |
| | processed_series = torch.tanh(processed_series) |
| | |
| |
|
| | return processed_series |
| |
|
| | def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
| | """Applies a random augmentation to a subset of the batch.""" |
| | with torch.no_grad(): |
| | if self.p_transform == 0: |
| | return time_series_batch |
| |
|
| | batch_size, seq_len, num_channels = time_series_batch.shape |
| | device = time_series_batch.device |
| |
|
| | augment_mask = torch.rand(batch_size, device=device) < self.p_transform |
| | indices_to_augment = torch.where(augment_mask)[0] |
| | num_to_augment = indices_to_augment.numel() |
| |
|
| | if num_to_augment == 0: |
| | return time_series_batch |
| |
|
| | subset_to_augment = time_series_batch[indices_to_augment] |
| |
|
| | subset_permuted = subset_to_augment.permute(0, 2, 1) |
| |
|
| | augmented_subset_list = [] |
| | for i in range(num_to_augment): |
| | original_series = subset_permuted[i : i + 1] |
| | augmented_series = self._apply_random_conv_stack(original_series) |
| |
|
| | rescaled_series = self._rescale_signal(augmented_series.squeeze(0), original_series.squeeze(0)) |
| | augmented_subset_list.append(rescaled_series.unsqueeze(0)) |
| |
|
| | if augmented_subset_list: |
| | augmented_subset = torch.cat(augmented_subset_list, dim=0) |
| | augmented_subset_final = augmented_subset.permute(0, 2, 1) |
| |
|
| | augmented_batch = time_series_batch.clone() |
| | augmented_batch[indices_to_augment] = augmented_subset_final |
| | return augmented_batch |
| | else: |
| | return time_series_batch |
| |
|