Source code for ts_benchmark.pipeline

# -*- coding: utf-8 -*-
from dataclasses import dataclass
from functools import reduce
from operator import and_
from typing import List, Dict, Type, Optional

import pandas as pd

from ts_benchmark.data.data_source import (
    LocalForecastingDataSource,
    DataSource,
)
from ts_benchmark.data.suites.global_storage import GlobalStorageDataServer
from ts_benchmark.evaluation.evaluate_model import eval_model
from ts_benchmark.models import get_models
from ts_benchmark.recording import save_log
from ts_benchmark.utils.parallel import ParallelBackend


[docs] @dataclass class DatasetInfo: # the possible values of the meta-info field 'size' size_value: List # the class of data source for this dataset datasrc_class: Type[DataSource]
PREDEFINED_DATASETS = { "large_forecast": DatasetInfo( size_value=["large", "small"], datasrc_class=LocalForecastingDataSource, ), "small_forecast": DatasetInfo( size_value=["small"], datasrc_class=LocalForecastingDataSource ), "user_forecast": DatasetInfo( size_value=["user"], datasrc_class=LocalForecastingDataSource ), }
[docs] def filter_data( metadata: pd.DataFrame, size_value: List[str], feature_dict: Optional[Dict] = None ) -> List[str]: """ Filters the dataset based on given filters :param metadata: The meta information DataFrame. :param size_value: The allowed values of the 'size' meta-info field. :param feature_dict: A dictionary of filters where each key is a meta-info field and the corresponding value is the field value to keep. If None is given, no extra filter is applied. :return: A list of file names that meet the filter criteria. """ # Remove items with a value of None in feature_dict feature_dict = {k: v for k, v in feature_dict.items() if v is not None} # Use the reduce and and_ functions to filter data file names that meet the criteria filt_metadata = metadata if feature_dict is not None: filt_metadata = metadata[ reduce(and_, (metadata[k] == v for k, v in feature_dict.items())) ] filt_metadata = filt_metadata[filt_metadata["size"].isin(size_value)] return filt_metadata["file_name"].tolist()
def _get_model_names(model_names: List[str]): """ Rename models if there exists duplications. If a model A appears multiple times in the list, each appearance will be renamed to `A`, `A_1`, `A_2`, ... :param model_names: A list of model names. :return: The renamed list of model names. """ s = pd.Series(model_names) cumulative_counts = s.groupby(s).cumcount() return [ f"{model_name}_{cnt}" if cnt > 0 else model_name for model_name, cnt in zip(model_names, cumulative_counts) ]
[docs] def pipeline( data_config: dict, model_config: dict, evaluation_config: dict, save_path: str, ) -> List[str]: """ Execute the benchmark pipeline process The pipline includes loading data, building models, evaluating models, and generating reports. :param data_config: Configuration for data loading. :param model_config: Configuration for model construction. :param evaluation_config: Configuration for model evaluation. :param save_path: The relative path for saving evaluation results, relative to the result folder. :return: A list of log file names where evaluation results are saved. """ # prepare data # TODO: move these code into the data module, after the pipeline interface is unified dataset_name_list = data_config.get("data_set_name", ["small_forecast"]) if not dataset_name_list: dataset_name_list = ["small_forecast"] if isinstance(dataset_name_list, str): dataset_name_list = [dataset_name_list] for dataset_name in dataset_name_list: if dataset_name not in PREDEFINED_DATASETS: raise ValueError(f"Unknown dataset {dataset_name}.") data_src_type = PREDEFINED_DATASETS[dataset_name_list[0]].datasrc_class if not all( PREDEFINED_DATASETS[dataset_name].datasrc_class is data_src_type for dataset_name in dataset_name_list ): raise ValueError("Not supporting different types of data sources.") data_src: DataSource = PREDEFINED_DATASETS[dataset_name_list[0]].datasrc_class() data_name_list = data_config.get("data_name_list", None) if not data_name_list: data_name_list = [] for dataset_name in dataset_name_list: size_value = PREDEFINED_DATASETS[dataset_name].size_value feature_dict = data_config.get("feature_dict", None) data_name_list.extend( filter_data( data_src.dataset.metadata, size_value, feature_dict=feature_dict ) ) data_name_list = list(set(data_name_list)) if not data_name_list: raise ValueError("No dataset specified.") data_src.load_series_list(data_name_list) data_server = GlobalStorageDataServer(data_src, ParallelBackend()) data_server.start_async() # modeling model_factory_list = get_models(model_config) result_list = [ eval_model(model_factory, data_name_list, evaluation_config) for model_factory in model_factory_list ] model_save_names = [ it.split(".")[-1] for it in _get_model_names( [model_factory.model_name for model_factory in model_factory_list] ) ] log_file_names = [] for model_factory, result_itr, model_save_name in zip( model_factory_list, result_list, model_save_names ): for i, result_df in enumerate(result_itr.collect()): log_file_names.append( save_log( result_df, save_path, model_save_name if i == 0 else f"{model_save_name}-{i}", ) ) return log_file_names