devjas1 commited on
Commit
a64b261
·
1 Parent(s): f5cad9a

(FEAT)[Results Management]: Implement session and persistent results/statistics handling

Browse files

- Added 'ResultsManager' class for managing session and persistent storage of inference and comparison results.
- Functions to log, retrieve, and aggregate model predictions, confidences, processing times, and accuracy for single and multi-model workflows.
- Implemented agreement matrix calculation for model comparison statistics.
- Added export functions for JSON and full report generation of recent and historical results.
Integrated hooks for session state use in Streamlit, supporting UI display and download features.
- Error handling for missing data and data consistency checks.

Files changed (1) hide show
  1. utils/results_manager.py +218 -2
utils/results_manager.py CHANGED
@@ -1,14 +1,17 @@
1
  """Session results management for multi-file inference.
2
- Handles in-memory results table and export functionality"""
 
3
 
4
  import streamlit as st
5
  import pandas as pd
6
  import json
7
  from datetime import datetime
8
- from typing import Dict, List, Any, Optional
9
  import numpy as np
10
  from pathlib import Path
11
  import io
 
 
12
 
13
 
14
  def local_css(file_name):
@@ -199,6 +202,219 @@ class ResultsManager:
199
 
200
  return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  @staticmethod
203
  # ==UTILITY FUNCTIONS==
204
  def init_session_state():
 
1
  """Session results management for multi-file inference.
2
+ Handles in-memory results table and export functionality.
3
+ Supports multi-model comparison and statistical analysis."""
4
 
5
  import streamlit as st
6
  import pandas as pd
7
  import json
8
  from datetime import datetime
9
+ from typing import Dict, List, Any, Optional, Tuple
10
  import numpy as np
11
  from pathlib import Path
12
  import io
13
+ from collections import defaultdict
14
+ import matplotlib.pyplot as plt
15
 
16
 
17
  def local_css(file_name):
 
202
 
203
  return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
204
 
