analist commited on
Commit
ee9aa01
·
verified ·
1 Parent(s): 00994c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -126
app.py CHANGED
@@ -96,148 +96,212 @@ def plot_feature_importance(model, feature_names, model_type):
96
  plt.title(f"Feature Importance - {model_type}")
97
  return plt.gcf()
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def app():
100
- st.title("Interpréteur de Modèles ML")
 
101
 
102
  # Load data
103
  X_train, y_train, X_test, y_test, feature_names = load_data()
104
 
105
  # Train models if not in session state
106
  if 'model_results' not in st.session_state:
107
- with st.spinner("Entraînement des modèles en cours..."):
108
  st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
109
 
110
- # Sidebar
111
- st.sidebar.title("Navigation")
112
- selected_model = st.sidebar.selectbox(
113
- "Sélectionnez un modèle",
114
- list(st.session_state.model_results.keys())
115
- )
116
-
117
- page = st.sidebar.radio(
118
- "Sélectionnez une section",
119
- ["Performance des modèles",
120
- "Interprétation du modèle",
121
- "Analyse des caractéristiques",
122
- "Simulateur de prédictions"]
123
- )
 
 
 
124
 
125
  current_model = st.session_state.model_results[selected_model]['model']
126
 
127
- # Performance des modèles
128
- if page == "Performance des modèles":
129
- st.header("Performance des modèles")
130
-
131
- # Plot global performance comparison
132
- st.subheader("Comparaison des performances")
133
- performance_fig = plot_model_performance(st.session_state.model_results)
134
- st.pyplot(performance_fig)
135
-
136
- # Detailed metrics for selected model
137
- st.subheader(f"Métriques détaillées - {selected_model}")
138
- col1, col2 = st.columns(2)
139
-
140
- with col1:
141
- st.write("Métriques d'entraînement:")
142
- for metric, value in st.session_state.model_results[selected_model]['train_metrics'].items():
143
- st.write(f"{metric}: {value:.4f}")
144
-
145
- with col2:
146
- st.write("Métriques de test:")
147
- for metric, value in st.session_state.model_results[selected_model]['test_metrics'].items():
148
- st.write(f"{metric}: {value:.4f}")
149
-
150
- # Interprétation du modèle
151
- elif page == "Interprétation du modèle":
152
- st.header(f"Interprétation du modèle - {selected_model}")
153
-
154
- if selected_model in ["Decision Tree", "Random Forest"]:
155
- if selected_model == "Decision Tree":
156
- st.subheader("Visualisation de l'arbre")
157
- max_depth = st.slider("Profondeur maximale à afficher", 1, 5, 3)
158
- fig, ax = plt.subplots(figsize=(20, 10))
159
- plot_tree(current_model, feature_names=list(feature_names),
160
- max_depth=max_depth, filled=True, rounded=True)
161
- st.pyplot(fig)
162
 
163
- st.subheader("Règles de décision importantes")
164
- if selected_model == "Decision Tree":
165
- st.text(export_text(current_model, feature_names=list(feature_names)))
166
-
167
- # SHAP values for all models
168
- st.subheader("SHAP Values")
169
- with st.spinner("Calcul des valeurs SHAP en cours..."):
170
- explainer = shap.TreeExplainer(current_model) if selected_model != "Logistic Regression" \
171
- else shap.LinearExplainer(current_model, X_train)
172
- shap_values = explainer.shap_values(X_train[:100]) # Using first 100 samples for speed
173
 
174
- fig, ax = plt.subplots(figsize=(10, 6))
175
- shap.summary_plot(shap_values, X_train[:100], feature_names=list(feature_names),
176
- show=False)
177
- st.pyplot(fig)
178
-
179
- # Analyse des caractéristiques
180
- elif page == "Analyse des caractéristiques":
181
- st.header("Analyse des caractéristiques")
182
-
183
- # Feature importance
184
- st.subheader("Importance des caractéristiques")
185
- importance_fig = plot_feature_importance(current_model, feature_names, selected_model)
186
- st.pyplot(importance_fig)
187
-
188
- # Feature correlation
189
- st.subheader("Matrice de corrélation")
190
- corr_matrix = X_train.corr()
191
- fig, ax = plt.subplots(figsize=(10, 8))
192
- sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0)
193
- st.pyplot(fig)
194
-
195
- # Simulateur de prédictions
196
- else:
197
- st.header("Simulateur de prédictions")
198
-
199
- input_values = {}
200
- for feature in feature_names:
201
- if X_train[feature].dtype == 'object':
202
- input_values[feature] = st.selectbox(
203
- f"Sélectionnez {feature}",
204
- options=X_train[feature].unique()
205
- )
206
- else:
207
- input_values[feature] = st.slider(
208
- f"Valeur pour {feature}",
209
- float(X_train[feature].min()),
210
- float(X_train[feature].max()),
211
- float(X_train[feature].mean())
212
- )
213
-
214
- if st.button("Prédire"):
215
- input_df = pd.DataFrame([input_values])
216
 
