# -*- coding: utf-8 -*-
from typing import Tuple
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from ts_benchmark.baselines.time_series_library.utils.timefeatures import (
time_features,
)
from ts_benchmark.utils.data_processing import split_before
[docs]
class SlidingWindowDataLoader:
"""
SlidingWindDataLoader class.
This class encapsulates a sliding window data loader for generating time series training samples.
"""
def __init__(
self,
dataset: pd.DataFrame,
batch_size: int = 1,
history_length: int = 10,
prediction_length: int = 2,
shuffle: bool = True,
):
"""
Initialize SlidingWindDataLoader.
:param dataset: Pandas DataFrame containing time series data.
:param batch_size: Batch size.
:param history_length: The length of historical data.
:param prediction_length: The length of the predicted data.
:param shuffle: Whether to shuffle the dataset.
"""
self.dataset = dataset
self.batch_size = batch_size
self.history_length = history_length
self.prediction_length = prediction_length
self.shuffle = shuffle
self.current_index = 0
def __len__(self) -> int:
"""
Returns the length of the data loader.
:return: The length of the data loader.
"""
return len(self.dataset) - self.history_length - self.prediction_length + 1
def __iter__(self) -> "SlidingWindowDataLoader":
"""
Create an iterator and return.
:return: Data loader iterator.
"""
if self.shuffle:
self._shuffle_dataset()
self.current_index = 0
return self
def __next__(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate data for the next batch.
:return: A tuple containing input data and target data.
"""
if self.current_index >= len(self):
raise StopIteration
batch_inputs = []
batch_targets = []
for _ in range(self.batch_size):
window_data = self.dataset.iloc[
self.current_index : self.current_index
+ self.history_length
+ self.prediction_length,
:,
]
if len(window_data) < self.history_length + self.prediction_length:
raise StopIteration # Stop iteration when the dataset is less than one window size and prediction step size
inputs = window_data.iloc[: self.history_length].values
targets = window_data.iloc[
self.history_length : self.history_length + self.prediction_length
].values
batch_inputs.append(inputs)
batch_targets.append(targets)
self.current_index += 1
# Convert NumPy array to PyTorch tensor
batch_inputs = torch.tensor(batch_inputs, dtype=torch.float32)
batch_targets = torch.tensor(batch_targets, dtype=torch.float32)
return batch_inputs, batch_targets
def _shuffle_dataset(self):
"""
Shuffle the dataset.
"""
self.dataset = self.dataset.sample(frac=1).reset_index(drop=True)
[docs]
def train_val_split(train_data, ratio, seq_len):
"""
Splits the dataset into training and validation sets based on the given ratio.
:param train_data: The full dataset to be split.
:param ratio: The ratio of the training data. If 1, no split is performed.
:param seq_len: The sequence length to adjust the split, ensuring sequences remain complete.
:return: A tuple containing the training data and validation data. If the ratio is 1, the second value is None.
"""
if ratio == 1:
return train_data, None
elif seq_len is not None:
border = int((train_data.shape[0]) * ratio)
train_data_value, valid_data_rest = split_before(train_data, border)
train_data_rest, valid_data = split_before(train_data, border - seq_len)
return train_data_value, valid_data
else:
border = int((train_data.shape[0]) * ratio)
train_data_value, valid_data_rest = split_before(train_data, border)
return train_data_value, valid_data_rest
[docs]
def decompose_time(
time: np.ndarray,
freq: str,
) -> np.ndarray:
"""
Split the given array of timestamps into components based on the frequency.
:param time: Array of timestamps.
:param freq: The frequency of the time stamp.
:return: Array of timestamp components.
"""
df_stamp = pd.DataFrame(pd.to_datetime(time), columns=["date"])
freq_scores = {
"m": 0,
"w": 1,
"b": 2,
"d": 2,
"h": 3,
"t": 4,
"s": 5,
}
max_score = max(freq_scores.values())
df_stamp["month"] = df_stamp.date.dt.month
if freq_scores.get(freq, max_score) >= 1:
df_stamp["day"] = df_stamp.date.dt.day
if freq_scores.get(freq, max_score) >= 2:
df_stamp["weekday"] = df_stamp.date.dt.weekday
if freq_scores.get(freq, max_score) >= 3:
df_stamp["hour"] = df_stamp.date.dt.hour
if freq_scores.get(freq, max_score) >= 4:
df_stamp["minute"] = df_stamp.date.dt.minute
if freq_scores.get(freq, max_score) >= 5:
df_stamp["second"] = df_stamp.date.dt.second
return df_stamp.drop(["date"], axis=1).values
[docs]
def get_time_mark(
time_stamp: np.ndarray,
timeenc: int,
freq: str,
) -> np.ndarray:
"""
Extract temporal features from the time stamp.
:param time_stamp: The time stamp ndarray.
:param timeenc: The time encoding type.
:param freq: The frequency of the time stamp.
:return: The mark of the time stamp.
"""
if timeenc == 0:
origin_size = time_stamp.shape
data_stamp = decompose_time(time_stamp.flatten(), freq)
data_stamp = data_stamp.reshape(origin_size + (-1,))
elif timeenc == 1:
origin_size = time_stamp.shape
data_stamp = time_features(pd.to_datetime(time_stamp.flatten()), freq=freq)
data_stamp = data_stamp.transpose(1, 0)
data_stamp = data_stamp.reshape(origin_size + (-1,))
else:
raise ValueError("Unknown time encoding {}".format(timeenc))
return data_stamp.astype(np.float32)
[docs]
def forecasting_data_provider(data, config, timeenc, batch_size, shuffle, drop_last):
"""
Provides the dataset and data loader for forecasting tasks.
:param data: The dataset to be used for forecasting.
:param config: Configuration object containing sequence length, prediction length, label length, frequency, and other parameters.
:param timeenc: Indicator for whether to use time encoding.
:param batch_size: The size of batches for the data loader.
:param shuffle: Whether to shuffle the data before loading.
:param drop_last: Whether to drop the last incomplete batch if the dataset size is not divisible by the batch size.
:return: A tuple containing the dataset and data loader.
"""
dataset = DatasetForTransformer(
dataset=data,
history_len=config.seq_len,
prediction_len=config.pred_len,
label_len=config.label_len,
timeenc=timeenc,
freq=config.freq,
)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=config.num_workers,
drop_last=drop_last,
)
return dataset, data_loader