drkareemkamal commited on
Commit
ce08beb
ยท
verified ยท
1 Parent(s): 62e030b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +312 -116
src/streamlit_app.py CHANGED
@@ -1,131 +1,327 @@
1
-
2
  import streamlit as st
3
- import matplotlib.pyplot as plt
4
- import joblib
5
  import pandas as pd
6
- import json
7
  import numpy as np
8
- import pandas as pd
9
  import joblib
 
 
 
 
10
 
11
- DEFAULT_MATERIALS = ["Local Anaesthesia","Dry Needle","Botox","Saline","Magnesium","PRF"]
12
-
13
- def load_materials(path="material_list.json"):
14
- try:
15
- with open(path, "r", encoding="utf-8") as f:
16
- mats = json.load(f).get("materials", [])
17
- return mats if mats else DEFAULT_MATERIALS
18
- except FileNotFoundError:
19
- return DEFAULT_MATERIALS
20
-
21
- def build_baseline_from_form(row_dict):
22
- # Convert to single-row DataFrame
23
- df = pd.DataFrame([row_dict])
24
- # Ensure presence of all baseline columns model might expect.
25
- # (The pipeline has imputation & one-hot; we just pass what's available.)
26
- # Guarantee numeric types for core measures
27
- for c in ["age","pain_m0","mmo_m0","ohip_14_m0"]:
28
- if c in df.columns:
29
- df[c] = pd.to_numeric(df[c], errors="coerce")
30
- # previous_injection as 0/1
31
- if "previous_injection" in df.columns:
32
- df["previous_injection"] = pd.to_numeric(df["previous_injection"], errors="coerce")
33
- # include placeholder for time2_days if training had it
34
- if "time2_days" not in df.columns:
35
- df["time2_days"] = np.nan
36
- # Fill optional text columns if missing
37
- for tcol in ["location","muscle_injected","adjunctive_tretment","material_in_previous_injection"]:
38
- if tcol not in df.columns:
39
- df[tcol] = ""
40
- return df
41
-
42
- def score_material(model, row_dict, material):
43
- base = build_baseline_from_form(row_dict)
44
- base["material_injected"] = material
45
- proba = float(model.predict_proba(base)[0,1])
46
- return proba
47
-
48
- def rank_materials(model, row_dict, materials):
49
- rows = []
50
- for m in materials:
51
- p = score_material(model, row_dict, m)
52
- rows.append((m, p))
53
- rows.sort(key=lambda x: x[1], reverse=True)
54
- return rows
55
-
56
- st.set_page_config(page_title="TMJ Success Predictor (Material-specific)", page_icon="๐Ÿฆท", layout="wide")
57
- st.title("๐Ÿฆท TMJ Success Predictor โ€” Material-specific")
58
- st.caption("Predicts 3-month success probability **conditioned on the selected injection material**. Also compares across all materials.")
59
-
60
- MODEL_PATH = "src/best_tmj_success_classifier.pkl"
61
- MATS_PATH = "src/material_list.json"
62
 
 
63
  @st.cache_resource
64
  def load_artifacts():
65
- model = joblib.load(MODEL_PATH)
66
- mats = load_materials(MATS_PATH)
67
- return model, mats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- try:
70
- model, materials = load_artifacts()
71
- except Exception as e:
72
- st.error(f"Failed to load model: {e}")
73
- st.stop()
 
74
 
75
- with st.form("frm"):
76
- st.subheader("Baseline inputs")
77
- c1,c2,c3 = st.columns(3)
78
- with c1:
79
- sex = st.selectbox("Sex", ["male","female","unknown"], index=1)
80
- age = st.number_input("Age", 10, 100, value=30)
81
- with c2:
82
- pain_m0 = st.number_input("Pain M0 (0โ€“10)", min_value=0.0, max_value=10.0, value=7.0, step=0.1)
83
- mmo_m0 = st.number_input("MMO M0 (mm)", min_value=0.0, max_value=80.0, value=35.0, step=1.0)
84
- with c3:
85
- ohip_14_m0 = st.number_input("OHIP-14 M0 (0โ€“56)", min_value=0.0, max_value=56.0, value=20.0, step=1.0)
86
- prev = st.selectbox("Previous injection?", ["no","yes"], index=0)
 
 
 