205
+ @staticmethod
206
+ def add_multi_model_results(
207
+ filename: str,
208
+ model_results: Dict[str, Dict[str, Any]],
209
+ ground_truth: Optional[int] = None,
210
+ metadata: Optional[Dict[str, Any]] = None,
211
+ ) -> None:
212
+ """
213
+ Add results from multiple models for the same file.
214
+
215
+ Args:
216
+ filename: Name of the processed file
217
+ model_results: Dict with model_name -> result dict
218
+ ground_truth: True label if available
219
+ metadata: Additional file metadata
220
+ """
221
+ for model_name, result in model_results.items():
222
+ ResultsManager.add_results(
223
+ filename=filename,
224
+ model_name=model_name,
225
+ prediction=result["prediction"],
226
+ predicted_class=result["predicted_class"],
227
+ confidence=result["confidence"],
228
+ logits=result["logits"],
229
+ ground_truth=ground_truth,
230
+ processing_time=result.get("processing_time", 0.0),
231
+ metadata=metadata,
232
+ )
233
+
234
+ @staticmethod
235
+ def get_comparison_stats() -> Dict[str, Any]:
236
+ """Get comparative statistics across all models."""
237
+ results = ResultsManager.get_results()
238
+ if not results:
239
+ return {}
240
+
241
+ # Group results by model
242
+ model_stats = defaultdict(list)
243
+ for result in results:
244
+ model_stats[result["model"]].append(result)
245
+
246
+ comparison = {}
247
+ for model_name, model_results in model_stats.items():
248
+ stats = {
249
+ "total_predictions": len(model_results),
250
+ "avg_confidence": np.mean([r["confidence"] for r in model_results]),
251
+ "std_confidence": np.std([r["confidence"] for r in model_results]),
252
+ "avg_processing_time": np.mean(
253
+ [r["processing_time"] for r in model_results]
254
+ ),
255
+ "stable_predictions": sum(
256
+ 1 for r in model_results if r["prediction"] == 0
257
+ ),
258
+ "weathered_predictions": sum(
259
+ 1 for r in model_results if r["prediction"] == 1
260
+ ),
261
+ }
262
+
263
+ # Calculate accuracy if ground truth available
264
+ with_gt = [r for r in model_results if r["ground_truth"] is not None]
265
+ if with_gt:
266
+ correct = sum(
267
+ 1 for r in with_gt if r["prediction"] == r["ground_truth"]
268
+ )
269
+ stats["accuracy"] = correct / len(with_gt)
270
+ stats["num_with_ground_truth"] = len(with_gt)
271
+ else:
272
+ stats["accuracy"] = None
273
+ stats["num_with_ground_truth"] = 0
274
+
275
+ comparison[model_name] = stats
276
+
277
+ return comparison
278
+
279
+ @staticmethod
280
+ def get_agreement_matrix() -> pd.DataFrame:
281
+ """
282
+ Calculate agreement matrix between models for the same files.
283
+
284
+ Returns:
285
+ DataFrame showing model agreement rates
286
+ """
287
+ results = ResultsManager.get_results()
288
+ if not results:
289
+ return pd.DataFrame()
290
+
291
+ # Group by filename
292
+ file_results = defaultdict(dict)
293
+ for result in results:
294
+ file_results[result["filename"]][result["model"]] = result["prediction"]
295
+
296
+ # Get unique models
297
+ all_models = list(set(r["model"] for r in results))
298
+
299
+ if len(all_models) < 2:
300
+ return pd.DataFrame()
301
+
302
+ # Calculate agreement matrix
303
+ agreement_matrix = np.zeros((len(all_models), len(all_models)))
304
+
305
+ for i, model1 in enumerate(all_models):
306
+ for j, model2 in enumerate(all_models):
307
+ if i == j:
308
+ agreement_matrix[i, j] = 1.0 # Perfect self-agreement
309
+ else:
310
+ agreements = 0
311
+ comparisons = 0
312
+
313
+ for filename, predictions in file_results.items():
314
+ if model1 in predictions and model2 in predictions:
315
+ comparisons += 1
316
+ if predictions[model1] == predictions[model2]:
317
+ agreements += 1
318
+
319
+ if comparisons > 0:
320
+ agreement_matrix[i, j] = agreements / comparisons
321
+
322
+ return pd.DataFrame(agreement_matrix, index=all_models, columns=all_models)
323
+
324
+ @staticmethod
325
+ def create_comparison_visualization() -> plt.Figure:
326
+ """Create visualization comparing model performance."""
327
+ comparison_stats = ResultsManager.get_comparison_stats()
328
+
329
+ if not comparison_stats:
330
+ return None
331
+
332
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
333
+
334
+ models = list(comparison_stats.keys())
335
+
336
+ # 1. Average Confidence
337
+ confidences = [comparison_stats[m]["avg_confidence"] for m in models]
338
+ conf_stds = [comparison_stats[m]["std_confidence"] for m in models]
339
+ ax1.bar(models, confidences, yerr=conf_stds, capsize=5)
340
+ ax1.set_title("Average Confidence by Model")
341
+ ax1.set_ylabel("Confidence")
342
+ ax1.tick_params(axis="x", rotation=45)
343
+
344
+ # 2. Processing Time
345
+ proc_times = [comparison_stats[m]["avg_processing_time"] for m in models]
346
+ ax2.bar(models, proc_times)
347
+ ax2.set_title("Average Processing Time")
348
+ ax2.set_ylabel("Time (seconds)")
349
+ ax2.tick_params(axis="x", rotation=45)
350
+
351
+ # 3. Prediction Distribution
352
+ stable_counts = [comparison_stats[m]["stable_predictions"] for m in models]
353
+ weathered_counts = [
354
+ comparison_stats[m]["weathered_predictions"] for m in models
355
+ ]
356
+
357
+ x = np.arange(len(models))
358
+ width = 0.35
359
+ ax3.bar(x - width / 2, stable_counts, width, label="Stable", alpha=0.8)
360
+ ax3.bar(x + width / 2, weathered_counts, width, label="Weathered", alpha=0.8)
361
+ ax3.set_title("Prediction Distribution")
362
+ ax3.set_ylabel("Count")
363
+ ax3.set_xticks(x)
364
+ ax3.set_xticklabels(models, rotation=45)
365
+ ax3.legend()
366
+
367
+ # 4. Accuracy (if available)
368
+ accuracies = []
369
+ models_with_acc = []
370
+ for model in models:
371
+ if comparison_stats[model]["accuracy"] is not None:
372
+ accuracies.append(comparison_stats[model]["accuracy"])
373
+ models_with_acc.append(model)
374
+
375
+ if accuracies:
376
+ ax4.bar(models_with_acc, accuracies)
377
+ ax4.set_title("Model Accuracy (where ground truth available)")
378
+ ax4.set_ylabel("Accuracy")
379
+ ax4.set_ylim(0, 1)
380
+ ax4.tick_params(axis="x", rotation=45)
381
+ else:
382
+ ax4.text(
383
+ 0.5,
384
+ 0.5,
385
+ "No ground truth\navailable",
386
+ ha="center",
387
+ va="center",
388
+ transform=ax4.transAxes,
389
+ )
390
+ ax4.set_title("Model Accuracy")
391
+
392
+ plt.tight_layout()
393
+ return fig
394
+
395
+ @staticmethod
396
+ def export_comparison_report() -> str:
397
+ """Export comprehensive comparison report as JSON."""
398
+ comparison_stats = ResultsManager.get_comparison_stats()
399
+ agreement_matrix = ResultsManager.get_agreement_matrix()
400
+
401
+ report = {
402
+ "timestamp": datetime.now().isoformat(),
403
+ "model_comparison": comparison_stats,
404
+ "agreement_matrix": (
405
+ agreement_matrix.to_dict() if not agreement_matrix.empty else {}
406
+ ),
407
+ "summary": {
408
+ "total_models_compared": len(comparison_stats),
409
+ "total_files_processed": len(
410
+ set(r["filename"] for r in ResultsManager.get_results())
411
+ ),
412
+ "overall_statistics": ResultsManager.get_summary_stats(),
413
+ },
414
+ }
415
+
416
+ return json.dumps(report, indent=2, default=str)
417
+
418
  @staticmethod
419
  # ==UTILITY FUNCTIONS==
420
  def init_session_state():