Spaces:
Sleeping
Sleeping
devjas1
fix(display): Refactor batch results display to improve clarity and metrics presentation
9e65713
| """Multi-file processing utiltities for batch inference. | |
| Handles multiple file uploads and iterative processing.""" | |
| from typing import List, Dict, Any, Tuple, Optional | |
| import time | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| from .preprocessing import resample_spectrum | |
| from .errors import ErrorHandler, safe_execute | |
| from .results_manager import ResultsManager | |
| from .confidence import calculate_softmax_confidence | |
| def parse_spectrum_data( | |
| text_content: str, filename: str = "unknown" | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Parse spectrum data from text content | |
| Args: | |
| text_content: Raw text content of the spectrum file | |
| filename: Name of the file for error reporting | |
| Returns: | |
| Tuple of (x_values, y_values) as numpy arrays | |
| Raises: | |
| ValueError: If the data cannot be parsed | |
| """ | |
| try: | |
| lines = text_content.strip().split("\n") | |
| # ==Remove empty lines and comments== | |
| data_lines = [] | |
| for line in lines: | |
| line = line.strip() | |
| if line and not line.startswith("#") and not line.startswith("%"): | |
| data_lines.append(line) | |
| if not data_lines: | |
| raise ValueError("No data lines found in file") | |
| # ==Try to parse== | |
| x_vals, y_vals = [], [] | |
| for i, line in enumerate(data_lines): | |
| try: | |
| # Handle different separators | |
| parts = line.replace(",", " ").split() | |
| numbers = [ | |
| p | |
| for p in parts | |
| if p.replace(".", "", 1) | |
| .replace("-", "", 1) | |
| .replace("+", "", 1) | |
| .isdigit() | |
| ] | |
| if len(numbers) >= 2: | |
| x_val = float(numbers[0]) | |
| y_val = float(numbers[1]) | |
| x_vals.append(x_val) | |
| y_vals.append(y_val) | |
| except ValueError: | |
| ErrorHandler.log_warning( | |
| f"Could not parse line {i+1}: {line}", f"Parsing {filename}" | |
| ) | |
| continue | |
| if len(x_vals) < 10: # ==Need minimum points for interpolation== | |
| raise ValueError( | |
| f"Insufficient data points ({len(x_vals)}). Need at least 10 points." | |
| ) | |
| x = np.array(x_vals) | |
| y = np.array(y_vals) | |
| # Check for NaNs | |
| if np.any(np.isnan(x)) or np.any(np.isnan(y)): | |
| raise ValueError("Input data contains NaN values") | |
| # Check monotonic increasing x | |
| if not np.all(np.diff(x) > 0): | |
| raise ValueError("Wavenumbers must be strictly increasing") | |
| # Check reasonable range for Raman spectroscopy | |
| if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100: | |
| raise ValueError( | |
| f"Invalid wavenumber range: {min(x)} - {max(x)}. Expected ~400-4000 cm⁻¹ with span >100" | |
| ) | |
| return x, y | |
| except Exception as e: | |
| raise ValueError(f"Failed to parse spectrum data: {str(e)}") | |
| def process_single_file( | |
| filename: str, | |
| text_content: str, | |
| model_choice: str, | |
| load_model_func, | |
| run_inference_func, | |
| label_file_func, | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Process a single spectrum file | |
| Args: | |
| filename: Name of the file | |
| text_content: Raw text content | |
| model_choice: Selected model name | |
| load_model_func: Function to load the model | |
| run_inference_func: Function to run inference | |
| label_file_func: Function to extract ground truth label | |
| Returns: | |
| Dictionary with processing results or None if failed | |
| """ | |
| start_time = time.time() | |
| try: | |
| # ==Parse spectrum data== | |
| result, success = safe_execute( | |
| parse_spectrum_data, | |
| text_content, | |
| filename, | |
| error_context=f"parsing {filename}", | |
| show_error=False, | |
| ) | |
| if not success or result is None: | |
| return None | |
| x_raw, y_raw = result | |
| # ==Resample spectrum== | |
| result, success = safe_execute( | |
| resample_spectrum, | |
| x_raw, | |
| y_raw, | |
| 500, # TARGET_LEN | |
| error_context=f"resampling {filename}", | |
| show_error=False, | |
| ) | |
| if not success or result is None: | |
| return None | |
| x_resampled, y_resampled = result | |
| # ==Run inference== | |
| result, success = safe_execute( | |
| run_inference_func, | |
| y_resampled, | |
| model_choice, | |
| error_context=f"inference on {filename}", | |
| show_error=False, | |
| ) | |
| if not success or result is None: | |
| ErrorHandler.log_error( | |
| Exception("Inference failed"), f"processing {filename}" | |
| ) | |
| return None | |
| prediction, logits_list, probs, inference_time, logits = result | |
| # ==Calculate confidence== | |
| if logits is not None: | |
| probs_np, max_confidence, confidence_level, confidence_emoji = ( | |
| calculate_softmax_confidence(logits) | |
| ) | |
| else: | |
| probs_np = np.array([]) | |
| max_confidence = 0.0 | |
| confidence_level = "LOW" | |
| confidence_emoji = "🔴" | |
| # ==Get ground truth== | |
| try: | |
| ground_truth = label_file_func(filename) | |
| ground_truth = ground_truth if ground_truth >= 0 else None | |
| except Exception: | |
| ground_truth = None | |
| # ==Get predicted class== | |
| label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"} | |
| predicted_class = label_map.get(prediction, f"Unknown ({prediction})") | |
| processing_time = time.time() - start_time | |
| return { | |
| "filename": filename, | |
| "success": True, | |
| "prediction": prediction, | |
| "predicted_class": predicted_class, | |
| "confidence": max_confidence, | |
| "confidence_level": confidence_level, | |
| "confidence_emoji": confidence_emoji, | |
| "logits": logits_list if logits_list else [], | |
| "probabilities": probs_np.tolist() if len(probs_np) > 0 else [], | |
| "ground_truth": ground_truth, | |
| "processing_time": processing_time, | |
| "x_raw": x_raw, | |
| "y_raw": y_raw, | |
| "x_resampled": x_resampled, | |
| "y_resampled": y_resampled, | |
| } | |
| except Exception as e: | |
| ErrorHandler.log_error(e, f"processing {filename}") | |
| return { | |
| "filename": filename, | |
| "success": False, | |
| "error": str(e), | |
| "processing_time": time.time() - start_time, | |
| } | |
| def process_multiple_files( | |
| uploaded_files: List, | |
| model_choice: str, | |
| load_model_func, | |
| run_inference_func, | |
| label_file_func, | |
| progress_callback=None, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Process multiple uploaded files | |
| Args: | |
| uploaded_files: List of uploaded file objects | |
| model_choice: Selected model name | |
| load_model_func: Function to load the model | |
| run_inference_func: Function to run inference | |
| label_file_func: Function to extract ground truth label | |
| progress_callback: Optional callback to update progress | |
| Returns: | |
| List of processing results | |
| """ | |
| results = [] | |
| total_files = len(uploaded_files) | |
| ErrorHandler.log_info(f"Starting batch processing of {total_files} files") | |
| for i, uploaded_file in enumerate(uploaded_files): | |
| if progress_callback: | |
| progress_callback(i, total_files, uploaded_file.name) | |
| try: | |
| # ==Read file content== | |
| raw = uploaded_file.read() | |
| text_content = raw.decode("utf-8") if isinstance(raw, bytes) else raw | |
| # ==Process the file== | |
| result = process_single_file( | |
| uploaded_file.name, | |
| text_content, | |
| model_choice, | |
| load_model_func, | |
| run_inference_func, | |
| label_file_func, | |
| ) | |
| if result: | |
| results.append(result) | |
| # ==Add successful results to the results manager== | |
| if result.get("success", False): | |
| ResultsManager.add_results( | |
| filename=result["filename"], | |
| model_name=model_choice, | |
| prediction=result["prediction"], | |
| predicted_class=result["predicted_class"], | |
| confidence=result["confidence"], | |
| logits=result["logits"], | |
| ground_truth=result["ground_truth"], | |
| processing_time=result["processing_time"], | |
| metadata={ | |
| "confidence_level": result["confidence_level"], | |
| "confidence_emoji": result["confidence_emoji"], | |
| }, | |
| ) | |
| except Exception as e: | |
| ErrorHandler.log_error(e, f"reading file {uploaded_file.name}") | |
| results.append( | |
| { | |
| "filename": uploaded_file.name, | |
| "success": False, | |
| "error": f"Failed to read file: {str(e)}", | |
| } | |
| ) | |
| if progress_callback: | |
| progress_callback(total_files, total_files, "Complete") | |
| ErrorHandler.log_info( | |
| f"Completed batch processing: {sum(1 for r in results if r.get('success', False))}/{total_files} successful" | |
| ) | |
| return results | |
| def display_batch_results(batch_results: list): | |
| """Renders a clean, consolidated summary of batch processing results using metrics and a pandas DataFrame replacing the old expander list""" | |
| if not batch_results: | |
| st.info("No batch results to display.") | |
| return | |
| successful_runs = [r for r in batch_results if r.get("success", False)] | |
| failed_runs = [r for r in batch_results if not r.get("success", False)] | |
| # 1. High Level Metrics | |
| st.markdown("###### Batch Summary") | |
| metric_cols = st.columns(3) | |
| metric_cols[0].metric("Total Files Processed", f"{len(batch_results)}") | |
| metric_cols[1].metric("✔️ Successful", f"{len(successful_runs)}") | |
| metric_cols[2].metric("❌ Failed", f"{len(failed_runs)}") | |
| # 3 Hidden Failure Details | |
| if failed_runs: | |
| with st.expander( | |
| f"View details for {len(failed_runs)} failed file(s)", expanded=False | |
| ): | |
| for r in failed_runs: | |
| st.error(f"**File:** `{r.get('filename', 'unknown')}`") | |
| st.caption( | |
| f"Reason for failure: {r.get('error', 'No details provided')}" | |
| ) | |
| # Legacy display batch results | |
| # def display_batch_results(results: List[Dict[str, Any]]) -> None: | |
| # """ | |
| # Display batch processing results in the UI | |
| # Args: | |
| # results: List of processing results | |
| # """ | |
| # if not results: | |
| # st.warning("No results to display") | |
| # return | |
| # successful = [r for r in results if r.get("success", False)] | |
| # failed = [r for r in results if not r.get("success", False)] | |
| # # ==Summary== | |
| # col1, col2, col3 = st.columns(3, border=True) | |
| # with col1: | |
| # st.metric("Total Files", len(results)) | |
| # with col2: | |
| # st.metric("Successful", len(successful), | |
| # delta=f"{len(successful)/len(results)*100:.1f}%") | |
| # with col3: | |
| # st.metric("Failed", len( | |
| # failed), delta=f"-{len(failed)/len(results)*100:.1f}%" if failed else "0%") | |
| # # ==Results tabs== | |
| # tab1, tab2 = st.tabs(["✅Successful", "❌ Failed"], width="stretch") | |
| # with tab1: | |
| # with st.expander("Successful"): | |
| # if successful: | |
| # for result in successful: | |
| # with st.expander(f"{result['filename']}", expanded=False): | |
| # col1, col2 = st.columns(2) | |
| # with col1: | |
| # st.write( | |
| # f"**Prediction:** {result['predicted_class']}") | |
| # st.write( | |
| # f"**Confidence:** {result['confidence_emoji']} {result['confidence_level']} ({result['confidence']:.3f})") | |
| # with col2: | |
| # st.write( | |
| # f"**Processing Time:** {result['processing_time']:.3f}s") | |
| # if result['ground_truth'] is not None: | |
| # gt_label = {0: "Stable", 1: "Weathered"}.get( | |
| # result['ground_truth'], "Unknown") | |
| # correct = "✅" if result['prediction'] == result['ground_truth'] else "❌" | |
| # st.write( | |
| # f"**Ground Truth:** {gt_label} {correct}") | |
| # else: | |
| # st.info("No successful results") | |
| # with tab2: | |
| # if failed: | |
| # for result in failed: | |
| # with st.expander(f"❌ {result['filename']}", expanded=False): | |
| # st.error(f"Error: {result.get('error', 'Unknown error')}") | |
| # else: | |
| # st.success("No failed files!") | |
| def create_batch_uploader() -> List: | |
| """ | |
| Create multi-file uploader widget | |
| Returns: | |
| List of uploaded files | |
| """ | |
| uploaded_files = st.file_uploader( | |
| "Upload multiple Raman spectrum files (.txt)", | |
| type="txt", | |
| accept_multiple_files=True, | |
| help="Select multiple .txt files with wavenumber and intensity columns", | |
| key="batch_uploader", | |
| ) | |
| return uploaded_files if uploaded_files else [] | |