Spaces:
Sleeping
Sleeping
| """ | |
| Specialized Medical AI Model Router - Phase 3 | |
| Routes structured medical data to appropriate specialized AI models. | |
| This module integrates with the preprocessing pipeline to provide model-specific | |
| preprocessing, inference, and confidence scoring for medical AI analysis. | |
| Author: MiniMax Agent | |
| Date: 2025-10-29 | |
| Version: 1.0.0 | |
| """ | |
| import os | |
| import logging | |
| import asyncio | |
| import time | |
| from typing import Dict, List, Optional, Any, Tuple, Union | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # Import existing model infrastructure | |
| from model_loader import MedicalModelLoader | |
| # Import new preprocessing components | |
| from preprocessing_pipeline import ProcessingPipelineResult | |
| from medical_schemas import ( | |
| ValidationResult, ConfidenceScore, ECGAnalysis, RadiologyAnalysis, | |
| LaboratoryResults, ClinicalNotesAnalysis | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ModelInferenceResult: | |
| """Result of specialized model inference""" | |
| model_name: str | |
| input_data: Dict[str, Any] | |
| output_data: Dict[str, Any] | |
| confidence_score: float | |
| processing_time: float | |
| model_metadata: Dict[str, Any] | |
| warnings: List[str] | |
| errors: List[str] | |
| class SpecializedModelConfig: | |
| """Configuration for specialized medical models""" | |
| model_name: str | |
| model_type: str # "classification", "segmentation", "generation", "extraction" | |
| input_format: str # "ecg_signal", "dicom_image", "clinical_text", "lab_values" | |
| output_schema: str # Schema name for output validation | |
| preprocessing_required: bool | |
| gpu_memory_mb: Optional[int] | |
| timeout_seconds: int | |
| fallback_models: List[str] | |
| class SpecializedModelRouter: | |
| """Routes structured medical data to specialized AI models""" | |
| def __init__(self, model_loader: Optional[MedicalModelLoader] = None): | |
| self.model_loader = model_loader or MedicalModelLoader() | |
| self.model_configs = self._initialize_model_configs() | |
| self.model_cache = {} | |
| self.inference_stats = { | |
| "total_inferences": 0, | |
| "successful_inferences": 0, | |
| "average_processing_time": 0.0, | |
| "model_usage_counts": {}, | |
| "error_counts": {} | |
| } | |
| logger.info("Specialized Model Router initialized") | |
| def _initialize_model_configs(self) -> Dict[str, SpecializedModelConfig]: | |
| """Initialize configuration for specialized medical models""" | |
| return { | |
| # ECG Models | |
| "hubert_ecg": SpecializedModelConfig( | |
| model_name=" superh transformercs/HubERT-ECG", | |
| model_type="classification", | |
| input_format="ecg_signal", | |
| output_schema="ECGAnalysis", | |
| preprocessing_required=True, | |
| gpu_memory_mb=4096, | |
| timeout_seconds=30, | |
| fallback_models=["bio_clinicalbert"] | |
| ), | |
| # Radiology Models | |
| "monai_unetr": SpecializedModelConfig( | |
| model_name="monai/UNet", # Will be loaded from local or remote | |
| model_type="segmentation", | |
| input_format="dicom_image", | |
| output_schema="RadiologyAnalysis", | |
| preprocessing_required=True, | |
| gpu_memory_mb=8192, | |
| timeout_seconds=60, | |
| fallback_models=["generic_segmentation"] | |
| ), | |
| # Clinical Text Models | |
| "medgemma": SpecializedModelConfig( | |
| model_name="google/medgemma-4b", # Placeholder for actual MedGemma model | |
| model_type="generation", | |
| input_format="clinical_text", | |
| output_schema="ClinicalNotesAnalysis", | |
| preprocessing_required=True, | |
| gpu_memory_mb=16384, | |
| timeout_seconds=45, | |
| fallback_models=["bio_clinicalbert", "pubmedbert"] | |
| ), | |
| # Laboratory Models | |
| "biomedical_ner": SpecializedModelConfig( | |
| model_name="Clinical-AI-Apollo/BiomedNLP-PubMedBERT-base-uncased-abstract", | |
| model_type="extraction", | |
| input_format="lab_text", | |
| output_schema="LaboratoryResults", | |
| preprocessing_required=False, | |
| gpu_memory_mb=2048, | |
| timeout_seconds=20, | |
| fallback_models=["scibert"] | |
| ), | |
| # Generic fallback models | |
| "bio_clinicalbert": SpecializedModelConfig( | |
| model_name="emilyalsentzer/Bio_ClinicalBERT", | |
| model_type="classification", | |
| input_format="clinical_text", | |
| output_schema="ClinicalNotesAnalysis", | |
| preprocessing_required=False, | |
| gpu_memory_mb=1024, | |
| timeout_seconds=15, | |
| fallback_models=[] | |
| ), | |
| "pubmedbert": SpecializedModelConfig( | |
| model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", | |
| model_type="classification", | |
| input_format="clinical_text", | |
| output_schema="ClinicalNotesAnalysis", | |
| preprocessing_required=False, | |
| gpu_memory_mb=1024, | |
| timeout_seconds=15, | |
| fallback_models=[] | |
| ) | |
| } | |
| async def route_and_infer(self, pipeline_result: ProcessingPipelineResult) -> ModelInferenceResult: | |
| """ | |
| Route structured data to appropriate specialized model and perform inference | |
| Args: | |
| pipeline_result: Result from preprocessing pipeline | |
| Returns: | |
| ModelInferenceResult with model output and confidence | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Step 1: Determine optimal model routing | |
| model_config = self._select_optimal_model(pipeline_result) | |
| # Step 2: Validate input data format | |
| input_validation = self._validate_input_format(pipeline_result, model_config) | |
| if not input_validation["is_valid"]: | |
| logger.warning(f"Input validation failed: {input_validation['errors']}") | |
| return self._create_error_result(model_config.model_name, input_validation["errors"]) | |
| # Step 3: Preprocess input data for model | |
| preprocessed_input = await self._preprocess_for_model(pipeline_result, model_config) | |
| # Step 4: Perform model inference | |
| inference_result = await self._perform_model_inference(preprocessed_input, model_config) | |
| # Step 5: Post-process and validate output | |
| final_output = self._postprocess_model_output(inference_result, model_config) | |
| # Step 6: Calculate confidence score | |
| confidence_score = self._calculate_model_confidence( | |
| pipeline_result, model_config, final_output | |
| ) | |
| processing_time = time.time() - start_time | |
| # Update statistics | |
| self._update_inference_stats(model_config.model_name, True, processing_time) | |
| return ModelInferenceResult( | |
| model_name=model_config.model_name, | |
| input_data=preprocessed_input, | |
| output_data=final_output, | |
| confidence_score=confidence_score, | |
| processing_time=processing_time, | |
| model_metadata={ | |
| "model_config": model_config.__dict__, | |
| "input_validation": input_validation, | |
| "pipeline_confidence": pipeline_result.validation_result.compliance_score | |
| }, | |
| warnings=[], | |
| errors=[] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Model routing/inference error: {str(e)}") | |
| # Try fallback model | |
| fallback_result = await self._try_fallback_model(pipeline_result) | |
| if fallback_result: | |
| return fallback_result | |
| # Return error result | |
| error_result = ModelInferenceResult( | |
| model_name="error", | |
| input_data={}, | |
| output_data={"error": str(e)}, | |
| confidence_score=0.0, | |
| processing_time=time.time() - start_time, | |
| model_metadata={"error": str(e)}, | |
| warnings=[], | |
| errors=[str(e)] | |
| ) | |
| self._update_inference_stats("error", False, time.time() - start_time) | |
| return error_result | |
| def _select_optimal_model(self, pipeline_result: ProcessingPipelineResult) -> SpecializedModelConfig: | |
| """Select optimal model based on data type and quality""" | |
| # Extract document type from pipeline result | |
| doc_type = "unknown" | |
| confidence = pipeline_result.validation_result.compliance_score | |
| if "ECG" in pipeline_result.file_detection.file_type.value: | |
| doc_type = "ecg" | |
| elif "radiology" in pipeline_result.file_detection.file_type.value: | |
| doc_type = "radiology" | |
| elif "laboratory" in pipeline_result.file_detection.file_type.value: | |
| doc_type = "laboratory" | |
| elif "clinical" in pipeline_result.file_detection.file_type.value: | |
| doc_type = "clinical" | |
| # Model selection logic | |
| if doc_type == "ecg" and confidence > 0.8: | |
| return self.model_configs["hubert_ecg"] | |
| elif doc_type == "radiology" and confidence > 0.7: | |
| return self.model_configs["monai_unetr"] | |
| elif doc_type == "clinical" and confidence > 0.6: | |
| return self.model_configs["medgemma"] | |
| elif doc_type == "laboratory": | |
| return self.model_configs["biomedical_ner"] | |
| else: | |
| # Use general biomedical model for low confidence or unknown types | |
| return self.model_configs["bio_clinicalbert"] | |
| def _validate_input_format(self, pipeline_result: ProcessingPipelineResult, | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Validate input data format for the selected model""" | |
| validation_result = { | |
| "is_valid": True, | |
| "errors": [], | |
| "warnings": [], | |
| "input_checks": {} | |
| } | |
| try: | |
| # Check required fields based on input format | |
| if model_config.input_format == "ecg_signal": | |
| validation_result["input_checks"] = self._validate_ecg_input(pipeline_result) | |
| elif model_config.input_format == "dicom_image": | |
| validation_result["input_checks"] = self._validate_dicom_input(pipeline_result) | |
| elif model_config.input_format in ["clinical_text", "lab_text"]: | |
| validation_result["input_checks"] = self._validate_text_input(pipeline_result) | |
| # Apply validation rules | |
| for check_name, check_result in validation_result["input_checks"].items(): | |
| if not check_result["passed"]: | |
| validation_result["is_valid"] = False | |
| validation_result["errors"].append(f"{check_name}: {check_result['error']}") | |
| except Exception as e: | |
| validation_result["is_valid"] = False | |
| validation_result["errors"].append(f"Validation error: {str(e)}") | |
| return validation_result | |
| def _validate_ecg_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]: | |
| """Validate ECG signal input format""" | |
| checks = {} | |
| # Check if we have signal data | |
| if hasattr(pipeline_result.extraction_result, 'signal_data'): | |
| signal_data = pipeline_result.extraction_result.signal_data | |
| checks["has_signal_data"] = { | |
| "passed": bool(signal_data), | |
| "error": "No ECG signal data found" if not signal_data else None | |
| } | |
| # Check sampling rate | |
| if hasattr(pipeline_result.extraction_result, 'sampling_rate'): | |
| sampling_rate = pipeline_result.extraction_result.sampling_rate | |
| checks["adequate_sampling_rate"] = { | |
| "passed": sampling_rate >= 250, # Minimum 250 Hz for ECG | |
| "error": f"Sampling rate {sampling_rate} Hz too low for ECG analysis" if sampling_rate < 250 else None | |
| } | |
| # Check signal duration | |
| if hasattr(pipeline_result.extraction_result, 'duration'): | |
| duration = pipeline_result.extraction_result.duration | |
| checks["adequate_duration"] = { | |
| "passed": duration >= 5.0, # Minimum 5 seconds | |
| "error": f"Signal duration {duration:.1f}s too short for analysis" if duration < 5.0 else None | |
| } | |
| else: | |
| checks["has_signal_data"] = { | |
| "passed": False, | |
| "error": "Extraction result does not contain ECG signal data" | |
| } | |
| return checks | |
| def _validate_dicom_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]: | |
| """Validate DICOM image input format""" | |
| checks = {} | |
| if hasattr(pipeline_result.extraction_result, 'image_data'): | |
| image_data = pipeline_result.extraction_result.image_data | |
| checks["has_image_data"] = { | |
| "passed": bool(image_data.size > 0), | |
| "error": "No image data found" if image_data.size == 0 else None | |
| } | |
| # Check image dimensions | |
| if image_data.size > 0: | |
| checks["adequate_resolution"] = { | |
| "passed": min(image_data.shape) >= 64, | |
| "error": f"Image resolution too low: {image_data.shape}" if min(image_data.shape) < 64 else None | |
| } | |
| else: | |
| checks["has_image_data"] = { | |
| "passed": False, | |
| "error": "Extraction result does not contain DICOM image data" | |
| } | |
| return checks | |
| def _validate_text_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]: | |
| """Validate text input format""" | |
| checks = {} | |
| # Check for text content | |
| if hasattr(pipeline_result.extraction_result, 'raw_text'): | |
| text = pipeline_result.extraction_result.raw_text | |
| checks["has_text_content"] = { | |
| "passed": bool(text and len(text.strip()) > 50), | |
| "error": "Insufficient text content for analysis" if not text or len(text.strip()) <= 50 else None | |
| } | |
| else: | |
| checks["has_text_content"] = { | |
| "passed": False, | |
| "error": "No text content found in extraction result" | |
| } | |
| return checks | |
| async def _preprocess_for_model(self, pipeline_result: ProcessingPipelineResult, | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Preprocess input data for model-specific requirements""" | |
| if not model_config.preprocessing_required: | |
| # Return structured data as-is for models that don't need preprocessing | |
| return { | |
| "raw_data": pipeline_result.structured_data, | |
| "metadata": pipeline_result.pipeline_metadata, | |
| "validation_result": pipeline_result.validation_result | |
| } | |
| try: | |
| if model_config.input_format == "ecg_signal": | |
| return await self._preprocess_ecg_signal(pipeline_result, model_config) | |
| elif model_config.input_format == "dicom_image": | |
| return await self._preprocess_dicom_image(pipeline_result, model_config) | |
| elif model_config.input_format in ["clinical_text", "lab_text"]: | |
| return await self._preprocess_clinical_text(pipeline_result, model_config) | |
| else: | |
| return {"raw_data": pipeline_result.structured_data} | |
| except Exception as e: | |
| logger.error(f"Preprocessing error: {str(e)}") | |
| return {"raw_data": pipeline_result.structured_data, "preprocessing_error": str(e)} | |
| async def _preprocess_ecg_signal(self, pipeline_result: ProcessingPipelineResult, | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Preprocess ECG signal data for HuBERT-ECG model""" | |
| extraction_result = pipeline_result.extraction_result | |
| # Prepare ECG signal in format expected by HuBERT-ECG | |
| ecg_input = { | |
| "signals": extraction_result.signal_data, | |
| "sampling_rate": extraction_result.sampling_rate, | |
| "duration": extraction_result.duration, | |
| "leads": extraction_result.lead_names | |
| } | |
| # Add preprocessing metadata | |
| preprocessing_metadata = { | |
| "original_sampling_rate": extraction_result.sampling_rate, | |
| "resampled": False, # Would implement resampling if needed | |
| "filtered": True, # Assuming signal was already filtered | |
| "segment_length_seconds": min(10.0, extraction_result.duration) # Use up to 10 seconds | |
| } | |
| return { | |
| "ecg_data": ecg_input, | |
| "preprocessing_metadata": preprocessing_metadata, | |
| "model_ready": True | |
| } | |
| async def _preprocess_dicom_image(self, pipeline_result: ProcessingPipelineResult, | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Preprocess DICOM image data for MONAI UNETR""" | |
| extraction_result = pipeline_result.extraction_result | |
| # Prepare image data for MONAI | |
| image_input = { | |
| "image_array": extraction_result.image_data, | |
| "spacing": extraction_result.pixel_spacing, | |
| "modality": extraction_result.modality, | |
| "body_part": extraction_result.body_part | |
| } | |
| # Add preprocessing metadata | |
| preprocessing_metadata = { | |
| "window_level": self._get_window_settings(extraction_result.modality), | |
| "normalized": True, | |
| "resized": False, # Would implement resizing if needed | |
| "channels_added": True # MONAI expects channel dimension | |
| } | |
| return { | |
| "dicom_data": image_input, | |
| "preprocessing_metadata": preprocessing_metadata, | |
| "model_ready": True | |
| } | |
| async def _preprocess_clinical_text(self, pipeline_result: ProcessingPipelineResult, | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Preprocess clinical text for MedGemma or biomedical models""" | |
| extraction_result = pipeline_result.extraction_result | |
| # Extract text content | |
| if hasattr(extraction_result, 'raw_text'): | |
| text_content = extraction_result.raw_text | |
| elif hasattr(extraction_result, 'structured_data'): | |
| text_content = str(extraction_result.structured_data) | |
| else: | |
| text_content = str(pipeline_result.structured_data) | |
| # Prepare text for model | |
| text_input = { | |
| "raw_text": text_content, | |
| "document_type": pipeline_result.file_detection.file_type.value, | |
| "deidentified": pipeline_result.deidentification_result is not None | |
| } | |
| # Add preprocessing metadata | |
| preprocessing_metadata = { | |
| "tokenized": False, # Will be done by model | |
| "max_length": 512, # Typical max sequence length | |
| "language": "en", | |
| "medical_domain": self._extract_medical_domain(pipeline_result) | |
| } | |
| return { | |
| "text_data": text_input, | |
| "preprocessing_metadata": preprocessing_metadata, | |
| "model_ready": True | |
| } | |
| def _get_window_settings(self, modality: str) -> Dict[str, float]: | |
| """Get appropriate window settings for medical imaging""" | |
| window_configs = { | |
| "CT": {"level": 40, "width": 400}, # Lung window | |
| "MRI": {"level": 0, "width": 500}, # Brain window | |
| "XRAY": {"level": 0, "width": 1000} # General window | |
| } | |
| return window_configs.get(modality, {"level": 0, "width": 500}) | |
| def _extract_medical_domain(self, pipeline_result: ProcessingPipelineResult) -> str: | |
| """Extract medical domain from pipeline result""" | |
| file_type = pipeline_result.file_detection.file_type.value | |
| if "ecg" in file_type or "ECG" in file_type: | |
| return "cardiology" | |
| elif "radiology" in file_type: | |
| return "radiology" | |
| elif "laboratory" in file_type: | |
| return "laboratory" | |
| elif "clinical" in file_type: | |
| return "clinical" | |
| else: | |
| return "general" | |
| async def _perform_model_inference(self, preprocessed_input: Dict[str, Any], | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Perform inference using the specialized model""" | |
| try: | |
| if model_config.model_type == "classification": | |
| return await self._perform_classification_inference(preprocessed_input, model_config) | |
| elif model_config.model_type == "segmentation": | |
| return await self._perform_segmentation_inference(preprocessed_input, model_config) | |
| elif model_config.model_type == "generation": | |
| return await self._perform_generation_inference(preprocessed_input, model_config) | |
| elif model_config.model_type == "extraction": | |
| return await self._perform_extraction_inference(preprocessed_input, model_config) | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_config.model_type}") | |
| except Exception as e: | |
| logger.error(f"Model inference error: {str(e)}") | |
| raise | |
| async def _perform_classification_inference(self, preprocessed_input: Dict[str, Any], | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Perform classification inference (e.g., ECG rhythm classification)""" | |
| # Use existing model loader for classification tasks | |
| model_key = "bio_clinicalbert" # Use biomedical model for now | |
| try: | |
| # Prepare input for model | |
| if "ecg_data" in preprocessed_input: | |
| # ECG classification | |
| ecg_data = preprocessed_input["ecg_data"] | |
| text_input = f"ECG Analysis: {len(ecg_data['signals'])} leads, {ecg_data['duration']:.1f}s duration" | |
| else: | |
| text_input = preprocessed_input.get("text_data", {}).get("raw_text", "") | |
| # Perform inference using model loader | |
| result = await self.model_loader.run_inference( | |
| model_key, | |
| text_input, | |
| {"max_new_tokens": 200, "task": "classification"} | |
| ) | |
| return { | |
| "model_output": result, | |
| "classification_type": "medical_document_classification", | |
| "confidence": 0.8 # Default confidence | |
| } | |
| except Exception as e: | |
| logger.error(f"Classification inference error: {str(e)}") | |
| raise | |
| async def _perform_segmentation_inference(self, preprocessed_input: Dict[str, Any], | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Perform segmentation inference (e.g., organ segmentation in medical images)""" | |
| try: | |
| dicom_data = preprocessed_input["dicom_data"] | |
| image_array = dicom_data["image_array"] | |
| modality = dicom_data["modality"] | |
| # Placeholder segmentation result | |
| # In real implementation, would use MONAI UNETR | |
| segmentation_result = { | |
| "segmentation_mask": np.random.rand(*image_array.shape) > 0.7, # Placeholder | |
| "organ_detected": f"{modality.lower()}_tissue", | |
| "volume_estimate_ml": np.prod(image_array.shape) * 0.001, # Placeholder | |
| "confidence": 0.75 | |
| } | |
| return { | |
| "model_output": segmentation_result, | |
| "segmentation_type": f"{modality}_segmentation" | |
| } | |
| except Exception as e: | |
| logger.error(f"Segmentation inference error: {str(e)}") | |
| raise | |
| async def _perform_generation_inference(self, preprocessed_input: Dict[str, Any], | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Perform text generation inference (e.g., clinical summary generation)""" | |
| try: | |
| text_data = preprocessed_input["text_data"] | |
| raw_text = text_data["raw_text"] | |
| # Use biomedical model for text generation | |
| model_key = "bio_clinicalbert" | |
| # Prepare generation prompt | |
| prompt = f"Analyze the following medical text and provide a structured summary:\n\n{raw_text}" | |
| # Perform inference | |
| result = await self.model_loader.run_inference( | |
| model_key, | |
| prompt, | |
| {"max_new_tokens": 300, "task": "generation"} | |
| ) | |
| return { | |
| "model_output": result, | |
| "generation_type": "clinical_summary", | |
| "original_length": len(raw_text), | |
| "generated_length": len(str(result)) | |
| } | |
| except Exception as e: | |
| logger.error(f"Generation inference error: {str(e)}") | |
| raise | |
| async def _perform_extraction_inference(self, preprocessed_input: Dict[str, Any], | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Perform extraction inference (e.g., lab value extraction)""" | |
| try: | |
| text_data = preprocessed_input["text_data"] | |
| raw_text = text_data["raw_text"] | |
| # Use biomedical NER model for extraction | |
| model_key = "biomedical_ner_all" | |
| # Perform NER extraction | |
| result = await self.model_loader.run_inference( | |
| model_key, | |
| raw_text, | |
| {"task": "ner", "aggregation_strategy": "simple"} | |
| ) | |
| return { | |
| "model_output": result, | |
| "extraction_type": "medical_entities", | |
| "entities_found": len(result) if isinstance(result, list) else 0 | |
| } | |
| except Exception as e: | |
| logger.error(f"Extraction inference error: {str(e)}") | |
| raise | |
| def _postprocess_model_output(self, inference_result: Dict[str, Any], | |
| model_config: SpecializedModelConfig) -> Dict[str, Any]: | |
| """Post-process model output to match expected schema""" | |
| try: | |
| model_output = inference_result["model_output"] | |
| # Convert to appropriate schema format | |
| if model_config.output_schema == "ECGAnalysis": | |
| return self._convert_to_ecg_schema(model_output, inference_result) | |
| elif model_config.output_schema == "RadiologyAnalysis": | |
| return self._convert_to_radiology_schema(model_output, inference_result) | |
| elif model_config.output_schema == "LaboratoryResults": | |
| return self._convert_to_laboratory_schema(model_output, inference_result) | |
| elif model_config.output_schema == "ClinicalNotesAnalysis": | |
| return self._convert_to_clinical_notes_schema(model_output, inference_result) | |
| else: | |
| return {"model_output": model_output, "schema": "generic"} | |
| except Exception as e: | |
| logger.error(f"Post-processing error: {str(e)}") | |
| return {"model_output": inference_result.get("model_output", {}), "error": str(e)} | |
| def _convert_to_ecg_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]: | |
| """Convert model output to ECG schema format""" | |
| # This would convert model-specific ECG output to the canonical ECGAnalysis schema | |
| return { | |
| "model_output": model_output, | |
| "schema": "ECGAnalysis", | |
| "postprocessed": True | |
| } | |
| def _convert_to_radiology_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]: | |
| """Convert model output to radiology schema format""" | |
| return { | |
| "model_output": model_output, | |
| "schema": "RadiologyAnalysis", | |
| "postprocessed": True | |
| } | |
| def _convert_to_laboratory_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]: | |
| """Convert model output to laboratory schema format""" | |
| return { | |
| "model_output": model_output, | |
| "schema": "LaboratoryResults", | |
| "postprocessed": True | |
| } | |
| def _convert_to_clinical_notes_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]: | |
| """Convert model output to clinical notes schema format""" | |
| return { | |
| "model_output": model_output, | |
| "schema": "ClinicalNotesAnalysis", | |
| "postprocessed": True | |
| } | |
| def _calculate_model_confidence(self, pipeline_result: ProcessingPipelineResult, | |
| model_config: SpecializedModelConfig, | |
| model_output: Dict[str, Any]) -> float: | |
| """Calculate confidence score for model inference""" | |
| try: | |
| # Base confidence from pipeline | |
| pipeline_confidence = pipeline_result.validation_result.compliance_score | |
| # Model-specific confidence adjustments | |
| model_confidence = 0.8 # Default high confidence for specialized models | |
| # Adjust based on model type | |
| if model_config.model_type == "classification": | |
| model_confidence = 0.85 | |
| elif model_config.model_type == "segmentation": | |
| model_confidence = 0.80 | |
| elif model_config.model_type == "generation": | |
| model_confidence = 0.75 | |
| elif model_config.model_type == "extraction": | |
| model_confidence = 0.90 | |
| # Check for model output quality | |
| if "error" in model_output: | |
| model_confidence *= 0.3 # Reduce confidence for error outputs | |
| # Calculate weighted confidence | |
| overall_confidence = (0.4 * pipeline_confidence + 0.6 * model_confidence) | |
| return min(1.0, max(0.0, overall_confidence)) | |
| except Exception as e: | |
| logger.error(f"Confidence calculation error: {str(e)}") | |
| return 0.5 | |
| async def _try_fallback_model(self, pipeline_result: ProcessingPipelineResult) -> Optional[ModelInferenceResult]: | |
| """Try fallback model when primary model fails""" | |
| try: | |
| # Use generic biomedical model as fallback | |
| fallback_config = self.model_configs["bio_clinicalbert"] | |
| # Prepare generic text input | |
| text_input = str(pipeline_result.structured_data) | |
| # Perform inference with fallback | |
| result = await self.model_loader.run_inference( | |
| "bio_clinicalbert", | |
| text_input[:1000], # Limit text length | |
| {"max_new_tokens": 150, "task": "general"} | |
| ) | |
| return ModelInferenceResult( | |
| model_name="fallback_bio_clinicalbert", | |
| input_data={"fallback_text": text_input[:1000]}, | |
| output_data={"model_output": result, "fallback_used": True}, | |
| confidence_score=0.4, # Lower confidence for fallback | |
| processing_time=0.0, | |
| model_metadata={"fallback_reason": "primary_model_failed"}, | |
| warnings=["Used fallback model due to primary model failure"], | |
| errors=[] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Fallback model error: {str(e)}") | |
| return None | |
| def _create_error_result(self, model_name: str, errors: List[str]) -> ModelInferenceResult: | |
| """Create error result for failed inference""" | |
| return ModelInferenceResult( | |
| model_name=model_name, | |
| input_data={}, | |
| output_data={"error": "Input validation failed"}, | |
| confidence_score=0.0, | |
| processing_time=0.0, | |
| model_metadata={"validation_errors": errors}, | |
| warnings=[], | |
| errors=errors | |
| ) | |
| def _update_inference_stats(self, model_name: str, success: bool, processing_time: float): | |
| """Update inference statistics""" | |
| self.inference_stats["total_inferences"] += 1 | |
| if success: | |
| self.inference_stats["successful_inferences"] += 1 | |
| # Update processing time average | |
| total_time = self.inference_stats["average_processing_time"] * (self.inference_stats["total_inferences"] - 1) | |
| self.inference_stats["average_processing_time"] = (total_time + processing_time) / self.inference_stats["total_inferences"] | |
| # Update usage counts | |
| self.inference_stats["model_usage_counts"][model_name] = self.inference_stats["model_usage_counts"].get(model_name, 0) + 1 | |
| if not success: | |
| error_type = "inference_failure" | |
| self.inference_stats["error_counts"][error_type] = self.inference_stats["error_counts"].get(error_type, 0) + 1 | |
| def get_inference_statistics(self) -> Dict[str, Any]: | |
| """Get comprehensive inference statistics""" | |
| return { | |
| "total_inferences": self.inference_stats["total_inferences"], | |
| "success_rate": self.inference_stats["successful_inferences"] / max(self.inference_stats["total_inferences"], 1), | |
| "average_processing_time": self.inference_stats["average_processing_time"], | |
| "model_usage_breakdown": self.inference_stats["model_usage_counts"], | |
| "error_breakdown": self.inference_stats["error_counts"], | |
| "router_health": "healthy" if self.inference_stats["successful_inferences"] > self.inference_stats["total_inferences"] * 0.8 else "degraded" | |
| } | |
| # Export main classes | |
| __all__ = [ | |
| "SpecializedModelRouter", | |
| "ModelInferenceResult", | |
| "SpecializedModelConfig" | |
| ] |