lumenex / app.py
walaa2022's picture
Update app.py
576ba03 verified
# app.py - Medical AI with Proper Vision Analysis
import gradio as gr
import torch
from transformers import (
BlipProcessor, BlipForConditionalGeneration,
AutoProcessor, AutoModelForCausalLM,
pipeline
)
from PIL import Image
import logging
from collections import defaultdict, Counter
import time
import requests
from io import BytesIO
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Usage tracking
class UsageTracker:
def __init__(self):
self.stats = {
'total_analyses': 0,
'successful_analyses': 0,
'failed_analyses': 0,
'average_processing_time': 0.0,
'question_types': Counter()
}
def log_analysis(self, success, duration, question_type=None):
self.stats['total_analyses'] += 1
if success:
self.stats['successful_analyses'] += 1
else:
self.stats['failed_analyses'] += 1
total_time = self.stats['average_processing_time'] * (self.stats['total_analyses'] - 1)
self.stats['average_processing_time'] = (total_time + duration) / self.stats['total_analyses']
if question_type:
self.stats['question_types'][question_type] += 1
# Rate limiting
class RateLimiter:
def __init__(self, max_requests_per_hour=60):
self.max_requests_per_hour = max_requests_per_hour
self.requests = defaultdict(list)
def is_allowed(self, user_id="default"):
current_time = time.time()
hour_ago = current_time - 3600
self.requests[user_id] = [req_time for req_time in self.requests[user_id] if req_time > hour_ago]
if len(self.requests[user_id]) < self.max_requests_per_hour:
self.requests[user_id].append(current_time)
return True
return False
# Initialize components
usage_tracker = UsageTracker()
rate_limiter = RateLimiter()
# Try multiple models for better medical analysis
MODELS_TO_TRY = [
"microsoft/git-base-coco", # Better for detailed descriptions
"Salesforce/blip2-opt-2.7b", # More capable BLIP2 model
"Salesforce/blip-image-captioning-large" # Fallback
]
# Global variables
model = None
processor = None
device = "cuda" if torch.cuda.is_available() else "cpu"
current_model_name = None
def load_best_model():
"""Try to load the best available model for medical image analysis"""
global model, processor, current_model_name
for model_name in MODELS_TO_TRY:
try:
logger.info(f"Trying to load: {model_name}")
if "git-base" in model_name:
# Use transformers pipeline for GIT model
model = pipeline("image-to-text", model=model_name, device=0 if torch.cuda.is_available() else -1)
processor = None
current_model_name = model_name
logger.info(f"βœ… Successfully loaded GIT model: {model_name}")
return True
elif "blip2" in model_name:
# Try BLIP2 model
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
current_model_name = model_name
logger.info(f"βœ… Successfully loaded BLIP2 model: {model_name}")
return True
else:
# Standard BLIP model
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
if torch.cuda.is_available() and hasattr(model, 'to'):
model = model.to(device)
current_model_name = model_name
logger.info(f"βœ… Successfully loaded BLIP model: {model_name}")
return True
except Exception as e:
logger.warning(f"Failed to load {model_name}: {e}")
continue
logger.error("❌ Failed to load any model")
return False
# Load model at startup
model_ready = load_best_model()
def get_detailed_medical_analysis(image, question):
"""Get detailed medical analysis using the best available model"""
try:
if "git-base" in current_model_name:
# Use GIT model (usually gives more detailed descriptions)
results = model(image, max_new_tokens=200)
description = results[0]['generated_text'] if results else "Unable to analyze image"
# For medical questions, try to expand the analysis
if any(word in question.lower() for word in ['abnormal', 'diagnosis', 'condition', 'pathology']):
# Add medical context to the basic description
medical_prompt = f"Medical analysis: {description}"
return description, medical_prompt
return description, description
elif "blip2" in current_model_name:
# Use BLIP2 model
inputs = processor(image, question, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=150, do_sample=False)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Also get unconditional description
basic_inputs = processor(image, return_tensors="pt")
if torch.cuda.is_available():
basic_inputs = {k: v.to(device) for k, v in basic_inputs.items()}
with torch.no_grad():
basic_ids = model.generate(**basic_inputs, max_new_tokens=100, do_sample=False)
basic_text = processor.batch_decode(basic_ids, skip_special_tokens=True)[0]
return basic_text, generated_text
else:
# Standard BLIP model - improved approach
# Get unconditional caption first
inputs = processor(image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(**inputs, max_length=100, num_beams=3, do_sample=False)
basic_description = processor.decode(output_ids[0], skip_special_tokens=True)
# Try conditional generation with better prompting
medical_prompts = [
f"Question: {question} Answer:",
f"Medical analysis: {question}",
f"Describe the medical findings: {question}"
]
best_response = basic_description
for prompt in medical_prompts:
try:
inputs_qa = processor(image, prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs_qa = {k: v.to(device) for k, v in inputs_qa.items()}
with torch.no_grad():
qa_output_ids = model.generate(
**inputs_qa,
max_length=200,
num_beams=3,
do_sample=False,
early_stopping=True
)
# Decode only generated part
input_length = inputs_qa['input_ids'].shape[1]
qa_response = processor.decode(qa_output_ids[0][input_length:], skip_special_tokens=True).strip()
if qa_response and len(qa_response) > 20 and not qa_response.lower().startswith('question'):
best_response = qa_response
break
except Exception as e:
continue
return basic_description, best_response
except Exception as e:
logger.error(f"Analysis failed: {e}")
return "Unable to analyze image", "Analysis failed"
def enhance_medical_description(basic_desc, clinical_question, patient_history):
"""Enhance basic description with medical context and educational content"""
# Common medical image analysis patterns
chest_xray_analysis = """
**Systematic Chest X-ray Analysis:**
**Technical Quality:**
- Image appears to be a standard PA chest radiograph
- Adequate penetration and positioning for diagnostic evaluation
**Anatomical Review:**
- **Heart**: Cardiac silhouette evaluation for size and contour
- **Lungs**: Assessment of lung fields for opacity, consolidation, or air trapping
- **Pleura**: Examination for pleural effusion or pneumothorax
- **Bones**: Rib cage and spine alignment assessment
- **Soft Tissues**: Evaluation of surrounding structures
**Clinical Correlation Needed:**
Given the patient's presentation with cough and fever, key considerations include:
- **Pneumonia**: Look for consolidation, air bronchograms, or infiltrates
- **Viral vs Bacterial**: Pattern recognition for different infectious etiologies
- **Atelectasis**: Collapsed lung segments that might appear as increased opacity
- **Pleural Changes**: Fluid collection that could indicate infection complications
**Educational Points:**
- Chest X-rays are the first-line imaging for respiratory symptoms
- Clinical correlation is essential - symptoms guide interpretation
- Follow-up imaging may be needed based on treatment response
"""
# Determine if this is likely a chest X-ray
if any(term in basic_desc.lower() for term in ['chest', 'lung', 'rib', 'heart', 'x-ray', 'radiograph']) or \
any(term in clinical_question.lower() for term in ['chest', 'lung', 'respiratory', 'cough']):
enhanced_analysis = chest_xray_analysis
else:
# Generic medical image analysis
enhanced_analysis = f"""
**Medical Image Analysis Framework:**
**Image Description:**
{basic_desc}
**Clinical Context Integration:**
- Patient presentation: {patient_history if patient_history else 'Clinical history provided'}
- Imaging indication: {clinical_question}
**Systematic Approach:**
1. **Technical Assessment**: Image quality and acquisition parameters
2. **Anatomical Review**: Systematic evaluation of visible structures
3. **Pathological Assessment**: Identification of any abnormal findings
4. **Clinical Correlation**: Integration with patient symptoms and history
**Educational Considerations:**
- Medical imaging interpretation requires systematic approach
- Clinical context significantly influences interpretation priorities
- Multiple imaging modalities may be complementary for diagnosis
- Professional radiological review is essential for clinical decisions
"""
return enhanced_analysis
def analyze_medical_image(image, clinical_question, patient_history=""):
"""Enhanced medical image analysis with better AI models"""
start_time = time.time()
# Rate limiting
if not rate_limiter.is_allowed():
usage_tracker.log_analysis(False, time.time() - start_time)
return "⚠️ Rate limit exceeded. Please wait before trying again."
if not model_ready or model is None:
usage_tracker.log_analysis(False, time.time() - start_time)
return "❌ Medical AI model not loaded. Please refresh the page."
if image is None:
return "⚠️ Please upload a medical image first."
if not clinical_question.strip():
return "⚠️ Please provide a clinical question."
try:
logger.info("Starting enhanced medical image analysis...")
# Get detailed analysis from AI model
basic_description, detailed_response = get_detailed_medical_analysis(image, clinical_question)
# Enhance with medical knowledge
enhanced_analysis = enhance_medical_description(basic_description, clinical_question, patient_history)
# Create comprehensive medical report
formatted_response = f"""# πŸ₯ **Enhanced Medical AI Analysis**
## **Clinical Question:** {clinical_question}
{f"## **Patient History:** {patient_history}" if patient_history.strip() else ""}
---
## πŸ” **AI Vision Analysis**
### **Image Description:**
{basic_description}
### **Question-Specific Analysis:**
{detailed_response}
---
## πŸ“‹ **Medical Assessment Framework**
{enhanced_analysis}
---
## πŸŽ“ **Educational Summary**
**Learning Objectives:**
- Demonstrate systematic approach to medical image interpretation
- Integrate clinical history with imaging findings
- Understand the importance of professional validation in medical diagnosis
**Key Teaching Points:**
- Medical imaging is one component of comprehensive patient assessment
- Clinical correlation enhances diagnostic accuracy
- Multiple imaging modalities may provide complementary information
- Professional interpretation is essential for patient care decisions
**Clinical Decision Making:**
Based on the combination of:
- Patient symptoms: {patient_history if patient_history else 'As provided'}
- Imaging findings: As described above
- Clinical context: {clinical_question}
**Next Steps in Clinical Practice:**
- Professional radiological review
- Correlation with laboratory findings
- Consider additional imaging if clinically indicated
- Follow-up based on treatment response
"""
# Add medical disclaimer
disclaimer = """
---
## ⚠️ **IMPORTANT MEDICAL DISCLAIMER**
**FOR EDUCATIONAL AND RESEARCH PURPOSES ONLY**
- **🚫 AI Limitations**: AI analysis has significant limitations for medical diagnosis
- **πŸ‘¨β€βš•οΈ Professional Review Required**: All findings must be validated by qualified healthcare professionals
- **🚨 Emergency Care**: For urgent medical concerns, seek immediate medical attention
- **πŸ₯ Clinical Integration**: AI findings are educational tools, not diagnostic conclusions
- **πŸ“‹ Learning Tool**: Designed for medical education and training purposes
- **πŸ”’ Privacy**: Do not upload real patient data or identifiable information
**This analysis demonstrates AI-assisted medical image interpretation concepts for educational purposes only.**
---
**Model**: {current_model_name} | **Device**: {device.upper()} | **Purpose**: Medical Education
"""
# Log successful analysis
duration = time.time() - start_time
question_type = classify_question(clinical_question)
usage_tracker.log_analysis(True, duration, question_type)
logger.info(f"βœ… Enhanced medical analysis completed in {duration:.2f}s")
return formatted_response + disclaimer
except Exception as e:
duration = time.time() - start_time
usage_tracker.log_analysis(False, duration)
logger.error(f"❌ Analysis error: {str(e)}")
return f"❌ Enhanced analysis failed: {str(e)}\n\nPlease try again with a different image."
def classify_question(question):
"""Classify clinical question type"""
question_lower = question.lower()
if any(word in question_lower for word in ['describe', 'findings', 'observe', 'see', 'show']):
return 'descriptive'
elif any(word in question_lower for word in ['diagnosis', 'differential', 'condition']):
return 'diagnostic'
elif any(word in question_lower for word in ['abnormal', 'pathology', 'disease']):
return 'pathological'
else:
return 'general'
def get_usage_stats():
"""Get usage statistics"""
stats = usage_tracker.stats
if stats['total_analyses'] == 0:
return "πŸ“Š **Usage Statistics**\n\nNo analyses performed yet."
success_rate = (stats['successful_analyses'] / stats['total_analyses']) * 100
return f"""πŸ“Š **Enhanced Medical AI Statistics**
**Performance Metrics:**
- **Total Analyses**: {stats['total_analyses']}
- **Success Rate**: {success_rate:.1f}%
- **Average Processing Time**: {stats['average_processing_time']:.2f} seconds
**Question Types:**
{chr(10).join([f"- **{qtype.title()}**: {count}" for qtype, count in stats['question_types'].most_common(3)])}
**System Status**: {'🟒 Enhanced Model Active' if model_ready else 'πŸ”΄ Offline'}
**Current Model**: {current_model_name if current_model_name else 'None'}
**Device**: {device.upper()}
"""
def clear_all():
"""Clear all inputs and outputs"""
return None, "", "", ""
def set_chest_example():
"""Set chest X-ray example"""
return "Describe this chest X-ray systematically and identify any abnormalities", "30-year-old patient with productive cough, fever, and shortness of breath"
def set_pathology_example():
"""Set pathology example"""
return "What pathological findings are visible? Describe the tissue characteristics.", "Biopsy specimen for histopathological evaluation"
def set_general_example():
"""Set general analysis example"""
return "Provide a systematic analysis of this medical image", "Patient requiring comprehensive imaging evaluation"
# Create enhanced Gradio interface
def create_interface():
with gr.Blocks(
title="Enhanced Medical AI Analysis",
theme=gr.themes.Soft(),
css="""
.gradio-container { max-width: 1400px !important; }
.disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; }
.success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px 0; }
.enhanced { background-color: #f0fdf4; border: 1px solid #bbf7d0; border-radius: 8px; padding: 16px 0; }
"""
) as demo:
# Header
gr.Markdown("""
# πŸ₯ Enhanced Medical AI Image Analysis
**Advanced Medical AI with Better Vision Models - Educational Analysis**
**Enhanced Features:** 🧠 Multiple AI Models β€’ πŸ”¬ Systematic Analysis β€’ πŸ“‹ Educational Framework β€’ πŸŽ“ Clinical Integration
""")
# Status display
if model_ready:
gr.Markdown(f"""
<div class="enhanced">
βœ… <strong>ENHANCED MEDICAL AI READY</strong><br>
Advanced model loaded: <strong>{current_model_name}</strong><br>
Now provides detailed medical image analysis with systematic framework and educational content.
</div>
""")
else:
gr.Markdown("""
<div class="disclaimer">
⚠️ <strong>MODEL LOADING</strong><br>
Enhanced Medical AI is loading. Please wait and refresh if needed.
</div>
""")
# Medical disclaimer
gr.Markdown("""
<div class="disclaimer">
⚠️ <strong>MEDICAL DISCLAIMER</strong><br>
This enhanced tool provides AI-assisted medical analysis for <strong>educational purposes only</strong>.
Uses advanced vision models for better image understanding. Do not upload real patient data.
</div>
""")
with gr.Row():
# Left column - Main interface
with gr.Column(scale=2):
# Image upload
gr.Markdown("## πŸ“€ Medical Image Upload")
image_input = gr.Image(
label="Upload Medical Image (Enhanced Analysis)",
type="pil",
height=300
)
# Clinical inputs
gr.Markdown("## πŸ’¬ Clinical Information")
clinical_question = gr.Textbox(
label="Clinical Question *",
placeholder="Enhanced examples:\nβ€’ Systematically describe this chest X-ray and identify abnormalities\nβ€’ What pathological findings are visible in this image?\nβ€’ Provide detailed analysis of anatomical structures\nβ€’ Analyze this medical scan for educational purposes",
lines=4
)
patient_history = gr.Textbox(
label="Patient History & Clinical Context",
placeholder="Detailed example: 35-year-old female with 3-day history of productive cough, fever (38.5Β°C), shortness of breath, and left-sided chest pain",
lines=3
)
# Action buttons
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
analyze_btn = gr.Button("πŸ” Enhanced Medical Analysis", variant="primary", size="lg")
# Results
gr.Markdown("## πŸ“‹ Enhanced Medical Analysis Results")
output = gr.Textbox(
label="Advanced AI Medical Analysis (Multiple Models)",
lines=25,
show_copy_button=True,
placeholder="Upload a medical image and provide detailed clinical question for comprehensive AI analysis..."
)
# Right column - Status and controls
with gr.Column(scale=1):
gr.Markdown("## ℹ️ Enhanced System Status")
system_info = f"""
**Status**: {'βœ… Advanced Models Active' if model_ready else 'πŸ”„ Loading'}
**Primary Model**: {current_model_name if current_model_name else 'Loading...'}
**Device**: {device.upper()}
**Enhancement**: 🧠 Multiple AI Models
**Analysis**: πŸ“‹ Systematic Framework
"""
gr.Markdown(system_info)
# Statistics
gr.Markdown("## πŸ“Š Usage Analytics")
stats_display = gr.Markdown(get_usage_stats())
refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats", size="sm")
# Quick examples
if model_ready:
gr.Markdown("## 🎯 Enhanced Examples")
chest_btn = gr.Button("🫁 Chest X-ray Analysis", size="sm")
pathology_btn = gr.Button("πŸ”¬ Pathology Study", size="sm")
general_btn = gr.Button("πŸ“‹ Systematic Analysis", size="sm")
gr.Markdown("## πŸš€ Enhancements")
gr.Markdown(f"""
βœ… **Advanced Vision Models**
βœ… **Systematic Medical Framework**
βœ… **Educational Integration**
βœ… **Clinical Context Analysis**
βœ… **Model**: {current_model_name.split('/')[-1] if current_model_name else 'Enhanced'}
""")
# Event handlers
analyze_btn.click(
fn=analyze_medical_image,
inputs=[image_input, clinical_question, patient_history],
outputs=output,
show_progress=True
)
clear_btn.click(
fn=clear_all,
outputs=[image_input, clinical_question, patient_history, output]
)
refresh_stats_btn.click(
fn=get_usage_stats,
outputs=stats_display
)
# Quick example handlers
if model_ready:
chest_btn.click(
fn=set_chest_example,
outputs=[clinical_question, patient_history]
)
pathology_btn.click(
fn=set_pathology_example,
outputs=[clinical_question, patient_history]
)
general_btn.click(
fn=set_general_example,
outputs=[clinical_question, patient_history]
)
# Footer
gr.Markdown(f"""
---
## πŸš€ **Enhanced Medical AI Features**
### **Advanced Vision Models:**
- **Microsoft GIT**: Enhanced image-to-text capabilities
- **BLIP2**: Advanced vision-language understanding
- **Multi-Model Fallback**: Automatic best model selection
- **Better Descriptions**: More detailed and accurate analysis
### **Medical Framework Integration:**
- **Systematic Analysis**: Structured medical image interpretation
- **Clinical Correlation**: Integration of symptoms with imaging
- **Educational Content**: Teaching points and learning objectives
- **Professional Guidelines**: Follows medical education standards
**Current Model**: {current_model_name if current_model_name else 'Loading...'} | **Purpose**: Enhanced Medical Education
""")
return demo
# Launch the application
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)