ts_benchmark.models package
ts_benchmark.models.model_base module
Functions
|
Decorate a function to add or update its annotations. |
Classes
The standard interface of batch maker. |
|
The standard interface of benchmark-compatible models. |
- class ModelBase[source]
Bases:
objectThe standard interface of benchmark-compatible models.
Users are recommended to inherit this class to implement or adapt their own models.
- batch_forecast(horizon: int, batch_maker: BatchMaker, **kwargs) numpy.ndarray[source]
Perform batch forecasting with the model.
- Parameters:
horizon – The length of each prediction.
batch_maker – Make batch data used for prediction.
- Returns:
The prediction result.
- abstract forecast(horizon: int, series: pandas.DataFrame, **kwargs) numpy.ndarray[source]
Forecasting with the model
TODO: support returning DataFrames
- Parameters:
horizon – Forecast length.
series – Time series data to make inferences on.
- Returns:
Forecast result.
- abstract forecast_fit(train_data: pandas.DataFrame, *, train_ratio_in_tv: float = 1.0, **kwargs) ModelBase[source]
Fit a model on time series data
- Parameters:
train_data – Time series data.
train_ratio_in_tv – Represents the splitting ratio of the training set validation set. If it is equal to 1, it means that the validation set is not partitioned.
- Returns:
The fitted model object.
- abstract property model_name
Returns the name of the model.
ts_benchmark.models.model_loader module
Functions
Obtain the hyperparameters of the model. |
|
|
Obtain model information based on model configuration. |
|
Obtain a list of ModelFactory objects based on model configuration. |
|
Import model information. |
Classes
|
Model factory, the standard type to instantiate models in the pipeline. |
- class ModelFactory(model_name: str, model_factory: Callable, model_hyper_params: dict)[source]
Bases:
objectModel factory, the standard type to instantiate models in the pipeline.
- get_model_hyper_params(recommend_model_hyper_params: Dict, required_hyper_params: Dict, model_config: Dict) Dict[source]
Obtain the hyperparameters of the model.
The hyperparameter dictionary is constructed following these steps:
Fill in the recommended hyperparameters;
Update the hyperparameters with those specified in the model_config;
- Parameters:
recommend_model_hyper_params – A dictionary of hyperparameters recommended by the benchmark.
required_hyper_params – A dictionary of hyperparameters to be filled by the benchmark, in format {model_param_name: std_param_name}. Please refer to
import_model_info()for details about this argument.model_config –
Model configuration, the supported fields are:
model_hyper_params: dictionary, optional; This dictionary specifies the hyperparameters used in the corresponding model;
- Returns:
The constructed model hyperparameter dictionary.
- Raises:
ValueError – If there are unfilled hyperparameters.
- get_model_info(model_config: Dict) Dict | Callable[source]
Obtain model information based on model configuration.
- Parameters:
model_config –
A dictionary that contains model configuration information. The supported fields are:
model_name: str. The path to the model information, the following paths are searched in order to find the model information:
{model_name[7:]} if model_name.startswith(“global.”)
ts_benchmark.baselines.{model_name}
{model_name}
adapter: str, optional. The adapter name to wrap the found model information. Must be one of the adapters defined in
ts_benchmark.baselines.__init__;
- Returns:
The model information corresponding to the config.
- Raises:
ImportError – If the specified model package cannot be imported.
AttributeError – If the specified model_name cannot be found in the imported module.
- get_models(all_model_config: Dict) List[ModelFactory][source]
Obtain a list of ModelFactory objects based on model configuration.
- Parameters:
all_model_config –
A dictionary that contains all model configuration information, supported fields are:
models: list. A list of model information, where each item is a dictionary. The supported fields in each dictionary are:
model_name: str. The path to the model information. Please refer to
get_model_info()for the details about the model searching strategy;adapter: str, optional. The adapter name to wrap the found model information. Must be one of the adapters defined in
ts_benchmark.baselines.__init__;
recommend_model_hyper_params: dictionary, optional; A dictionary of globally recommended hyperparameters that the benchmark supplies to all models;
- Returns:
List of model factories used to instantiate different models.
- import_model_info(model_path: str) Dict | Callable[source]
Import model information.
We first clarify some concepts before defining model information:
required hyperparameters: This is a specially designed mechanism to enable models to relinquish the settings of some hyperparameters to the benchmark. For example, if a model cannot automatically decide the best input window size (corresponding hyperparameter input_window_size), it can leave the decision to the benchmark, so that the benchmark is able to use a globally recommended setting (corresponding hyperparameter input_chunk_length) to produce a fair comparison between different models;. In this example, to enable this mechanism properly, the model is required to provide a required_hyper_params field in dictionary {“input_window_size”: “input_chunk_length”}.
Model information should be either:
A dictionary containing these fields:
model_factory: Callable. A callable that accepts hyperparameters as kwargs;
model_hyper_params: Dictionary, optional; A dictionary containing hyperparameters for the model. These hyperparameters overwrite the ones specified by recommended hyperparameters;
required_hyper_params: Dictionary, optional; A dictionary of hyperparameters to be filled by the benchmark, in format {model_param_name: std_param_name}.
model_name: str, optional; The name of the model that is recorded in the output logs.
A callable that returns an instance compatible with
ModelBaseinterface when called with hyperparameters as keyword arguments. This callable may optionally support the following features:attribute required_hyper_params: Dictionary, optional; A dictionary of hyperparameters to be filled by the benchmark, in format {model_param_name: std_param_name}.
- Parameters:
model_path – The fully qualified path to the model information.
- Returns:
The imported model information.