Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	devjas1
				
			
		(FEAT)[Enhanced Results Widget]: Integrate advanced probability breakdown, QC, and provenance export
		fe030dd
		
		| """Multi-file processing utilities for batch inference. | |
| Handles multiple file uploads and iterative processing. | |
| Supports TXT, CSV, and JSON file formats with automatic detection.""" | |
| from typing import List, Dict, Any, Tuple, Optional, Union | |
| import time | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import json | |
| import csv | |
| import io | |
| from pathlib import Path | |
| import hashlib | |
| from .preprocessing import preprocess_spectrum | |
| from .errors import ErrorHandler, safe_execute | |
| from .results_manager import ResultsManager | |
| from .confidence import calculate_softmax_confidence | |
| from config import TARGET_LEN | |
| def detect_file_format(filename: str, content: str) -> str: | |
| """Automatically detect file format based on exstention and content | |
| Args: | |
| filename: Name of the file | |
| content: Content of the file | |
| Returns: | |
| File format: .'txt', .'csv', .'json' | |
| """ | |
| # First try by extension | |
| suffix = Path(filename).suffix.lower() | |
| if suffix == ".json": | |
| try: | |
| json.loads(content) | |
| return "json" | |
| except json.JSONDecodeError: | |
| pass | |
| elif suffix == ".csv": | |
| return "csv" | |
| elif suffix == ".txt": | |
| return "txt" | |
| # If extension doesn't match or is unclear, try content detection | |
| content_stripped = content.strip() | |
| # Try JSON | |
| if content_stripped.startswith(("{", "[")): | |
| try: | |
| json.loads(content) | |
| return "json" | |
| except json.JSONDecodeError: | |
| pass | |
| # Try CSV (look for commas in first few lines) | |
| lines = content_stripped.split("\n")[:5] | |
| comma_count = sum(line.count(",") for line in lines) | |
| if comma_count > len(lines): # More commas than lines suggests CSV | |
| return "csv" | |
| # Default to TXT | |
| return "txt" | |
| def parse_json_spectrum(content: str) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Parse spectrum data from JSON format. | |
| Expected formats: | |
| - {"wavenumbers": [...], "intensities": [...]} | |
| - {"x": [...], "y": [...]} | |
| - [{"wavenumber": val, "intensity": val}, ...] | |
| """ | |
| try: | |
| data = json.loads(content) | |
| # Format 1: Object with arrays | |
| if isinstance(data, dict): | |
| x_key = None | |
| y_key = None | |
| # Try common key names for x-axis | |
| for key in ["wavenumbers", "wavenumber", "x", "freq", "frequency"]: | |
| if key in data: | |
| x_key = key | |
| break | |
| # Try common key names for y-axis | |
| for key in ["intensities", "intensity", "y", "counts", "absorbance"]: | |
| if key in data: | |
| y_key = key | |
| break | |
| if x_key and y_key: | |
| x_vals = np.array(data[x_key], dtype=float) | |
| y_vals = np.array(data[y_key], dtype=float) | |
| return x_vals, y_vals | |
| # Format 2: Array of objects | |
| elif isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict): | |
| x_vals = [] | |
| y_vals = [] | |
| for item in data: | |
| # Try to find x and y values | |
| x_val = None | |
| y_val = None | |
| for x_key in ["wavenumber", "wavenumbers", "x", "freq"]: | |
| if x_key in item: | |
| x_val = float(item[x_key]) | |
| break | |
| for y_key in ["intensity", "intensities", "y", "counts"]: | |
| if y_key in item: | |
| y_val = float(item[y_key]) | |
| break | |
| if x_val is not None and y_val is not None: | |
| x_vals.append(x_val) | |
| y_vals.append(y_val) | |
| if x_vals and y_vals: | |
| return np.array(x_vals), np.array(y_vals) | |
| raise ValueError( | |
| "JSON format not recognized. Expected wavenumber/intensity pairs." | |
| ) | |
| except json.JSONDecodeError as e: | |
| raise ValueError(f"Invalid JSON format: {str(e)}") from e | |
| except Exception as e: | |
| raise ValueError(f"Failed to parse JSON spectrum: {str(e)}") from e | |
| def parse_csv_spectrum( | |
| content: str, filename: str = "unknown" | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Parse spectrum data from CSV format. | |
| Handles various CSV formats with headers or without. | |
| """ | |
| try: | |
| # Use StringIO to treat string as file-like object | |
| csv_file = io.StringIO(content) | |
| # Try to detect delimiter | |
| sample = content[:1024] | |
| delimiter = "," | |
| if sample.count(";") > sample.count(","): | |
| delimiter = ";" | |
| elif sample.count("\t") > sample.count(","): | |
| delimiter = "\t" | |
| # Read CSV | |
| csv_reader = csv.reader(csv_file, delimiter=delimiter) | |
| rows = list(csv_reader) | |
| if not rows: | |
| raise ValueError("Empty CSV file") | |
| # Check if first row is header | |
| has_header = False | |
| try: | |
| # If first row contains non-numeric data, it's likely a header | |
| float(rows[0][0]) | |
| float(rows[0][1]) | |
| except (ValueError, IndexError): | |
| has_header = True | |
| data_rows = rows[1:] if has_header else rows | |
| # Extract x and y values | |
| x_vals = [] | |
| y_vals = [] | |
| for i, row in enumerate(data_rows): | |
| if len(row) < 2: | |
| continue | |
| try: | |
| x_val = float(row[0]) | |
| y_val = float(row[1]) | |
| x_vals.append(x_val) | |
| y_vals.append(y_val) | |
| except ValueError: | |
| ErrorHandler.log_warning( | |
| f"Could not parse CSV row {i+1}: {row}", f"Parsing {filename}" | |
| ) | |
| continue | |
| if len(x_vals) < 10: | |
| raise ValueError( | |
| f"Insufficient data points ({len(x_vals)}). Need at least 10 points." | |
| ) | |
| return np.array(x_vals), np.array(y_vals) | |
| except Exception as e: | |
| raise ValueError(f"Failed to parse CSV spectrum: {str(e)}") from e | |
| def parse_spectrum_data( | |
| text_content: str, filename: str = "unknown", file_format: Optional[str] = None | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Parse spectrum data from text content with automatic format detection. | |
| Args: | |
| text_content: Raw text content of the spectrum file | |
| filename: Name of the file for error reporting | |
| file_format: Force specific format ('txt', 'csv', 'json') or None for auto-detection | |
| Returns: | |
| Tuple of (x_values, y_values) as numpy arrays | |
| Raises: | |
| ValueError: If the data cannot be parsed | |
| """ | |
| try: | |
| # Detect format if not specified | |
| if file_format is None: | |
| file_format = detect_file_format(filename, text_content) | |
| # Parse based on detected/specified format | |
| if file_format == "json": | |
| x, y = parse_json_spectrum(text_content) | |
| elif file_format == "csv": | |
| x, y = parse_csv_spectrum(text_content, filename) | |
| else: # Default to TXT format | |
| x, y = parse_txt_spectrum(text_content, filename) | |
| # Common validation for all formats | |
| validate_spectrum_data(x, y, filename) | |
| return x, y | |
| except Exception as e: | |
| raise ValueError(f"Failed to parse spectrum data: {str(e)}") from e | |
| def parse_txt_spectrum( | |
| content: str, filename: str = "unknown" | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """Robustly parse spectrum data from TXT format.""" | |
| lines = content.strip().split("\n") | |
| x_vals, y_vals = [], [] | |
| for i, line in enumerate(lines): | |
| line = line.strip() | |
| if not line or line.startswith(("#", "%")): | |
| continue | |
| try: | |
| # Handle different separators | |
| parts = line.replace(",", " ").replace(";", " ").replace("\t", " ").split() | |
| # Find the first two valid numbers in the line | |
| numbers = [] | |
| for part in parts: | |
| if part: # Skip empty strings from multiple spaces | |
| try: | |
| numbers.append(float(part)) | |
| except ValueError: | |
| continue # Ignore non-numeric parts | |
| if len(numbers) >= 2: | |
| x_vals.append(numbers[0]) | |
| y_vals.append(numbers[1]) | |
| else: | |
| ErrorHandler.log_warning( | |
| f"Could not find two numbers on line {i+1}: '{line}'", | |
| f"Parsing {filename}", | |
| ) | |
| except ValueError as e: | |
| ErrorHandler.log_warning( | |
| f"Error parsing line {i+1}: '{line}'. Error: {e}", | |
| f"Parsing {filename}", | |
| ) | |
| continue | |
| if len(x_vals) < 10: | |
| raise ValueError( | |
| f"Insufficient data points ({len(x_vals)}). Need at least 10 points." | |
| ) | |
| return np.array(x_vals), np.array(y_vals) | |
| def validate_spectrum_data(x: np.ndarray, y: np.ndarray, filename: str) -> None: | |
| """ | |
| Validate parsed spectrum data for common issues. | |
| """ | |
| # Check for NaNs | |
| if np.any(np.isnan(x)) or np.any(np.isnan(y)): | |
| raise ValueError("Input data contains NaN values") | |
| # Check monotonic increasing x (sort if needed) | |
| if not np.all(np.diff(x) >= 0): | |
| # Sort by x values if not monotonic | |
| sort_idx = np.argsort(x) | |
| x = x[sort_idx] | |
| y = y[sort_idx] | |
| ErrorHandler.log_warning( | |
| "Wavenumbers were not monotonic - data has been sorted", | |
| f"Parsing {filename}", | |
| ) | |
| # Check reasonable range for spectroscopy | |
| if min(x) < 0 or max(x) > 10000 or (max(x) - min(x)) < 100: | |
| ErrorHandler.log_warning( | |
| f"Unusual wavenumber range: {min(x):.1f} - {max(x):.1f} cm⁻¹", | |
| f"Parsing {filename}", | |
| ) | |
| def process_single_file( | |
| filename: str, | |
| text_content: str, | |
| model_choice: str, | |
| run_inference_func, | |
| label_file_func, | |
| modality: str, | |
| target_len: int, | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Process a single spectrum file | |
| Args: | |
| filename: Name of the file | |
| text_content: Raw text content | |
| model_choice: Selected model name | |
| run_inference_func: Function to run inference | |
| label_file_func: Function to extract ground truth label | |
| Returns: | |
| Dictionary with processing results or None if failed | |
| """ | |
| start_time = time.time() | |
| try: | |
| # 1. Parse spectrum data | |
| x_raw, y_raw = parse_spectrum_data(text_content, filename) | |
| # 2. Preprocess spectrum using the full, modality-aware pipeline | |
| x_resampled, y_resampled = preprocess_spectrum( | |
| x_raw, y_raw, modality=modality, target_len=target_len | |
| ) | |
| # 3. Run inference, passing modality | |
| cache_key = hashlib.md5( | |
| f"{y_resampled.tobytes()}{model_choice}".encode() | |
| ).hexdigest() | |
| prediction, logits_list, probs, inference_time, logits = run_inference_func( | |
| y_resampled, model_choice, modality=modality, cache_key=cache_key | |
| ) | |
| if prediction is None: | |
| raise ValueError("Inference returned None. Model may have failed to load.") | |
| # ==Calculate confidence== | |
| if logits is not None: | |
| probs_np, max_confidence, confidence_level, confidence_emoji = ( | |
| calculate_softmax_confidence(logits) | |
| ) | |
| else: | |
| # Fallback for older models or if logits are not returned | |
| probs_np = np.array(probs) if probs is not None else np.array([]) | |
| max_confidence = float(np.max(probs_np)) if probs_np.size > 0 else 0.0 | |
| confidence_level = "LOW" | |
| confidence_emoji = "🔴" | |
| # ==Get ground truth== | |
| ground_truth = label_file_func(filename) | |
| ground_truth = ( | |
| ground_truth if ground_truth is not None and ground_truth >= 0 else None | |
| ) | |
| # ==Get predicted class== | |
| label_map = {0: "Stable (Unweathered)", 1: "Weathered (Degraded)"} | |
| predicted_class = label_map.get(int(prediction), f"Unknown ({prediction})") | |
| processing_time = time.time() - start_time | |
| return { | |
| "filename": filename, | |
| "success": True, | |
| "prediction": int(prediction), | |
| "predicted_class": predicted_class, | |
| "confidence": max_confidence, | |
| "confidence_level": confidence_level, | |
| "confidence_emoji": confidence_emoji, | |
| "logits": logits_list if logits_list else [], | |
| "probabilities": probs_np.tolist() if len(probs_np) > 0 else [], | |
| "ground_truth": ground_truth, | |
| "processing_time": processing_time, | |
| "x_raw": x_raw, | |
| "y_raw": y_raw, | |
| "x_resampled": x_resampled, | |
| "y_resampled": y_resampled, | |
| } | |
| except ValueError as e: | |
| ErrorHandler.log_error(e, f"processing {filename}") | |
| return { | |
| "filename": filename, | |
| "success": False, | |
| "error": str(e), | |
| "processing_time": time.time() - start_time, | |
| } | |
| def process_multiple_files( | |
| uploaded_files: List, | |
| model_choice: str, | |
| run_inference_func, | |
| label_file_func, | |
| modality: str, | |
| progress_callback=None, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Process multiple uploaded files | |
| Args: | |
| uploaded_files: List of uploaded file objects | |
| model_choice: Selected model name | |
| run_inference_func: Function to run inference | |
| label_file_func: Function to extract ground truth label | |
| progress_callback: Optional callback to update progress | |
| Returns: | |
| List of processing results | |
| """ | |
| results = [] | |
| total_files = len(uploaded_files) | |
| ErrorHandler.log_info( | |
| f"Starting batch processing of {total_files} files with modality '{modality}'" | |
| ) | |
| for i, uploaded_file in enumerate(uploaded_files): | |
| if progress_callback: | |
| progress_callback(i, total_files, uploaded_file.name) | |
| try: | |
| # ==Read file content== | |
| raw = uploaded_file.read() | |
| text_content = raw.decode("utf-8") if isinstance(raw, bytes) else raw | |
| # ==Process the file== | |
| result = process_single_file( | |
| filename=uploaded_file.name, | |
| text_content=text_content, | |
| model_choice=model_choice, | |
| run_inference_func=run_inference_func, | |
| label_file_func=label_file_func, | |
| modality=modality, | |
| target_len=TARGET_LEN, | |
| ) | |
| if result: | |
| results.append(result) | |
| # ==Add successful results to the results manager== | |
| if result.get("success", False): | |
| ResultsManager.add_results( | |
| filename=result["filename"], | |
| model_name=model_choice, | |
| prediction=result["prediction"], | |
| predicted_class=result["predicted_class"], | |
| confidence=result["confidence"], | |
| logits=result["logits"], | |
| ground_truth=result["ground_truth"], | |
| processing_time=result["processing_time"], | |
| metadata={ | |
| "confidence_level": result["confidence_level"], | |
| "confidence_emoji": result["confidence_emoji"], | |
| # Storing the spectrum data for later visualization | |
| "x_raw": result["x_raw"], | |
| "y_raw": result["y_raw"], | |
| "x_resampled": result["x_resampled"], | |
| "y_resampled": result["y_resampled"], | |
| }, | |
| ) | |
| except ValueError as e: | |
| ErrorHandler.log_error(e, f"reading file {uploaded_file.name}") | |
| results.append( | |
| { | |
| "filename": uploaded_file.name, | |
| "success": False, | |
| "error": f"Failed to read file: {str(e)}", | |
| } | |
| ) | |
| if progress_callback: | |
| progress_callback(total_files, total_files, "Complete") | |
| ErrorHandler.log_info( | |
| f"Completed batch processing: {sum(1 for r in results if r.get('success', False))}/{total_files} successful" | |
| ) | |
| return results | |
