Source code for ice.fault_diagnosis.models.base
from abc import ABC
import pandas as pd
import torch
from torch import nn
from torch.optim import Adam
from ice.base import BaseModel
from ice.fault_diagnosis.metrics import (
accuracy, correct_daignosis_rate, true_positive_rate, false_positive_rate)
[docs]class BaseFaultDiagnosis(BaseModel, ABC):
"""Base class for all fault diagnosis models."""
def _prepare_for_training(self, input_dim: int, output_dim: int):
weight = torch.ones(output_dim, device=self.device) * 0.5
weight[1:] /= output_dim - 1
self.loss_fn = nn.CrossEntropyLoss(weight=weight)
self.optimizer = Adam(self.model.parameters(), lr=self.lr)
def _predict(self, sample: torch.Tensor) -> torch.Tensor:
sample = sample.to(self.device)
logits = self.model(sample)
return logits.argmax(axis=1).cpu()
def _calculate_metrics(self, pred: torch.tensor, target: torch.tensor) -> dict:
metrics = {
'accuracy': accuracy(pred, target),
'correct_daignosis_rate': correct_daignosis_rate(pred, target),
'true_positive_rate': true_positive_rate(pred, target),
'false_positive_rate': false_positive_rate(pred, target),
}
return metrics
def _set_dims(self, df: pd.DataFrame, target: pd.Series):
self.input_dim = df.shape[1]
self.output_dim = len(set(target))