{ "cells": [ { "cell_type": "markdown", "id": "5d62c97e", "metadata": {}, "source": [ "# Results of fault diagnosis using TCN" ] }, { "cell_type": "code", "execution_count": 9, "id": "c12297e6", "metadata": {}, "outputs": [], "source": [ "from ice.fault_diagnosis.datasets import FaultDiagnosisRiethTEP\n", "from ice.fault_diagnosis.models import TCN\n", "from sklearn.preprocessing import StandardScaler\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "markdown", "id": "b9cb6459", "metadata": {}, "source": [ "Download the dataset." ] }, { "cell_type": "code", "execution_count": 2, "id": "6a71a39f", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ec7af5252301494c85187fe4084fe3da", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Reading data/rieth_tep/df.csv: 0%| | 0/15330000 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4046be2c40d94339a68d4b17003dcc68", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Reading data/rieth_tep/target.csv: 0%| | 0/15330000 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7c68ffa1b402446ebf41665797a05e05", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Reading data/rieth_tep/train_mask.csv: 0%| | 0/15330000 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset = FaultDiagnosisRiethTEP()" ] }, { "cell_type": "markdown", "id": "88388f56", "metadata": {}, "source": [ "Normalize the data." ] }, { "cell_type": "code", "execution_count": 3, "id": "9295a11a", "metadata": {}, "outputs": [], "source": [ "scaler = StandardScaler()\n", "dataset.df[dataset.train_mask] = scaler.fit_transform(dataset.df[dataset.train_mask])\n", "dataset.df[dataset.test_mask] = scaler.transform(dataset.df[dataset.test_mask])" ] }, { "cell_type": "markdown", "id": "dd2fbd9a", "metadata": {}, "source": [ "Create the TCN model." ] }, { "cell_type": "code", "execution_count": 4, "id": "297fb60d", "metadata": {}, "outputs": [], "source": [ "model = TCN(\n", " window_size=60,\n", " batch_size=128,\n", " num_layers=1,\n", " kernel_size=3,\n", " hidden_dim=32,\n", " lr=1e-4,\n", " num_epochs=30,\n", " verbose=True,\n", " device='cpu',\n", " save_checkpoints=True,\n", " val_ratio=0.1,\n", ")" ] }, { "cell_type": "markdown", "id": "468592df", "metadata": {}, "source": [ "Load the checkpoint." ] }, { "cell_type": "code", "execution_count": 5, "id": "8beffc86", "metadata": {}, "outputs": [], "source": [ "model.load_checkpoint('tcn_fault_diagnosis_epoch_30.tar')" ] }, { "cell_type": "markdown", "id": "d4719e4a", "metadata": {}, "source": [ "Evaluate the model on the test data." ] }, { "cell_type": "code", "execution_count": 6, "id": "dd9ad1af", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "164dd9d5bcfc4e8fa1663f328291083c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Creating sequence of samples: 0%| | 0/10500 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Steps ...: 0%| | 0/73829 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "metrics = model.evaluate(\n", " dataset.df[dataset.test_mask],\n", " dataset.target[dataset.test_mask]\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "id": "f08af96f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Fault | \n", "TPR | \n", "FPR | \n", "
---|---|---|---|
0 | \n", "0 | \n", "0.9675 | \n", "0.0000 | \n", "
1 | \n", "1 | \n", "0.9738 | \n", "0.0000 | \n", "
2 | \n", "3 | \n", "0.9643 | \n", "0.0000 | \n", "
3 | \n", "4 | \n", "0.9584 | \n", "0.0000 | \n", "
4 | \n", "5 | \n", "0.9731 | \n", "0.0000 | \n", "
5 | \n", "6 | \n", "0.9679 | \n", "0.0000 | \n", "
6 | \n", "7 | \n", "0.9691 | \n", "0.0000 | \n", "
7 | \n", "9 | \n", "0.9651 | \n", "0.0000 | \n", "
8 | \n", "10 | \n", "0.9788 | \n", "0.0000 | \n", "
9 | \n", "11 | \n", "0.9526 | \n", "0.0000 | \n", "
10 | \n", "12 | \n", "0.9418 | \n", "0.0001 | \n", "
11 | \n", "13 | \n", "0.9780 | \n", "0.0000 | \n", "
12 | \n", "15 | \n", "0.9752 | \n", "0.0000 | \n", "
13 | \n", "16 | \n", "0.9608 | \n", "0.0000 | \n", "
14 | \n", "17 | \n", "0.9357 | \n", "0.0000 | \n", "
15 | \n", "18 | \n", "0.9717 | \n", "0.0000 | \n", "
16 | \n", "19 | \n", "0.9482 | \n", "0.0000 | \n", "