{ "cells": [ { "cell_type": "markdown", "id": "b3ceb4f1", "metadata": {}, "source": [ "# Tutorial on fault diagnosis task" ] }, { "cell_type": "code", "execution_count": 1, "id": "741dc760", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/vitalijpozdnakov/miniconda3/envs/ice/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import numpy as np\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from ice.fault_diagnosis.datasets import FaultDiagnosisSmallTEP\n", "from ice.fault_diagnosis.models import MLP" ] }, { "cell_type": "markdown", "id": "d6b9fced", "metadata": {}, "source": [ "Download the dataset." ] }, { "cell_type": "code", "execution_count": 2, "id": "d6ca9486", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading small_tep: 100%|██████████| 18.2M/18.2M [00:01<00:00, 10.1MB/s]\n", "Extracting df.csv: 58.6MB [00:00, 153MB/s] \n", "Extracting train_mask.csv: 9.77MB [00:00, 1.60GB/s] \n", "Extracting target.csv: 9.77MB [00:00, 2.33GB/s] \n", "Reading data/small_tep/df.csv: 100%|██████████| 153300/153300 [00:01<00:00, 79692.16it/s]\n", "Reading data/small_tep/target.csv: 100%|██████████| 153300/153300 [00:00<00:00, 1771006.14it/s]\n", "Reading data/small_tep/train_mask.csv: 100%|██████████| 153300/153300 [00:00<00:00, 1837204.89it/s]\n" ] } ], "source": [ "dataset = FaultDiagnosisSmallTEP()" ] }, { "cell_type": "code", "execution_count": 3, "id": "92db5267", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
xmeas_1xmeas_2xmeas_3xmeas_4xmeas_5xmeas_6xmeas_7xmeas_8xmeas_9xmeas_10...xmv_2xmv_3xmv_4xmv_5xmv_6xmv_7xmv_8xmv_9xmv_10xmv_11
run_idsample
41340207310.250383674.04529.09.232026.88942.4022704.374.863120.410.33818...53.74424.65762.54422.13739.93542.32347.75747.51041.25818.447
20.251093659.44556.69.426426.72142.5762705.075.000120.410.33620...53.41424.58859.25922.08440.17638.55443.69247.42741.35917.194
30.250383660.34477.89.442626.87542.0702706.274.771120.420.33563...54.35724.66661.27522.38040.24438.99046.69947.46841.19920.530
40.249773661.34512.19.477626.75842.0632707.275.224120.390.33553...53.94624.72559.85622.27740.25738.07247.54147.65841.64318.089
50.294053679.04497.09.338126.88942.6502705.175.388120.390.32632...53.65828.79760.71721.94739.14441.95547.64547.34641.50718.461
.....................................................................
3121488199560.248423694.24491.29.394626.78042.6552708.374.765120.410.32959...53.89124.58063.32021.86738.86836.06148.08845.47041.46317.078
9570.226123736.44523.19.365526.77842.7302711.075.142120.380.32645...53.67521.83164.14222.02738.84239.14444.56045.59841.59116.720
9580.223863692.84476.59.398426.67342.5282712.774.679120.430.32484...54.23322.05359.22822.23539.04035.11645.73745.49041.88416.310
9590.225613664.24483.09.429326.43542.4692710.274.857120.380.31932...53.33522.24860.56721.82037.97933.39448.50345.51240.63020.996
9600.225853717.64492.89.406126.86942.1762710.574.722120.410.31926...53.21722.22563.42922.25937.98634.81047.81045.63941.89818.378
\n", "

153300 rows × 52 columns

\n", "
" ], "text/plain": [ " xmeas_1 xmeas_2 xmeas_3 xmeas_4 xmeas_5 xmeas_6 \\\n", "run_id sample \n", "413402073 1 0.25038 3674.0 4529.0 9.2320 26.889 42.402 \n", " 2 0.25109 3659.4 4556.6 9.4264 26.721 42.576 \n", " 3 0.25038 3660.3 4477.8 9.4426 26.875 42.070 \n", " 4 0.24977 3661.3 4512.1 9.4776 26.758 42.063 \n", " 5 0.29405 3679.0 4497.0 9.3381 26.889 42.650 \n", "... ... ... ... ... ... ... \n", "312148819 956 0.24842 3694.2 4491.2 9.3946 26.780 42.655 \n", " 957 0.22612 3736.4 4523.1 9.3655 26.778 42.730 \n", " 958 0.22386 3692.8 4476.5 9.3984 26.673 42.528 \n", " 959 0.22561 3664.2 4483.0 9.4293 26.435 42.469 \n", " 960 0.22585 3717.6 4492.8 9.4061 26.869 42.176 \n", "\n", " xmeas_7 xmeas_8 xmeas_9 xmeas_10 ... xmv_2 xmv_3 \\\n", "run_id sample ... \n", "413402073 1 2704.3 74.863 120.41 0.33818 ... 53.744 24.657 \n", " 2 2705.0 75.000 120.41 0.33620 ... 53.414 24.588 \n", " 3 2706.2 74.771 120.42 0.33563 ... 54.357 24.666 \n", " 4 2707.2 75.224 120.39 0.33553 ... 53.946 24.725 \n", " 5 2705.1 75.388 120.39 0.32632 ... 53.658 28.797 \n", "... ... ... ... ... ... ... ... \n", "312148819 956 2708.3 74.765 120.41 0.32959 ... 53.891 24.580 \n", " 957 2711.0 75.142 120.38 0.32645 ... 53.675 21.831 \n", " 958 2712.7 74.679 120.43 0.32484 ... 54.233 22.053 \n", " 959 2710.2 74.857 120.38 0.31932 ... 53.335 22.248 \n", " 960 2710.5 74.722 120.41 0.31926 ... 53.217 22.225 \n", "\n", " xmv_4 xmv_5 xmv_6 xmv_7 xmv_8 xmv_9 xmv_10 \\\n", "run_id sample \n", "413402073 1 62.544 22.137 39.935 42.323 47.757 47.510 41.258 \n", " 2 59.259 22.084 40.176 38.554 43.692 47.427 41.359 \n", " 3 61.275 22.380 40.244 38.990 46.699 47.468 41.199 \n", " 4 59.856 22.277 40.257 38.072 47.541 47.658 41.643 \n", " 5 60.717 21.947 39.144 41.955 47.645 47.346 41.507 \n", "... ... ... ... ... ... ... ... \n", "312148819 956 63.320 21.867 38.868 36.061 48.088 45.470 41.463 \n", " 957 64.142 22.027 38.842 39.144 44.560 45.598 41.591 \n", " 958 59.228 22.235 39.040 35.116 45.737 45.490 41.884 \n", " 959 60.567 21.820 37.979 33.394 48.503 45.512 40.630 \n", " 960 63.429 22.259 37.986 34.810 47.810 45.639 41.898 \n", "\n", " xmv_11 \n", "run_id sample \n", "413402073 1 18.447 \n", " 2 17.194 \n", " 3 20.530 \n", " 4 18.089 \n", " 5 18.461 \n", "... ... \n", "312148819 956 17.078 \n", " 957 16.720 \n", " 958 16.310 \n", " 959 20.996 \n", " 960 18.378 \n", "\n", "[153300 rows x 52 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.df" ] }, { "cell_type": "code", "execution_count": 4, "id": "5021318f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "run_id sample\n", "413402073 1 0\n", " 2 0\n", " 3 0\n", " 4 0\n", " 5 0\n", " ..\n", "312148819 956 20\n", " 957 20\n", " 958 20\n", " 959 20\n", " 960 20\n", "Name: target, Length: 153300, dtype: int64" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.target" ] }, { "cell_type": "markdown", "id": "6bd716cf", "metadata": {}, "source": [ "Split the data into train and test sets by `run_id`." ] }, { "cell_type": "markdown", "id": "71b30525", "metadata": {}, "source": [ "Scale the data." ] }, { "cell_type": "code", "execution_count": 12, "id": "77572a8e", "metadata": {}, "outputs": [], "source": [ "scaler = StandardScaler()\n", "dataset.df[dataset.train_mask] = scaler.fit_transform(dataset.df[dataset.train_mask])\n", "dataset.df[dataset.test_mask] = scaler.transform(dataset.df[dataset.test_mask])" ] }, { "cell_type": "markdown", "id": "b2a742bd", "metadata": {}, "source": [ "Create the [MLP](ice.fault_diagnosis.models.mlp) model." ] }, { "cell_type": "code", "execution_count": 14, "id": "8d04ae19", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Creating sequence of samples: 100%|██████████| 105/105 [00:00<00:00, 777.58it/s]\n", "Epochs ...: 10%|█ | 1/10 [00:02<00:24, 2.70s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Loss: 0.4350\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 20%|██ | 2/10 [00:05<00:23, 2.88s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2, Loss: 0.5717\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 30%|███ | 3/10 [00:08<00:19, 2.82s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3, Loss: 0.3677\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 40%|████ | 4/10 [00:11<00:16, 2.79s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4, Loss: 0.6005\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 50%|█████ | 5/10 [00:13<00:13, 2.77s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5, Loss: 0.7325\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 60%|██████ | 6/10 [00:16<00:11, 2.86s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6, Loss: 0.3584\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 70%|███████ | 7/10 [00:19<00:08, 2.87s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7, Loss: 0.5394\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 80%|████████ | 8/10 [00:22<00:05, 2.81s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8, Loss: 0.3939\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 90%|█████████ | 9/10 [00:27<00:03, 3.57s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9, Loss: 0.3941\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epochs ...: 100%|██████████| 10/10 [00:35<00:00, 3.52s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10, Loss: 0.4601\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "model = MLP(window_size=10, lr=0.001, verbose=True)\n", "model.fit(dataset.df[dataset.train_mask], dataset.target[dataset.train_mask])" ] }, { "cell_type": "markdown", "id": "dd8990a1", "metadata": {}, "source": [ "Evaluate the metrics." ] }, { "cell_type": "code", "execution_count": 15, "id": "3d761723", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Creating sequence of samples: 100%|██████████| 105/105 [00:00<00:00, 345.39it/s]\n", " \r" ] }, { "data": { "text/plain": [ "{'accuracy': 0.7945664160401003}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = model.evaluate(dataset.df[dataset.test_mask], dataset.target[dataset.test_mask])\n", "metrics" ] }, { "cell_type": "code", "execution_count": null, "id": "7323d50c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }