polymer-aging-ml / utils /batch_processing.py
devjas1
(TASK)[Batch Processing Utilities]: Add module to support batch comparison and processing of spectral files.
5543304
raw
history blame
10.7 kB
"""This file provides utilities for **batch processing** spectral data files (such as Raman spectra) for polymer classification. Its main goal is to process multiple files efficiently—either synchronously or asynchronously—using one or more machine learning models, and to collect, summarize, and export the results. It is designed for integration with a Streamlit-based UI, supporting file uploads and batch inference."""
import os
import time
import json
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
from dataclasses import dataclass, asdict
import pandas as pd
import numpy as np
import streamlit as st
from utils.preprocessing import preprocess_spectrum
from utils.multifile import parse_spectrum_data
from utils.async_inference import submit_batch_inference, wait_for_batch_completion
from core_logic import run_inference
@dataclass
class BatchProcessingResult:
"""Result from batch processing operation."""
filename: str
model_name: str
prediction: int
confidence: float
logits: List[float]
inference_time: float
status: str = "success"
error: Optional[str] = None
ground_truth: Optional[int] = None
class BatchProcessor:
"""Handles batch processing of spectral data files."""
def __init__(self, modality: str = "raman"):
self.modality = modality
self.results: List[BatchProcessingResult] = []
def process_files_sync(
self,
file_data: List[Tuple[str, str]], # (filename, content)
model_names: List[str],
target_len: int = 500,
) -> List[BatchProcessingResult]:
"""Process files synchronously."""
results = []
for filename, content in file_data:
for model_name in model_names:
try:
# Parse spectrum data
x_raw, y_raw = parse_spectrum_data(content)
# Preprocess
x_proc, y_proc = preprocess_spectrum(
x_raw, y_raw, modality=self.modality, target_len=target_len
)
# Run inference
start_time = time.time()
prediction, logits_list, probs, inference_time, logits = (
run_inference(y_proc, model_name)
)
if prediction is not None:
confidence = max(probs) if probs is not None else 0.0
result = BatchProcessingResult(
filename=filename,
model_name=model_name,
prediction=int(prediction),
confidence=confidence,
logits=logits_list or [],
inference_time=inference_time or 0.0,
ground_truth=self._extract_ground_truth(filename),
)
else:
result = BatchProcessingResult(
filename=filename,
model_name=model_name,
prediction=-1,
confidence=0.0,
logits=[],
inference_time=0.0,
status="failed",
error="Inference failed",
)
results.append(result)
except Exception as e:
result = BatchProcessingResult(
filename=filename,
model_name=model_name,
prediction=-1,
confidence=0.0,
logits=[],
inference_time=0.0,
status="failed",
error=str(e),
)
results.append(result)
self.results.extend(results)
return results
def process_files_async(
self,
file_data: List[Tuple[str, str]],
model_names: List[str],
target_len: int = 500,
max_concurrent: int = 3,
) -> List[BatchProcessingResult]:
"""Process files asynchronously."""
results = []
# Process files in chunks to manage concurrency
chunk_size = max_concurrent
file_chunks = [
file_data[i : i + chunk_size] for i in range(0, len(file_data), chunk_size)
]
for chunk in file_chunks:
chunk_results = self._process_chunk_async(chunk, model_names, target_len)
results.extend(chunk_results)
self.results.extend(results)
return results
def _process_chunk_async(
self, file_chunk: List[Tuple[str, str]], model_names: List[str], target_len: int
) -> List[BatchProcessingResult]:
"""Process a chunk of files asynchronously."""
results = []
for filename, content in file_chunk:
try:
# Parse and preprocess
x_raw, y_raw = parse_spectrum_data(content)
x_proc, y_proc = preprocess_spectrum(
x_raw, y_raw, modality=self.modality, target_len=target_len
)
# Submit async inference for all models
task_ids = submit_batch_inference(
model_names=model_names,
input_data=y_proc,
inference_func=run_inference,
)
# Wait for completion
inference_results = wait_for_batch_completion(task_ids, timeout=60.0)
# Process results
for model_name in model_names:
if model_name in inference_results:
model_result = inference_results[model_name]
if "error" not in model_result:
prediction, logits_list, probs, inference_time, logits = (
model_result
)
confidence = max(probs) if probs else 0.0
result = BatchProcessingResult(
filename=filename,
model_name=model_name,
prediction=prediction or -1,
confidence=confidence,
logits=logits_list or [],
inference_time=inference_time or 0.0,
ground_truth=self._extract_ground_truth(filename),
)
else:
result = BatchProcessingResult(
filename=filename,
model_name=model_name,
prediction=-1,
confidence=0.0,
logits=[],
inference_time=0.0,
status="failed",
error=model_result["error"],
)
else:
result = BatchProcessingResult(
filename=filename,
model_name=model_name,
prediction=-1,
confidence=0.0,
logits=[],
inference_time=0.0,
status="failed",
error="No result received",
)
results.append(result)
except Exception as e:
# Create error results for all models
for model_name in model_names:
result = BatchProcessingResult(
filename=filename,
model_name=model_name,
prediction=-1,
confidence=0.0,
logits=[],
inference_time=0.0,
status="failed",
error=str(e),
)
results.append(result)
return results
def _extract_ground_truth(self, filename: str) -> Optional[int]:
"""Extract ground truth label from filename."""
try:
from core_logic import label_file
return label_file(filename)
except:
return None
def get_summary_statistics(self) -> Dict[str, Any]:
"""Calculate summary statistics for batch processing results."""
if not self.results:
return {}
successful_results = [r for r in self.results if r.status == "success"]
failed_results = [r for r in self.results if r.status == "failed"]
stats = {
"total_files": len(set(r.filename for r in self.results)),
"total_inferences": len(self.results),
"successful_inferences": len(successful_results),
"failed_inferences": len(failed_results),
"success_rate": (
len(successful_results) / len(self.results) if self.results else 0
),
"models_used": list(set(r.model_name for r in self.results)),
"average_inference_time": (
np.mean([r.inference_time for r in successful_results])
if successful_results
else 0
),
"total_processing_time": sum(r.inference_time for r in successful_results),
}
# Calculate accuracy if ground truth is available
gt_results = [r for r in successful_results if r.ground_truth is not None]
if gt_results:
correct_predictions = sum(
1 for r in gt_results if r.prediction == r.ground_truth
)
stats["accuracy"] = correct_predictions / len(gt_results)
stats["samples_with_ground_truth"] = len(gt_results)
return stats
def export_results(self, format: str = "csv") -> str:
"""Export results to specified format."""
# Placeholder implementation to ensure a string is always returned
return "Export functionality not implemented yet."