Source code for ice.health_index_estimation.models.TCM

from torch import nn
import torch
from ice.health_index_estimation.models.base import BaseHealthIndexEstimation
from pandas import DataFrame, Series
from torch.nn import functional as F


class LocalFeatureExtraction(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(LocalFeatureExtraction, self).__init__()
        self.conv1 = nn.Conv1d(
            in_channels, out_channels, kernel_size, padding=kernel_size // 2
        )
        self.pool = nn.MaxPool1d(1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        return x


class AttentionMechanism(nn.Module):
    def __init__(self, feature_dim, timestep_dim):
        super(AttentionMechanism, self).__init__()
        self.feature_fc = nn.Linear(feature_dim, feature_dim)
        self.timestep_fc = nn.Linear(timestep_dim, timestep_dim)

    def forward(self, x):
        # Compute feature weights
        feature_weights = F.softmax(self.feature_fc(x), dim=-1)
        # Compute timestep weights
        timestep_weights = F.softmax(
            self.timestep_fc(x.transpose(1, 2)), dim=-1
        ).transpose(1, 2)

        # Apply attention
        feature_attention = feature_weights * x
        timestep_attention = timestep_weights * x
        return feature_attention * timestep_attention


class ParallelCNNBiLSTM(nn.Module):
    def __init__(
        self,
        input_dim,
        cnn_out_channels,
        cnn_kernel_size,
        lstm_hidden_size,
        lstm_num_layers,
        num_heads,
        sequence_length,
    ):
        super(ParallelCNNBiLSTM, self).__init__()
        self.attention = AttentionMechanism(input_dim, sequence_length)
        self.cnn1 = LocalFeatureExtraction(input_dim, cnn_out_channels, cnn_kernel_size)
        self.cnn2 = LocalFeatureExtraction(input_dim, cnn_out_channels, cnn_kernel_size)
        self.cnn3 = LocalFeatureExtraction(
            cnn_out_channels, cnn_out_channels, cnn_kernel_size
        )
        self.bilstm = nn.LSTM(
            cnn_out_channels * 2,
            lstm_hidden_size,
            lstm_num_layers,
            batch_first=True,
            bidirectional=True,
        )
        self.adaptive_pool = nn.AdaptiveAvgPool1d(
            sequence_length
        )  # Adaptive pooling to match dimensions
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = self.attention(x)  # Apply attention mechanism to x
        cnn1_out = self.cnn1(x.transpose(1, 2)).transpose(
            1, 2
        )  # Apply CNN and keep the original sequence length
        cnn2_out = self.cnn2(x.transpose(1, 2)).transpose(1, 2)

        cnn1_out = self.dropout(cnn1_out)
        cnn2_out = self.dropout(cnn2_out)

        cnn2_out = self.cnn3(cnn2_out.transpose(1, 2)).transpose(1, 2)

        # concatenate and pass to lstm
        combined_out = torch.cat((cnn1_out, cnn2_out), dim=-1)

        lstm_out, _ = self.bilstm(combined_out)

        combined_out = self.dropout(combined_out)
        return combined_out


class ToolWearMonitoringModel(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        window_size,
        cnn_out_channels=8,
        cnn_kernel_size=3,
        lstm_hidden_size=8,
        lstm_num_layers=2,
        num_heads=2,
    ):
        super(ToolWearMonitoringModel, self).__init__()
        self.parallel_cnn_bilstm = ParallelCNNBiLSTM(
            input_dim,
            cnn_out_channels,
            cnn_kernel_size,
            lstm_hidden_size,
            lstm_num_layers,
            num_heads,
            window_size,
        )
        self.linear = nn.Linear(
            lstm_hidden_size * window_size * 2, output_dim
        )  # Flatten before the linear layer

    def forward(self, x):
        x = self.parallel_cnn_bilstm(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.linear(x)
        return torch.squeeze(x, dim=1)


[docs]class TCM(BaseHealthIndexEstimation): """ Wear condition monitoring module from https://doi.org/10.1016/j.jmsy.2021.12.002 """ def __init__( self, window_size: int = 64, stride: int = 300, batch_size: int = 64, lr: float = 5e-5, num_epochs: int = 50, device: str = "cpu", verbose: bool = True, name: str = "TCM_hi_estimation", random_seed: int = 42, val_ratio: float = 0.15, save_checkpoints: bool = False, lstm_hidden_size: int = 8, lstm_num_layers: int = 2, ): """ Args: window_size (int): The window size to train the model. stride (int): The time interval between first points of consecutive sliding windows. batch_size (int): The batch size to train the model. lr (float): The larning rate to train the model. num_epochs (float): The number of epochs to train the model. device (str): The name of a device to train the model. `cpu` and `cuda` are possible. verbose (bool): If true, show the progress bar in training. name (str): The name of the model for artifact storing. random_seed (int): Seed for random number generation to ensure reproducible results. val_ratio (float): Proportion of the dataset used for validation, between 0 and 1. save_checkpoints (bool): If true, store checkpoints. """ super().__init__( window_size, stride, batch_size, lr, num_epochs, device, verbose, name, random_seed, val_ratio, save_checkpoints, ) self.val_metrics = True self.window_size = window_size self.lstm_hidden_size = lstm_hidden_size self.lstm_num_layers = lstm_num_layers _param_conf_map = dict( BaseHealthIndexEstimation._param_conf_map, **{"lstm_hidden_size": ["MODEL", "HIDDEN_DIM"]}, **{"lstm_num_layers": ["MODEL", "NUM_LAYERS"]}, ) def _create_model(self, input_dim: int, output_dim: int): self.model = nn.Sequential( ToolWearMonitoringModel( input_dim=input_dim, output_dim=output_dim, window_size=self.window_size, lstm_hidden_size=self.lstm_hidden_size, lstm_num_layers=self.lstm_num_layers ) )