Source code for ice.health_index_estimation.utils

from tqdm.auto import tqdm
import pandas as pd
from torch.utils.data import Dataset
import numpy as np


[docs]class SlidingWindowDataset(Dataset): def __init__( self, df: pd.DataFrame, target: pd.Series, window_size: int, stride: int = 1 ): self.df = df self.target = target self.window_size = window_size window_end_indices = [] run_ids = df.index.get_level_values(0).unique() for run_id in tqdm(run_ids, desc="Creating sequence of samples"): indices = np.array(df.index.get_locs([run_id])) indices = indices[self.window_size :: stride] window_end_indices.extend(indices) self.window_end_indices = np.array(window_end_indices) def __len__(self): return len(self.window_end_indices) def __getitem__(self, idx): window_index = self.window_end_indices[idx] sample = self.df.values[window_index - self.window_size : window_index] target = self.target.values[window_index] return sample.astype(np.float32), target