import torch import torchvision import torchvision.transforms as transforms import transformers import gradio as gr import pandas as pd import numpy as np import cv2 from PIL import Image import io from transformers import BertTokenizer, BertModel from torch.utils.data import Dataset, DataLoader import networkx as nx import logging import mlflow import os # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class VisionRAGHealthcare: def __init__(self, device=None): self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {self.device}") self.initialize_models() def initialize_models(self): """Initialize vision and language models""" try: # Initialize DenseNet121 pre-trained on ImageNet self.vision_model = torchvision.models.densenet121(weights='IMAGENET1K_V1') self.vision_model.eval() self.vision_model = self.vision_model.to(self.device) # Define image transformations self.image_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Initialize BERT for text processing self.text_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.text_model = BertModel.from_pretrained('bert-base-uncased').to(self.device) logger.info("Models initialized successfully") except Exception as e: logger.error(f"Error initializing models: {str(e)}") raise def preprocess_image(self, image): """Preprocess medical images""" try: # Handle different image input types if isinstance(image, np.ndarray): image = Image.fromarray(image).convert('RGB') elif isinstance(image, str): image = Image.open(image).convert('RGB') elif isinstance(image, Image.Image): image = image.convert('RGB') else: raise ValueError(f"Unsupported image type: {type(image)}") # Apply preprocessing steps image_tensor = self.image_transforms(image).unsqueeze(0) return image_tensor.to(self.device) except Exception as e: logger.error(f"Error preprocessing image: {str(e)}") raise def preprocess_text(self, text): """Preprocess medical text""" try: inputs = self.text_tokenizer( text, padding=True, truncation=True, max_length=512, return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} return inputs except Exception as e: logger.error(f"Error preprocessing text: {str(e)}") raise def extract_features(self, image_tensor, text_inputs): """Extract features from both image and text""" try: with torch.no_grad(): # Extract image features image_features = self.vision_model.features(image_tensor) image_features = torch.nn.functional.adaptive_avg_pool2d(image_features, (1, 1)) image_features = torch.flatten(image_features, 1) # Extract text features text_features = self.text_model(**text_inputs).last_hidden_state[:, 0, :] return image_features, text_features except Exception as e: logger.error(f"Error extracting features: {str(e)}") raise class HealthcareGradioApp: def __init__(self): self.vision_rag = VisionRAGHealthcare() def analyze_symptoms(self, clinical_notes): """Analyze clinical symptoms and provide relevant insights""" symptoms = { 'cough': ['persistent cough', 'coughing', 'dry cough', 'wet cough'], 'fever': ['fever', 'high temperature', 'febrile'], 'breathing': ['shortness of breath', 'difficulty breathing', 'dyspnea'], 'chest_pain': ['chest pain', 'chest discomfort', 'chest tightness'], 'duration': ['days', 'weeks', 'months', 'chronic'], 'severity': ['mild', 'moderate', 'severe', 'intense'] } findings = [] clinical_notes = clinical_notes.lower() # Analyze symptoms if any(term in clinical_notes for term in symptoms['cough']): findings.append("Cough detected in symptoms") if any(term in clinical_notes for term in symptoms['fever']): findings.append("Fever reported") if any(term in clinical_notes for term in symptoms['breathing']): findings.append("Breathing difficulties noted") if any(term in clinical_notes for term in symptoms['chest_pain']): findings.append("Chest discomfort/pain reported") # Analyze duration if any(term in clinical_notes for term in symptoms['duration']): findings.append("Chronic/ongoing symptoms noted") # Analyze severity if any(term in clinical_notes for term in symptoms['severity']): findings.append("Moderate to severe symptoms indicated") return findings def analyze_image_features(self, image_features): """Analyze image features for potential medical conditions""" feature_array = image_features.cpu().numpy() feature_mean = np.mean(feature_array) feature_std = np.std(feature_array) image_findings = [] if feature_std > 0.5: image_findings.append("Notable variations detected in lung tissue density") if feature_mean > 0: image_findings.append("Areas of increased opacity observed") if feature_mean < -0.5: image_findings.append("Possible areas of decreased density noted") return image_findings def suggest_recommendations(self, symptom_findings, image_findings): """Generate medical recommendations based on findings""" recommendations = [] # Symptom-based recommendations if "Cough detected in symptoms" in symptom_findings: recommendations.append("- Monitor cough characteristics and duration") recommendations.append("- Consider pulmonary function testing") if "Fever reported" in symptom_findings: recommendations.append("- Regular temperature monitoring recommended") recommendations.append("- Consider complete blood count and inflammatory markers") if "Breathing difficulties noted" in symptom_findings: recommendations.append("- Pulse oximetry monitoring advised") recommendations.append("- Consider spirometry assessment") # Image-based recommendations if "Notable variations detected in lung tissue density" in image_findings: recommendations.append("- Follow-up imaging recommended in 2-4 weeks") if "Areas of increased opacity observed" in image_findings: recommendations.append("- Consider high-resolution CT scan") # General recommendations recommendations.append("- Regular monitoring of vital signs") recommendations.append("- Follow-up with primary care physician") return recommendations def process_medical_case(self, image, clinical_notes): """Process medical case with both image and text data""" try: if image is None: return "Error: Please upload a medical image for analysis." if not clinical_notes or not clinical_notes.strip(): return "Error: Please provide clinical notes for analysis." # Process inputs image_tensor = self.vision_rag.preprocess_image(image) text_inputs = self.vision_rag.preprocess_text(clinical_notes) # Extract and analyze features image_features, text_features = self.vision_rag.extract_features( image_tensor, text_inputs ) # Analyze components symptom_findings = self.analyze_symptoms(clinical_notes) image_findings = self.analyze_image_features(image_features) recommendations = self.suggest_recommendations(symptom_findings, image_findings) # Generate report report = """ 🏥 Medical Case Analysis Report ============================ 📋 Clinical Symptoms Analysis: --------------------------- {} 🔍 Imaging Analysis Findings: ------------------------- {} 💡 Recommendations: ---------------- {} ⚠️ Important Notice: ----------------- This is an AI-assisted analysis and should be reviewed by a healthcare professional. Always consult with a medical doctor for proper diagnosis and treatment. """.format( "\n".join(f"• {finding}" for finding in symptom_findings), "\n".join(f"• {finding}" for finding in image_findings), "\n".join(recommendations) ) return report except Exception as e: error_msg = f"Error in analysis: {str(e)}" logger.error(error_msg) return error_msg def launch_gradio_interface(self): """Launch Gradio interface""" # Custom CSS for better styling custom_css = """ .gradio-container {max-width: 900px !important} .image-upload {min-height: 400px !important} """ iface = gr.Interface( fn=self.process_medical_case, inputs=[ gr.Image( type="numpy", label="Upload Medical Image" # Removed 'info' parameter from gr.Image ), gr.Textbox( lines=5, label="Clinical Notes", placeholder="Enter detailed patient symptoms, medical history, and current condition...", info="Include: symptoms, duration, severity, and relevant medical history" ) ], outputs=gr.Textbox( label="Medical Analysis Report", lines=15 ), title="🏥 Healthcare Vision-RAG AI System", description=""" Upload medical images and clinical notes for AI-powered analysis. This system combines visual and textual analysis for comprehensive medical insights. """, article=""" ### How to use: 1. Upload a medical image (X-ray, CT scan) 2. Enter detailed clinical notes 3. Review the AI-generated analysis report ### Note: This is an AI assistance tool and should be used in conjunction with professional medical judgment. """, theme="default", css=custom_css, allow_flagging="never" ) return iface def main(): # MLflow tracking mlflow.set_tracking_uri("sqlite:///mlflow.db") mlflow.set_experiment("vision-rag-healthcare") try: print("🚀 Initializing Healthcare Vision-RAG System...") app = HealthcareGradioApp() print("📊 Launching Gradio interface...") iface = app.launch_gradio_interface() iface.launch( share=True, debug=True, # Removed 'enable_queue' parameter as it's deprecated show_error=True ) except Exception as e: logger.error(f"Error launching application: {str(e)}") raise if __name__ == "__main__": main()