Source code for ts_benchmark.baselines.darts.darts_models

# -*- coding: utf-8 -*-
import contextlib
import functools
import logging
import os
from typing import Dict, Optional, Any, Tuple, NoReturn, ContextManager

import darts
import darts.models as darts_models
import numpy as np
import pandas as pd
from darts import TimeSeries
from sklearn.preprocessing import StandardScaler

from ts_benchmark.baselines.utils import train_val_split
from ts_benchmark.common.constant import ROOT_PATH
from ts_benchmark.models import ModelBase

if str(darts.__version__) >= "0.25.0":
    from darts.models.utils import NotImportedModule

logger = logging.getLogger(__name__)

TAG = pd.read_csv(
    os.path.join(ROOT_PATH, "ts_benchmark", "baselines", "tag_csv", "darts_tag.csv")
)


[docs] class DartsConfig: def __init__(self, **kwargs): self.params = { **kwargs, } def __getattr__(self, key: str) -> Any: return self.get(key) def __getitem__(self, key: str) -> Any: return self.get(key)
[docs] def get(self, key: str, default: Any = None) -> Any: return self.params.get(key, default)
[docs] def get_darts_class_params(self) -> dict: ret = self.params.copy() ret.pop("norm") self._fix_multi_gpu(ret) return ret
def _fix_multi_gpu(self, args_dict: Dict) -> NoReturn: """ Check and disable using multi-gpu per task training and inferencing on multiple gpus with 'ddp' strategy (default in lightning) is error-prone in complicated work flow, the problems include but not limited to: - do heavy initialization in all processes (e.g. full data loading) - hangs when the program is interrupted (e.g. exceptions that are caught elsewhere) - not compatible with the parallel paradigm of ray As a result, we disallow a single worker to work on multiple gpus by changing gpu settings in the the input argument dictionary. TODO: what if we want to evaluate huge-sized models in the future? :param args_dict: argument dictionary to be passed to Darts models. """ # CUDA_VISIBLE_DEVICES should be set by the parallel backend gpu_devices = list( filter(None, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")) ) if len(gpu_devices) > 1: pl_args = args_dict.get("pl_trainer_kwargs", {}) device_args = pl_args.get("devices", None) if ( device_args is None or (isinstance(device_args, list) and len(device_args) > 1) or (isinstance(device_args, int) and device_args > 1) ): args_dict.setdefault("pl_trainer_kwargs", {}) args_dict["pl_trainer_kwargs"]["devices"] = [0] logger.warning( "Multi-gpu training is not supported, using only gpu %s", gpu_devices[0], )
[docs] class DartsModelAdapter(ModelBase): """ Darts model adapter class Adapts Darts models to OTB forecasting interface. """ def __init__( self, model_class: type, model_args: dict, model_name: Optional[str] = None, allow_fit_on_eval: bool = False, supports_validation: bool = False, **kwargs ): """ Initialize the Darts model adapter object. :param model_class: Darts model class. :param model_args: Model initialization parameters. :param model_name: Model name. :param allow_fit_on_eval: Is it allowed to fit the model during the prediction phase. :param supports_validation: Whether the model supports inputting a validation series. :param kwargs: other arguments added to model_args. """ self.model = None self.model_class = model_class self.config = DartsConfig(**{**model_args, **kwargs}) self._model_name = model_name self.allow_fit_on_eval = allow_fit_on_eval self.supports_validation = supports_validation self.scaler = StandardScaler() self.train_ratio_in_tv = 1 @property def model_name(self): """ Returns the name of the model. """ return self._model_name
[docs] def forecast_fit( self, train_data: pd.DataFrame, *, train_ratio_in_tv: float = 1.0, **kwargs ) -> "ModelBase": """ Fit a suitable Darts model on time series data. :param train_data: Time series data. :param train_ratio_in_tv: Represents the splitting ratio of the training set validation set. If it is equal to 1, it means that the validation set is not partitioned. :return: The fitted model object. """ self.train_ratio_in_tv = train_ratio_in_tv if self.allow_fit_on_eval or self.model_name == "RegressionModel": # If it is true, it means that statistical learning methods use retraining to predict # future values, because statistical learning does not require partitioning the validation set. # Therefore, the segmentation ratio is set to 1, which means that the validation set # is not segmented valid_data = None else: train_data, valid_data = train_val_split( train_data, self.train_ratio_in_tv, self.config.get("input_chunk_length", 0), ) self.model = self.model_class(**self.config.get_darts_class_params()) if self.config.norm: self.scaler.fit(train_data.values) train_data = pd.DataFrame( self.scaler.transform(train_data.values), columns=train_data.columns, index=train_data.index, ) if self.supports_validation and valid_data is not None: valid_data = pd.DataFrame( self.scaler.transform(valid_data.values), columns=valid_data.columns, index=valid_data.index, ) with self._suppress_lightning_logs(): train_data = TimeSeries.from_dataframe(train_data) if self.supports_validation and valid_data is not None: valid_data = TimeSeries.from_dataframe(valid_data) self.model.fit(train_data, val_series=valid_data) else: self.model.fit(train_data) return self
[docs] def forecast(self, horizon: int, series: pd.DataFrame, **kwargs) -> np.ndarray: """ Use the adapted Darts model for prediction. :param horizon: Forecast length. :param series: Time series data to make inferences on. :return: Forecast result. """ if self.config.norm: series = pd.DataFrame( self.scaler.transform(series.values), columns=series.columns, index=series.index, ) with self._suppress_lightning_logs(): if self.allow_fit_on_eval: self.forecast_fit(series, train_ratio_in_tv=self.train_ratio_in_tv) fsct_result = self.model.predict(horizon) else: series = TimeSeries.from_dataframe(series) fsct_result = self.model.predict(horizon, series) predict = fsct_result.values() if self.config.norm: predict = self.scaler.inverse_transform(predict) return predict
@contextlib.contextmanager def _suppress_lightning_logs(self) -> ContextManager: pl_logger = logging.getLogger("pytorch_lightning") old_level = pl_logger.level pl_logger.setLevel(logging.CRITICAL) try: yield finally: pl_logger.setLevel(old_level)
def _generate_model_factory( model_class: type, model_args: dict, model_name: str, required_args: dict, allow_fit_on_eval: bool, supports_validation: bool, ) -> Dict: """ Generate model factory information for creating Darts model adapters. :param model_name: Model name. :param model_class: Darts model class. :param model_args: Predefined model hyperparameters that can be overwritten by the hyperparameters of the input factory function. :param required_args: Requires hyperparameters recommended by benchmark. :param allow_fit_on_eval: Is it allowed to fit the model during the prediction phase. :param supports_validation: Whether the model supports inputting a validation series. :return: A dictionary containing the model factory and required parameters. """ model_factory = functools.partial( DartsModelAdapter, model_class=model_class, model_args=model_args, model_name=model_name, allow_fit_on_eval=allow_fit_on_eval, supports_validation=supports_validation, ) return {"model_factory": model_factory, "required_hyper_params": required_args} # predefined model_args and required_args for darts models DEEP_MODEL_REQUIRED_ARGS = { "input_chunk_length": "input_chunk_length", "output_chunk_length": "output_chunk_length", "norm": "norm", } REGRESSION_MODEL_REQUIRED_ARGS = { "lags": "input_chunk_length", "output_chunk_length": "output_chunk_length", "norm": "norm", } STAT_MODEL_REQUIRED_ARGS = { "norm": "norm", } DEEP_MODEL_ARGS = { "pl_trainer_kwargs": { "enable_progress_bar": False, } } def _get_model_info(model_name: str, required_args: Dict, model_args: Dict) -> Tuple: """ Helper function to retrieve darts model information by name :param model_name: name of the model. :param required_args: arguments that the model requires from the pipeline. :param model_args: specified model arguments. :return: a tuple including model name, model_class, required args and model args. """ model_class = getattr(darts_models, model_name, None) if str(darts.__version__) >= "0.25.0" and isinstance(model_class, type(NotImportedModule)): model_class = None return model_name, model_class, required_args, model_args # deep models implemented by darts DARTS_DEEP_MODELS = [ _get_model_info("TCNModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), _get_model_info( "TFTModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS, ), _get_model_info("TransformerModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), _get_model_info("NHiTSModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), _get_model_info("TiDEModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), _get_model_info("BlockRNNModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), _get_model_info("RNNModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), _get_model_info("DLinearModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), _get_model_info("NBEATSModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), _get_model_info("NLinearModel", DEEP_MODEL_REQUIRED_ARGS, DEEP_MODEL_ARGS), ] # regression models implemented by darts DARTS_REGRESSION_MODELS = [ _get_model_info("RandomForest", REGRESSION_MODEL_REQUIRED_ARGS, {}), _get_model_info("XGBModel", REGRESSION_MODEL_REQUIRED_ARGS, {}), _get_model_info("CatBoostModel", REGRESSION_MODEL_REQUIRED_ARGS, {}), _get_model_info("LightGBMModel", REGRESSION_MODEL_REQUIRED_ARGS, {}), _get_model_info("LinearRegressionModel", REGRESSION_MODEL_REQUIRED_ARGS, {}), _get_model_info("RegressionModel", REGRESSION_MODEL_REQUIRED_ARGS, {}), ] # statistical models implemented by darts, # these models are specially allowed to retrain during inference DARTS_STAT_MODELS = [ _get_model_info("KalmanForecaster", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("ARIMA", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("VARIMA", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("AutoARIMA", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("StatsForecastAutoCES", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("StatsForecastAutoTheta", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("StatsForecastAutoETS", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("ExponentialSmoothing", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("StatsForecastAutoARIMA", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("FFT", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("FourTheta", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("Croston", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("NaiveDrift", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("NaiveMean", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("NaiveSeasonal", STAT_MODEL_REQUIRED_ARGS, {}), _get_model_info("NaiveMovingAverage", STAT_MODEL_REQUIRED_ARGS, {}), ] # Generate model factories for each model class and required parameters in DARTS_DEEP_MODELS # and add them to global variables for _model_name, _model_class, _required_args, _model_args in DARTS_DEEP_MODELS: if _model_class is None: logger.warning( "Model %s is not available, skipping model registration", _model_name ) globals()[_model_name] = None continue globals()[_model_name] = _generate_model_factory( model_class=_model_class, model_args=_model_args, model_name=_model_name, required_args=_required_args, allow_fit_on_eval=False, supports_validation=True, ) # Generate model factories for each model class and required parameters in DARTS_REGRESSION_MODELS # and add them to global variables for _model_name, _model_class, _required_args, _model_args in DARTS_REGRESSION_MODELS: if _model_class is None: logger.warning( "Model %s is not available, skipping model registration", _model_name ) globals()[_model_name] = None continue globals()[_model_name] = _generate_model_factory( model_class=_model_class, model_args=_model_args, model_name=_model_name, required_args=_required_args, allow_fit_on_eval=False, supports_validation=False, ) # Generate model factories for each model class and required parameters in DARTS_STAT_MODELS # and add them to global variables for _model_name, _model_class, _required_args, _model_args in DARTS_STAT_MODELS: if _model_class is None: logger.warning( "Model %s is not available, skipping model registration", _model_name ) globals()[_model_name] = None continue globals()[_model_name] = _generate_model_factory( model_class=_model_class, model_args=_model_args, model_name=_model_class.__name__, required_args=_required_args, allow_fit_on_eval=True, supports_validation=False, ) # Adapters for general darts models
[docs] def darts_deep_model_adapter(model_class: type) -> Dict: """ Adapts a Darts deep model class to OTB protocol :param model_class: a class of deep forecasting model from Darts library. :return: model factory that follows the OTB protocol. """ return _generate_model_factory( model_class, DEEP_MODEL_ARGS, model_class.__name__, DEEP_MODEL_REQUIRED_ARGS, allow_fit_on_eval=False, supports_validation=True, )
[docs] def darts_statistical_model_adapter(model_class: type) -> Dict: """ Adapts a Darts statistical model class to OTB protocol :param model_class: a class of statistical forecasting model from Darts library. :return: model factory that follows the OTB protocol. """ return _generate_model_factory( model_class, {}, model_class.__name__, STAT_MODEL_REQUIRED_ARGS, allow_fit_on_eval=True, supports_validation=False, )
[docs] def darts_regression_model_adapter(model_class: type) -> Dict: """ Adapts a Darts regression model class to OTB protocol :param model_class: a class of regression forecasting model from Darts library. :return: model factory that follows the OTB protocol. """ return _generate_model_factory( model_class, {}, model_class.__name__, REGRESSION_MODEL_REQUIRED_ARGS, allow_fit_on_eval=True, supports_validation=False, )