ts_benchmark.models package

ts_benchmark.models.model_base module

Functions

annotate(**kwargs)

Decorate a function to add or update its annotations.

Classes

BatchMaker()

The standard interface of batch maker.

ModelBase()

The standard interface of benchmark-compatible models.

class BatchMaker[source]

Bases: object

The standard interface of batch maker.

abstract make_batch(batch_size: int, win_size: int) dict[source]

Provide a batch of data to be used for batch prediction.

Parameters:
  • batch_size – The length of one batch.

  • win_size – The length of data for one prediction.

Returns:

A batch of data for prediction.

class ModelBase[source]

Bases: object

The standard interface of benchmark-compatible models.

Users are recommended to inherit this class to implement or adapt their own models.

batch_forecast(horizon: int, batch_maker: BatchMaker, **kwargs) numpy.ndarray[source]

Perform batch forecasting with the model.

Parameters:
  • horizon – The length of each prediction.

  • batch_maker – Make batch data used for prediction.

Returns:

The prediction result.

abstract forecast(horizon: int, series: pandas.DataFrame, **kwargs) numpy.ndarray[source]

Forecasting with the model

TODO: support returning DataFrames

Parameters:
  • horizon – Forecast length.

  • series – Time series data to make inferences on.

Returns:

Forecast result.

abstract forecast_fit(train_data: pandas.DataFrame, *, train_ratio_in_tv: float = 1.0, **kwargs) ModelBase[source]

Fit a model on time series data

Parameters:
  • train_data – Time series data.

  • train_ratio_in_tv – Represents the splitting ratio of the training set validation set. If it is equal to 1, it means that the validation set is not partitioned.

Returns:

The fitted model object.

abstract property model_name

Returns the name of the model.

annotate(**kwargs)[source]

Decorate a function to add or update its annotations.

Parameters:

kwargs – Keyword arguments representing the annotations to be added or updated.

Returns:

A wrapper function that updates the annotations of the original function.

ts_benchmark.models.model_loader module

Functions

get_model_hyper_params(...)

Obtain the hyperparameters of the model.

get_model_info(model_config)

Obtain model information based on model configuration.

get_models(all_model_config)

Obtain a list of ModelFactory objects based on model configuration.

import_model_info(model_path)

Import model information.

Classes

ModelFactory(model_name, model_factory, ...)

Model factory, the standard type to instantiate models in the pipeline.

class ModelFactory(model_name: str, model_factory: Callable, model_hyper_params: dict)[source]

Bases: object

Model factory, the standard type to instantiate models in the pipeline.

get_model_hyper_params(recommend_model_hyper_params: Dict, required_hyper_params: Dict, model_config: Dict) Dict[source]

Obtain the hyperparameters of the model.

The hyperparameter dictionary is constructed following these steps:

  • Fill in the recommended hyperparameters;

  • Update the hyperparameters with those specified in the model_config;

Parameters:
  • recommend_model_hyper_params – A dictionary of hyperparameters recommended by the benchmark.

  • required_hyper_params – A dictionary of hyperparameters to be filled by the benchmark, in format {model_param_name: std_param_name}. Please refer to import_model_info() for details about this argument.

  • model_config

    Model configuration, the supported fields are:

    • model_hyper_params: dictionary, optional; This dictionary specifies the hyperparameters used in the corresponding model;

Returns:

The constructed model hyperparameter dictionary.

Raises:

ValueError – If there are unfilled hyperparameters.

get_model_info(model_config: Dict) Dict | Callable[source]

Obtain model information based on model configuration.

Parameters:

model_config

A dictionary that contains model configuration information. The supported fields are:

  • model_name: str. The path to the model information, the following paths are searched in order to find the model information:

    • {model_name[7:]} if model_name.startswith(“global.”)

    • ts_benchmark.baselines.{model_name}

    • {model_name}

  • adapter: str, optional. The adapter name to wrap the found model information. Must be one of the adapters defined in ts_benchmark.baselines.__init__;

Returns:

The model information corresponding to the config.

Raises:
  • ImportError – If the specified model package cannot be imported.

  • AttributeError – If the specified model_name cannot be found in the imported module.

get_models(all_model_config: Dict) List[ModelFactory][source]

Obtain a list of ModelFactory objects based on model configuration.

Parameters:

all_model_config

A dictionary that contains all model configuration information, supported fields are:

  • models: list. A list of model information, where each item is a dictionary. The supported fields in each dictionary are:

    • model_name: str. The path to the model information. Please refer to get_model_info() for the details about the model searching strategy;

    • adapter: str, optional. The adapter name to wrap the found model information. Must be one of the adapters defined in ts_benchmark.baselines.__init__;

  • recommend_model_hyper_params: dictionary, optional; A dictionary of globally recommended hyperparameters that the benchmark supplies to all models;

Returns:

List of model factories used to instantiate different models.

import_model_info(model_path: str) Dict | Callable[source]

Import model information.

We first clarify some concepts before defining model information:

  • required hyperparameters: This is a specially designed mechanism to enable models to relinquish the settings of some hyperparameters to the benchmark. For example, if a model cannot automatically decide the best input window size (corresponding hyperparameter input_window_size), it can leave the decision to the benchmark, so that the benchmark is able to use a globally recommended setting (corresponding hyperparameter input_chunk_length) to produce a fair comparison between different models;. In this example, to enable this mechanism properly, the model is required to provide a required_hyper_params field in dictionary {“input_window_size”: “input_chunk_length”}.

Model information should be either:

  • A dictionary containing these fields:

    • model_factory: Callable. A callable that accepts hyperparameters as kwargs;

    • model_hyper_params: Dictionary, optional; A dictionary containing hyperparameters for the model. These hyperparameters overwrite the ones specified by recommended hyperparameters;

    • required_hyper_params: Dictionary, optional; A dictionary of hyperparameters to be filled by the benchmark, in format {model_param_name: std_param_name}.

    • model_name: str, optional; The name of the model that is recorded in the output logs.

  • A callable that returns an instance compatible with ModelBase interface when called with hyperparameters as keyword arguments. This callable may optionally support the following features:

    • attribute required_hyper_params: Dictionary, optional; A dictionary of hyperparameters to be filled by the benchmark, in format {model_param_name: std_param_name}.

Parameters:

model_path – The fully qualified path to the model information.

Returns:

The imported model information.