| import time | |
| from typing import List, Optional | |
| import pandas as pd | |
| from gluonts.dataset.common import Dataset | |
| from gluonts.model.forecast import Forecast, QuantileForecast | |
| from .abstract import AbstractPredictor | |
| class AutoGluonPredictor(AbstractPredictor): | |
| def __init__( | |
| self, | |
| prediction_length: int, | |
| freq: str, | |
| seasonality: int, | |
| time_limit: Optional[int] = None, | |
| presets: str = "high_quality", | |
| eval_metric: str = "MASE", | |
| seed: int = 1, | |
| enable_ensemble: bool = True, | |
| hyperparameters: Optional[dict] = None, | |
| **kwargs | |
| ): | |
| super().__init__(prediction_length, freq, seasonality) | |
| self.presets = presets | |
| self.eval_metric = eval_metric | |
| self.time_limit = time_limit | |
| self.seed = seed | |
| self.enable_ensemble = enable_ensemble | |
| self.hyperparameters = hyperparameters | |
| def fit_predict(self, dataset: Dataset) -> List[Forecast]: | |
| from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor | |
| train_data = TimeSeriesDataFrame(dataset) | |
| predictor = TimeSeriesPredictor( | |
| prediction_length=self.prediction_length, | |
| eval_metric=self.eval_metric, | |
| eval_metric_seasonal_period=self.seasonality, | |
| quantile_levels=self.quantile_levels, | |
| ) | |
| start_time = time.time() | |
| predictor.fit( | |
| train_data, | |
| time_limit=self.time_limit, | |
| presets=self.presets, | |
| random_seed=self.seed, | |
| enable_ensemble=self.enable_ensemble, | |
| hyperparameters=self.hyperparameters, | |
| ) | |
| predictions = predictor.predict(train_data) | |
| self.save_runtime(time.time() - start_time) | |
| return self._predictions_df_to_gluonts_forecast( | |
| predictions_df=predictions.drop("mean", axis=1), dataset=dataset | |
| ) | |
| def _predictions_df_to_gluonts_forecast( | |
| self, predictions_df, dataset: Dataset | |
| ) -> List[Forecast]: | |
| agts_forecasts = [ | |
| f.droplevel("item_id") | |
| for _, f in predictions_df.groupby(level="item_id", sort=False) | |
| ] | |
| forecast_list = [] | |
| for ts, f in zip(dataset, agts_forecasts): | |
| item_id = ts["item_id"] | |
| forecast_list.append( | |
| QuantileForecast( | |
| forecast_arrays=f.values.T, | |
| forecast_keys=f.columns, | |
| start_date=pd.Period(f.index[0], freq=self.freq), | |
| item_id=item_id, | |
| ) | |
| ) | |
| return forecast_list | |