217
- prediction = current_model.predict_proba(input_df)
 
 
 
218
 
219
- st.write("Probabilités prédites:")
220
- st.write({f"Classe {i}": f"{prob:.2%}" for i, prob in enumerate(prediction[0])})
221
-
222
- if selected_model == "Decision Tree":
223
- st.subheader("Chemin de décision")
224
- node_indicator = current_model.decision_path(input_df)
225
- leaf_id = current_model.apply(input_df)
226
-
227
- node_index = node_indicator.indices[node_indicator.indptr[0]:node_indicator.indptr[1]]
228
-
229
- rules = []
230
- for node_id in node_index:
231
- if node_id != leaf_id[0]:
232
- threshold = current_model.tree_.threshold[node_id]
233
- feature = feature_names[current_model.tree_.feature[node_id]]
234
- if input_df.iloc[0][feature] <= threshold:
235
- rules.append(f"{feature} ≤ {threshold:.2f}")
236
- else:
237
- rules.append(f"{feature} > {threshold:.2f}")
238
-
239
- for rule in rules:
240
- st.write(rule)
241
 
242
  if __name__ == "__main__":
243
  app()
 
96
  plt.title(f"Feature Importance - {model_type}")
97
  return plt.gcf()
98
 
99
+ import streamlit as st
100
+ import pandas as pd
101
+ import numpy as np
102
+ import matplotlib.pyplot as plt
103
+ from sklearn.tree import plot_tree, export_text
104
+ import seaborn as sns
105
+ from sklearn.preprocessing import LabelEncoder
106
+ from sklearn.ensemble import RandomForestClassifier
107
+ from sklearn.tree import DecisionTreeClassifier
108
+ from sklearn.ensemble import GradientBoostingClassifier
109
+ from sklearn.linear_model import LogisticRegression
110
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
111
+ import shap
112
+
113
+ # Configuration de la page et du thème
114
+ st.set_page_config(
115
+ page_title="ML Model Interpreter",
116
+ layout="wide",
117
+ initial_sidebar_state="expanded"
118
+ )
119
+
120
+ # CSS personnalisé
121
+ st.markdown("""
122
+ <style>
123
+ /* Couleurs principales */
124
+ :root {
125
+ --primary-blue: #1E88E5;
126
+ --light-blue: #90CAF9;
127
+ --dark-blue: #0D47A1;
128
+ --white: #FFFFFF;
129
+ }
130
+
131
+ /* En-tête principal */
132
+ .main-header {
133
+ color: var(--dark-blue);
134
+ text-align: center;
135
+ padding: 1rem;
136
+ background: linear-gradient(90deg, var(--white) 0%, var(--light-blue) 50%, var(--white) 100%);
137
+ border-radius: 10px;
138
+ margin-bottom: 2rem;
139
+ }
140
+
141
+ /* Carte pour les métriques */
142
+ .metric-card {
143
+ background-color: white;
144
+ padding: 1.5rem;
145
+ border-radius: 10px;
146
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
147
+ margin-bottom: 1rem;
148
+ }
149
+
150
+ /* Style pour les sous-titres */
151
+ .sub-header {
152
+ color: var(--primary-blue);
153
+ border-bottom: 2px solid var(--light-blue);
154
+ padding-bottom: 0.5rem;
155
+ margin-bottom: 1rem;
156
+ }
157
+
158
+ /* Style pour les valeurs de métriques */
159
+ .metric-value {
160
+ font-size: 1.5rem;
161
+ font-weight: bold;
162
+ color: var(--primary-blue);
163
+ }
164
+
165
+ /* Style pour la barre latérale */
166
+ .sidebar .sidebar-content {
167
+ background-color: var(--white);
168
+ }
169
+
170
+ /* Style pour les boutons */
171
+ .stButton > button {
172
+ background-color: var(--primary-blue);
173
+ color: white;
174
+ border-radius: 5px;
175
+ border: none;
176
+ padding: 0.5rem 1rem;
177
+ }
178
+
179
+ /* Style pour les sliders */
180
+ .stSlider > div > div {
181
+ background-color: var(--light-blue);
182
+ }
183
+
184
+ /* Style pour les selectbox */
185
+ .stSelectbox > div > div {
186
+ background-color: white;
187
+ border: 1px solid var(--light-blue);
188
+ }
189
+ </style>
190
+ """, unsafe_allow_html=True)
191
+
192
+ def custom_metric_card(title, value, prefix=""):
193
+ return f"""
194
+ <div class="metric-card">
195
+ <h3 style="color: #1E88E5; margin-bottom: 0.5rem;">{title}</h3>
196
+ <p class="metric-value">{prefix}{value:.4f}</p>
197
+ </div>
198
+ """
199
+
200
+ def plot_with_style(fig):
201
+ # Style matplotlib
202
+ plt.style.use('seaborn')
203
+ fig.patch.set_facecolor('#FFFFFF')
204
+ for ax in fig.axes:
205
+ ax.set_facecolor('#F8F9FA')
206
+ ax.grid(True, linestyle='--', alpha=0.7)
207
+ ax.spines['top'].set_visible(False)
208
+ ax.spines['right'].set_visible(False)
209
+ return fig
210
+
211
+ # [Fonctions load_data et train_models restent identiques]
212
+
213
+ def plot_model_performance(results):
214
+ metrics = ['accuracy', 'f1', 'precision', 'recall', 'roc_auc']
215
+ fig, axes = plt.subplots(1, 2, figsize=(15, 6))
216
+
217
+ # Configuration du style
218
+ plt.style.use('seaborn')
219
+ colors = ['#1E88E5', '#90CAF9', '#0D47A1', '#42A5F5']
220
+
221
+ # Training metrics
222
+ train_data = {model: [results[model]['train_metrics'][metric] for metric in metrics]
223
+ for model in results.keys()}
224
+ train_df = pd.DataFrame(train_data, index=metrics)
225
+ train_df.plot(kind='bar', ax=axes[0], title='Performance d\'Entraînement',
226
+ color=colors)
227
+ axes[0].set_ylim(0, 1)
228
+
229
+ # Test metrics
230
+ test_data = {model: [results[model]['test_metrics'][metric] for metric in metrics]
231
+ for model in results.keys()}
232
+ test_df = pd.DataFrame(test_data, index=metrics)
233
+ test_df.plot(kind='bar', ax=axes[1], title='Performance de Test',
234
+ color=colors)
235
+ axes[1].set_ylim(0, 1)
236
+
237
+ # Style des graphiques
238
+ for ax in axes:
239
+ ax.set_facecolor('#F8F9FA')
240
+ ax.grid(True, linestyle='--', alpha=0.7)
241
+ ax.spines['top'].set_visible(False)
242
+ ax.spines['right'].set_visible(False)
243
+ plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
244
+
245
+ plt.tight_layout()
246
+ return fig
247
+
248
  def app():
