Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
a64b261
1
Parent(s):
f5cad9a
(FEAT)[Results Management]: Implement session and persistent results/statistics handling
Browse files- Added 'ResultsManager' class for managing session and persistent storage of inference and comparison results.
- Functions to log, retrieve, and aggregate model predictions, confidences, processing times, and accuracy for single and multi-model workflows.
- Implemented agreement matrix calculation for model comparison statistics.
- Added export functions for JSON and full report generation of recent and historical results.
Integrated hooks for session state use in Streamlit, supporting UI display and download features.
- Error handling for missing data and data consistency checks.
- utils/results_manager.py +218 -2
utils/results_manager.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
"""Session results management for multi-file inference.
|
| 2 |
-
Handles in-memory results table and export functionality
|
|
|
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
import pandas as pd
|
| 6 |
import json
|
| 7 |
from datetime import datetime
|
| 8 |
-
from typing import Dict, List, Any, Optional
|
| 9 |
import numpy as np
|
| 10 |
from pathlib import Path
|
| 11 |
import io
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def local_css(file_name):
|
|
@@ -199,6 +202,219 @@ class ResultsManager:
|
|
| 199 |
|
| 200 |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
@staticmethod
|
| 203 |
# ==UTILITY FUNCTIONS==
|
| 204 |
def init_session_state():
|
|
|
|
| 1 |
"""Session results management for multi-file inference.
|
| 2 |
+
Handles in-memory results table and export functionality.
|
| 3 |
+
Supports multi-model comparison and statistical analysis."""
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
import pandas as pd
|
| 7 |
import json
|
| 8 |
from datetime import datetime
|
| 9 |
+
from typing import Dict, List, Any, Optional, Tuple
|
| 10 |
import numpy as np
|
| 11 |
from pathlib import Path
|
| 12 |
import io
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
|
| 16 |
|
| 17 |
def local_css(file_name):
|
|
|
|
| 202 |
|
| 203 |
return len(st.session_state[ResultsManager.RESULTS_KEY]) < original_length
|
| 204 |
|
| 205 |
+
@staticmethod
|
| 206 |
+
def add_multi_model_results(
|
| 207 |
+
filename: str,
|
| 208 |
+
model_results: Dict[str, Dict[str, Any]],
|
| 209 |
+
ground_truth: Optional[int] = None,
|
| 210 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 211 |
+
) -> None:
|
| 212 |
+
"""
|
| 213 |
+
Add results from multiple models for the same file.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
filename: Name of the processed file
|
| 217 |
+
model_results: Dict with model_name -> result dict
|
| 218 |
+
ground_truth: True label if available
|
| 219 |
+
metadata: Additional file metadata
|
| 220 |
+
"""
|
| 221 |
+
for model_name, result in model_results.items():
|
| 222 |
+
ResultsManager.add_results(
|
| 223 |
+
filename=filename,
|
| 224 |
+
model_name=model_name,
|
| 225 |
+
prediction=result["prediction"],
|
| 226 |
+
predicted_class=result["predicted_class"],
|
| 227 |
+
confidence=result["confidence"],
|
| 228 |
+
logits=result["logits"],
|
| 229 |
+
ground_truth=ground_truth,
|
| 230 |
+
processing_time=result.get("processing_time", 0.0),
|
| 231 |
+
metadata=metadata,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
def get_comparison_stats() -> Dict[str, Any]:
|
| 236 |
+
"""Get comparative statistics across all models."""
|
| 237 |
+
results = ResultsManager.get_results()
|
| 238 |
+
if not results:
|
| 239 |
+
return {}
|
| 240 |
+
|
| 241 |
+
# Group results by model
|
| 242 |
+
model_stats = defaultdict(list)
|
| 243 |
+
for result in results:
|
| 244 |
+
model_stats[result["model"]].append(result)
|
| 245 |
+
|
| 246 |
+
comparison = {}
|
| 247 |
+
for model_name, model_results in model_stats.items():
|
| 248 |
+
stats = {
|
| 249 |
+
"total_predictions": len(model_results),
|
| 250 |
+
"avg_confidence": np.mean([r["confidence"] for r in model_results]),
|
| 251 |
+
"std_confidence": np.std([r["confidence"] for r in model_results]),
|
| 252 |
+
"avg_processing_time": np.mean(
|
| 253 |
+
[r["processing_time"] for r in model_results]
|
| 254 |
+
),
|
| 255 |
+
"stable_predictions": sum(
|
| 256 |
+
1 for r in model_results if r["prediction"] == 0
|
| 257 |
+
),
|
| 258 |
+
"weathered_predictions": sum(
|
| 259 |
+
1 for r in model_results if r["prediction"] == 1
|
| 260 |
+
),
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
# Calculate accuracy if ground truth available
|
| 264 |
+
with_gt = [r for r in model_results if r["ground_truth"] is not None]
|
| 265 |
+
if with_gt:
|
| 266 |
+
correct = sum(
|
| 267 |
+
1 for r in with_gt if r["prediction"] == r["ground_truth"]
|
| 268 |
+
)
|
| 269 |
+
stats["accuracy"] = correct / len(with_gt)
|
| 270 |
+
stats["num_with_ground_truth"] = len(with_gt)
|
| 271 |
+
else:
|
| 272 |
+
stats["accuracy"] = None
|
| 273 |
+
stats["num_with_ground_truth"] = 0
|
| 274 |
+
|
| 275 |
+
comparison[model_name] = stats
|
| 276 |
+
|
| 277 |
+
return comparison
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def get_agreement_matrix() -> pd.DataFrame:
|
| 281 |
+
"""
|
| 282 |
+
Calculate agreement matrix between models for the same files.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
DataFrame showing model agreement rates
|
| 286 |
+
"""
|
| 287 |
+
results = ResultsManager.get_results()
|
| 288 |
+
if not results:
|
| 289 |
+
return pd.DataFrame()
|
| 290 |
+
|
| 291 |
+
# Group by filename
|
| 292 |
+
file_results = defaultdict(dict)
|
| 293 |
+
for result in results:
|
| 294 |
+
file_results[result["filename"]][result["model"]] = result["prediction"]
|
| 295 |
+
|
| 296 |
+
# Get unique models
|
| 297 |
+
all_models = list(set(r["model"] for r in results))
|
| 298 |
+
|
| 299 |
+
if len(all_models) < 2:
|
| 300 |
+
return pd.DataFrame()
|
| 301 |
+
|
| 302 |
+
# Calculate agreement matrix
|
| 303 |
+
agreement_matrix = np.zeros((len(all_models), len(all_models)))
|
| 304 |
+
|
| 305 |
+
for i, model1 in enumerate(all_models):
|
| 306 |
+
for j, model2 in enumerate(all_models):
|
| 307 |
+
if i == j:
|
| 308 |
+
agreement_matrix[i, j] = 1.0 # Perfect self-agreement
|
| 309 |
+
else:
|
| 310 |
+
agreements = 0
|
| 311 |
+
comparisons = 0
|
| 312 |
+
|
| 313 |
+
for filename, predictions in file_results.items():
|
| 314 |
+
if model1 in predictions and model2 in predictions:
|
| 315 |
+
comparisons += 1
|
| 316 |
+
if predictions[model1] == predictions[model2]:
|
| 317 |
+
agreements += 1
|
| 318 |
+
|
| 319 |
+
if comparisons > 0:
|
| 320 |
+
agreement_matrix[i, j] = agreements / comparisons
|
| 321 |
+
|
| 322 |
+
return pd.DataFrame(agreement_matrix, index=all_models, columns=all_models)
|
| 323 |
+
|
| 324 |
+
@staticmethod
|
| 325 |
+
def create_comparison_visualization() -> plt.Figure:
|
| 326 |
+
"""Create visualization comparing model performance."""
|
| 327 |
+
comparison_stats = ResultsManager.get_comparison_stats()
|
| 328 |
+
|
| 329 |
+
if not comparison_stats:
|
| 330 |
+
return None
|
| 331 |
+
|
| 332 |
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
|
| 333 |
+
|
| 334 |
+
models = list(comparison_stats.keys())
|
| 335 |
+
|
| 336 |
+
# 1. Average Confidence
|
| 337 |
+
confidences = [comparison_stats[m]["avg_confidence"] for m in models]
|
| 338 |
+
conf_stds = [comparison_stats[m]["std_confidence"] for m in models]
|
| 339 |
+
ax1.bar(models, confidences, yerr=conf_stds, capsize=5)
|
| 340 |
+
ax1.set_title("Average Confidence by Model")
|
| 341 |
+
ax1.set_ylabel("Confidence")
|
| 342 |
+
ax1.tick_params(axis="x", rotation=45)
|
| 343 |
+
|
| 344 |
+
# 2. Processing Time
|
| 345 |
+
proc_times = [comparison_stats[m]["avg_processing_time"] for m in models]
|
| 346 |
+
ax2.bar(models, proc_times)
|
| 347 |
+
ax2.set_title("Average Processing Time")
|
| 348 |
+
ax2.set_ylabel("Time (seconds)")
|
| 349 |
+
ax2.tick_params(axis="x", rotation=45)
|
| 350 |
+
|
| 351 |
+
# 3. Prediction Distribution
|
| 352 |
+
stable_counts = [comparison_stats[m]["stable_predictions"] for m in models]
|
| 353 |
+
weathered_counts = [
|
| 354 |
+
comparison_stats[m]["weathered_predictions"] for m in models
|
| 355 |
+
]
|
| 356 |
+
|
| 357 |
+
x = np.arange(len(models))
|
| 358 |
+
width = 0.35
|
| 359 |
+
ax3.bar(x - width / 2, stable_counts, width, label="Stable", alpha=0.8)
|
| 360 |
+
ax3.bar(x + width / 2, weathered_counts, width, label="Weathered", alpha=0.8)
|
| 361 |
+
ax3.set_title("Prediction Distribution")
|
| 362 |
+
ax3.set_ylabel("Count")
|
| 363 |
+
ax3.set_xticks(x)
|
| 364 |
+
ax3.set_xticklabels(models, rotation=45)
|
| 365 |
+
ax3.legend()
|
| 366 |
+
|
| 367 |
+
# 4. Accuracy (if available)
|
| 368 |
+
accuracies = []
|
| 369 |
+
models_with_acc = []
|
| 370 |
+
for model in models:
|
| 371 |
+
if comparison_stats[model]["accuracy"] is not None:
|
| 372 |
+
accuracies.append(comparison_stats[model]["accuracy"])
|
| 373 |
+
models_with_acc.append(model)
|
| 374 |
+
|
| 375 |
+
if accuracies:
|
| 376 |
+
ax4.bar(models_with_acc, accuracies)
|
| 377 |
+
ax4.set_title("Model Accuracy (where ground truth available)")
|
| 378 |
+
ax4.set_ylabel("Accuracy")
|
| 379 |
+
ax4.set_ylim(0, 1)
|
| 380 |
+
ax4.tick_params(axis="x", rotation=45)
|
| 381 |
+
else:
|
| 382 |
+
ax4.text(
|
| 383 |
+
0.5,
|
| 384 |
+
0.5,
|
| 385 |
+
"No ground truth\navailable",
|
| 386 |
+
ha="center",
|
| 387 |
+
va="center",
|
| 388 |
+
transform=ax4.transAxes,
|
| 389 |
+
)
|
| 390 |
+
ax4.set_title("Model Accuracy")
|
| 391 |
+
|
| 392 |
+
plt.tight_layout()
|
| 393 |
+
return fig
|
| 394 |
+
|
| 395 |
+
@staticmethod
|
| 396 |
+
def export_comparison_report() -> str:
|
| 397 |
+
"""Export comprehensive comparison report as JSON."""
|
| 398 |
+
comparison_stats = ResultsManager.get_comparison_stats()
|
| 399 |
+
agreement_matrix = ResultsManager.get_agreement_matrix()
|
| 400 |
+
|
| 401 |
+
report = {
|
| 402 |
+
"timestamp": datetime.now().isoformat(),
|
| 403 |
+
"model_comparison": comparison_stats,
|
| 404 |
+
"agreement_matrix": (
|
| 405 |
+
agreement_matrix.to_dict() if not agreement_matrix.empty else {}
|
| 406 |
+
),
|
| 407 |
+
"summary": {
|
| 408 |
+
"total_models_compared": len(comparison_stats),
|
| 409 |
+
"total_files_processed": len(
|
| 410 |
+
set(r["filename"] for r in ResultsManager.get_results())
|
| 411 |
+
),
|
| 412 |
+
"overall_statistics": ResultsManager.get_summary_stats(),
|
| 413 |
+
},
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
return json.dumps(report, indent=2, default=str)
|
| 417 |
+
|
| 418 |
@staticmethod
|
| 419 |
# ==UTILITY FUNCTIONS==
|
| 420 |
def init_session_state():
|