Spaces:
Sleeping
Sleeping
devjas1
(FEAT)[Enhanced Results Widget]: Integrate advanced probability breakdown, QC, and provenance export
fe030dd
| """ | |
| Training job management system for ML Hub functionality. | |
| Handles asynchronous training jobs, progress tracking, and result management. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import uuid | |
| import threading | |
| import concurrent.futures | |
| import multiprocessing | |
| from datetime import datetime, timedelta | |
| from typing import Dict, List, Optional, Callable, Any, Tuple | |
| from pathlib import Path | |
| from dataclasses import dataclass, field | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from torch.utils.data import TensorDataset, DataLoader | |
| from sklearn.metrics import confusion_matrix, accuracy_score, f1_score | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from scipy.signal import find_peaks | |
| from scipy.spatial.distance import euclidean | |
| # Add project-specific imports | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from models.registry import choices as model_choices, build as build_model | |
| from utils.training_engine import TrainingEngine | |
| from utils.training_types import ( | |
| TrainingConfig, | |
| TrainingProgress, | |
| TrainingStatus, | |
| CVStrategy, | |
| get_cv_splitter, | |
| ) | |
| from utils.preprocessing import preprocess_spectrum | |
| def spectral_cosine_similarity(y_true: np.ndarray, y_pred: np.ndarray) -> float: | |
| """Calculate cosine similarity between spectral predictions and true values""" | |
| # Reshape if needed for cosine similarity calculation | |
| if y_true.ndim == 1: | |
| y_true = y_true.reshape(1, -1) | |
| if y_pred.ndim == 1: | |
| y_pred = y_pred.reshape(1, -1) | |
| return float(cosine_similarity(y_true, y_pred)[0, 0]) | |
| def peak_matching_score( | |
| spectrum1: np.ndarray, | |
| spectrum2: np.ndarray, | |
| height_threshold: float = 0.1, | |
| distance: int = 5, | |
| ) -> float: | |
| """Calculate peak matching score between two spectra""" | |
| try: | |
| # Find peaks in both spectra | |
| peaks1, _ = find_peaks(spectrum1, height=height_threshold, distance=distance) | |
| peaks2, _ = find_peaks(spectrum2, height=height_threshold, distance=distance) | |
| if len(peaks1) == 0 or len(peaks2) == 0: | |
| return 0.0 | |
| # Calculate matching peaks (within tolerance) | |
| tolerance = 3 # wavenumber tolerance | |
| matches = 0 | |
| for peak1 in peaks1: | |
| for peak2 in peaks2: | |
| if abs(peak1 - peak2) <= tolerance: | |
| matches += 1 | |
| break | |
| # Return normalized matching score | |
| return matches / max(len(peaks1), len(peaks2)) | |
| except: | |
| return 0.0 | |
| def spectral_euclidean_distance(y_true: np.ndarray, y_pred: np.ndarray) -> float: | |
| """Calculate normalized Euclidean distance between spectra""" | |
| try: | |
| distance = euclidean(y_true.flatten(), y_pred.flatten()) | |
| # Normalize by the length of the spectrum | |
| return distance / len(y_true.flatten()) | |
| except: | |
| return float("inf") | |
| def calculate_spectroscopy_metrics( | |
| y_true: np.ndarray, y_pred: np.ndarray, probabilities: Optional[np.ndarray] = None | |
| ) -> Dict[str, float]: | |
| """Calculate comprehensive spectroscopy-specific metrics""" | |
| metrics = {} | |
| try: | |
| # Standard classification metrics | |
| metrics["accuracy"] = accuracy_score(y_true, y_pred) | |
| metrics["f1_score"] = f1_score(y_true, y_pred, average="weighted") | |
| # Spectroscopy-specific metrics | |
| if probabilities is not None and len(probabilities.shape) > 1: | |
| # For classification with probabilities, use cosine similarity on prob distributions | |
| unique_classes = np.unique(y_true) | |
| if len(unique_classes) > 1: | |
| # Convert true labels to one-hot for similarity calculation | |
| y_true_onehot = np.eye(len(unique_classes))[y_true] | |
| metrics["cosine_similarity"] = float( | |
| cosine_similarity( | |
| y_true_onehot.mean(axis=0).reshape(1, -1), | |
| probabilities.mean(axis=0).reshape(1, -1), | |
| )[0, 0] | |
| ) | |
| # Add bias audit metric (class distribution comparison) | |
| unique_true, counts_true = np.unique(y_true, return_counts=True) | |
| unique_pred, counts_pred = np.unique(y_pred, return_counts=True) | |
| # Calculate distribution difference (Jensen-Shannon divergence approximation) | |
| true_dist = counts_true / len(y_true) | |
| pred_dist = np.zeros_like(true_dist) | |
| for i, class_label in enumerate(unique_true): | |
| if class_label in unique_pred: | |
| pred_idx = np.where(unique_pred == class_label)[0][0] | |
| pred_dist[i] = counts_pred[pred_idx] / len(y_pred) | |
| # Simple distribution similarity (1 - average absolute difference) | |
| metrics["distribution_similarity"] = 1.0 - np.mean( | |
| np.abs(true_dist - pred_dist) | |
| ) | |
| except Exception as e: | |
| print(f"Error calculating spectroscopy metrics: {e}") | |
| # Return basic metrics | |
| metrics = { | |
| "accuracy": accuracy_score(y_true, y_pred) if len(y_true) > 0 else 0.0, | |
| "f1_score": ( | |
| f1_score(y_true, y_pred, average="weighted") if len(y_true) > 0 else 0.0 | |
| ), | |
| "cosine_similarity": 0.0, | |
| "distribution_similarity": 0.0, | |
| } | |
| return metrics | |
| class AugmentationConfig: | |
| """Data augmentation configuration""" | |
| enable_augmentation: bool = False | |
| noise_level: float = 0.01 # Noise level for augmentation | |
| class PreprocessingConfig: | |
| """Preprocessing configuration""" | |
| baseline_correction: bool = True | |
| smoothing: bool = True | |
| normalization: bool = True | |
| class TrainingConfig: | |
| """Training configuration parameters""" | |
| model_name: str | |
| dataset_path: str | |
| target_len: int = 500 | |
| batch_size: int = 16 | |
| epochs: int = 10 | |
| learning_rate: float = 1e-3 | |
| num_folds: int = 10 | |
| modality: str = "raman" | |
| device: str = "auto" # auto, cpu, cuda | |
| cv_strategy: str = "stratified_kfold" # New field for CV strategy | |
| spectral_weight: float = 0.1 # Weight for spectroscopy-specific metrics | |
| augmentation: AugmentationConfig = field(default_factory=AugmentationConfig) | |
| preprocessing: PreprocessingConfig = field(default_factory=PreprocessingConfig) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary for serialization""" | |
| return asdict(self) | |
| class TrainingProgress: | |
| """Training progress tracking with enhanced metrics""" | |
| current_fold: int = 0 | |
| total_folds: int = 10 | |
| current_epoch: int = 0 | |
| total_epochs: int = 10 | |
| current_loss: float = 0.0 | |
| current_accuracy: float = 0.0 | |
| fold_accuracies: List[float] = field(default_factory=list) | |
| confusion_matrices: List[List[List[int]]] = field(default_factory=list) | |
| spectroscopy_metrics: List[Dict[str, float]] = field(default_factory=list) | |
| start_time: Optional[datetime] = None | |
| end_time: Optional[datetime] = None | |
| class TrainingJob: | |
| """Training job container""" | |
| job_id: str | |
| config: TrainingConfig | |
| status: TrainingStatus = TrainingStatus.PENDING | |
| progress: TrainingProgress = None | |
| error_message: Optional[str] = None | |
| created_at: datetime = None | |
| started_at: Optional[datetime] = None | |
| completed_at: Optional[datetime] = None | |
| weights_path: Optional[str] = None | |
| logs_path: Optional[str] = None | |
| def __post_init__(self): | |
| if self.progress is None: | |
| self.progress = TrainingProgress( | |
| total_folds=self.config.num_folds, total_epochs=self.config.epochs | |
| ) | |
| if self.created_at is None: | |
| self.created_at = datetime.now() | |
| class TrainingManager: | |
| """Manager for training jobs with async execution and progress tracking""" | |
| def __init__( | |
| self, | |
| max_workers: int = 2, | |
| output_dir: str = "outputs", | |
| use_multiprocessing: bool = True, | |
| ): | |
| self.max_workers = max_workers | |
| self.use_multiprocessing = use_multiprocessing | |
| # Use ProcessPoolExecutor for CPU/GPU-bound tasks, ThreadPoolExecutor for I/O-bound | |
| if use_multiprocessing: | |
| # Limit workers to available CPU cores to prevent oversubscription | |
| actual_workers = min(max_workers, multiprocessing.cpu_count()) | |
| self.executor = concurrent.futures.ProcessPoolExecutor( | |
| max_workers=actual_workers | |
| ) | |
| else: | |
| self.executor = concurrent.futures.ThreadPoolExecutor( | |
| max_workers=max_workers | |
| ) | |
| self.jobs: Dict[str, TrainingJob] = {} | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(exist_ok=True) | |
| (self.output_dir / "weights").mkdir(exist_ok=True) | |
| def generate_job_id(self) -> str: | |
| """Generate unique job ID""" | |
| return f"train_{uuid.uuid4().hex[:8]}_{int(time.time())}" | |
| def submit_training_job( | |
| self, config: TrainingConfig, progress_callback: Optional[Callable] = None | |
| ) -> str: | |
| """Submit a new training job""" | |
| job_id = self.generate_job_id() | |
| job = TrainingJob(job_id=job_id, config=config) | |
| self.jobs[job_id] = job | |
| # Submit to thread pool | |
| self.executor.submit( | |
| self._run_training_job, job, progress_callback=progress_callback | |
| ) | |
| return job_id | |
| def _run_training_job(self, job: TrainingJob) -> None: | |
| """Execute training job (runs in separate thread)""" | |
| try: | |
| job.status = TrainingStatus.RUNNING | |
| job.started_at = datetime.now() | |
| if job.progress: | |
| job.progress.start_time = job.started_at | |
| if progress_callback: | |
| progress_callback(job) | |
| # Load and preprocess data | |
| X, y = self._load_and_preprocess_data(job) | |
| if X is None or y is None: | |
| raise ValueError("Failed to load dataset") | |
| # Define a callback to update the job's progress object | |
| def engine_progress_callback(progress_data: dict): | |
| if job.progress: | |
| if progress_data["type"] == "fold_start": | |
| job.progress.current_fold = progress_data["fold"] | |
| elif progress_data["type"] == "epoch_end": | |
| job.progress.current_epoch = progress_data["epoch"] | |
| job.progress.current_loss = progress_data["loss"] | |
| if progress_callback: | |
| progress_callback(job) | |
| # Instantiate and run the training engine | |
| engine = TrainingEngine(job.config) | |
| results = engine.run(X, y, progress_callback=engine_progress_callback) | |
| # Update job with results | |
| if job.progress: | |
| job.progress.fold_accuracies = results["fold_accuracies"] | |
| job.progress.confusion_matrices = results["confusion_matrices"] | |
| # Save model weights and logs | |
| self._save_model_weights(job, results["model_state_dict"]) | |
| self._save_training_results(job) | |
| job.status = TrainingStatus.COMPLETED | |
| job.completed_at = datetime.now() | |
| job.progress.end_time = job.completed_at | |
| except Exception as e: | |
| job.status = TrainingStatus.FAILED | |
| job.error_message = str(e) | |
| job.completed_at = datetime.now() | |
| finally: | |
| if progress_callback: | |
| progress_callback(job) | |
| def _load_and_preprocess_data( | |
| self, job: TrainingJob | |
| ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: | |
| """Load and preprocess dataset with enhanced validation and security""" | |
| try: | |
| config = job.config | |
| dataset_path = Path(config.dataset_path) | |
| # Enhanced path validation and security | |
| if not dataset_path.exists(): | |
| raise FileNotFoundError(f"Dataset path not found: {dataset_path}") | |
| # Validate dataset path is within allowed directories (security) | |
| try: | |
| dataset_path = dataset_path.resolve() | |
| allowed_bases = [ | |
| Path("datasets").resolve(), | |
| Path("data").resolve(), | |
| Path("/tmp").resolve(), | |
| ] | |
| if not any( | |
| str(dataset_path).startswith(str(base)) for base in allowed_bases | |
| ): | |
| raise ValueError( | |
| f"Dataset path outside allowed directories: {dataset_path}" | |
| ) | |
| except Exception as e: | |
| print(f"Path validation error: {e}") | |
| raise ValueError("Invalid dataset path") | |
| # Load data from dataset directory | |
| X, y = [], [] | |
| total_files = 0 | |
| processed_files = 0 | |
| max_files_per_class = 1000 # Limit to prevent memory issues | |
| max_file_size = 10 * 1024 * 1024 # 10MB per file | |
| # Look for data files in the dataset directory | |
| for label_dir in dataset_path.iterdir(): | |
| if not label_dir.is_dir(): | |
| continue | |
| label = 0 if "stable" in label_dir.name.lower() else 1 | |
| files_in_class = 0 | |
| # Support multiple file formats | |
| file_patterns = ["*.txt", "*.csv", "*.json"] | |
| for pattern in file_patterns: | |
| for file_path in label_dir.glob(pattern): | |
| total_files += 1 | |
| # Security: Check file size | |
| if file_path.stat().st_size > max_file_size: | |
| print( | |
| f"Skipping large file: {file_path} ({file_path.stat().st_size} bytes)" | |
| ) | |
| continue | |
| # Limit files per class | |
| if files_in_class >= max_files_per_class: | |
| print( | |
| f"Reached maximum files per class ({max_files_per_class}) for {label_dir.name}" | |
| ) | |
| break | |
| try: | |
| # Load spectrum data based on file type | |
| if file_path.suffix.lower() == ".txt": | |
| data = np.loadtxt(file_path) | |
| if data.ndim == 2 and data.shape[1] >= 2: | |
| x_raw, y_raw = data[:, 0], data[:, 1] | |
| elif data.ndim == 1: | |
| # Single column data | |
| x_raw = np.arange(len(data)) | |
| y_raw = data | |
| else: | |
| continue | |
| elif file_path.suffix.lower() == ".csv": | |
| import pandas as pd | |
| df = pd.read_csv(file_path) | |
| if df.shape[1] >= 2: | |
| x_raw, y_raw = ( | |
| df.iloc[:, 0].values, | |
| df.iloc[:, 1].values, | |
| ) | |
| else: | |
| x_raw = np.arange(len(df)) | |
| y_raw = df.iloc[:, 0].values | |
| elif file_path.suffix.lower() == ".json": | |
| with open(file_path, "r") as f: | |
| data_dict = json.load(f) | |
| if isinstance(data_dict, dict): | |
| if "x" in data_dict and "y" in data_dict: | |
| x_raw, y_raw = np.array( | |
| data_dict["x"] | |
| ), np.array(data_dict["y"]) | |
| elif "spectrum" in data_dict: | |
| y_raw = np.array(data_dict["spectrum"]) | |
| x_raw = np.arange(len(y_raw)) | |
| else: | |
| continue | |
| else: | |
| continue | |
| else: | |
| continue | |
| # Validate data integrity | |
| if len(x_raw) != len(y_raw) or len(x_raw) < 10: | |
| print( | |
| f"Invalid data in file {file_path}: insufficient data points" | |
| ) | |
| continue | |
| # Check for NaN or infinite values | |
| if np.any(np.isnan(y_raw)) or np.any(np.isinf(y_raw)): | |
| print( | |
| f"Invalid data in file {file_path}: NaN or infinite values" | |
| ) | |
| continue | |
| # Validate reasonable value ranges for spectroscopy | |
| if np.min(y_raw) < -1000 or np.max(y_raw) > 1e6: | |
| print( | |
| f"Suspicious data values in file {file_path}: outside expected range" | |
| ) | |
| continue | |
| # Preprocess spectrum | |
| _, y_processed = preprocess_spectrum( | |
| x_raw, | |
| y_raw, | |
| modality=config.modality, | |
| target_len=config.target_len, | |
| do_baseline=config.baseline_correction, | |
| do_smooth=config.smoothing, | |
| do_normalize=config.normalization, | |
| ) | |
| # Final validation of processed data | |
| if ( | |
| y_processed is None | |
| or len(y_processed) != config.target_len | |
| ): | |
| print(f"Preprocessing failed for file {file_path}") | |
| continue | |
| X.append(y_processed) | |
| y.append(label) | |
| files_in_class += 1 | |
| processed_files += 1 | |
| except Exception as e: | |
| print(f"Error processing file {file_path}: {e}") | |
| continue | |
| # Validate final dataset | |
| if len(X) == 0: | |
| raise ValueError("No valid data files found in dataset") | |
| if len(X) < 10: | |
| raise ValueError( | |
| f"Insufficient data: only {len(X)} samples found (minimum 10 required)" | |
| ) | |
| # Check class balance | |
| unique_labels, counts = np.unique(y, return_counts=True) | |
| if len(unique_labels) < 2: | |
| raise ValueError("Dataset must contain at least 2 classes") | |
| min_class_size = min(counts) | |
| if min_class_size < 3: | |
| raise ValueError( | |
| f"Insufficient samples in one class: minimum {min_class_size} (need at least 3)" | |
| ) | |
| print(f"Dataset loaded: {processed_files}/{total_files} files processed") | |
| print(f"Class distribution: {dict(zip(unique_labels, counts))}") | |
| return np.array(X, dtype=np.float32), np.array(y, dtype=np.int64) | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| return None, None | |
| def _save_model_weights(self, job: TrainingJob, model_state_dict: dict): | |
| """Saves the model's state dictionary to a file.""" | |
| weights_dir = self.output_dir / "weights" | |
| weights_dir.mkdir(exist_ok=True) | |
| job.weights_path = str(weights_dir / f"{job.config.model_name}_model.pth") | |
| torch.save(model_state_dict, job.weights_path) | |
| def _save_training_results(self, job: TrainingJob): | |
| """Save training results and logs with enhanced metrics""" | |
| logs_dir = self.output_dir / "logs" | |
| logs_dir.mkdir(exist_ok=True) | |
| job.logs_path = str(logs_dir / f"{job.job_id}_log.json") | |
| # Calculate comprehensive summary metrics | |
| spectro_summary = {} | |
| if job.progress.spectroscopy_metrics: | |
| # Average across all folds for each metric | |
| metric_keys = job.progress.spectroscopy_metrics[0].keys() | |
| for key in metric_keys: | |
| values = [ | |
| fold_metrics.get(key, 0.0) | |
| for fold_metrics in job.progress.spectroscopy_metrics | |
| ] | |
| spectro_summary[f"mean_{key}"] = float(np.mean(values)) | |
| spectro_summary[f"std_{key}"] = float(np.std(values)) | |
| results = { | |
| "job_id": job.job_id, | |
| "config": job.config.to_dict(), | |
| "status": job.status.value, | |
| "created_at": job.created_at.isoformat(), | |
| "started_at": job.started_at.isoformat() if job.started_at else None, | |
| "completed_at": job.completed_at.isoformat() if job.completed_at else None, | |
| "progress": { | |
| "fold_accuracies": job.progress.fold_accuracies, | |
| "confusion_matrices": job.progress.confusion_matrices, | |
| "spectroscopy_metrics": job.progress.spectroscopy_metrics, | |
| "mean_accuracy": ( | |
| np.mean(job.progress.fold_accuracies) | |
| if job.progress.fold_accuracies | |
| else 0.0 | |
| ), | |
| "std_accuracy": ( | |
| np.std(job.progress.fold_accuracies) | |
| if job.progress.fold_accuracies | |
| else 0.0 | |
| ), | |
| "spectroscopy_summary": spectro_summary, | |
| }, | |
| "weights_path": job.weights_path, | |
| "error_message": job.error_message, | |
| } | |
| if job.logs_path: | |
| with open(job.logs_path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| def get_job_status(self, job_id: str) -> Optional[TrainingJob]: | |
| """Get current status of a training job""" | |
| return self.jobs.get(job_id) | |
| def list_jobs( | |
| self, status_filter: Optional[TrainingStatus] = None | |
| ) -> List[TrainingJob]: | |
| """List all jobs, optionally filtered by status""" | |
| jobs = list(self.jobs.values()) | |
| if status_filter: | |
| jobs = [job for job in jobs if job.status == status_filter] | |
| return sorted(jobs, key=lambda j: j.created_at, reverse=True) | |
| def cancel_job(self, job_id: str) -> bool: | |
| """Cancel a running job""" | |
| job = self.jobs.get(job_id) | |
| if job and job.status == TrainingStatus.RUNNING: | |
| job.status = TrainingStatus.CANCELLED | |
| job.completed_at = datetime.now() | |
| # Note: This is a simple cancellation - actual thread termination is more complex | |
| return True | |
| return False | |
| def cleanup_old_jobs(self, max_age_hours: int = 24): | |
| """Clean up old completed/failed jobs""" | |
| cutoff_time = datetime.now() - timedelta(hours=max_age_hours) | |
| to_remove = [] | |
| for job_id, job in self.jobs.items(): | |
| if ( | |
| job.status | |
| in [ | |
| TrainingStatus.COMPLETED, | |
| TrainingStatus.FAILED, | |
| TrainingStatus.CANCELLED, | |
| ] | |
| and job.completed_at | |
| and job.completed_at < cutoff_time | |
| ): | |
| to_remove.append(job_id) | |
| for job_id in to_remove: | |
| del self.jobs[job_id] | |
| def shutdown(self): | |
| """Shutdown the training manager""" | |
| self.executor.shutdown(wait=True) | |
| # Global training manager instance | |
| _training_manager = None | |
| def get_training_manager() -> TrainingManager: | |
| """Get global training manager instance""" | |
| global _training_manager | |
| if _training_manager is None: | |
| _training_manager = TrainingManager() | |
| return _training_manager | |