ChaoukiBenzekri commited on
Commit
2ec1466
1 Parent(s): 062b83a

Add costom_functions

Browse files
Files changed (1) hide show
  1. custom_functions.py +487 -0
custom_functions.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import uuid
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.image as mpimg
8
+ import plotly.graph_objects as go
9
+ import plotly.figure_factory as ff
10
+ import tensorflow as tf
11
+ import keras
12
+
13
+ from keras.layers import Conv2D
14
+ from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
15
+ from keras.models import load_model
16
+
17
+ # Afficher le nom et l'icone des réseaux d'un membre
18
+ def show_profile(name, linkedin_url, github_url):
19
+ st.markdown("""
20
+ <style>
21
+ .icon-img {
22
+ height: 20px; # Taille des icônes
23
+ width: 20px;
24
+ margin-top: 0px;
25
+ margin-left: 5px;
26
+ margin-right:5px;
27
+ }
28
+ .name-with-icons {
29
+ display: flex-shrink;
30
+ align-items: center;
31
+ margin-top: -5px;
32
+ margin-bottom: 0px;
33
+ }
34
+ .name-with-icons:first-child {
35
+ margin-top: -5px; /* Marge négative pour le premier élément */
36
+ }
37
+ </style>
38
+ """, unsafe_allow_html = True)
39
+
40
+ st.markdown(f"""
41
+ <div class='name-with-icons'>
42
+ <a href='{github_url}' target='_blank'><img class='icon-img' src='https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg'></a>
43
+ <a href='{linkedin_url}' target='_blank'><img class='icon-img' src='https://upload.wikimedia.org/wikipedia/commons/c/ca/LinkedIn_logo_initials.png'></a>
44
+ {name}
45
+ </div>
46
+ """, unsafe_allow_html = True)
47
+
48
+ # Création de bloc de texte esthétiques
49
+ def create_styled_box(text, text_color, background_color, alignment = 'left', display = 'block'):
50
+ unique_id = uuid.uuid4().hex
51
+ # Ajout de styles CSS pour le cadre avec une couleur de fond et une couleur de texte personnalisée
52
+ st.markdown(f"""
53
+ <style>
54
+ .styled-box-{unique_id} {{
55
+ background-color: {background_color}; /* couleur de fond */
56
+ color: {text_color}; /* couleur de texte */
57
+ padding: 10px 20px; /* espace intérieur vertical et horizontal */
58
+ border: 1px solid {text_color}; /* bordure de la même couleur que le texte */
59
+ border-radius: 8px; /* bord arrondi */
60
+ font-size: 16px; /* taille de la police */
61
+ font-weight: regular; /* format de la police */
62
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2); /* ombre pour donner de la profondeur */
63
+ text-align: {alignment}; /* alignement du texte */
64
+ display: {display};
65
+ }}
66
+ </style>
67
+ """, unsafe_allow_html = True)
68
+
69
+ # Afficher le texte dans le cadre stylisé
70
+ st.markdown(f"""
71
+ <div class="styled-box-{unique_id}">{text}</div>
72
+ """, unsafe_allow_html = True)
73
+
74
+ # Fonction pour calculer l'intensité lumineuse moyenne d'une image
75
+ def calc_mean_intensity(image_path):
76
+ img = mpimg.imread(image_path)
77
+ # Convertir en nuances de gris si l'image est en couleur
78
+ if img.ndim == 3:
79
+ img_gray = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])
80
+ else:
81
+ img_gray = img
82
+ return np.mean(img_gray)
83
+
84
+ # Fonction pour extraire les sources depuis les urls
85
+ def source_extract(url):
86
+ pattern = re.compile(r'https?://(?:www\.)?([^/]+)')
87
+ match = pattern.search(url)
88
+ if match:
89
+ return match.group(1)
90
+ else:
91
+ return None
92
+
93
+ # Fonctions de plot des métriques
94
+ def plot_loss_curve(history):
95
+ fig = go.Figure()
96
+ fig.add_trace(go.Scatter(x = list(range(len(history['loss']))),
97
+ y = history['loss'],
98
+ mode = 'lines+markers',
99
+ name = 'Perte d\'entraînement',
100
+ marker = dict(color = 'lightblue')))
101
+
102
+ fig.add_trace(go.Scatter(x = list(range(len(history['val_loss']))),
103
+ y = history['val_loss'],
104
+ mode = 'lines+markers',
105
+ name = 'Perte de validation',
106
+ marker = dict(color = 'salmon')))
107
+
108
+ fig.update_layout(title = dict(text = "Courbe de Perte", font = dict(color = 'white')),
109
+ xaxis_title = dict(text = "Époque", font = dict(color = 'white')),
110
+ yaxis_title = dict(text = "Perte", font = dict(color = 'white')),
111
+ template = 'plotly_white',
112
+ paper_bgcolor = 'rgba(0,0,0,0)',
113
+ plot_bgcolor = 'rgba(0,0,0,0)',
114
+ legend = dict(font = dict(color = 'white')))
115
+ st.plotly_chart(fig)
116
+
117
+ def plot_precision_curve(history):
118
+ fig = go.Figure()
119
+ fig.add_trace(go.Scatter(x = list(range(len(history['precision']))),
120
+ y = history['precision'],
121
+ mode = 'lines+markers',
122
+ name = "Précision d'entraînement",
123
+ marker = dict(color = 'lightblue')))
124
+
125
+ fig.add_trace(go.Scatter(x = list(range(len(history['val_precision']))),
126
+ y = history['val_precision'],
127
+ mode = 'lines+markers',
128
+ name = 'Précision de validation',
129
+ marker = dict(color = 'salmon')))
130
+
131
+ fig.update_layout(title = dict(text = "Courbe de Précision", font = dict(color = 'white')),
132
+ xaxis_title=dict(text = "Époque", font = dict(color = 'white')),
133
+ yaxis_title=dict(text = "Précision", font = dict(color = 'white')),
134
+ template = 'plotly_white',
135
+ paper_bgcolor = 'rgba(0,0,0,0)',
136
+ plot_bgcolor = 'rgba(0,0,0,0)',
137
+ legend = dict(font = dict(color='white')))
138
+ st.plotly_chart(fig)
139
+
140
+ def plot_auc(history):
141
+ fig = go.Figure()
142
+ fig.add_trace(go.Scatter(x = list(range(len(history['auc']))),
143
+ y = history['auc'],
144
+ mode = 'lines+markers',
145
+ name = "AUC moyen d'entraînement",
146
+ marker = dict(color='lightblue')))
147
+
148
+ fig.add_trace(go.Scatter(x = list(range(len(history['val_auc']))),
149
+ y = history['val_auc'],
150
+ mode = 'lines+markers',
151
+ name = 'AUC moyen de validation',
152
+ marker=dict(color = 'salmon')))
153
+
154
+ fig.update_layout(title = "Courbe de AUC-ROC",
155
+ xaxis_title = "Époque",
156
+ yaxis_title = "Area Under Curve",
157
+ template = 'plotly_white',
158
+ paper_bgcolor = 'rgba(0,0,0,0)',
159
+ plot_bgcolor = 'rgba(0,0,0,0)',
160
+ legend = dict(font = dict(color = 'white')),
161
+ xaxis = dict(tickfont = dict(color = 'white')),
162
+ yaxis = dict(tickfont = dict(color = 'white')),
163
+ title_font = dict(color = 'white'))
164
+ st.plotly_chart(fig)
165
+
166
+ def plot_f1_score(history):
167
+ fig = go.Figure()
168
+ fig.add_trace(go.Scatter(x = list(range(len(history['f1_score']))),
169
+ y=np.mean(history['f1_score'], axis = 1),
170
+ mode = 'lines+markers',
171
+ name = "F1 Score d'entraînement",
172
+ marker = dict(color = 'lightblue')))
173
+
174
+ fig.add_trace(go.Scatter(x = list(range(len(history['val_f1_score']))),
175
+ y = np.mean(history['val_f1_score'], axis = 1),
176
+ mode = 'lines+markers',
177
+ name = 'F1 Score moyen de validation',
178
+ marker = dict(color = 'salmon')))
179
+
180
+ fig.update_layout(title = "Courbe de F1 Score",
181
+ xaxis_title = "Époque",
182
+ yaxis_title = "F1 Score",
183
+ template = 'plotly_white',
184
+ paper_bgcolor = 'rgba(0,0,0,0)',
185
+ plot_bgcolor = 'rgba(0,0,0,0)',
186
+ legend = dict(font = dict(color = 'white')),
187
+ xaxis = dict(tickfont = dict(color = 'white')),
188
+ yaxis = dict(tickfont = dict(color = 'white')),
189
+ title_font = dict(color = 'white'))
190
+ st.plotly_chart(fig)
191
+
192
+ # Plot des matrices de confusion
193
+ def plot_CM(matrix):
194
+ confusion_matrix = np.array(matrix)
195
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
196
+ st.title('Matrice de Confusion')
197
+ fig = ff.create_annotated_heatmap(z = confusion_matrix, x = class_names, y = class_names, colorscale = 'RdBu')
198
+ fig.update_layout(
199
+ title = 'Matrice de Confusion',
200
+ xaxis = dict(title = 'Classe Prédite'),
201
+ yaxis = dict(title = 'Classe Réelle')
202
+ )
203
+ st.plotly_chart(fig)
204
+
205
+ def plot_CM_ResNetV2():
206
+ confusion_lines = [
207
+ [192, 3, 7, 2],
208
+ [9, 148, 24, 0],
209
+ [5, 17, 139, 4],
210
+ [1, 0, 4, 165]
211
+ ]
212
+
213
+ confusion_matrix = np.array(confusion_lines)
214
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
215
+ size = max(confusion_matrix.shape)
216
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
217
+ fig.update_layout(
218
+ width=size*140,
219
+ height=size*140,
220
+ xaxis=dict(title='Classe Prédite'),
221
+ yaxis=dict(title='Classe Réelle')
222
+ )
223
+ st.plotly_chart(fig)
224
+
225
+ def plot_CM_ResNet121():
226
+ confusion_lines = [
227
+ [181, 11, 11, 1],
228
+ [8, 148, 25, 0],
229
+ [10, 17, 135, 3],
230
+ [2, 0, 1, 167]
231
+ ]
232
+
233
+ confusion_matrix = np.array(confusion_lines)
234
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
235
+ size = max(confusion_matrix.shape)
236
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
237
+ fig.update_layout(
238
+ width=size*140,
239
+ height=size*140,
240
+ xaxis=dict(title='Classe Prédite'),
241
+ yaxis=dict(title='Classe Réelle')
242
+ )
243
+ st.plotly_chart(fig)
244
+
245
+ def plot_CM_DenseNet201():
246
+ confusion_lines = [
247
+ [190, 5, 7, 2],
248
+ [6, 156, 19, 0],
249
+ [4, 21, 139, 1],
250
+ [1, 0, 0, 169]
251
+ ]
252
+
253
+ confusion_matrix = np.array(confusion_lines)
254
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
255
+ size = max(confusion_matrix.shape)
256
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
257
+ fig.update_layout(
258
+ width=size*140,
259
+ height=size*140,
260
+ xaxis=dict(title='Classe Prédite'),
261
+ yaxis=dict(title='Classe Réelle')
262
+ )
263
+ st.plotly_chart(fig)
264
+
265
+ def plot_CM_VGG16():
266
+
267
+ confusion_lines = [
268
+ [178, 5, 9, 2],
269
+ [4, 152, 17, 0],
270
+ [2, 11, 160, 3],
271
+ [0, 0, 4, 175]
272
+ ]
273
+
274
+ confusion_matrix = np.array(confusion_lines)
275
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
276
+ size = max(confusion_matrix.shape)
277
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
278
+ fig.update_layout(
279
+ width=size*140,
280
+ height=size*140,
281
+ xaxis=dict(title='Classe Prédite'),
282
+ yaxis=dict(title='Classe Réelle')
283
+ )
284
+ st.plotly_chart(fig)
285
+
286
+ def plot_CM_VGG19():
287
+ confusion_lines = [
288
+ [182, 7, 3, 0],
289
+ [7, 158, 8, 0],
290
+ [8, 21, 142, 5],
291
+ [1, 1, 3, 174]
292
+ ]
293
+
294
+ confusion_matrix = np.array(confusion_lines)
295
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
296
+ size = max(confusion_matrix.shape)
297
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
298
+ fig.update_layout(
299
+ width=size*140,
300
+ height=size*140,
301
+ xaxis=dict(title='Classe Prédite'),
302
+ yaxis=dict(title='Classe Réelle')
303
+ )
304
+ st.plotly_chart(fig)
305
+
306
+ def plot_CM_ConvnextTiny():
307
+ confusion_lines = [
308
+ [122, 11, 19, 0],
309
+ [13, 142, 15, 0],
310
+ [17, 14, 144, 1],
311
+ [4, 1, 7, 168]
312
+ ]
313
+
314
+ confusion_matrix = np.array(confusion_lines)
315
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
316
+ size = max(confusion_matrix.shape)
317
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
318
+ fig.update_layout(
319
+ width=size*140,
320
+ height=size*140,
321
+ xaxis=dict(title='Classe Prédite'),
322
+ yaxis=dict(title='Classe Réelle')
323
+ )
324
+ st.plotly_chart(fig)
325
+
326
+ def plot_CM_ConvnextBase():
327
+ confusion_lines = [
328
+ [168, 9, 15, 0],
329
+ [9, 152, 12, 0],
330
+ [12, 8, 153, 3],
331
+ [2, 0, 8, 169]
332
+ ]
333
+
334
+ confusion_matrix = np.array(confusion_lines)
335
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
336
+ size = max(confusion_matrix.shape)
337
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
338
+ fig.update_layout(
339
+ width=size*140,
340
+ height=size*140,
341
+ xaxis=dict(title='Classe Prédite'),
342
+ yaxis=dict(title='Classe Réelle')
343
+ )
344
+ st.plotly_chart(fig)
345
+
346
+ def plot_CM_EfficientNet_B4():
347
+ confusion_lines = [
348
+ [177, 17, 10, 0],
349
+ [3, 148, 30, 0],
350
+ [1, 13, 151, 0],
351
+ [2, 1, 9, 158]
352
+ ]
353
+
354
+ confusion_matrix = np.array(confusion_lines)
355
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
356
+ size = max(confusion_matrix.shape)
357
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
358
+ fig.update_layout(
359
+ width=size*140,
360
+ height=size*140,
361
+ xaxis=dict(title='Classe Prédite'),
362
+ yaxis=dict(title='Classe Réelle')
363
+ )
364
+ st.plotly_chart(fig)
365
+
366
+ def plot_CM_VGG16_FT():
367
+ confusion_lines = [
368
+ [229, 7, 6, 0],
369
+ [1, 198, 16, 1],
370
+ [7, 6, 199, 4],
371
+ [0, 0, 1, 225]
372
+ ]
373
+
374
+ confusion_matrix = np.array(confusion_lines)
375
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
376
+ size = max(confusion_matrix.shape)
377
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
378
+ fig.update_layout(
379
+ width=size*140,
380
+ height=size*140,
381
+ xaxis=dict(title='Classe Prédite'),
382
+ yaxis=dict(title='Classe Réelle')
383
+ )
384
+ st.plotly_chart(fig)
385
+
386
+ def plot_CM_ResNetFT():
387
+
388
+ confusion_lines = [
389
+ [278, 6, 3, 1],
390
+ [3, 242, 16, 0],
391
+ [7, 38, 197, 17],
392
+ [2, 1, 0, 265]
393
+ ]
394
+
395
+ confusion_matrix = np.array(confusion_lines)
396
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
397
+ size = max(confusion_matrix.shape)
398
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
399
+ fig.update_layout(
400
+ width=size*140,
401
+ height=size*140,
402
+ xaxis=dict(title='Classe Prédite'),
403
+ yaxis=dict(title='Classe Réelle')
404
+ )
405
+ st.plotly_chart(fig)
406
+
407
+ def plot_CM_DenseNetFT():
408
+ confusion_lines = [
409
+ [285, 1, 2, 0],
410
+ [3,235 , 23, 0],
411
+ [5, 17, 232, 5],
412
+ [0, 0, 3, 265]
413
+ ]
414
+
415
+ confusion_matrix = np.array(confusion_lines)
416
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
417
+ size = max(confusion_matrix.shape)
418
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
419
+ fig.update_layout(
420
+ width=size*140,
421
+ height=size*140,
422
+ xaxis=dict(title='Classe Prédite'),
423
+ yaxis=dict(title='Classe Réelle')
424
+ )
425
+ st.plotly_chart(fig)
426
+
427
+ def plot_CM_ENetB4():
428
+ confusion_lines = [
429
+ [282, 2, 3, 1],
430
+ [4, 244, 13, 0],
431
+ [2, 22, 220, 15],
432
+ [1, 0, 0, 267]
433
+ ]
434
+
435
+ confusion_matrix = np.array(confusion_lines)
436
+ class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
437
+ size = max(confusion_matrix.shape)
438
+ fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
439
+ fig.update_layout(
440
+ width=size*140,
441
+ height=size*140,
442
+ xaxis=dict(title='Classe Prédite'),
443
+ yaxis=dict(title='Classe Réelle')
444
+ )
445
+ st.plotly_chart(fig)
446
+
447
+ # GRAD-CAM
448
+ def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
449
+ model_output = model.output if isinstance(model.output, list) else [model.output]
450
+ grad_model = tf.keras.models.Model(
451
+ inputs=model.inputs,
452
+ outputs=[model.get_layer(last_conv_layer_name).output] + model_output
453
+ )
454
+
455
+ with tf.GradientTape() as tape:
456
+ last_conv_layer_output, preds = grad_model(img_array)
457
+ if pred_index is None:
458
+ pred_index = tf.argmax(preds[0])
459
+ class_channel = preds[:, pred_index]
460
+
461
+ grads = tape.gradient(class_channel, last_conv_layer_output)
462
+
463
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
464
+
465
+ last_conv_layer_output = last_conv_layer_output[0]
466
+ heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
467
+ heatmap = tf.squeeze(heatmap)
468
+
469
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
470
+ return heatmap.numpy()
471
+
472
+ def save_and_display_gradcam(img, heatmap, alpha=0.4):
473
+ heatmap = np.uint8(255 * heatmap)
474
+
475
+ jet = plt.cm.jet
476
+
477
+ jet_colors = jet(np.arange(256))[:, :3]
478
+ jet_heatmap = jet_colors[heatmap]
479
+
480
+ jet_heatmap = array_to_img(jet_heatmap)
481
+ jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
482
+ jet_heatmap = img_to_array(jet_heatmap)
483
+
484
+ superimposed_img = jet_heatmap * alpha + img
485
+ superimposed_img = array_to_img(superimposed_img)
486
+
487
+ return superimposed_img