# -*- coding: utf-8 -*-
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Dict, NoReturn, List
import pandas as pd
from ts_benchmark.common.constant import FORECASTING_DATASET_PATH
from ts_benchmark.data.dataset import Dataset
from ts_benchmark.data.utils import load_series_info, read_data
logger = logging.getLogger(__name__)
[docs]
class DataSource:
"""
A class that manages and reads from data sources
A data source is responsible for loading data into the internal dataset object,
as well as detecting and updating data in the source storage.
"""
# The class for the internal dataset object
DATASET_CLASS = Dataset
def __init__(
self,
data_dict: Optional[Dict[str, pd.DataFrame]] = None,
metadata: Optional[pd.DataFrame] = None,
):
"""
initializer
:param data_dict: A dictionary of time series, where the keys are the names and
the values are DataFrames following the OTB protocol.
:param metadata: A DataFrame where the index contains series names and columns
contains meta-info fields.
"""
self._dataset = self.DATASET_CLASS()
self._dataset.set_data(data_dict, metadata)
@property
def dataset(self) -> Dataset:
"""
Returns the internally maintained dataset object
This dataset is where the DataSource loads data into.
"""
return self._dataset
[docs]
def load_series_list(self, series_list: List[str]) -> NoReturn:
"""
Loads a list of time series from the source
The series data and (optionally) meta information are loaded into the internal dataset.
:param series_list: The list of series names.
"""
raise NotImplementedError(f"{self.__class__.__name__} does not support loading series at runtime.")
[docs]
class LocalDataSource(DataSource):
"""
The data source that manages data files in a local directory
"""
#: index column name of the metadata
_INDEX_COL = "file_name"
def __init__(self, local_data_path: str, metadata_file_name: str):
"""
initializer
Only the metadata is loaded during initialization, while all series data are
loaded on demand.
:param local_data_path: the directory that contains csv data files and metadata.
:param metadata_file_name: name of the metadata file.
"""
self.local_data_path = local_data_path
self.metadata_path = os.path.join(local_data_path, metadata_file_name)
metadata = self.update_meta_index()
super().__init__({}, metadata)
[docs]
def load_series_list(self, series_list: List[str]) -> NoReturn:
logger.info("Start loading %s series in parallel", len(series_list))
data_dict = {}
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(self._load_series, series_name)
for series_name in series_list
]
for future, series_name in zip(futures, series_list):
data_dict[series_name] = future.result()
logger.info("Data loading finished.")
self.dataset.update_data(data_dict)
def _load_metadata(self) -> pd.DataFrame:
"""
Loads metadata from a local csv file
"""
metadata = pd.read_csv(self.metadata_path)
metadata.set_index(self._INDEX_COL, drop=False, inplace=True)
return metadata
def _load_series(self, series_name: str) -> pd.DataFrame:
"""
Loads a time series from a single data file
:param series_name: Series name.
:return: A time series in DataFrame format.
"""
datafile_path = os.path.join(self.local_data_path, series_name)
data = read_data(datafile_path)
return data
[docs]
class LocalForecastingDataSource(LocalDataSource):
"""
The local data source of the forecasting task
"""
def __init__(self):
super().__init__(
FORECASTING_DATASET_PATH,
"FORECAST_META.csv"
)