alessio21 commited on
Commit
36c3e63
β€’
1 Parent(s): 95eddc1

Upload run.ipynb

Browse files
Files changed (1) hide show
  1. run.ipynb +653 -0
run.ipynb ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 3,
20
+ "metadata": {
21
+ "id": "yowZ_FwQ53s6"
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "!pip install -q seaborn plotly sentence-transformers prince gradio==3.41.2"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "source": [
31
+ "import matplotlib.pyplot as plt\n",
32
+ "import numpy as np\n",
33
+ "import pandas as pd\n",
34
+ "import os\n",
35
+ "import tensorflow as tf\n",
36
+ "from tensorflow import keras\n",
37
+ "import seaborn as sns\n",
38
+ "\n",
39
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score\n",
40
+ "from sklearn.metrics import f1_score, confusion_matrix, precision_recall_curve, roc_curve\n",
41
+ "from sklearn.metrics import ConfusionMatrixDisplay\n",
42
+ "\n",
43
+ "from sklearn.model_selection import train_test_split\n",
44
+ "from tensorflow.keras import layers, losses\n",
45
+ "from tensorflow.keras.datasets import fashion_mnist\n",
46
+ "from tensorflow.keras.models import Model\n",
47
+ "\n",
48
+ "from plotly.subplots import make_subplots\n",
49
+ "import plotly.graph_objects as go\n",
50
+ "\n",
51
+ "from sklearn.decomposition import PCA\n",
52
+ "\n",
53
+ "import plotly.express as px\n",
54
+ "from scipy.interpolate import griddata\n",
55
+ "import sklearn\n",
56
+ "from sklearn.tree import DecisionTreeClassifier\n",
57
+ "from sklearn.metrics import confusion_matrix, precision_score, roc_auc_score, precision_recall_curve\n",
58
+ "from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, cross_val_predict, StratifiedKFold\n",
59
+ "from sentence_transformers import SentenceTransformer\n",
60
+ "\n",
61
+ "from sklearn import tree\n",
62
+ "\n",
63
+ "\n",
64
+ "import gradio as gr\n",
65
+ "import os\n",
66
+ "import json\n",
67
+ "from datetime import datetime, timedelta\n",
68
+ "import shutil\n",
69
+ "import random\n",
70
+ "import plotly.io as pio\n",
71
+ "\n",
72
+ "import joblib\n",
73
+ "\n",
74
+ "\n",
75
+ "\n",
76
+ "#load models\n",
77
+ "autoencoder = keras.models.load_model('models/autoencoder')\n",
78
+ "classifier = keras.models.load_model('models/classifier')\n",
79
+ "decision_tree = joblib.load(\"models/decision_tree_model.pkl\")\n",
80
+ "llm_model = SentenceTransformer(r\"sentence-transformers/paraphrase-MiniLM-L6-v2\")\n",
81
+ "\n",
82
+ "pca_2d_llm_clusters = joblib.load('models/pca_llm_model.pkl')\n",
83
+ "\n",
84
+ "print(\"models loaded\")\n",
85
+ "\n",
86
+ "\n",
87
+ "\n",
88
+ "#compute training dataset constant (min and max) for data normalization\n",
89
+ "\n",
90
+ "dataframe = pd.read_csv('ecg.csv', header=None)\n",
91
+ "dataframe[140] = dataframe[140].apply(lambda x: 1 if x==0 else 0)\n",
92
+ "\n",
93
+ "df_ecg = dataframe[[i for i in range(140)]]\n",
94
+ "ecg_raw_data = df_ecg.values\n",
95
+ "labels = dataframe.values[:, -1]\n",
96
+ "ecg_data = ecg_raw_data[:, :]\n",
97
+ "train_data, test_data, train_labels, test_labels = train_test_split(\n",
98
+ " ecg_data, labels, test_size=0.2, random_state=21)\n",
99
+ "\n",
100
+ "min_val = tf.reduce_min(train_data)\n",
101
+ "max_val = tf.reduce_max(train_data)\n",
102
+ "\n",
103
+ "print(\"constant computing: OK\")\n",
104
+ "\n",
105
+ "\n",
106
+ "#compute PCA for latent space representation\n",
107
+ "\n",
108
+ "ecg_data = (ecg_data - min_val) / (max_val - min_val)\n",
109
+ "\n",
110
+ "ecg_data = tf.cast(ecg_data, tf.float32)\n",
111
+ "\n",
112
+ "print(ecg_data.shape)\n",
113
+ "X = autoencoder.encoder(ecg_data).numpy()\n",
114
+ "\n",
115
+ "n_components=2\n",
116
+ "pca = PCA(n_components=n_components)\n",
117
+ "X_compressed = pca.fit_transform(X)\n",
118
+ "\n",
119
+ "\n",
120
+ "column_names = [f\"Feature{i + 1}\" for i in range(n_components)]\n",
121
+ "categories = [\"normal\",\"heart disease\"]\n",
122
+ "target_categorical = pd.Categorical.from_codes(labels.astype(int), categories=categories)\n",
123
+ "df_compressed = pd.DataFrame(X_compressed, columns=column_names)\n",
124
+ "df_compressed[\"target\"] = target_categorical\n",
125
+ "\n",
126
+ "print(\"PCA: done\")\n",
127
+ "\n",
128
+ "\n",
129
+ "#load dataset for decision tree map plot\n",
130
+ "df_plot = pd.read_csv(\"df_mappa.csv\", sep=\",\", header=0)\n",
131
+ "print(\"df map for decision tree loaded.\")\n",
132
+ "\n",
133
+ "#load dataset form llm pca\n",
134
+ "df_pca_llm = pd.read_csv(\"df_PCA_llm.csv\",sep=\",\",header=0)\n",
135
+ "\n",
136
+ "\n",
137
+ "\n",
138
+ "\n",
139
+ "\n",
140
+ "\n",
141
+ "#useful functions\n",
142
+ "\n",
143
+ "def df_encoding(df):\n",
144
+ " df.ExerciseAngina.replace(\n",
145
+ " {\n",
146
+ " 'N' : 'No',\n",
147
+ " 'Y' : 'exercise-induced angina'\n",
148
+ " },\n",
149
+ " inplace = True\n",
150
+ " )\n",
151
+ " df.FastingBS.replace(\n",
152
+ " {\n",
153
+ " 0 : 'Not Diabetic',\n",
154
+ " 1 : 'High fasting blood sugar'\n",
155
+ " },\n",
156
+ " inplace = True\n",
157
+ " )\n",
158
+ " df.Sex.replace(\n",
159
+ " {\n",
160
+ " 'M' : 'Man',\n",
161
+ " 'F' : 'Female'\n",
162
+ " },\n",
163
+ " inplace = True\n",
164
+ " )\n",
165
+ " df.ChestPainType.replace(\n",
166
+ " {\n",
167
+ " 'ATA' : 'Atypical',\n",
168
+ " 'NAP' : 'Non-Anginal Pain',\n",
169
+ " 'ASY' : 'Asymptomatic',\n",
170
+ " 'TA' : 'Typical Angina'\n",
171
+ " },\n",
172
+ " inplace = True\n",
173
+ " )\n",
174
+ " df.RestingECG.replace(\n",
175
+ " {\n",
176
+ " 'Normal' : 'Normal',\n",
177
+ " 'ST' : 'ST-T wave abnormality',\n",
178
+ " 'LVH' : 'Probable left ventricular hypertrophy'\n",
179
+ " },\n",
180
+ " inplace = True\n",
181
+ " )\n",
182
+ " df.ST_Slope.replace(\n",
183
+ " {\n",
184
+ " 'Up' : 'Up',\n",
185
+ " 'Flat' : 'Flat',\n",
186
+ " 'Down' : 'Downsloping'\n",
187
+ " },\n",
188
+ " inplace = True\n",
189
+ " )\n",
190
+ "\n",
191
+ " return df\n",
192
+ "\n",
193
+ "\n",
194
+ "\n",
195
+ "def compile_text_no_target(x):\n",
196
+ "\n",
197
+ "\n",
198
+ " text = f\"\"\"Age: {x['Age']},\n",
199
+ " Sex: {x['Sex']},\n",
200
+ " Chest Pain Type: {x['ChestPainType']},\n",
201
+ " RestingBP: {x['RestingBP']},\n",
202
+ " Cholesterol: {x['Cholesterol']},\n",
203
+ " FastingBS: {x['FastingBS']},\n",
204
+ " RestingECG: {x['RestingECG']},\n",
205
+ " MaxHR: {x['MaxHR']}\n",
206
+ " Exercise Angina: {x['ExerciseAngina']},\n",
207
+ " Old peak: {x['Oldpeak']},\n",
208
+ " ST_Slope: {x['ST_Slope']}\n",
209
+ " \"\"\"\n",
210
+ "\n",
211
+ " return text\n",
212
+ "\n",
213
+ "def LLM_transform(df , model = llm_model):\n",
214
+ " sentences = df.apply(lambda x: compile_text_no_target(x), axis=1).tolist()\n",
215
+ "\n",
216
+ "\n",
217
+ "\n",
218
+ " #model = SentenceTransformer(r\"sentence-transformers/paraphrase-MiniLM-L6-v2\")\n",
219
+ "\n",
220
+ " output = model.encode(sentences=sentences, show_progress_bar= True, normalize_embeddings = True)\n",
221
+ "\n",
222
+ " df_embedding = pd.DataFrame(output)\n",
223
+ "\n",
224
+ " return df_embedding\n",
225
+ "\n",
226
+ "\n",
227
+ "\n",
228
+ "\n",
229
+ "\n",
230
+ "\n",
231
+ "\n",
232
+ "\n",
233
+ "\n",
234
+ "def upload_ecg(file):\n",
235
+ "\n",
236
+ "\n",
237
+ "\n",
238
+ " if len(os.listdir(\"current_ecg\"))>0: # se ci sono file nella cartella, eliminali\n",
239
+ "\n",
240
+ " try:\n",
241
+ " for filename in os.listdir(\"current_ecg\"):\n",
242
+ " file_path = os.path.join(\"current_ecg\", filename)\n",
243
+ " if os.path.isfile(file_path):\n",
244
+ " os.remove(file_path)\n",
245
+ " print(f\"I file nella cartella 'current_ecg' sono stati eliminati.\")\n",
246
+ "\n",
247
+ " except Exception as e:\n",
248
+ " print(f\"Errore nell'eliminazione dei file: {str(e)}\")\n",
249
+ "\n",
250
+ "\n",
251
+ "\n",
252
+ " df = pd.read_csv(file.name,header=None) #file.name Γ¨ il path temporaneo del file caricato\n",
253
+ "\n",
254
+ "\n",
255
+ " source_directory = os.path.dirname(file.name) # Replace with the source directory path\n",
256
+ " destination_directory = 'current_ecg' # Replace with the destination directory path\n",
257
+ "\n",
258
+ "\n",
259
+ " # Specify the filename (including the extension) of the CSV file you want to copy\n",
260
+ " file_to_copy = os.path.basename(file.name) # Replace with the actual filename\n",
261
+ "\n",
262
+ "\n",
263
+ " # Construct the full source and destination file paths\n",
264
+ " source_file_path = f\"{source_directory}/{file_to_copy}\"\n",
265
+ " destination_file_path = f\"{destination_directory}/{file_to_copy}\"\n",
266
+ "\n",
267
+ " # Copy the file from the source directory to the destination directory\n",
268
+ " shutil.copy(source_file_path, destination_file_path)\n",
269
+ "\n",
270
+ "\n",
271
+ " return \"Your ECG is ready, you can analyze it!\"\n",
272
+ "\n",
273
+ "\n",
274
+ "\n",
275
+ "\n",
276
+ "\n",
277
+ "\n",
278
+ "\n",
279
+ "\n",
280
+ "\n",
281
+ "\n",
282
+ "def ecg_availability(patient_name):\n",
283
+ "\n",
284
+ " folder_path = os.path.join(\"PATIENT\",patient_name)\n",
285
+ " status_file_path = os.path.join(folder_path, \"status.json\")\n",
286
+ "\n",
287
+ " # Check if the \"status.json\" file exists\n",
288
+ " if not os.path.isfile(status_file_path):\n",
289
+ " return None # If the file doesn't exist, return None\n",
290
+ "\n",
291
+ " # Load the JSON data from the \"status.json\" file\n",
292
+ " with open(status_file_path, 'r') as status_file:\n",
293
+ " status_data = json.load(status_file)\n",
294
+ "\n",
295
+ " # Extract the last datetime from the status JSON (if available)\n",
296
+ " last_datetime_str = status_data.get(\"last_datetime\", None)\n",
297
+ "\n",
298
+ " # Get the list of CSV files in the folder\n",
299
+ " csv_files = [f for f in os.listdir(folder_path) if f.endswith(\".csv\")]\n",
300
+ "\n",
301
+ " if last_datetime_str is None:\n",
302
+ " return f\"New ECG available\" # If the JSON is empty, return all CSV files\n",
303
+ "\n",
304
+ " last_datetime = datetime.strptime(last_datetime_str, \"%B_%d_%H_%M_%S\")\n",
305
+ "\n",
306
+ " # Find successive CSV files\n",
307
+ " successive_csv_files = []\n",
308
+ " for csv_file in csv_files:\n",
309
+ " csv_datetime_str = csv_file.split('.')[0]\n",
310
+ " csv_datetime = datetime.strptime(csv_datetime_str, \"%B_%d_%H_%M_%S\")\n",
311
+ "\n",
312
+ " # Check if the CSV datetime is successive to the last saved datetime\n",
313
+ " if csv_datetime > last_datetime:\n",
314
+ " successive_csv_files.append(csv_file)\n",
315
+ "\n",
316
+ " if len(successive_csv_file)>0:\n",
317
+ " return f\"New ECG available (last ECG: {last_datetime})\"\n",
318
+ "\n",
319
+ " else:\n",
320
+ " return f\"No ECG available (last ECG: {last_datetime})\"\n",
321
+ "\n",
322
+ "\n",
323
+ "\n",
324
+ "\n",
325
+ "def ecg_analysis():\n",
326
+ "\n",
327
+ " df = pd.read_csv(os.path.join(\"current_ecg\",os.listdir(\"current_ecg\")[0]))\n",
328
+ "\n",
329
+ "\n",
330
+ " df_ecg = df[[str(i) for i in range(140)]] #ecg data columns\n",
331
+ " df_data = df_ecg.values #raw data. shape: (n_rows , 140)\n",
332
+ " df_data = (df_data - min_val) / (max_val - min_val)\n",
333
+ " df_data = tf.cast(df_data, tf.float32) #raw data. shape: (n_rows , 140)\n",
334
+ "\n",
335
+ "\n",
336
+ " df_tree = df[[\"ChestPainType\",\"ST_Slope\"]].copy() #dataset for decision tree\n",
337
+ "\n",
338
+ " df_llm = df[[\"Age\",\"Sex\",\"ChestPainType\",\"RestingBP\",\"Cholesterol\",\"FastingBS\",\"RestingECG\",\"MaxHR\",\"ExerciseAngina\",\"Oldpeak\",\"ST_Slope\"]].copy() # dataset for LLM\n",
339
+ "\n",
340
+ " true_label = df.values[:,-1]\n",
341
+ "\n",
342
+ " # ----------------ECG ANALYSIS WITH AUTOENCODER-------------------------------\n",
343
+ " heartbeat_encoder_preds = autoencoder.encoder(df_data).numpy() #encoder data representation. shape: (n_rows , 8)\n",
344
+ " heartbeat_decoder_preds = autoencoder.decoder(heartbeat_encoder_preds).numpy() #decoder data reconstruction. shape: (n_rows , 140)\n",
345
+ "\n",
346
+ " classification_res = classifier.predict(df_data) #shape: (n_rows , 1)\n",
347
+ "\n",
348
+ "\n",
349
+ " print(\"shapes of: encoder preds, decoder preds, classification preds/n\",heartbeat_encoder_preds.shape,heartbeat_decoder_preds.shape,classification_res.shape)\n",
350
+ "\n",
351
+ " #heartbeat_indexes = [i for i, pred in enumerate(classification_res) if pred == 0]\n",
352
+ "\n",
353
+ " p_encoder_preds = heartbeat_encoder_preds[0,:] #encoder representation of the chosen row\n",
354
+ " p_decoder_preds = heartbeat_decoder_preds[0,:] #decoder reconstruction of the chosen row\n",
355
+ " p_class_res = classification_res[0,:] # classification res of the chosen row\n",
356
+ " p_true = true_label[0]\n",
357
+ "\n",
358
+ "\n",
359
+ "\n",
360
+ "\n",
361
+ " #LATENT SPACE PLOT\n",
362
+ "\n",
363
+ " # Create the scatter plot\n",
364
+ " fig = px.scatter(df_compressed, x='Feature1', y='Feature2', color='target', color_discrete_map={0: 'red', 1: 'blue'},\n",
365
+ " labels={'Target': 'Binary Target'},size_max=18)\n",
366
+ "\n",
367
+ "\n",
368
+ " # Disable hover information\n",
369
+ " # fig.update_traces(mode=\"markers\",\n",
370
+ " # hovertemplate = None,\n",
371
+ " # hoverinfo = \"skip\")\n",
372
+ "\n",
373
+ " # Customize the plot layout\n",
374
+ " fig.update_layout(\n",
375
+ " title='Latent space 2D (PCA reduction)',\n",
376
+ " xaxis_title='component 1',\n",
377
+ " yaxis_title='component 2'\n",
378
+ " )\n",
379
+ "\n",
380
+ " # add new point\n",
381
+ " new_point_compressed = pca.transform(p_encoder_preds.reshape(1,-1))\n",
382
+ "\n",
383
+ " new_point = {'X':[new_point_compressed[0][0]] , 'Y':[new_point_compressed[0][1]] } # Target value 2 for the new point\n",
384
+ "\n",
385
+ " new_point_df = pd.DataFrame(new_point)\n",
386
+ "\n",
387
+ " #fig.add_trace(px.scatter(new_point_df, x='X', y='Y').data[0])\n",
388
+ " fig.add_trace(go.Scatter(\n",
389
+ " x=new_point_df['X'],\n",
390
+ " y=new_point_df['Y'],\n",
391
+ " mode='markers',\n",
392
+ " marker=dict(symbol='star', color='black', size=15),\n",
393
+ " name='actual patient'\n",
394
+ " ))\n",
395
+ "\n",
396
+ " d = fig.to_dict()\n",
397
+ " d[\"data\"][0][\"type\"] = \"scatter\"\n",
398
+ "\n",
399
+ " fig=go.Figure(d)\n",
400
+ "\n",
401
+ "\n",
402
+ "\n",
403
+ " # DECODER RECONSTRUCTION PLOT\n",
404
+ "\n",
405
+ " fig_reconstruction = plt.figure(figsize=(10,8))\n",
406
+ " sns.set(font_scale = 2)\n",
407
+ " sns.set_style(\"white\")\n",
408
+ " plt.plot(df_data[0], 'black',linewidth=2)\n",
409
+ " plt.plot(heartbeat_decoder_preds[0], 'red',linewidth=2)\n",
410
+ " plt.fill_between(np.arange(140), heartbeat_decoder_preds[0], df_data[0], color='lightcoral')\n",
411
+ " plt.legend(labels=[\"Input\", \"Reconstruction\", \"Error\"])\n",
412
+ "\n",
413
+ " #classification probability\n",
414
+ "\n",
415
+ " # ----------DECISION TREE ANALYSIS---------------------------------\n",
416
+ "\n",
417
+ "\n",
418
+ " # Define the desired column order\n",
419
+ " encoded_features = ['ST_Slope_Up', 'ST_Slope_Flat', 'ST_Slope_Down', 'ChestPainType_ASY', 'ChestPainType_ATA', 'ChestPainType_NAP', 'ChestPainType_TA'] #il modello vuole le colonne in un determinato ordine\n",
420
+ "\n",
421
+ " X_plot = pd.DataFrame(columns=encoded_features)\n",
422
+ "\n",
423
+ " for k in range(len(df_tree['ST_Slope'])):\n",
424
+ " X_plot.loc[k] = 0\n",
425
+ " if df_tree['ST_Slope'][k] == 'Up':\n",
426
+ " X_plot['ST_Slope_Up'][k] = 1\n",
427
+ " if df_tree['ST_Slope'][k] == 'Flat':\n",
428
+ " X_plot['ST_Slope_Flat'][k] = 1\n",
429
+ " if df_tree['ST_Slope'][k] == 'Down':\n",
430
+ " X_plot['ST_Slope_Down'][k] = 1\n",
431
+ " if df_tree['ChestPainType'][k] == 'ASY':\n",
432
+ " X_plot['ChestPainType_ASY'][k] = 1\n",
433
+ " if df_tree['ChestPainType'][k] == 'ATA':\n",
434
+ " X_plot['ChestPainType_ATA'][k] = 1\n",
435
+ " if df_tree['ChestPainType'][k] == 'NAP':\n",
436
+ " X_plot['ChestPainType_NAP'][k] = 1\n",
437
+ " if df_tree['ChestPainType'][k] == 'TA':\n",
438
+ " X_plot['ChestPainType_TA'][k] = 1\n",
439
+ "\n",
440
+ "\n",
441
+ " #model prediction\n",
442
+ " y_score = decision_tree.predict_proba(X_plot)[:,1]\n",
443
+ "\n",
444
+ " chest_pain = []\n",
445
+ " slop = []\n",
446
+ "\n",
447
+ " for k in range(len(X_plot)):\n",
448
+ " if X_plot['ChestPainType_ASY'][k] == 1 and X_plot['ChestPainType_ATA'][k] == 0 and X_plot['ChestPainType_NAP'][k] == 0 and X_plot['ChestPainType_TA'][k] == 0:\n",
449
+ " chest_pain.append(0)\n",
450
+ " if X_plot['ChestPainType_ASY'][k] == 0 and X_plot['ChestPainType_ATA'][k] == 1 and X_plot['ChestPainType_NAP'][k] == 0 and X_plot['ChestPainType_TA'][k] == 0:\n",
451
+ " chest_pain.append(1)\n",
452
+ " if X_plot['ChestPainType_ASY'][k] == 0 and X_plot['ChestPainType_ATA'][k] == 0 and X_plot['ChestPainType_NAP'][k] == 1 and X_plot['ChestPainType_TA'][k] == 0:\n",
453
+ " chest_pain.append(2)\n",
454
+ " if X_plot['ChestPainType_ASY'][k] == 0 and X_plot['ChestPainType_ATA'][k] == 0 and X_plot['ChestPainType_NAP'][k] == 0 and X_plot['ChestPainType_TA'][k] == 1:\n",
455
+ " chest_pain.append(3)\n",
456
+ " if X_plot['ST_Slope_Up'][k] == 1 and X_plot['ST_Slope_Flat'][k] == 0 and X_plot['ST_Slope_Down'][k] == 0:\n",
457
+ " slop.append(0)\n",
458
+ " if X_plot['ST_Slope_Up'][k] == 0 and X_plot['ST_Slope_Flat'][k] == 1 and X_plot['ST_Slope_Down'][k] == 0:\n",
459
+ " slop.append(1)\n",
460
+ " if X_plot['ST_Slope_Up'][k] == 0 and X_plot['ST_Slope_Flat'][k] == 0 and X_plot['ST_Slope_Down'][k] == 1:\n",
461
+ " slop.append(2)\n",
462
+ "\n",
463
+ "\n",
464
+ " # Create a structured grid\n",
465
+ " fig_tree = plt.figure()\n",
466
+ " x1 = np.linspace(df_plot['ST_Slope'].min()-0.5, df_plot['ST_Slope'].max()+0.5)\n",
467
+ " x2 = np.linspace(df_plot['ChestPainType'].min()-0.5, df_plot['ChestPainType'].max()+0.5)\n",
468
+ " X1, X2 = np.meshgrid(x1, x2)\n",
469
+ "\n",
470
+ " # Interpolate the 'Prob' values onto the grid\n",
471
+ " points = df_plot[['ST_Slope', 'ChestPainType']].values\n",
472
+ " values = df_plot['Prob'].values\n",
473
+ " Z = griddata(points, values, (X1, X2), method='nearest')\n",
474
+ "\n",
475
+ " # Create the contour plot with regions colored by interpolated 'Prob'\n",
476
+ " plt.contourf(X1, X2, Z, cmap='coolwarm', levels=10)\n",
477
+ " plt.colorbar(label='Predicted Probability')\n",
478
+ "\n",
479
+ " # Add data points if needed\n",
480
+ " plt.scatter(slop[:1], chest_pain[:1], c=\"k\", cmap='coolwarm', edgecolor='k', marker='o', label=f'prob={y_score[:1].round(3)}')\n",
481
+ "\n",
482
+ " # Remove the numerical labels from the x and y axes\n",
483
+ " plt.xticks([])\n",
484
+ " plt.yticks([])\n",
485
+ "\n",
486
+ " # Add custom labels \"0\" and \"1\" near the center of the axis\n",
487
+ " plt.text(0.0, -0.7, \"Up\", ha='center',fontsize=15)\n",
488
+ " plt.text(1.00, -0.7, \"Flat\", ha='center',fontsize=15)\n",
489
+ " plt.text(2.00, -0.7, \"Down\", ha='center',fontsize=15)\n",
490
+ " plt.text(-0.62, 0.0, \"ASY\", rotation='vertical', va='center',fontsize=15)\n",
491
+ " plt.text(-0.62, 1.00, \"ATA\", rotation='vertical', va='center',fontsize=15)\n",
492
+ " plt.text(-0.62, 2.0, \"NAP\", rotation='vertical', va='center',fontsize=15)\n",
493
+ " plt.text(-0.62, 3.0, \"TA\", rotation='vertical', va='center',fontsize=15)\n",
494
+ "\n",
495
+ " # Add labels and title\n",
496
+ " plt.xlabel('ST_Slope', fontsize=15, labelpad=20)\n",
497
+ " plt.ylabel('ChestPainType', fontsize=15, labelpad=20)\n",
498
+ " #plt.legend()\n",
499
+ "\n",
500
+ "\n",
501
+ "\n",
502
+ " # ------------LLM ANALYSIS------------------------------------\n",
503
+ "\n",
504
+ " df_llm_encoding = df_encoding(df_llm)\n",
505
+ " df_point_LLM = LLM_transform(df_llm_encoding)\n",
506
+ "\n",
507
+ " df_point_LLM.columns = [str(column) for column in df_point_LLM.columns]\n",
508
+ "\n",
509
+ " pca_llm_point = pca_2d_llm_clusters.transform(df_point_LLM)\n",
510
+ " pca_llm_point.columns = [\"comp1\", \"comp2\"]\n",
511
+ "\n",
512
+ "\n",
513
+ " #clusters\n",
514
+ "\n",
515
+ " fig_llm_cluster = plt.figure()\n",
516
+ " x = df_pca_llm['comp1']\n",
517
+ " y = df_pca_llm['comp2']\n",
518
+ "\n",
519
+ " labels = ['Cluster 0', 'Cluster 1', 'Cluster 2', 'Cluster 3']\n",
520
+ "\n",
521
+ " # Create a dictionary to map 'RestingECG' values to colors\n",
522
+ " color_mapping = {0: 'r', 1: 'b', 2: 'g', 3: 'y'}\n",
523
+ "\n",
524
+ " for i in df_pca_llm['cluster'].unique():\n",
525
+ " color = color_mapping.get(i, 'k') # Use 'k' (black) for undefined values\n",
526
+ " plt.scatter(x[df_pca_llm['cluster'] == i], y[df_pca_llm['cluster'] == i], c=color, label=labels[i])\n",
527
+ "\n",
528
+ " plt.scatter(pca_llm_point['comp1'], pca_llm_point['comp1'], c='k', marker='D')\n",
529
+ "\n",
530
+ " # Remove the numerical labels from the x and y axes\n",
531
+ " plt.xticks([])\n",
532
+ " plt.yticks([])\n",
533
+ "\n",
534
+ " plt.xlabel('Principal Component 1')\n",
535
+ " plt.ylabel('Principal Component 2')\n",
536
+ " plt.legend()\n",
537
+ " plt.grid(False)\n",
538
+ "\n",
539
+ "\n",
540
+ "\n",
541
+ "\n",
542
+ "\n",
543
+ "\n",
544
+ "\n",
545
+ "\n",
546
+ " return fig, fig_reconstruction , f\"Heart disease probability: {int(p_class_res[0]*100)} %\" , fig_tree , f\"Heart disease probability: {int(y_score[0]*100)} %\" , fig_llm_cluster\n",
547
+ "\n",
548
+ "\n",
549
+ "\n",
550
+ "\n",
551
+ "\n",
552
+ "\n",
553
+ "\n",
554
+ "\n",
555
+ "#demo app\n",
556
+ "\n",
557
+ "with gr.Blocks(title=\"TIQUE - AI DEMO CAPABILITIES\") as demo:\n",
558
+ "\n",
559
+ " gr.Markdown(\"<h1><center>TIQUE: AI DEMO CAPABILITIES<center><h1>\")\n",
560
+ "\n",
561
+ "\n",
562
+ " with gr.Row():\n",
563
+ "\n",
564
+ " pazienti = [\"Elisabeth Smith\",\"Michael Mims\"]\n",
565
+ " menu_pazienti = gr.Dropdown(choices=pazienti,label=\"patients\")\n",
566
+ "\n",
567
+ " available_ecg_result = gr.Textbox()\n",
568
+ "\n",
569
+ "\n",
570
+ " menu_pazienti.input(ecg_availability, inputs=[menu_pazienti], outputs=[available_ecg_result])\n",
571
+ "\n",
572
+ " with gr.Row():\n",
573
+ "\n",
574
+ " input_file = gr.UploadButton(\"Click to Upload an ECG πŸ“\")\n",
575
+ " text_upload_results = gr.Textbox()\n",
576
+ "\n",
577
+ " input_file.upload(upload_ecg,inputs=[input_file],outputs=text_upload_results)\n",
578
+ "\n",
579
+ " with gr.Row():\n",
580
+ " ecg_start_analysis_button = gr.Button(value=\"Start ECG analysis\",scale=1)\n",
581
+ "\n",
582
+ "\n",
583
+ " gr.Markdown(\"## Large Language Model clustering\")\n",
584
+ "\n",
585
+ " with gr.Row():\n",
586
+ "\n",
587
+ " llm_cluster = gr.Plot()\n",
588
+ "\n",
589
+ "\n",
590
+ " gr.Markdown(\"## Autoencoder results:\")\n",
591
+ "\n",
592
+ " with gr.Row():\n",
593
+ "\n",
594
+ " with gr.Column():\n",
595
+ "\n",
596
+ " latent_space_representation = gr.Plot()\n",
597
+ "\n",
598
+ " with gr.Column():\n",
599
+ "\n",
600
+ " autoencoder_ecg_reconstruction = gr.Plot()\n",
601
+ "\n",
602
+ " classifier_nn_prediction = gr.Textbox()\n",
603
+ "\n",
604
+ " gr.Markdown(\"## Decision Tree results:\")\n",
605
+ "\n",
606
+ " with gr.Row():\n",
607
+ "\n",
608
+ " decision_tree_plot = gr.Plot()\n",
609
+ "\n",
610
+ " decision_tree_proba = gr.Textbox()\n",
611
+ "\n",
612
+ "\n",
613
+ "\n",
614
+ "\n",
615
+ " ecg_start_analysis_button.click(fn=ecg_analysis, inputs=None, outputs=[latent_space_representation,\n",
616
+ " autoencoder_ecg_reconstruction,\n",
617
+ " classifier_nn_prediction,decision_tree_plot, decision_tree_proba,\n",
618
+ " llm_cluster])\n",
619
+ "if __name__ == \"__main__\":\n",
620
+ " demo.launch()\n",
621
+ "\n",
622
+ "\n",
623
+ "\n",
624
+ "\n",
625
+ "\n",
626
+ "\n",
627
+ "\n",
628
+ "\n",
629
+ "\n",
630
+ "\n",
631
+ "\n",
632
+ "\n",
633
+ "\n",
634
+ "\n",
635
+ "\n",
636
+ "\n",
637
+ "\n",
638
+ "\n",
639
+ "\n",
640
+ "\n",
641
+ "\n",
642
+ "\n",
643
+ "\n",
644
+ "\n"
645
+ ],
646
+ "metadata": {
647
+ "id": "bVSujh5-677-"
648
+ },
649
+ "execution_count": null,
650
+ "outputs": []
651
+ }
652
+ ]
653
+ }