{ "cells": [ { "cell_type": "markdown", "id": "f2da72a3-830f-4a34-b767-230ed0dbf154", "metadata": {}, "source": [ "# Results of HI estimation using SOTA models" ] }, { "cell_type": "markdown", "id": "e3de787a-3db3-47c0-a1c6-a2b36cc6586f", "metadata": {}, "source": [ "This notebook presents experimental results of hi estimation on the Milling dataset using the model MLP-256.\n", "\n", "Importing libraries." ] }, { "cell_type": "code", "execution_count": 1, "id": "0d2bf80e-d11f-4c3d-ac75-e4f0cfd418f6", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\user\\conda\\envs\\ice_testing\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from ice.health_index_estimation.datasets import Milling\n", "from ice.health_index_estimation.models import MLP, TCM, IE_SBiGRU, Stacked_LSTM\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "from tqdm.auto import trange\n", " " ] }, { "cell_type": "markdown", "id": "41bd9f60-51bf-4bb8-adc8-33a6abf923a1", "metadata": {}, "source": [ "Initializing model class and train/test data split" ] }, { "cell_type": "code", "execution_count": 2, "id": "814ba92e-e394-49a4-ac06-7b68361cf078", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reading data/milling/case_1.csv: 100%|██████████| 153000/153000 [00:00<00:00, 1268689.48it/s]\n", "Reading data/milling/case_2.csv: 100%|██████████| 117000/117000 [00:00<00:00, 1248873.41it/s]\n", "Reading data/milling/case_3.csv: 100%|██████████| 126000/126000 [00:00<00:00, 1276983.81it/s]\n", "Reading data/milling/case_4.csv: 100%|██████████| 63000/63000 [00:00<00:00, 1374151.83it/s]\n", "Reading data/milling/case_5.csv: 100%|██████████| 54000/54000 [00:00<00:00, 1290003.79it/s]\n", "Reading data/milling/case_6.csv: 100%|██████████| 9000/9000 [00:00<00:00, 1003395.34it/s]\n", "Reading data/milling/case_7.csv: 100%|██████████| 72000/72000 [00:00<00:00, 1245529.71it/s]\n", "Reading data/milling/case_8.csv: 100%|██████████| 54000/54000 [00:00<00:00, 1321487.68it/s]\n", "Reading data/milling/case_9.csv: 100%|██████████| 81000/81000 [00:00<00:00, 1377467.66it/s]\n", "Reading data/milling/case_10.csv: 100%|██████████| 90000/90000 [00:00<00:00, 1347769.63it/s]\n", "Reading data/milling/case_11.csv: 100%|██████████| 207000/207000 [00:00<00:00, 1180064.87it/s]\n", "Reading data/milling/case_12.csv: 100%|██████████| 126000/126000 [00:00<00:00, 1276980.73it/s]\n", "Reading data/milling/case_13.csv: 100%|██████████| 135000/135000 [00:00<00:00, 1242669.46it/s]\n", "Reading data/milling/case_14.csv: 100%|██████████| 81000/81000 [00:00<00:00, 1310826.20it/s]\n", "Reading data/milling/case_15.csv: 100%|██████████| 63000/63000 [00:00<00:00, 1316886.37it/s]\n", "Reading data/milling/case_16.csv: 100%|██████████| 18000/18000 [00:00<00:00, 1128731.62it/s]\n", "C:\\Users\\user\\conda\\envs\\ice_testing\\Lib\\site-packages\\scipy\\interpolate\\_interpolate.py:479: RuntimeWarning: invalid value encountered in divide\n", " slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]\n" ] } ], "source": [ "dataset_class = Milling()\n", "\n", "data, target = pd.concat(dataset_class.df), pd.concat(dataset_class.target) \n", "test_data, test_target = dataset_class.test[0], dataset_class.test_target[0]" ] }, { "cell_type": "code", "execution_count": 3, "id": "75a22b4c-5b1c-4ce0-b74e-6886d1c61c90", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", "import pandas as pd \n", "\n", "scaler = MinMaxScaler()\n", "trainer_data = scaler.fit_transform(data)\n", "tester_data = scaler.transform(test_data)\n", "\n", "trainer_data = pd.DataFrame(trainer_data, index=data.index, columns=data.columns)\n", "tester_data = pd.DataFrame(tester_data, index=test_data.index, columns=test_data.columns)" ] }, { "cell_type": "code", "execution_count": 4, "id": "049500f5-24a0-4076-a2eb-2e15970e3abb", "metadata": {}, "outputs": [], "source": [ "# path_to_tar = \"hi_sota/\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "ac338ba4-ab10-41cc-916b-0413bed0d1c6", "metadata": { "scrolled": true }, "outputs": [], "source": [ "model_class = Stacked_LSTM(\n", " window_size=64,\n", " stride=1024, # 1024\n", " batch_size=253, # 256\n", " lr= 0.0031789041005068647, # 0.0004999805761074147,\n", " num_epochs=55,\n", " verbose=True,\n", " device='cuda'\n", " )\n", "# model_class.fit(trainer_data, target)\n", "model_class.load_checkpoint(path_to_tar + \"stack_sota.tar\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "1512dcc2-1765-42f6-ab71-b8a44aa699df", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Creating sequence of samples: 100%|██████████| 14/14 [00:00<00:00, 2809.31it/s]\n", " \r" ] }, { "data": { "text/plain": [ "{'mse': 0.0022332468596409335, 'rmse': 0.047257241346072384}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_class.evaluate(tester_data, test_target)" ] }, { "cell_type": "code", "execution_count": null, "id": "defe310c-5d1e-412d-b68d-447a975ef937", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "b58a2a2b-9b73-4ac8-959e-4f93db8cd8eb", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 7, "id": "e7638bc7-d161-4266-b2d1-aa3545856853", "metadata": { "scrolled": true }, "outputs": [], "source": [ "model_class = TCM(\n", " window_size=64,\n", " stride=1024, # 1024\n", " batch_size=253, # 256\n", " lr= 0.0031789041005068647, # 0.0004999805761074147,\n", " num_epochs=55,\n", " verbose=True,\n", " device='cuda'\n", " )\n", "# model_class.fit(trainer_data, target)\n", "model_class.load_checkpoint(path_to_tar + \"TCM_sota.tar\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "550c5609-a0ec-48d3-9149-a4de69ba8368", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Creating sequence of samples: 100%|██████████| 14/14 [00:00<00:00, 3511.98it/s]\n", " \r" ] }, { "data": { "text/plain": [ "{'mse': 0.004014168163365719, 'rmse': 0.06335746335962102}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_class.evaluate(tester_data, test_target)" ] }, { "cell_type": "code", "execution_count": null, "id": "ede9a8c3-27e5-42df-980e-bcdb1ddcec73", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "e1ee17ee-09a8-4e7e-a81d-bea529af3f7f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "767c37e2-1750-46d5-9cd7-f628e7322305", "metadata": {}, "source": [ "Training and testing with difference random seed for uncertainty estimation" ] }, { "cell_type": "code", "execution_count": 9, "id": "1c5e6783-49df-46b5-86a0-6f09ee501d0e", "metadata": { "scrolled": true }, "outputs": [], "source": [ "model_class = IE_SBiGRU(\n", " window_size=64,\n", " stride=1024, # 1024\n", " batch_size=253, # 256\n", " lr= 0.0011, # 0.0004999805761074147,\n", " num_epochs=35,\n", " verbose=True,\n", " device='cuda'\n", " )\n", "# model_class.fit(trainer_data, target)\n", "model_class.load_checkpoint(path_to_tar + \"IE_SBiGRU_sota.tar\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "6425462a-14a5-4580-b082-92e9165d8984", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Creating sequence of samples: 100%|██████████| 14/14 [00:00<00:00, 2341.13it/s]\n", " \r" ] }, { "data": { "text/plain": [ "{'mse': 0.004956771691496658, 'rmse': 0.07040434426579555}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_class.evaluate(tester_data, test_target)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }