Source code for ice.fault_diagnosis.utils
from tqdm.auto import tqdm
import pandas as pd
from torch.utils.data import Dataset
import numpy as np
[docs]class SlidingWindowDataset(Dataset):
def __init__(self, df: pd.DataFrame, target: pd.Series, window_size: int):
self.df = df
self.target = target
self.window_size = window_size
window_end_indices = []
run_ids = df.index.get_level_values(0).unique()
for run_id in tqdm(run_ids, desc='Creating sequence of samples'):
indices = np.array(df.index.get_locs([run_id]))
indices = indices[self.window_size:]
window_end_indices.extend(indices)
self.window_end_indices = np.array(window_end_indices)
def __len__(self):
return len(self.window_end_indices)
def __getitem__(self, idx):
window_index = self.window_end_indices[idx]
sample = self.df.values[window_index - self.window_size:window_index]
target = self.target.values[window_index]
return sample.astype(np.float32), target