| | import logging |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | from gluonts.model.forecast import QuantileForecast |
| |
|
| | from src.data.frequency import parse_frequency |
| | from src.plotting.plot_timeseries import ( |
| | plot_multivariate_timeseries, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def _prepare_data_for_plotting(input_data: dict, label_data: dict, max_context_length: int): |
| | history_values = np.asarray(input_data["target"], dtype=np.float32) |
| | future_values = np.asarray(label_data["target"], dtype=np.float32) |
| | start_period = input_data["start"] |
| |
|
| | def ensure_time_first(arr: np.ndarray) -> np.ndarray: |
| | if arr.ndim == 1: |
| | return arr.reshape(-1, 1) |
| | elif arr.ndim == 2: |
| | if arr.shape[0] < arr.shape[1]: |
| | return arr.T |
| | return arr |
| | else: |
| | return arr.reshape(arr.shape[-1], -1).T |
| |
|
| | history_values = ensure_time_first(history_values) |
| | future_values = ensure_time_first(future_values) |
| |
|
| | if max_context_length is not None and history_values.shape[0] > max_context_length: |
| | history_values = history_values[-max_context_length:] |
| |
|
| | |
| | start_timestamp = ( |
| | start_period.to_timestamp() if hasattr(start_period, "to_timestamp") else pd.Timestamp(start_period) |
| | ) |
| | return history_values, future_values, start_timestamp |
| |
|
| |
|
| | def _extract_quantile_predictions( |
| | forecast, |
| | ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: |
| | def ensure_2d_time_first(arr): |
| | if arr is None: |
| | return None |
| | arr = np.asarray(arr) |
| | if arr.ndim == 1: |
| | return arr.reshape(-1, 1) |
| | elif arr.ndim == 2: |
| | return arr |
| | else: |
| | return arr.reshape(arr.shape[0], -1) |
| |
|
| | if isinstance(forecast, QuantileForecast): |
| | try: |
| | median_pred = forecast.quantile(0.5) |
| | try: |
| | lower_bound = forecast.quantile(0.1) |
| | upper_bound = forecast.quantile(0.9) |
| | except (KeyError, ValueError): |
| | lower_bound = None |
| | upper_bound = None |
| | median_pred = ensure_2d_time_first(median_pred) |
| | lower_bound = ensure_2d_time_first(lower_bound) |
| | upper_bound = ensure_2d_time_first(upper_bound) |
| | return median_pred, lower_bound, upper_bound |
| | except Exception: |
| | try: |
| | median_pred = forecast.quantile(0.5) |
| | median_pred = ensure_2d_time_first(median_pred) |
| | return median_pred, None, None |
| | except Exception: |
| | return None, None, None |
| | else: |
| | try: |
| | samples = forecast.samples |
| | if samples.ndim == 1: |
| | median_pred = samples |
| | elif samples.ndim == 2: |
| | if samples.shape[0] == 1: |
| | median_pred = samples[0] |
| | else: |
| | median_pred = np.median(samples, axis=0) |
| | elif samples.ndim == 3: |
| | median_pred = np.median(samples, axis=0) |
| | else: |
| | median_pred = samples[0] if len(samples) > 0 else samples |
| | median_pred = ensure_2d_time_first(median_pred) |
| | return median_pred, None, None |
| | except Exception: |
| | return None, None, None |
| |
|
| |
|
| | def _create_plot( |
| | input_data: dict, |
| | label_data: dict, |
| | forecast, |
| | dataset_full_name: str, |
| | dataset_freq: str, |
| | max_context_length: int, |
| | title: str | None = None, |
| | ): |
| | try: |
| | history_values, future_values, start_timestamp = _prepare_data_for_plotting( |
| | input_data, label_data, max_context_length |
| | ) |
| | median_pred, lower_bound, upper_bound = _extract_quantile_predictions(forecast) |
| | if median_pred is None: |
| | logger.warning(f"Could not extract predictions for {dataset_full_name}") |
| | return None |
| |
|
| | def ensure_compatible_shape(pred_arr, target_arr): |
| | if pred_arr is None: |
| | return None |
| | pred_arr = np.asarray(pred_arr) |
| | target_arr = np.asarray(target_arr) |
| | if pred_arr.ndim == 1: |
| | pred_arr = pred_arr.reshape(-1, 1) |
| | if target_arr.ndim == 1: |
| | target_arr = target_arr.reshape(-1, 1) |
| | if pred_arr.shape != target_arr.shape: |
| | if pred_arr.shape[0] == target_arr.shape[0]: |
| | if pred_arr.shape[1] == 1 and target_arr.shape[1] > 1: |
| | pred_arr = np.broadcast_to(pred_arr, target_arr.shape) |
| | elif pred_arr.shape[1] > 1 and target_arr.shape[1] == 1: |
| | pred_arr = pred_arr[:, :1] |
| | elif pred_arr.shape[1] == target_arr.shape[1]: |
| | min_time = min(pred_arr.shape[0], target_arr.shape[0]) |
| | pred_arr = pred_arr[:min_time] |
| | else: |
| | if pred_arr.T.shape == target_arr.shape: |
| | pred_arr = pred_arr.T |
| | else: |
| | if pred_arr.size >= target_arr.shape[0]: |
| | pred_arr = pred_arr.flatten()[: target_arr.shape[0]].reshape(-1, 1) |
| | if target_arr.shape[1] > 1: |
| | pred_arr = np.broadcast_to(pred_arr, target_arr.shape) |
| | return pred_arr |
| |
|
| | median_pred = ensure_compatible_shape(median_pred, future_values) |
| | lower_bound = ensure_compatible_shape(lower_bound, future_values) |
| | upper_bound = ensure_compatible_shape(upper_bound, future_values) |
| |
|
| | title = title or f"GIFT-Eval: {dataset_full_name}" |
| | frequency = parse_frequency(dataset_freq) |
| | fig = plot_multivariate_timeseries( |
| | history_values=history_values, |
| | future_values=future_values, |
| | predicted_values=median_pred, |
| | lower_bound=lower_bound, |
| | upper_bound=upper_bound, |
| | start=start_timestamp, |
| | frequency=frequency, |
| | title=title, |
| | show=False, |
| | ) |
| | return fig |
| | except Exception as e: |
| | logger.warning(f"Failed to create plot for {dataset_full_name}: {e}") |
| | return None |
| |
|
| |
|
| | def create_plots_for_dataset( |
| | forecasts: list, |
| | test_data, |
| | dataset_metadata, |
| | max_plots: int, |
| | max_context_length: int, |
| | ) -> list[tuple[object, str]]: |
| | input_data_list = list(test_data.input) |
| | label_data_list = list(test_data.label) |
| | num_plots = min(len(forecasts), max_plots) |
| | logger.info(f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}") |
| |
|
| | figures_with_names: list[tuple[object, str]] = [] |
| | for i in range(num_plots): |
| | try: |
| | forecast = forecasts[i] |
| | input_data = input_data_list[i] |
| | label_data = label_data_list[i] |
| | title = ( |
| | f"GIFT-Eval: {dataset_metadata.full_name} - Window {i + 1}/{num_plots}" |
| | if hasattr(dataset_metadata, "full_name") |
| | else f"Window {i + 1}/{num_plots}" |
| | ) |
| | fig = _create_plot( |
| | input_data=input_data, |
| | label_data=label_data, |
| | forecast=forecast, |
| | dataset_full_name=getattr(dataset_metadata, "full_name", "dataset"), |
| | dataset_freq=getattr(dataset_metadata, "freq", "D"), |
| | max_context_length=max_context_length, |
| | title=title, |
| | ) |
| | if fig is not None: |
| | filename = f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png" |
| | figures_with_names.append((fig, filename)) |
| | except Exception as e: |
| | logger.warning(f"Error creating plot for window {i + 1}: {e}") |
| | continue |
| | return figures_with_names |
| |
|