import gradio as gr import spaces from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import torch from PIL import Image from datetime import datetime import numpy as np import os import cv2 import gc def check_dependencies(): required_packages = { 'torch': 'torch', 'gradio': 'gradio', 'transformers': 'transformers', 'PIL': 'pillow', 'cv2': 'opencv-python', 'numpy': 'numpy' } missing = [] for module, package in required_packages.items(): try: __import__(module) except ImportError: missing.append(package) if missing: raise ImportError(f"Missing required packages: {', '.join(missing)}. " f"Please install with: pip install {' '.join(missing)}") check_dependencies() class TempImageFile: def __init__(self, image_array): self.image_array = image_array self.path = None def __enter__(self): if self.image_array is None: raise ValueError("No image provided. Please upload an image before submitting.") # Convert numpy array to PIL Image img = Image.fromarray(np.uint8(self.image_array)) # Create temp directory if it doesn't exist os.makedirs("temp", exist_ok=True) # Generate a unique filename using timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.path = f"temp/image_{timestamp}.png" # Save the image img.save(self.path) return os.path.abspath(self.path) def __exit__(self, exc_type, exc_val, exc_tb): try: if self.path and os.path.exists(self.path): os.remove(self.path) # Also remove parent directory if empty parent_dir = os.path.dirname(self.path) if os.path.exists(parent_dir) and not os.listdir(parent_dir): os.rmdir(parent_dir) except Exception as e: print(f"Error cleaning up temporary file: {e}") finally: # Always clear references self.image_array = None self.path = None if exc_type is not None: return False # Re-raise the exception class ModelManager: def __init__(self): self._models = {} self._processors = {} self.available_models = ["Qwen/Qwen2-VL-7B-Instruct"] # Add default/available models def get_model(self, model_id): try: if model_id not in self._models: self._models[model_id] = Qwen2VLForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto" ).eval() return self._models[model_id] except Exception as e: raise RuntimeError(f"Failed to load model {model_id}: {str(e)}") def get_processor(self, model_id): if model_id not in self._processors: self._processors[model_id] = AutoProcessor.from_pretrained( model_id, trust_remote_code=True ) return self._processors[model_id] def unload_model(self, model_id): if model_id in self._models: del self._models[model_id] torch.cuda.empty_cache() def cleanup_unused_models(self, keep_model_id=None): for model_id in list(self._models.keys()): if model_id != keep_model_id: self.unload_model(model_id) def cleanup_all(self): try: for model_id in list(self._models.keys()): if model_id in self._models: # Move model to CPU before deletion self._models[model_id].to('cpu') del self._models[model_id] self._models.clear() self._processors.clear() torch.cuda.empty_cache() gc.collect() except Exception as e: print(f"Error during cleanup: {e}") model_manager = ModelManager() # Add new constant for default settings DEFAULT_SETTINGS = { "manufacturer": "Alcon", "lens_model": "SN60WF", "a_constant": 118.7, "target_refraction": 0.0, "surgeon_factor": 1.7, "personalization_factor": 0.0, "anterior_chamber_constant": 5.63, "vertex_distance": 12.0 } # Add new constants for specific cases LENS_MODELS = { "standard": { "model": "SN60WF", "a_constant": 118.7, }, "toric": { "model": "SN6AT4", "a_constant": 119.0, }, "short_eye": { "model": "MA60AC", "a_constant": 118.9, } } # Add new constants for validation MEASUREMENT_RANGES = { "axial_length": (20.0, 30.0), # mm "keratometry": (35.0, 50.0), # D "acd": (2.0, 5.0), # mm "lens_thickness": (3.0, 6.0), # mm "wtw": (10.0, 13.0), # mm "cct": (450, 650), # Ξm "astigmatism": (0.0, 6.0) # D } ERROR_MESSAGES = { "ValueError": "⚠ïļ Invalid Input", "RuntimeError": "🔧 Processing Error", "torch.cuda.OutOfMemoryError": "ðŸ’ū Memory Error", "Exception": "❌ Unexpected Error" } def format_error(error_type, message): return f"{ERROR_MESSAGES.get(error_type.__name__, '❌ Error')}: {message}" def validate_image(image_array): if image_array is None: raise ValueError("No image provided") # Validate array type and content if not isinstance(image_array, np.ndarray): raise ValueError("Invalid image format: must be numpy array") if not image_array.size > 0: raise ValueError("Empty image array") if image_array.dtype not in [np.uint8, np.float32]: raise ValueError(f"Unsupported image data type: {image_array.dtype}") # Add format validation if len(image_array.shape) not in [2, 3]: raise ValueError("Invalid image dimensions") if len(image_array.shape) == 3 and image_array.shape[2] not in [1, 3, 4]: raise ValueError("Invalid number of channels") if image_array.size > 10 * 1024 * 1024: # 10MB limit raise ValueError("Image size too large") height, width = image_array.shape[:2] if height > 4000 or width > 4000: raise ValueError("Image dimensions too large") return image_array def cleanup_temp_files(temp_dir="temp", max_age_hours=24): try: if not os.path.exists(temp_dir): return current_time = datetime.now() for filename in os.listdir(temp_dir): filepath = os.path.join(temp_dir, filename) file_modified = datetime.fromtimestamp(os.path.getmtime(filepath)) if (current_time - file_modified).total_seconds() > max_age_hours * 3600: os.remove(filepath) except Exception as e: print(f"Error during cleanup: {e}") def normalize_image_size(image, max_size=1024): width, height = image.size if width > max_size or height > max_size: ratio = min(max_size/width, max_size/height) new_size = (int(width*ratio), int(height*ratio)) image = image.resize(new_size, Image.LANCZOS) return image def extract_measurements(image_array): """ Extract measurements from the image using OCR or vision model Returns a dictionary of measurements """ # This would be implemented with actual OCR/vision logic # For now, returning mock data return { "axial_length": 23.5, "k1": 42.5, "k1_axis": 180, "k2": 43.5, "k2_axis": 90, "mean_k": 43.0, "astigmatism": 1.0, "acd": 3.2, "lens_thickness": 4.1, "wtw": 11.8, "cct": 540 } def validate_measurements(measurements): """Enhanced measurement validation with specific warnings""" warnings = [] critical_missing = [] risk_factors = [] # Critical measurements check critical_fields = { 'axial_length': 'Axial Length', 'k1': 'K1', 'k2': 'K2', 'acd': 'Anterior Chamber Depth' } for field, name in critical_fields.items(): if field not in measurements or measurements[field] is None: critical_missing.append(name) if critical_missing: raise ValueError(f"Critical measurements missing: {', '.join(critical_missing)}") # Enhanced range validation with specific warnings if 'axial_length' in measurements: al = measurements['axial_length'] if al < 22.0: risk_factors.append("Short eye - Consider special IOL options") elif al > 26.0: risk_factors.append("Long eye - Higher risk of retinal complications") if 'astigmatism' in measurements: ast = measurements['astigmatism'] if ast > 2.0: risk_factors.append("Significant astigmatism - Consider toric IOL") # Cross-validation checks if all(k in measurements for k in ['k1', 'k2']): k_diff = abs(measurements['k1'] - measurements['k2']) if k_diff > 3.0: warnings.append(f"Unusual keratometry difference ({k_diff:.1f}D)") return warnings, risk_factors def validate_image_quality(image_array): """Enhanced image quality validation with more detailed feedback""" if image_array is None: return 0, "No image provided" # Convert to grayscale if needed if len(image_array.shape) == 3: gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) else: gray = image_array # Enhanced quality metrics contrast = np.std(gray) brightness = np.mean(gray) laplacian = cv2.Laplacian(gray, cv2.CV_64F) blur_score = np.var(laplacian) # More sophisticated quality scoring with detailed feedback contrast_score = min(contrast / 50, 1.0) brightness_score = 1.0 - abs(brightness - 127.5) / 127.5 blur_score_normalized = min(blur_score / 500, 1.0) # Generate feedback messages feedback = [] if contrast_score < 0.5: feedback.append("Low contrast detected") if brightness_score < 0.5: feedback.append("Sub-optimal brightness") if blur_score_normalized < 0.5: feedback.append("Image may be blurry") # Weighted quality score quality_score = (contrast_score * 0.4 + brightness_score * 0.3 + blur_score_normalized * 0.3) return quality_score, ", ".join(feedback) if feedback else "Good image quality" def format_measurements_text(measurements): """Format measurements into a structured text block""" return f""" EXTRACTED MEASUREMENTS: ╔════════════════════════â•Ķ═══════════╗ ║ Parameter ║ Value ║ ╠════════════════════════╮═══════════â•Ģ ║ Axial Length ║ {measurements.get('axial_length', 'N/A'):9} mm ║ ║ K1 ║ {measurements.get('k1', 'N/A'):9} D ║ ║ K2 ║ {measurements.get('k2', 'N/A'):9} D ║ ║ Mean K ║ {measurements.get('mean_k', 'N/A'):9} D ║ ║ Astigmatism ║ {measurements.get('astigmatism', 'N/A'):9} D ║ ║ ACD ║ {measurements.get('acd', 'N/A'):9} mm ║ ║ Lens Thickness ║ {measurements.get('lens_thickness', 'N/A'):9} mm ║ ║ WTW ║ {measurements.get('wtw', 'N/A'):9} mm ║ ║ CCT ║ {measurements.get('cct', 'N/A'):9} Ξm ║ ╚════════════════════════â•Đ═══════════╝ """ def generate_analysis_prompt(measurements, settings, warnings=None, risk_factors=None, quality_score=None): """Enhanced prompt generation for IOL Master and similar biometry reports""" # Detect case type case_types = detect_case_type(measurements) # Build comprehensive analysis request prompt = f"""Analyze this IOL biometry report image. TASK: 1. Verify all visible measurements 2. Confirm measurement accuracy and consistency 3. Provide IOL power calculation using: - SRK/T formula - Barrett Universal II - Hill-RBF if applicable - Other formulas as appropriate for this case CONSIDERATIONS: 1. Image Quality Assessment 2. Measurement Validation 3. Formula Selection Rationale 4. Risk Factor Analysis {format_measurements_text(measurements)} SELECTED IOL PARAMETERS: â€Ē Manufacturer: {settings.get('manufacturer')} â€Ē Model: {settings.get('lens_model')} â€Ē A-Constant: {settings.get('a_constant')} â€Ē Target: {settings.get('target_refraction', 0)} D Please provide: 1. Measurement verification 2. IOL power calculations with multiple formulas 3. Specific recommendations for: - Optimal IOL power - Formula selection rationale - Surgical approach 4. Risk assessment and special considerations 5. Alternative options if needed""" # Add case-specific guidance if detected if case_types: prompt += "\n\nSpecial Considerations:" for _, description in case_types: prompt += f"\nâ€Ē {description}" return prompt def validate_settings(settings): """Enhanced settings validation with better error handling""" if not isinstance(settings, dict): return DEFAULT_SETTINGS # Fallback to defaults instead of raising error validated_settings = DEFAULT_SETTINGS.copy() for key, value in settings.items(): if key in validated_settings: try: # Type conversion and validation if isinstance(DEFAULT_SETTINGS[key], (int, float)): value = float(value) # Validate ranges if key == 'a_constant' and not 115 <= value <= 122: continue # Skip invalid value, keep default if key == 'target_refraction' and not -10 <= value <= 10: continue # Skip invalid value, keep default validated_settings[key] = value except (ValueError, TypeError): continue # Keep default value if conversion fails return validated_settings @spaces.GPU def run_example(image, model_id="Qwen/Qwen2-VL-7B-Instruct", settings=None, progress=gr.Progress()): """Enhanced error handling and resource management""" cleanup_temp_files() model = None try: progress(0, desc="Validating input...") if image is None: return "Error: No image provided" image = validate_image(image) settings = validate_settings(settings or {}) progress(0.2, desc="Loading model...") try: model = model_manager.get_model(model_id) processor = model_manager.get_processor(model_id) except Exception as e: return f"Error loading model: {str(e)}" progress(0.4, desc="Processing image...") with TempImageFile(image) as image_path: if not image_path: return "Error: Failed to save temporary image" # Convert and normalize image try: image = Image.fromarray(image).convert("RGB") image = normalize_image_size(image) except Exception as e: return f"Error processing image: {str(e)}" # Generate analysis prompt try: measurements = extract_measurements(image) warnings, risk_factors = validate_measurements(measurements) quality_score, quality_feedback = validate_image_quality(image) prompt = generate_analysis_prompt( measurements=measurements, settings=settings, warnings=warnings, risk_factors=risk_factors, quality_score=quality_score ) except Exception as e: return f"Error analyzing image: {str(e)}" # Prepare model inputs try: messages = [{ "role": "user", "content": [ {"type": "image", "image": image_path}, {"type": "text", "text": prompt} ] }] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ).to(model.device) except Exception as e: return f"Error preparing model inputs: {str(e)}" # Model inference with proper resource management try: progress(0.6, desc="Generating analysis...") with torch.inference_mode(): generated_ids = model.generate( **inputs, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.9 ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] progress(1.0, desc="Complete!") return output_text except torch.cuda.OutOfMemoryError: return "Error: GPU memory exhausted. Please try with a smaller image." except Exception as e: return f"Error during analysis: {str(e)}" except Exception as e: return f"Unexpected error: {str(e)}" finally: # Ensure proper cleanup try: if model is not None: model_manager.cleanup_unused_models(keep_model_id=model_id) torch.cuda.empty_cache() gc.collect() except Exception as e: print(f"Cleanup error: {e}") with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): # Image upload input_img = gr.Image( label="IOL Report Image", type="numpy" ) # Essential settings only with gr.Accordion("Settings", open=False): manufacturer = gr.Dropdown( choices=["Alcon", "Johnson & Johnson", "Zeiss", "Bausch & Lomb"], value=DEFAULT_SETTINGS["manufacturer"], label="Manufacturer" ) lens_model = gr.Textbox( value=DEFAULT_SETTINGS["lens_model"], label="Lens Model" ) a_constant = gr.Number( value=DEFAULT_SETTINGS["a_constant"], label="A-Constant" ) target_refraction = gr.Number( value=DEFAULT_SETTINGS["target_refraction"], label="Target Refraction (D)" ) submit_btn = gr.Button("Analyze Report", variant="primary") # Output column with gr.Column(): output_text = gr.Textbox( label="Analysis Results", lines=20, show_copy_button=True ) # Simplified event handler submit_btn.click( fn=run_example, inputs=[ input_img, gr.State("Qwen/Qwen2-VL-7B-Instruct"), gr.State(DEFAULT_SETTINGS) ], outputs=output_text ) # Simplified launch configuration demo.queue(concurrency_count=1) demo.launch(server_name="0.0.0.0", server_port=7860)