Source code for ice.health_index_estimation.models.base
from ice.base import BaseModel
from abc import ABC
import pandas as pd
from torch.optim import AdamW
import torch
from torch import nn
from ice.health_index_estimation.metrics import mse, rmse
[docs]class BaseHealthIndexEstimation (BaseModel, ABC):
"""Base class for all HI diagnosis models."""
def _prepare_for_training(self, input_dim: int, output_dim: int):
self.loss_fn = nn.L1Loss()
self.optimizer = AdamW(self.model.parameters(), lr=self.lr)
def _predict(self, sample: torch.Tensor) -> torch.Tensor:
sample = sample.to(self.device)
predicted_rul = self.model(sample)
return predicted_rul.cpu()
def _calculate_metrics(self, pred: torch.tensor, target: torch.tensor) -> dict:
metrics = {
"mse": mse(pred, target),
"rmse": rmse(pred, target),
}
return metrics
def _set_dims(self, df: pd.DataFrame, target: pd.Series):
self.input_dim = df.shape[1]
self.output_dim = 1