87
 
88
- st.subheader("Context (optional)")
89
- c4,c5 = st.columns(2)
90
- with c4:
91
- location = st.text_input("Location", value="TMJ Right")
92
- muscle_injected = st.text_input("Muscle injected", value="masseter")
93
- adjunctive_treatment = st.text_input("Adjunctive treatment", value="physiotherapy")
94
- with c5:
95
- material_prev = st.text_input("Material in previous injection", value="")
96
- # free text fields are optional; model pipeline is robust
97
- material_choice = st.selectbox("Material injected (to score)", materials)
98
- do_compare = st.checkbox("Compare all materials", value=True)
99
 
100
- submitted = st.form_submit_button("Predict")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
102
  if submitted:
103
- row = {
104
- "sex": sex,
105
- "age": age,
106
- "pain_m0": pain_m0,
107
- "mmo_m0": mmo_m0,
108
- "ohip_14_m0": ohip_14_m0,
109
- "previous_injection": 1 if prev=="yes" else 0,
110
- "location": location,
111
- "muscle_injected": muscle_injected,
112
- "adjunctive_treatment": adjunctive_treatment,
113
- "material_in_previous_injection": material_prev
114
- }
115
- # Score selected material
116
- proba = score_material(model, row, material_choice)
117
- st.metric("Success probability (selected material)", f"{proba:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- if do_compare:
120
- st.markdown("---")
121
- st.subheader("Compare all materials (ranked)")
122
- ranked = rank_materials(model, row, materials)
123
- df_rank = pd.DataFrame(ranked, columns=["material","success_proba"])
124
- st.dataframe(df_rank, width=True)
125
- # Bar plot (matplotlib; single plot; default colors)
126
- fig = plt.figure(figsize=(6,4))
127
- plt.bar(df_rank["material"], df_rank["success_proba"])
128
- plt.xticks(rotation=45, ha="right")
129
- plt.ylabel("Predicted success probability")
130
- plt.title("Material comparison โ€” success probability")
131
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
 
2
  import pandas as pd
 
3
  import numpy as np
 
4
  import joblib
5
+ import json
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ from datetime import datetime
9
 
10
+ # Page config
11
+ st.set_page_config(
12
+ page_title="TMJ Injection Success Predictor",
13
+ page_icon="๐Ÿ’‰",
14
+ layout="wide"
15
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Load model and materials
18
  @st.cache_resource
19
  def load_artifacts():
20
+ """Load the trained model and materials list"""
21
+ try:
22
+ # Load model
23
+ model = joblib.load('src/best_tmj_success_classifier_without_fe.pkl')
24
+
25
+ # Load materials list
26
+ try:
27
+ with open('src/material_list.json', 'r') as f:
28
+ materials_data = json.load(f)
29
+ materials = materials_data.get('materials', [])
30
+ except FileNotFoundError:
31
+ # Fallback to default materials
32
+ materials = ['Local Anaesthesia', 'Dry Needle', 'Botox',
33
+ 'Saline', 'Magnesium', 'PRF']
34
+ st.warning("Using default materials list. Train the model to generate actual materials from your data.")
35
+
36
+ # Load metadata if available
37
+ metadata = {}
38
+ try:
39
+ with open('model_metadata.json', 'r') as f:
40
+ metadata = json.load(f)
41
+ except FileNotFoundError:
42
+ pass
43
+
44
+ return model, materials, metadata
45
+ except Exception as e:
46
+ st.error(f"Error loading model: {str(e)}")
47
+ st.stop()
48
+
49
+ # Initialize
50
+ model, materials, metadata = load_artifacts()
51
 
52
+ # Title and description
53
+ st.title("๐Ÿฆท TMJ Injection Success Predictor")
54
+ st.markdown("""
55
+ This tool predicts the 3-month treatment success probability for TMJ injections based on patient baseline characteristics.
56
+ Enter the patient information below to see predictions for different injection materials.
57
+ """)
58
 
59
+ # Display model info if available
60
+ if metadata:
61
+ with st.expander("โ„น๏ธ Model Information"):
62
+ col1, col2, col3 = st.columns(3)
63
+ with col1:
64
+ st.metric("Model Type", metadata.get('model_type', 'Unknown'))
65
+ with col2:
66
+ st.metric("Test ROC-AUC", f"{metadata.get('test_roc_auc', 0):.3f}")
67
+ with col3:
68
+ st.metric("Training Date", metadata.get('training_date', 'Unknown')[:10])
69
+
70
+ st.write(f"**Success Definition:** {metadata.get('success_definition', 'Unknown')}")
71
+
72
+ if metadata.get('simplified_version', False):
73
+ st.info("This model uses the simplified feature set without text analysis.")
74
 
75
+ st.divider()
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Create form
78
+ with st.form("patient_form"):
79
+ st.subheader("Patient Information")
80
+
81
+ # Required fields
82
+ col1, col2 = st.columns(2)
83
+
84
+ with col1:
85
+ st.markdown("**Required Fields**")
86
+ sex = st.selectbox("Sex", options=['Male', 'Female'], help="Patient's biological sex")
87
+ age = st.number_input("Age", min_value=10, max_value=100, value=45, help="Patient age in years")
88
+ pain_m0 = st.slider("Baseline Pain (M0)", min_value=0, max_value=10, value=7,
89
+ help="Pain score at baseline (0-10 scale)")
90
+
91
+ with col2:
92
+ st.markdown("** **") # Empty space to align with "Required Fields"
93
+ mmo_m0 = st.slider("Baseline MMO (M0)", min_value=0, max_value=80, value=35,
94
+ help="Maximum mouth opening at baseline (mm)")
95
+ ohip_14_m0 = st.slider("Baseline OHIP-14 (M0)", min_value=0, max_value=56, value=28,
96
+ help="Oral Health Impact Profile score at baseline (0-56)")
97
+
98
+ st.divider()
99
+
100
+ # Optional fields
101
+ st.markdown("**Optional Fields**")
102
+ col3, col4 = st.columns(2)
103
+
104
+ with col3:
105
+ location = st.text_input("Location", placeholder="e.g., Right TMJ",
106
+ help="Injection location")
107
+ muscle_injected = st.text_input("Muscle Injected", placeholder="e.g., Masseter",
108
+ help="Specific muscle targeted")
109
+ adjunctive_treatment = st.text_input("Adjunctive Treatment", placeholder="e.g., Physical therapy",
110
+ help="Additional treatments")
111
+
112
+ with col4:
113
+ previous_injection = st.selectbox("Previous Injection", options=['No', 'Yes'],
114
+ help="Has the patient had previous TMJ injections?")
115
+ if previous_injection == 'Yes':
116
+ material_in_previous_injection = st.selectbox("Previous Material",
117
+ options=[''] + materials,
118
+ help="Material used in previous injection")
119
+ else:
120
+ material_in_previous_injection = ''
121
+
122
+ st.divider()
123
+
124
+ # Material selection for primary prediction
125
+ st.markdown("**Primary Prediction**")
126
+ selected_material = st.selectbox("Select Material for Prediction",
127
+ options=materials,
128
+ help="Choose the material you're considering for this patient")
129
+
130
+ # Compare all materials option
131
+ compare_all = st.checkbox("Compare all available materials", value=True,
132
+ help="Show predictions for all materials to help with decision making")
133
+
134
+ # Submit button
135
+ submitted = st.form_submit_button("๐Ÿ”ฎ Predict Success", use_container_width=True, type="primary")
136
 
137
+ # Process form submission
138
  if submitted:
139
+ # Create input dataframe
140
+ input_data = pd.DataFrame({
141
+ 'sex': [sex],
142
+ 'age': [age],
143
+ 'pain_m0': [pain_m0],
144
+ 'mmo_m0': [mmo_m0],
145
+ 'ohip_14_m0': [ohip_14_m0],
146
+ 'location': [location if location else np.nan],
147
+ 'muscle_injected': [muscle_injected if muscle_injected else np.nan],
148
+ 'adjunctive_treatment': [adjunctive_treatment if adjunctive_treatment else np.nan],
149
+ 'previous_injection': [1 if previous_injection == 'Yes' else 0],
150
+ 'material_in_previous_injection': [material_in_previous_injection if material_in_previous_injection else np.nan],
151
+ 'material_injected': [selected_material]
152
+ })
153
+
154
+ # Make prediction for selected material
155
+ try:
156
+ prediction_proba = model.predict_proba(input_data)[0, 1]
157
+
158
+ # Display primary prediction
159
+ st.divider()
160
+ st.subheader("Prediction Results")
161
+
162
+ # Create a visual indicator
163
+ col1, col2, col3 = st.columns([1, 2, 1])
164
+ with col2:
165
+ # Success probability gauge
166
+ fig = go.Figure(go.Indicator(
167
+ mode = "gauge+number+delta",
168
+ value = prediction_proba * 100,
169
+ domain = {'x': [0, 1], 'y': [0, 1]},
170
+ title = {'text': f"Success Probability with {selected_material}"},
171
+ number = {'suffix': "%", 'font': {'size': 40}},
172
+ gauge = {
173
+ 'axis': {'range': [None, 100]},
174
+ 'bar': {'color': "darkblue"},
175
+ 'steps': [
176
+ {'range': [0, 30], 'color': "lightgray"},
177
+ {'range': [30, 70], 'color': "gray"},
178
+ {'range': [70, 100], 'color': "lightgreen"}
179
+ ],
180
+ 'threshold': {
181
+ 'line': {'color': "red", 'width': 4},
182
+ 'thickness': 0.75,
183
+ 'value': 50
184
+ }
185
+ }
186
+ ))
187
+ fig.update_layout(height=400)
188
+ st.plotly_chart(fig, use_container_width=True)
189
+
190
+ # Interpretation
191
+ if prediction_proba >= 0.7:
192
+ st.success(f"โœ… High likelihood of success ({prediction_proba:.1%}) with {selected_material}")
193
+ elif prediction_proba >= 0.5:
194
+ st.warning(f"โš ๏ธ Moderate likelihood of success ({prediction_proba:.1%}) with {selected_material}")
195
+ else:
196
+ st.error(f"โŒ Low likelihood of success ({prediction_proba:.1%}) with {selected_material}")
197
+
198
+ # Compare all materials if requested
199
+ if compare_all:
200
+ st.divider()
201
+ st.subheader("๐Ÿ“Š Material Comparison")
202
+
203
+ # Predict for all materials
204
+ material_results = []
205
+ for material in materials:
206
+ temp_data = input_data.copy()
207
+ temp_data['material_injected'] = material
208
+ prob = model.predict_proba(temp_data)[0, 1]
209
+ material_results.append({
210
+ 'Material': material,
211
+ 'Success Probability': prob,
212
+ 'Success %': f"{prob:.1%}"
213
+ })
214
+
215
+ # Sort by probability
216
+ material_df = pd.DataFrame(material_results)
217
+ material_df = material_df.sort_values('Success Probability', ascending=False)
218
+
219
+ # Display results
220
+ col1, col2 = st.columns([1, 1])
221
+
222
+ with col1:
223
+ # Table view
224
+ st.markdown("**Ranked Materials**")
225
+ display_df = material_df[['Material', 'Success %']].reset_index(drop=True)
226
+ display_df.index += 1 # Start index at 1
227
+ st.dataframe(display_df, use_container_width=True)
228
+
229
+ # Highlight best option
230
+ best_material = material_df.iloc[0]['Material']
231
+ best_prob = material_df.iloc[0]['Success Probability']
232
+ if best_material != selected_material:
233
+ st.info(f"๐Ÿ’ก Consider using **{best_material}** for potentially better outcomes ({best_prob:.1%} vs {prediction_proba:.1%})")
234
+
235
+ with col2:
236
+ # Bar chart
237
+ st.markdown("**Visual Comparison**")
238
+ fig = px.bar(material_df,
239
+ x='Success Probability',
240
+ y='Material',
241
+ orientation='h',
242
+ color='Success Probability',
243
+ color_continuous_scale='RdYlGn',
244
+ range_color=[0, 1],
245
+ text='Success %')
246
+
247
+ fig.update_traces(textposition='outside')
248
+ fig.update_layout(
249
+ xaxis_title="Success Probability",
250
+ yaxis_title="",
251
+ showlegend=False,
252
+ xaxis=dict(range=[0, 1.1]),
253
+ height=400
254
+ )
255
+
256
+ # Add vertical line at 50%
257
+ fig.add_vline(x=0.5, line_dash="dash", line_color="gray",
258
+ annotation_text="50% threshold")
259
+
260
+ st.plotly_chart(fig, use_container_width=True)
261
+
262
+ # Additional insights
263
+ st.divider()
264
+ with st.expander("๐Ÿ“‹ Patient Summary"):
265
+ st.write("**Baseline Characteristics:**")
266
+ summary_cols = st.columns(3)
267
+ with summary_cols[0]:
268
+ st.write(f"- Age: {age} years")
269
+ st.write(f"- Sex: {sex}")
270
+ st.write(f"- Previous injection: {previous_injection}")
271
+ with summary_cols[1]:
272
+ st.write(f"- Pain score: {pain_m0}/10")
273
+ st.write(f"- MMO: {mmo_m0} mm")
274
+ st.write(f"- OHIP-14: {ohip_14_m0}/56")
275
+ with summary_cols[2]:
276
+ if location:
277
+ st.write(f"- Location: {location}")
278
+ if muscle_injected:
279
+ st.write(f"- Muscle: {muscle_injected}")
280
+ if adjunctive_treatment:
281
+ st.write(f"- Adjunctive: {adjunctive_treatment}")
282
+
283
+ except Exception as e:
284
+ st.error(f"Error making prediction: {str(e)}")
285
+ st.info("Please ensure the model was trained with all the necessary features.")
286
+
287
+ # Footer
288
+ st.divider()
289
+ st.markdown("""
290
+ <div style='text-align: center; color: gray;'>
291
+ <small>
292
+ TMJ Injection Success Predictor |
293
+ Model trained on historical patient data |
294
+ Predictions are probabilistic and should be used alongside clinical judgment
295
+ </small>
296
+ </div>
297
+ """, unsafe_allow_html=True)
298
 
299
+ # Sidebar with instructions
300
+ with st.sidebar:
301
+ st.header("๐Ÿ“– Instructions")
302
+ st.markdown("""
303
+ 1. **Enter patient baseline data** in the form
304
+ 2. **Select the material** you're considering
305
+ 3. **Click Predict** to see the success probability
306
+ 4. **Compare materials** to find the optimal choice
307
+
308
+ ---
309
+
310
+ ### ๐ŸŽฏ Success Definition
311
+ Treatment success is typically defined as:
312
+ - Pain reduction > 2 points
313
+ - MMO increase > 5 mm
314
+ - OHIP-14 reduction > 5 points
315
+
316
+ ---
317
+
318
+ ### ๐Ÿ“Š Interpretation Guide
319
+ - **70%+**: High success likelihood โœ…
320
+ - **50-70%**: Moderate success โš ๏ธ
321
+ - **<50%**: Low success likelihood โŒ
322
+
323
+ ---
324
+
325
+ ### โš•๏ธ Clinical Note
326
+ These predictions are based on statistical models and should complement, not replace, clinical expertise and patient-specific considerations.
327
+ """)