249
+ # En-tête principal avec style personnalisé
250
+ st.markdown('<h1 class="main-header">Interpréteur de Modèles ML</h1>', unsafe_allow_html=True)
251
 
252
  # Load data
253
  X_train, y_train, X_test, y_test, feature_names = load_data()
254
 
255
  # Train models if not in session state
256
  if 'model_results' not in st.session_state:
257
+ with st.spinner("🔄 Entraînement des modèles en cours..."):
258
  st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
259
 
260
+ # Sidebar avec style personnalisé
261
+ with st.sidebar:
262
+ st.markdown('<h2 style="color: #1E88E5;">Navigation</h2>', unsafe_allow_html=True)
263
+ selected_model = st.selectbox(
264
+ "📊 Sélectionnez un modèle",
265
+ list(st.session_state.model_results.keys())
266
+ )
267
+
268
+ st.markdown('<hr style="margin: 1rem 0;">', unsafe_allow_html=True)
269
+
270
+ page = st.radio(
271
+ "📑 Sélectionnez une section",
272
+ ["Performance des modèles",
273
+ "Interprétation du modèle",
274
+ "Analyse des caractéristiques",
275
+ "Simulateur de prédictions"]
276
+ )
277
 
278
  current_model = st.session_state.model_results[selected_model]['model']
279
 
280
+ # Container principal avec padding
281
+ main_container = st.container()
282
+ with main_container:
283
+ if page == "Performance des modèles":
284
+ st.markdown('<h2 class="sub-header">Performance des modèles</h2>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ # Graphiques de performance
287
+ performance_fig = plot_model_performance(st.session_state.model_results)
288
+ st.pyplot(plot_with_style(performance_fig))
 
 
 
 
 
 
 
289
 
290
+ # Métriques détaillées dans des cartes
291
+ st.markdown('<h3 class="sub-header">Métriques détaillées</h3>', unsafe_allow_html=True)
292
+ col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
+ with col1:
295
+ st.markdown('<h4 style="color: #1E88E5;">Entraînement</h4>', unsafe_allow_html=True)
296
+ for metric, value in st.session_state.model_results[selected_model]['train_metrics'].items():
297
+ st.markdown(custom_metric_card(metric.capitalize(), value), unsafe_allow_html=True)
298
 
299
+ with col2:
300
+ st.markdown('<h4 style="color: #1E88E5;">Test</h4>', unsafe_allow_html=True)
301
+ for metric, value in st.session_state.model_results[selected_model]['test_metrics'].items():
302
+ st.markdown(custom_metric_card(metric.capitalize(), value), unsafe_allow_html=True)
303
+
304
+ # [Le reste des sections avec style adapté...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  if __name__ == "__main__":
307
  app()