Source code for ts_benchmark.evaluation.strategy.strategy

# -*- coding: utf-8 -*-
import abc
import base64
import inspect
import json
import logging
import pickle
from functools import cached_property, lru_cache
from typing import Any, NoReturn, List, Dict, Optional

import numpy as np
import pandas as pd

from ts_benchmark.evaluation.evaluator import Evaluator
from ts_benchmark.models import ModelFactory


[docs] class ResultCollector: """ Result collector Result collectors helps to gather outputs returned by strategy classes, it is helpful define a custom result collector when the strategy class has irregular returns (e.g. returning multiple records in one evaluation). """ def __init__(self): self.results = []
[docs] def add(self, result: Any) -> NoReturn: """ Adds the output of a strategy to the collection :param result: The return value of a strategies' :meth:`execute` method. """ self.results.append(result)
[docs] def collect(self) -> List: """ Returns the current result collection """ return self.results
[docs] def reset(self) -> NoReturn: """ Resets the current result collection """ self.results = []
[docs] def get_size(self) -> int: """ Gets the number of collected results """ return len(self.results)
[docs] class Strategy(metaclass=abc.ABCMeta): """ The base class of strategies A strategy defines the evaluation pipeline of the specific time-series analysis task. .. warning:: Strategies are currently using pickle to store Python objects in the evaluation results, which is known to be unsafe during decoding. Although reading the evaluation log itself is safe, please DO NOT decode any pickled columns in the log file if the data source is untrusted. """ # The required fields by the current class in the `strategy_config`, subclasses should overwrite # this attribute when there are new required fields, and the required fields in the super # classes need not be included REQUIRED_CONFIGS = ["strategy_name"] # Most strategy configs allow inputting a mapping from data names to config values, this is # a required key in such mapping to set default config value for unspecified data names DEFAULT_CONFIG_KEY = "__default__" def __init__(self, strategy_config: Dict, evaluator: Evaluator): """ Initialize :param strategy_config: The configuration dict of a strategy. All scalar-valued configs accept inputting a mapping from data names to config values, which enables us to use different configs for different data. A "__default__" key in such mappings specifies the default config value for unspecified data names. :param evaluator: An evaluation object that calculates metrics. """ self.strategy_config = strategy_config self.evaluator = evaluator self._check_config()
[docs] @abc.abstractmethod def execute(self, series_name: str, model_factory: ModelFactory) -> Any: """ The primary interface to execute a strategy :param series_name: The name of a series data to evaluate. :param model_factory: A model factory that creates a new model with each invocation. :return: The results generated by evaluating a model on a series. """
[docs] def get_config_str(self, required_configs_only: bool = False) -> str: """ Gets the string representation of the strategy config :param required_configs_only: If True, includes only the keys specified by `REQUIRED_CONFIGS` in the string, otherwise, encode the strategy config as is. :return: A string representation of the strategy config. """ if required_configs_only: return json.dumps( { k: v for k, v in self.strategy_config.items() if k in self.get_required_configs() } ) else: return json.dumps(self.strategy_config, sort_keys=True)
def _check_config(self) -> NoReturn: """ Checks if there are missing configs or unexpected config """ provided_args = set(self.strategy_config) required_args = set(self.get_required_configs()) missing_args = required_args - provided_args extra_args = provided_args - required_args if missing_args: error_message = f"Missing options: {', '.join(sorted(missing_args))} " raise RuntimeError(error_message) if extra_args: error_message = f"Unknown options: {', '.join(sorted(extra_args))} " logging.warning(error_message)
[docs] def get_collector(self) -> ResultCollector: """ Creates a new compatible result collector """ return ResultCollector()
[docs] @classmethod @lru_cache(maxsize=1) def get_required_configs(cls) -> List[str]: """ Gets the required configs from the current class and all super classes """ ret = [] for super_cls in inspect.getmro(cls): if hasattr(super_cls, "REQUIRED_CONFIGS"): ret.extend(super_cls.REQUIRED_CONFIGS) return sorted(set(ret))
[docs] @staticmethod @abc.abstractmethod def accepted_metrics() -> List[str]: """ Gets the accepted metrics by this strategy """
@property @abc.abstractmethod def field_names(self) -> List[str]: """ Gets the field names of the result records """ @cached_property def _field_name_to_idx(self) -> Dict: """ A helper method that returns a mapping from result field names to index """ return {k: i for i, k in enumerate(self.field_names)}
[docs] def get_default_result(self, **kwargs) -> List: """ Gets the default result when the strategy fails to execute :param kwargs: Each key-value pair updates the `key` field with value `value` of the return value. """ ret = self.evaluator.default_result() ret += [np.nan] * (len(self.field_names) - len(ret)) for k, v in kwargs.items(): if k not in self._field_name_to_idx: raise ValueError(f"Unknown field name {k}") ret[self._field_name_to_idx[k]] = v return ret
def _encode_data(self, data: Any) -> str: """ Encodes Python objects in the results to a string with base64 coding So that the objects are properly stored in text files such as csv files. :param data: Any python object to be saved. :return: A string that encodes the data. """ encoded = pickle.dumps(data) encoded = base64.b64encode(encoded).decode("utf-8") return encoded def _get_scalar_config_value( self, config_name: str, series_name: Optional[str] ) -> Any: """ A helper method that retrieves a scalar config value for target series This method handles special input values such as a mapping from data names to config values. Subclasses are recommended to get config values using this method as long as the config is scalar-valued. :param config_name: The name of the config to retrieve. :param series_name: The name of the series. If None, the default config values is returned. :return: A scalar config value for the specified series name. """ if config_name not in self.strategy_config: raise ValueError(f"Missing config {config_name}.") config_value = self.strategy_config[config_name] if isinstance(config_value, dict): if ( series_name not in config_value and self.DEFAULT_CONFIG_KEY not in config_value ): raise ValueError( f"Config {config_name} for series {series_name} is missing, " f"please add {config_name} or a {self.DEFAULT_CONFIG_KEY} key to " "the configuration dict" ) return config_value.get(series_name, config_value[self.DEFAULT_CONFIG_KEY]) else: return config_value def _get_meta_info( self, meta_info: Optional[pd.Series], field: str, default: Any ) -> Any: """ A helper method to get fields from the meta information This method returns the default value when the meta-info is missing, and it raises an exception when the meta-info exists but the specified key is missing, :param meta_info: Meta-information returned by the data pool. :param field: The field to get. :param default: The default value to return if the meta-information is not available. :return: """ return meta_info[field].item() if meta_info is not None else default