Commit
•
2ec1466
1
Parent(s):
062b83a
Add costom_functions
Browse files- custom_functions.py +487 -0
custom_functions.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import uuid
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib.image as mpimg
|
8 |
+
import plotly.graph_objects as go
|
9 |
+
import plotly.figure_factory as ff
|
10 |
+
import tensorflow as tf
|
11 |
+
import keras
|
12 |
+
|
13 |
+
from keras.layers import Conv2D
|
14 |
+
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
|
15 |
+
from keras.models import load_model
|
16 |
+
|
17 |
+
# Afficher le nom et l'icone des réseaux d'un membre
|
18 |
+
def show_profile(name, linkedin_url, github_url):
|
19 |
+
st.markdown("""
|
20 |
+
<style>
|
21 |
+
.icon-img {
|
22 |
+
height: 20px; # Taille des icônes
|
23 |
+
width: 20px;
|
24 |
+
margin-top: 0px;
|
25 |
+
margin-left: 5px;
|
26 |
+
margin-right:5px;
|
27 |
+
}
|
28 |
+
.name-with-icons {
|
29 |
+
display: flex-shrink;
|
30 |
+
align-items: center;
|
31 |
+
margin-top: -5px;
|
32 |
+
margin-bottom: 0px;
|
33 |
+
}
|
34 |
+
.name-with-icons:first-child {
|
35 |
+
margin-top: -5px; /* Marge négative pour le premier élément */
|
36 |
+
}
|
37 |
+
</style>
|
38 |
+
""", unsafe_allow_html = True)
|
39 |
+
|
40 |
+
st.markdown(f"""
|
41 |
+
<div class='name-with-icons'>
|
42 |
+
<a href='{github_url}' target='_blank'><img class='icon-img' src='https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg'></a>
|
43 |
+
<a href='{linkedin_url}' target='_blank'><img class='icon-img' src='https://upload.wikimedia.org/wikipedia/commons/c/ca/LinkedIn_logo_initials.png'></a>
|
44 |
+
{name}
|
45 |
+
</div>
|
46 |
+
""", unsafe_allow_html = True)
|
47 |
+
|
48 |
+
# Création de bloc de texte esthétiques
|
49 |
+
def create_styled_box(text, text_color, background_color, alignment = 'left', display = 'block'):
|
50 |
+
unique_id = uuid.uuid4().hex
|
51 |
+
# Ajout de styles CSS pour le cadre avec une couleur de fond et une couleur de texte personnalisée
|
52 |
+
st.markdown(f"""
|
53 |
+
<style>
|
54 |
+
.styled-box-{unique_id} {{
|
55 |
+
background-color: {background_color}; /* couleur de fond */
|
56 |
+
color: {text_color}; /* couleur de texte */
|
57 |
+
padding: 10px 20px; /* espace intérieur vertical et horizontal */
|
58 |
+
border: 1px solid {text_color}; /* bordure de la même couleur que le texte */
|
59 |
+
border-radius: 8px; /* bord arrondi */
|
60 |
+
font-size: 16px; /* taille de la police */
|
61 |
+
font-weight: regular; /* format de la police */
|
62 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.2); /* ombre pour donner de la profondeur */
|
63 |
+
text-align: {alignment}; /* alignement du texte */
|
64 |
+
display: {display};
|
65 |
+
}}
|
66 |
+
</style>
|
67 |
+
""", unsafe_allow_html = True)
|
68 |
+
|
69 |
+
# Afficher le texte dans le cadre stylisé
|
70 |
+
st.markdown(f"""
|
71 |
+
<div class="styled-box-{unique_id}">{text}</div>
|
72 |
+
""", unsafe_allow_html = True)
|
73 |
+
|
74 |
+
# Fonction pour calculer l'intensité lumineuse moyenne d'une image
|
75 |
+
def calc_mean_intensity(image_path):
|
76 |
+
img = mpimg.imread(image_path)
|
77 |
+
# Convertir en nuances de gris si l'image est en couleur
|
78 |
+
if img.ndim == 3:
|
79 |
+
img_gray = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])
|
80 |
+
else:
|
81 |
+
img_gray = img
|
82 |
+
return np.mean(img_gray)
|
83 |
+
|
84 |
+
# Fonction pour extraire les sources depuis les urls
|
85 |
+
def source_extract(url):
|
86 |
+
pattern = re.compile(r'https?://(?:www\.)?([^/]+)')
|
87 |
+
match = pattern.search(url)
|
88 |
+
if match:
|
89 |
+
return match.group(1)
|
90 |
+
else:
|
91 |
+
return None
|
92 |
+
|
93 |
+
# Fonctions de plot des métriques
|
94 |
+
def plot_loss_curve(history):
|
95 |
+
fig = go.Figure()
|
96 |
+
fig.add_trace(go.Scatter(x = list(range(len(history['loss']))),
|
97 |
+
y = history['loss'],
|
98 |
+
mode = 'lines+markers',
|
99 |
+
name = 'Perte d\'entraînement',
|
100 |
+
marker = dict(color = 'lightblue')))
|
101 |
+
|
102 |
+
fig.add_trace(go.Scatter(x = list(range(len(history['val_loss']))),
|
103 |
+
y = history['val_loss'],
|
104 |
+
mode = 'lines+markers',
|
105 |
+
name = 'Perte de validation',
|
106 |
+
marker = dict(color = 'salmon')))
|
107 |
+
|
108 |
+
fig.update_layout(title = dict(text = "Courbe de Perte", font = dict(color = 'white')),
|
109 |
+
xaxis_title = dict(text = "Époque", font = dict(color = 'white')),
|
110 |
+
yaxis_title = dict(text = "Perte", font = dict(color = 'white')),
|
111 |
+
template = 'plotly_white',
|
112 |
+
paper_bgcolor = 'rgba(0,0,0,0)',
|
113 |
+
plot_bgcolor = 'rgba(0,0,0,0)',
|
114 |
+
legend = dict(font = dict(color = 'white')))
|
115 |
+
st.plotly_chart(fig)
|
116 |
+
|
117 |
+
def plot_precision_curve(history):
|
118 |
+
fig = go.Figure()
|
119 |
+
fig.add_trace(go.Scatter(x = list(range(len(history['precision']))),
|
120 |
+
y = history['precision'],
|
121 |
+
mode = 'lines+markers',
|
122 |
+
name = "Précision d'entraînement",
|
123 |
+
marker = dict(color = 'lightblue')))
|
124 |
+
|
125 |
+
fig.add_trace(go.Scatter(x = list(range(len(history['val_precision']))),
|
126 |
+
y = history['val_precision'],
|
127 |
+
mode = 'lines+markers',
|
128 |
+
name = 'Précision de validation',
|
129 |
+
marker = dict(color = 'salmon')))
|
130 |
+
|
131 |
+
fig.update_layout(title = dict(text = "Courbe de Précision", font = dict(color = 'white')),
|
132 |
+
xaxis_title=dict(text = "Époque", font = dict(color = 'white')),
|
133 |
+
yaxis_title=dict(text = "Précision", font = dict(color = 'white')),
|
134 |
+
template = 'plotly_white',
|
135 |
+
paper_bgcolor = 'rgba(0,0,0,0)',
|
136 |
+
plot_bgcolor = 'rgba(0,0,0,0)',
|
137 |
+
legend = dict(font = dict(color='white')))
|
138 |
+
st.plotly_chart(fig)
|
139 |
+
|
140 |
+
def plot_auc(history):
|
141 |
+
fig = go.Figure()
|
142 |
+
fig.add_trace(go.Scatter(x = list(range(len(history['auc']))),
|
143 |
+
y = history['auc'],
|
144 |
+
mode = 'lines+markers',
|
145 |
+
name = "AUC moyen d'entraînement",
|
146 |
+
marker = dict(color='lightblue')))
|
147 |
+
|
148 |
+
fig.add_trace(go.Scatter(x = list(range(len(history['val_auc']))),
|
149 |
+
y = history['val_auc'],
|
150 |
+
mode = 'lines+markers',
|
151 |
+
name = 'AUC moyen de validation',
|
152 |
+
marker=dict(color = 'salmon')))
|
153 |
+
|
154 |
+
fig.update_layout(title = "Courbe de AUC-ROC",
|
155 |
+
xaxis_title = "Époque",
|
156 |
+
yaxis_title = "Area Under Curve",
|
157 |
+
template = 'plotly_white',
|
158 |
+
paper_bgcolor = 'rgba(0,0,0,0)',
|
159 |
+
plot_bgcolor = 'rgba(0,0,0,0)',
|
160 |
+
legend = dict(font = dict(color = 'white')),
|
161 |
+
xaxis = dict(tickfont = dict(color = 'white')),
|
162 |
+
yaxis = dict(tickfont = dict(color = 'white')),
|
163 |
+
title_font = dict(color = 'white'))
|
164 |
+
st.plotly_chart(fig)
|
165 |
+
|
166 |
+
def plot_f1_score(history):
|
167 |
+
fig = go.Figure()
|
168 |
+
fig.add_trace(go.Scatter(x = list(range(len(history['f1_score']))),
|
169 |
+
y=np.mean(history['f1_score'], axis = 1),
|
170 |
+
mode = 'lines+markers',
|
171 |
+
name = "F1 Score d'entraînement",
|
172 |
+
marker = dict(color = 'lightblue')))
|
173 |
+
|
174 |
+
fig.add_trace(go.Scatter(x = list(range(len(history['val_f1_score']))),
|
175 |
+
y = np.mean(history['val_f1_score'], axis = 1),
|
176 |
+
mode = 'lines+markers',
|
177 |
+
name = 'F1 Score moyen de validation',
|
178 |
+
marker = dict(color = 'salmon')))
|
179 |
+
|
180 |
+
fig.update_layout(title = "Courbe de F1 Score",
|
181 |
+
xaxis_title = "Époque",
|
182 |
+
yaxis_title = "F1 Score",
|
183 |
+
template = 'plotly_white',
|
184 |
+
paper_bgcolor = 'rgba(0,0,0,0)',
|
185 |
+
plot_bgcolor = 'rgba(0,0,0,0)',
|
186 |
+
legend = dict(font = dict(color = 'white')),
|
187 |
+
xaxis = dict(tickfont = dict(color = 'white')),
|
188 |
+
yaxis = dict(tickfont = dict(color = 'white')),
|
189 |
+
title_font = dict(color = 'white'))
|
190 |
+
st.plotly_chart(fig)
|
191 |
+
|
192 |
+
# Plot des matrices de confusion
|
193 |
+
def plot_CM(matrix):
|
194 |
+
confusion_matrix = np.array(matrix)
|
195 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
196 |
+
st.title('Matrice de Confusion')
|
197 |
+
fig = ff.create_annotated_heatmap(z = confusion_matrix, x = class_names, y = class_names, colorscale = 'RdBu')
|
198 |
+
fig.update_layout(
|
199 |
+
title = 'Matrice de Confusion',
|
200 |
+
xaxis = dict(title = 'Classe Prédite'),
|
201 |
+
yaxis = dict(title = 'Classe Réelle')
|
202 |
+
)
|
203 |
+
st.plotly_chart(fig)
|
204 |
+
|
205 |
+
def plot_CM_ResNetV2():
|
206 |
+
confusion_lines = [
|
207 |
+
[192, 3, 7, 2],
|
208 |
+
[9, 148, 24, 0],
|
209 |
+
[5, 17, 139, 4],
|
210 |
+
[1, 0, 4, 165]
|
211 |
+
]
|
212 |
+
|
213 |
+
confusion_matrix = np.array(confusion_lines)
|
214 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
215 |
+
size = max(confusion_matrix.shape)
|
216 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
217 |
+
fig.update_layout(
|
218 |
+
width=size*140,
|
219 |
+
height=size*140,
|
220 |
+
xaxis=dict(title='Classe Prédite'),
|
221 |
+
yaxis=dict(title='Classe Réelle')
|
222 |
+
)
|
223 |
+
st.plotly_chart(fig)
|
224 |
+
|
225 |
+
def plot_CM_ResNet121():
|
226 |
+
confusion_lines = [
|
227 |
+
[181, 11, 11, 1],
|
228 |
+
[8, 148, 25, 0],
|
229 |
+
[10, 17, 135, 3],
|
230 |
+
[2, 0, 1, 167]
|
231 |
+
]
|
232 |
+
|
233 |
+
confusion_matrix = np.array(confusion_lines)
|
234 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
235 |
+
size = max(confusion_matrix.shape)
|
236 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
237 |
+
fig.update_layout(
|
238 |
+
width=size*140,
|
239 |
+
height=size*140,
|
240 |
+
xaxis=dict(title='Classe Prédite'),
|
241 |
+
yaxis=dict(title='Classe Réelle')
|
242 |
+
)
|
243 |
+
st.plotly_chart(fig)
|
244 |
+
|
245 |
+
def plot_CM_DenseNet201():
|
246 |
+
confusion_lines = [
|
247 |
+
[190, 5, 7, 2],
|
248 |
+
[6, 156, 19, 0],
|
249 |
+
[4, 21, 139, 1],
|
250 |
+
[1, 0, 0, 169]
|
251 |
+
]
|
252 |
+
|
253 |
+
confusion_matrix = np.array(confusion_lines)
|
254 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
255 |
+
size = max(confusion_matrix.shape)
|
256 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
257 |
+
fig.update_layout(
|
258 |
+
width=size*140,
|
259 |
+
height=size*140,
|
260 |
+
xaxis=dict(title='Classe Prédite'),
|
261 |
+
yaxis=dict(title='Classe Réelle')
|
262 |
+
)
|
263 |
+
st.plotly_chart(fig)
|
264 |
+
|
265 |
+
def plot_CM_VGG16():
|
266 |
+
|
267 |
+
confusion_lines = [
|
268 |
+
[178, 5, 9, 2],
|
269 |
+
[4, 152, 17, 0],
|
270 |
+
[2, 11, 160, 3],
|
271 |
+
[0, 0, 4, 175]
|
272 |
+
]
|
273 |
+
|
274 |
+
confusion_matrix = np.array(confusion_lines)
|
275 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
276 |
+
size = max(confusion_matrix.shape)
|
277 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
278 |
+
fig.update_layout(
|
279 |
+
width=size*140,
|
280 |
+
height=size*140,
|
281 |
+
xaxis=dict(title='Classe Prédite'),
|
282 |
+
yaxis=dict(title='Classe Réelle')
|
283 |
+
)
|
284 |
+
st.plotly_chart(fig)
|
285 |
+
|
286 |
+
def plot_CM_VGG19():
|
287 |
+
confusion_lines = [
|
288 |
+
[182, 7, 3, 0],
|
289 |
+
[7, 158, 8, 0],
|
290 |
+
[8, 21, 142, 5],
|
291 |
+
[1, 1, 3, 174]
|
292 |
+
]
|
293 |
+
|
294 |
+
confusion_matrix = np.array(confusion_lines)
|
295 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
296 |
+
size = max(confusion_matrix.shape)
|
297 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
298 |
+
fig.update_layout(
|
299 |
+
width=size*140,
|
300 |
+
height=size*140,
|
301 |
+
xaxis=dict(title='Classe Prédite'),
|
302 |
+
yaxis=dict(title='Classe Réelle')
|
303 |
+
)
|
304 |
+
st.plotly_chart(fig)
|
305 |
+
|
306 |
+
def plot_CM_ConvnextTiny():
|
307 |
+
confusion_lines = [
|
308 |
+
[122, 11, 19, 0],
|
309 |
+
[13, 142, 15, 0],
|
310 |
+
[17, 14, 144, 1],
|
311 |
+
[4, 1, 7, 168]
|
312 |
+
]
|
313 |
+
|
314 |
+
confusion_matrix = np.array(confusion_lines)
|
315 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
316 |
+
size = max(confusion_matrix.shape)
|
317 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
318 |
+
fig.update_layout(
|
319 |
+
width=size*140,
|
320 |
+
height=size*140,
|
321 |
+
xaxis=dict(title='Classe Prédite'),
|
322 |
+
yaxis=dict(title='Classe Réelle')
|
323 |
+
)
|
324 |
+
st.plotly_chart(fig)
|
325 |
+
|
326 |
+
def plot_CM_ConvnextBase():
|
327 |
+
confusion_lines = [
|
328 |
+
[168, 9, 15, 0],
|
329 |
+
[9, 152, 12, 0],
|
330 |
+
[12, 8, 153, 3],
|
331 |
+
[2, 0, 8, 169]
|
332 |
+
]
|
333 |
+
|
334 |
+
confusion_matrix = np.array(confusion_lines)
|
335 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
336 |
+
size = max(confusion_matrix.shape)
|
337 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
338 |
+
fig.update_layout(
|
339 |
+
width=size*140,
|
340 |
+
height=size*140,
|
341 |
+
xaxis=dict(title='Classe Prédite'),
|
342 |
+
yaxis=dict(title='Classe Réelle')
|
343 |
+
)
|
344 |
+
st.plotly_chart(fig)
|
345 |
+
|
346 |
+
def plot_CM_EfficientNet_B4():
|
347 |
+
confusion_lines = [
|
348 |
+
[177, 17, 10, 0],
|
349 |
+
[3, 148, 30, 0],
|
350 |
+
[1, 13, 151, 0],
|
351 |
+
[2, 1, 9, 158]
|
352 |
+
]
|
353 |
+
|
354 |
+
confusion_matrix = np.array(confusion_lines)
|
355 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
356 |
+
size = max(confusion_matrix.shape)
|
357 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
358 |
+
fig.update_layout(
|
359 |
+
width=size*140,
|
360 |
+
height=size*140,
|
361 |
+
xaxis=dict(title='Classe Prédite'),
|
362 |
+
yaxis=dict(title='Classe Réelle')
|
363 |
+
)
|
364 |
+
st.plotly_chart(fig)
|
365 |
+
|
366 |
+
def plot_CM_VGG16_FT():
|
367 |
+
confusion_lines = [
|
368 |
+
[229, 7, 6, 0],
|
369 |
+
[1, 198, 16, 1],
|
370 |
+
[7, 6, 199, 4],
|
371 |
+
[0, 0, 1, 225]
|
372 |
+
]
|
373 |
+
|
374 |
+
confusion_matrix = np.array(confusion_lines)
|
375 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
376 |
+
size = max(confusion_matrix.shape)
|
377 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
378 |
+
fig.update_layout(
|
379 |
+
width=size*140,
|
380 |
+
height=size*140,
|
381 |
+
xaxis=dict(title='Classe Prédite'),
|
382 |
+
yaxis=dict(title='Classe Réelle')
|
383 |
+
)
|
384 |
+
st.plotly_chart(fig)
|
385 |
+
|
386 |
+
def plot_CM_ResNetFT():
|
387 |
+
|
388 |
+
confusion_lines = [
|
389 |
+
[278, 6, 3, 1],
|
390 |
+
[3, 242, 16, 0],
|
391 |
+
[7, 38, 197, 17],
|
392 |
+
[2, 1, 0, 265]
|
393 |
+
]
|
394 |
+
|
395 |
+
confusion_matrix = np.array(confusion_lines)
|
396 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
397 |
+
size = max(confusion_matrix.shape)
|
398 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
399 |
+
fig.update_layout(
|
400 |
+
width=size*140,
|
401 |
+
height=size*140,
|
402 |
+
xaxis=dict(title='Classe Prédite'),
|
403 |
+
yaxis=dict(title='Classe Réelle')
|
404 |
+
)
|
405 |
+
st.plotly_chart(fig)
|
406 |
+
|
407 |
+
def plot_CM_DenseNetFT():
|
408 |
+
confusion_lines = [
|
409 |
+
[285, 1, 2, 0],
|
410 |
+
[3,235 , 23, 0],
|
411 |
+
[5, 17, 232, 5],
|
412 |
+
[0, 0, 3, 265]
|
413 |
+
]
|
414 |
+
|
415 |
+
confusion_matrix = np.array(confusion_lines)
|
416 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
417 |
+
size = max(confusion_matrix.shape)
|
418 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
419 |
+
fig.update_layout(
|
420 |
+
width=size*140,
|
421 |
+
height=size*140,
|
422 |
+
xaxis=dict(title='Classe Prédite'),
|
423 |
+
yaxis=dict(title='Classe Réelle')
|
424 |
+
)
|
425 |
+
st.plotly_chart(fig)
|
426 |
+
|
427 |
+
def plot_CM_ENetB4():
|
428 |
+
confusion_lines = [
|
429 |
+
[282, 2, 3, 1],
|
430 |
+
[4, 244, 13, 0],
|
431 |
+
[2, 22, 220, 15],
|
432 |
+
[1, 0, 0, 267]
|
433 |
+
]
|
434 |
+
|
435 |
+
confusion_matrix = np.array(confusion_lines)
|
436 |
+
class_names = ['Covid', 'Lung opacity', 'Normal', 'Viral Pneumonia']
|
437 |
+
size = max(confusion_matrix.shape)
|
438 |
+
fig = ff.create_annotated_heatmap(z=confusion_matrix, x=class_names, y=class_names, colorscale='Cividis')
|
439 |
+
fig.update_layout(
|
440 |
+
width=size*140,
|
441 |
+
height=size*140,
|
442 |
+
xaxis=dict(title='Classe Prédite'),
|
443 |
+
yaxis=dict(title='Classe Réelle')
|
444 |
+
)
|
445 |
+
st.plotly_chart(fig)
|
446 |
+
|
447 |
+
# GRAD-CAM
|
448 |
+
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
|
449 |
+
model_output = model.output if isinstance(model.output, list) else [model.output]
|
450 |
+
grad_model = tf.keras.models.Model(
|
451 |
+
inputs=model.inputs,
|
452 |
+
outputs=[model.get_layer(last_conv_layer_name).output] + model_output
|
453 |
+
)
|
454 |
+
|
455 |
+
with tf.GradientTape() as tape:
|
456 |
+
last_conv_layer_output, preds = grad_model(img_array)
|
457 |
+
if pred_index is None:
|
458 |
+
pred_index = tf.argmax(preds[0])
|
459 |
+
class_channel = preds[:, pred_index]
|
460 |
+
|
461 |
+
grads = tape.gradient(class_channel, last_conv_layer_output)
|
462 |
+
|
463 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
464 |
+
|
465 |
+
last_conv_layer_output = last_conv_layer_output[0]
|
466 |
+
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
|
467 |
+
heatmap = tf.squeeze(heatmap)
|
468 |
+
|
469 |
+
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
470 |
+
return heatmap.numpy()
|
471 |
+
|
472 |
+
def save_and_display_gradcam(img, heatmap, alpha=0.4):
|
473 |
+
heatmap = np.uint8(255 * heatmap)
|
474 |
+
|
475 |
+
jet = plt.cm.jet
|
476 |
+
|
477 |
+
jet_colors = jet(np.arange(256))[:, :3]
|
478 |
+
jet_heatmap = jet_colors[heatmap]
|
479 |
+
|
480 |
+
jet_heatmap = array_to_img(jet_heatmap)
|
481 |
+
jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
|
482 |
+
jet_heatmap = img_to_array(jet_heatmap)
|
483 |
+
|
484 |
+
superimposed_img = jet_heatmap * alpha + img
|
485 |
+
superimposed_img = array_to_img(superimposed_img)
|
486 |
+
|
487 |
+
return superimposed_img
|