|
|
|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
BlipProcessor, BlipForConditionalGeneration, |
|
AutoProcessor, AutoModelForCausalLM, |
|
pipeline |
|
) |
|
from PIL import Image |
|
import logging |
|
from collections import defaultdict, Counter |
|
import time |
|
import requests |
|
from io import BytesIO |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class UsageTracker: |
|
def __init__(self): |
|
self.stats = { |
|
'total_analyses': 0, |
|
'successful_analyses': 0, |
|
'failed_analyses': 0, |
|
'average_processing_time': 0.0, |
|
'question_types': Counter() |
|
} |
|
|
|
def log_analysis(self, success, duration, question_type=None): |
|
self.stats['total_analyses'] += 1 |
|
if success: |
|
self.stats['successful_analyses'] += 1 |
|
else: |
|
self.stats['failed_analyses'] += 1 |
|
|
|
total_time = self.stats['average_processing_time'] * (self.stats['total_analyses'] - 1) |
|
self.stats['average_processing_time'] = (total_time + duration) / self.stats['total_analyses'] |
|
|
|
if question_type: |
|
self.stats['question_types'][question_type] += 1 |
|
|
|
|
|
class RateLimiter: |
|
def __init__(self, max_requests_per_hour=60): |
|
self.max_requests_per_hour = max_requests_per_hour |
|
self.requests = defaultdict(list) |
|
|
|
def is_allowed(self, user_id="default"): |
|
current_time = time.time() |
|
hour_ago = current_time - 3600 |
|
self.requests[user_id] = [req_time for req_time in self.requests[user_id] if req_time > hour_ago] |
|
if len(self.requests[user_id]) < self.max_requests_per_hour: |
|
self.requests[user_id].append(current_time) |
|
return True |
|
return False |
|
|
|
|
|
usage_tracker = UsageTracker() |
|
rate_limiter = RateLimiter() |
|
|
|
|
|
MODELS_TO_TRY = [ |
|
"microsoft/git-base-coco", |
|
"Salesforce/blip2-opt-2.7b", |
|
"Salesforce/blip-image-captioning-large" |
|
] |
|
|
|
|
|
model = None |
|
processor = None |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
current_model_name = None |
|
|
|
def load_best_model(): |
|
"""Try to load the best available model for medical image analysis""" |
|
global model, processor, current_model_name |
|
|
|
for model_name in MODELS_TO_TRY: |
|
try: |
|
logger.info(f"Trying to load: {model_name}") |
|
|
|
if "git-base" in model_name: |
|
|
|
model = pipeline("image-to-text", model=model_name, device=0 if torch.cuda.is_available() else -1) |
|
processor = None |
|
current_model_name = model_name |
|
logger.info(f"β
Successfully loaded GIT model: {model_name}") |
|
return True |
|
|
|
elif "blip2" in model_name: |
|
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
) |
|
current_model_name = model_name |
|
logger.info(f"β
Successfully loaded BLIP2 model: {model_name}") |
|
return True |
|
|
|
else: |
|
|
|
processor = BlipProcessor.from_pretrained(model_name) |
|
model = BlipForConditionalGeneration.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
) |
|
if torch.cuda.is_available() and hasattr(model, 'to'): |
|
model = model.to(device) |
|
current_model_name = model_name |
|
logger.info(f"β
Successfully loaded BLIP model: {model_name}") |
|
return True |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to load {model_name}: {e}") |
|
continue |
|
|
|
logger.error("β Failed to load any model") |
|
return False |
|
|
|
|
|
model_ready = load_best_model() |
|
|
|
def get_detailed_medical_analysis(image, question): |
|
"""Get detailed medical analysis using the best available model""" |
|
try: |
|
if "git-base" in current_model_name: |
|
|
|
results = model(image, max_new_tokens=200) |
|
description = results[0]['generated_text'] if results else "Unable to analyze image" |
|
|
|
|
|
if any(word in question.lower() for word in ['abnormal', 'diagnosis', 'condition', 'pathology']): |
|
|
|
medical_prompt = f"Medical analysis: {description}" |
|
return description, medical_prompt |
|
|
|
return description, description |
|
|
|
elif "blip2" in current_model_name: |
|
|
|
inputs = processor(image, question, return_tensors="pt") |
|
if torch.cuda.is_available(): |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate(**inputs, max_new_tokens=150, do_sample=False) |
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
basic_inputs = processor(image, return_tensors="pt") |
|
if torch.cuda.is_available(): |
|
basic_inputs = {k: v.to(device) for k, v in basic_inputs.items()} |
|
|
|
with torch.no_grad(): |
|
basic_ids = model.generate(**basic_inputs, max_new_tokens=100, do_sample=False) |
|
|
|
basic_text = processor.batch_decode(basic_ids, skip_special_tokens=True)[0] |
|
|
|
return basic_text, generated_text |
|
|
|
else: |
|
|
|
|
|
inputs = processor(image, return_tensors="pt") |
|
if torch.cuda.is_available(): |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate(**inputs, max_length=100, num_beams=3, do_sample=False) |
|
|
|
basic_description = processor.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
medical_prompts = [ |
|
f"Question: {question} Answer:", |
|
f"Medical analysis: {question}", |
|
f"Describe the medical findings: {question}" |
|
] |
|
|
|
best_response = basic_description |
|
|
|
for prompt in medical_prompts: |
|
try: |
|
inputs_qa = processor(image, prompt, return_tensors="pt") |
|
if torch.cuda.is_available(): |
|
inputs_qa = {k: v.to(device) for k, v in inputs_qa.items()} |
|
|
|
with torch.no_grad(): |
|
qa_output_ids = model.generate( |
|
**inputs_qa, |
|
max_length=200, |
|
num_beams=3, |
|
do_sample=False, |
|
early_stopping=True |
|
) |
|
|
|
|
|
input_length = inputs_qa['input_ids'].shape[1] |
|
qa_response = processor.decode(qa_output_ids[0][input_length:], skip_special_tokens=True).strip() |
|
|
|
if qa_response and len(qa_response) > 20 and not qa_response.lower().startswith('question'): |
|
best_response = qa_response |
|
break |
|
|
|
except Exception as e: |
|
continue |
|
|
|
return basic_description, best_response |
|
|
|
except Exception as e: |
|
logger.error(f"Analysis failed: {e}") |
|
return "Unable to analyze image", "Analysis failed" |
|
|
|
def enhance_medical_description(basic_desc, clinical_question, patient_history): |
|
"""Enhance basic description with medical context and educational content""" |
|
|
|
|
|
chest_xray_analysis = """ |
|
**Systematic Chest X-ray Analysis:** |
|
|
|
**Technical Quality:** |
|
- Image appears to be a standard PA chest radiograph |
|
- Adequate penetration and positioning for diagnostic evaluation |
|
|
|
**Anatomical Review:** |
|
- **Heart**: Cardiac silhouette evaluation for size and contour |
|
- **Lungs**: Assessment of lung fields for opacity, consolidation, or air trapping |
|
- **Pleura**: Examination for pleural effusion or pneumothorax |
|
- **Bones**: Rib cage and spine alignment assessment |
|
- **Soft Tissues**: Evaluation of surrounding structures |
|
|
|
**Clinical Correlation Needed:** |
|
Given the patient's presentation with cough and fever, key considerations include: |
|
- **Pneumonia**: Look for consolidation, air bronchograms, or infiltrates |
|
- **Viral vs Bacterial**: Pattern recognition for different infectious etiologies |
|
- **Atelectasis**: Collapsed lung segments that might appear as increased opacity |
|
- **Pleural Changes**: Fluid collection that could indicate infection complications |
|
|
|
**Educational Points:** |
|
- Chest X-rays are the first-line imaging for respiratory symptoms |
|
- Clinical correlation is essential - symptoms guide interpretation |
|
- Follow-up imaging may be needed based on treatment response |
|
""" |
|
|
|
|
|
if any(term in basic_desc.lower() for term in ['chest', 'lung', 'rib', 'heart', 'x-ray', 'radiograph']) or \ |
|
any(term in clinical_question.lower() for term in ['chest', 'lung', 'respiratory', 'cough']): |
|
enhanced_analysis = chest_xray_analysis |
|
else: |
|
|
|
enhanced_analysis = f""" |
|
**Medical Image Analysis Framework:** |
|
|
|
**Image Description:** |
|
{basic_desc} |
|
|
|
**Clinical Context Integration:** |
|
- Patient presentation: {patient_history if patient_history else 'Clinical history provided'} |
|
- Imaging indication: {clinical_question} |
|
|
|
**Systematic Approach:** |
|
1. **Technical Assessment**: Image quality and acquisition parameters |
|
2. **Anatomical Review**: Systematic evaluation of visible structures |
|
3. **Pathological Assessment**: Identification of any abnormal findings |
|
4. **Clinical Correlation**: Integration with patient symptoms and history |
|
|
|
**Educational Considerations:** |
|
- Medical imaging interpretation requires systematic approach |
|
- Clinical context significantly influences interpretation priorities |
|
- Multiple imaging modalities may be complementary for diagnosis |
|
- Professional radiological review is essential for clinical decisions |
|
""" |
|
|
|
return enhanced_analysis |
|
|
|
def analyze_medical_image(image, clinical_question, patient_history=""): |
|
"""Enhanced medical image analysis with better AI models""" |
|
start_time = time.time() |
|
|
|
|
|
if not rate_limiter.is_allowed(): |
|
usage_tracker.log_analysis(False, time.time() - start_time) |
|
return "β οΈ Rate limit exceeded. Please wait before trying again." |
|
|
|
if not model_ready or model is None: |
|
usage_tracker.log_analysis(False, time.time() - start_time) |
|
return "β Medical AI model not loaded. Please refresh the page." |
|
|
|
if image is None: |
|
return "β οΈ Please upload a medical image first." |
|
|
|
if not clinical_question.strip(): |
|
return "β οΈ Please provide a clinical question." |
|
|
|
try: |
|
logger.info("Starting enhanced medical image analysis...") |
|
|
|
|
|
basic_description, detailed_response = get_detailed_medical_analysis(image, clinical_question) |
|
|
|
|
|
enhanced_analysis = enhance_medical_description(basic_description, clinical_question, patient_history) |
|
|
|
|
|
formatted_response = f"""# π₯ **Enhanced Medical AI Analysis** |
|
|
|
## **Clinical Question:** {clinical_question} |
|
{f"## **Patient History:** {patient_history}" if patient_history.strip() else ""} |
|
|
|
--- |
|
|
|
## π **AI Vision Analysis** |
|
|
|
### **Image Description:** |
|
{basic_description} |
|
|
|
### **Question-Specific Analysis:** |
|
{detailed_response} |
|
|
|
--- |
|
|
|
## π **Medical Assessment Framework** |
|
{enhanced_analysis} |
|
|
|
--- |
|
|
|
## π **Educational Summary** |
|
|
|
**Learning Objectives:** |
|
- Demonstrate systematic approach to medical image interpretation |
|
- Integrate clinical history with imaging findings |
|
- Understand the importance of professional validation in medical diagnosis |
|
|
|
**Key Teaching Points:** |
|
- Medical imaging is one component of comprehensive patient assessment |
|
- Clinical correlation enhances diagnostic accuracy |
|
- Multiple imaging modalities may provide complementary information |
|
- Professional interpretation is essential for patient care decisions |
|
|
|
**Clinical Decision Making:** |
|
Based on the combination of: |
|
- Patient symptoms: {patient_history if patient_history else 'As provided'} |
|
- Imaging findings: As described above |
|
- Clinical context: {clinical_question} |
|
|
|
**Next Steps in Clinical Practice:** |
|
- Professional radiological review |
|
- Correlation with laboratory findings |
|
- Consider additional imaging if clinically indicated |
|
- Follow-up based on treatment response |
|
""" |
|
|
|
|
|
disclaimer = """ |
|
--- |
|
## β οΈ **IMPORTANT MEDICAL DISCLAIMER** |
|
|
|
**FOR EDUCATIONAL AND RESEARCH PURPOSES ONLY** |
|
|
|
- **π« AI Limitations**: AI analysis has significant limitations for medical diagnosis |
|
- **π¨ββοΈ Professional Review Required**: All findings must be validated by qualified healthcare professionals |
|
- **π¨ Emergency Care**: For urgent medical concerns, seek immediate medical attention |
|
- **π₯ Clinical Integration**: AI findings are educational tools, not diagnostic conclusions |
|
- **π Learning Tool**: Designed for medical education and training purposes |
|
- **π Privacy**: Do not upload real patient data or identifiable information |
|
|
|
**This analysis demonstrates AI-assisted medical image interpretation concepts for educational purposes only.** |
|
|
|
--- |
|
**Model**: {current_model_name} | **Device**: {device.upper()} | **Purpose**: Medical Education |
|
""" |
|
|
|
|
|
duration = time.time() - start_time |
|
question_type = classify_question(clinical_question) |
|
usage_tracker.log_analysis(True, duration, question_type) |
|
|
|
logger.info(f"β
Enhanced medical analysis completed in {duration:.2f}s") |
|
return formatted_response + disclaimer |
|
|
|
except Exception as e: |
|
duration = time.time() - start_time |
|
usage_tracker.log_analysis(False, duration) |
|
logger.error(f"β Analysis error: {str(e)}") |
|
return f"β Enhanced analysis failed: {str(e)}\n\nPlease try again with a different image." |
|
|
|
def classify_question(question): |
|
"""Classify clinical question type""" |
|
question_lower = question.lower() |
|
if any(word in question_lower for word in ['describe', 'findings', 'observe', 'see', 'show']): |
|
return 'descriptive' |
|
elif any(word in question_lower for word in ['diagnosis', 'differential', 'condition']): |
|
return 'diagnostic' |
|
elif any(word in question_lower for word in ['abnormal', 'pathology', 'disease']): |
|
return 'pathological' |
|
else: |
|
return 'general' |
|
|
|
def get_usage_stats(): |
|
"""Get usage statistics""" |
|
stats = usage_tracker.stats |
|
if stats['total_analyses'] == 0: |
|
return "π **Usage Statistics**\n\nNo analyses performed yet." |
|
|
|
success_rate = (stats['successful_analyses'] / stats['total_analyses']) * 100 |
|
|
|
return f"""π **Enhanced Medical AI Statistics** |
|
|
|
**Performance Metrics:** |
|
- **Total Analyses**: {stats['total_analyses']} |
|
- **Success Rate**: {success_rate:.1f}% |
|
- **Average Processing Time**: {stats['average_processing_time']:.2f} seconds |
|
|
|
**Question Types:** |
|
{chr(10).join([f"- **{qtype.title()}**: {count}" for qtype, count in stats['question_types'].most_common(3)])} |
|
|
|
**System Status**: {'π’ Enhanced Model Active' if model_ready else 'π΄ Offline'} |
|
**Current Model**: {current_model_name if current_model_name else 'None'} |
|
**Device**: {device.upper()} |
|
""" |
|
|
|
def clear_all(): |
|
"""Clear all inputs and outputs""" |
|
return None, "", "", "" |
|
|
|
def set_chest_example(): |
|
"""Set chest X-ray example""" |
|
return "Describe this chest X-ray systematically and identify any abnormalities", "30-year-old patient with productive cough, fever, and shortness of breath" |
|
|
|
def set_pathology_example(): |
|
"""Set pathology example""" |
|
return "What pathological findings are visible? Describe the tissue characteristics.", "Biopsy specimen for histopathological evaluation" |
|
|
|
def set_general_example(): |
|
"""Set general analysis example""" |
|
return "Provide a systematic analysis of this medical image", "Patient requiring comprehensive imaging evaluation" |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks( |
|
title="Enhanced Medical AI Analysis", |
|
theme=gr.themes.Soft(), |
|
css=""" |
|
.gradio-container { max-width: 1400px !important; } |
|
.disclaimer { background-color: #fef2f2; border: 1px solid #fecaca; border-radius: 8px; padding: 16px; margin: 16px 0; } |
|
.success { background-color: #f0f9ff; border: 1px solid #bae6fd; border-radius: 8px; padding: 16px 0; } |
|
.enhanced { background-color: #f0fdf4; border: 1px solid #bbf7d0; border-radius: 8px; padding: 16px 0; } |
|
""" |
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
# π₯ Enhanced Medical AI Image Analysis |
|
|
|
**Advanced Medical AI with Better Vision Models - Educational Analysis** |
|
|
|
**Enhanced Features:** π§ Multiple AI Models β’ π¬ Systematic Analysis β’ π Educational Framework β’ π Clinical Integration |
|
""") |
|
|
|
|
|
if model_ready: |
|
gr.Markdown(f""" |
|
<div class="enhanced"> |
|
β
<strong>ENHANCED MEDICAL AI READY</strong><br> |
|
Advanced model loaded: <strong>{current_model_name}</strong><br> |
|
Now provides detailed medical image analysis with systematic framework and educational content. |
|
</div> |
|
""") |
|
else: |
|
gr.Markdown(""" |
|
<div class="disclaimer"> |
|
β οΈ <strong>MODEL LOADING</strong><br> |
|
Enhanced Medical AI is loading. Please wait and refresh if needed. |
|
</div> |
|
""") |
|
|
|
|
|
gr.Markdown(""" |
|
<div class="disclaimer"> |
|
β οΈ <strong>MEDICAL DISCLAIMER</strong><br> |
|
This enhanced tool provides AI-assisted medical analysis for <strong>educational purposes only</strong>. |
|
Uses advanced vision models for better image understanding. Do not upload real patient data. |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=2): |
|
|
|
gr.Markdown("## π€ Medical Image Upload") |
|
image_input = gr.Image( |
|
label="Upload Medical Image (Enhanced Analysis)", |
|
type="pil", |
|
height=300 |
|
) |
|
|
|
|
|
gr.Markdown("## π¬ Clinical Information") |
|
clinical_question = gr.Textbox( |
|
label="Clinical Question *", |
|
placeholder="Enhanced examples:\nβ’ Systematically describe this chest X-ray and identify abnormalities\nβ’ What pathological findings are visible in this image?\nβ’ Provide detailed analysis of anatomical structures\nβ’ Analyze this medical scan for educational purposes", |
|
lines=4 |
|
) |
|
|
|
patient_history = gr.Textbox( |
|
label="Patient History & Clinical Context", |
|
placeholder="Detailed example: 35-year-old female with 3-day history of productive cough, fever (38.5Β°C), shortness of breath, and left-sided chest pain", |
|
lines=3 |
|
) |
|
|
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button("ποΈ Clear All", variant="secondary") |
|
analyze_btn = gr.Button("π Enhanced Medical Analysis", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown("## π Enhanced Medical Analysis Results") |
|
output = gr.Textbox( |
|
label="Advanced AI Medical Analysis (Multiple Models)", |
|
lines=25, |
|
show_copy_button=True, |
|
placeholder="Upload a medical image and provide detailed clinical question for comprehensive AI analysis..." |
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## βΉοΈ Enhanced System Status") |
|
|
|
system_info = f""" |
|
**Status**: {'β
Advanced Models Active' if model_ready else 'π Loading'} |
|
**Primary Model**: {current_model_name if current_model_name else 'Loading...'} |
|
**Device**: {device.upper()} |
|
**Enhancement**: π§ Multiple AI Models |
|
**Analysis**: π Systematic Framework |
|
""" |
|
gr.Markdown(system_info) |
|
|
|
|
|
gr.Markdown("## π Usage Analytics") |
|
stats_display = gr.Markdown(get_usage_stats()) |
|
refresh_stats_btn = gr.Button("π Refresh Stats", size="sm") |
|
|
|
|
|
if model_ready: |
|
gr.Markdown("## π― Enhanced Examples") |
|
chest_btn = gr.Button("π« Chest X-ray Analysis", size="sm") |
|
pathology_btn = gr.Button("π¬ Pathology Study", size="sm") |
|
general_btn = gr.Button("π Systematic Analysis", size="sm") |
|
|
|
gr.Markdown("## π Enhancements") |
|
gr.Markdown(f""" |
|
β
**Advanced Vision Models** |
|
β
**Systematic Medical Framework** |
|
β
**Educational Integration** |
|
β
**Clinical Context Analysis** |
|
β
**Model**: {current_model_name.split('/')[-1] if current_model_name else 'Enhanced'} |
|
""") |
|
|
|
|
|
analyze_btn.click( |
|
fn=analyze_medical_image, |
|
inputs=[image_input, clinical_question, patient_history], |
|
outputs=output, |
|
show_progress=True |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_all, |
|
outputs=[image_input, clinical_question, patient_history, output] |
|
) |
|
|
|
refresh_stats_btn.click( |
|
fn=get_usage_stats, |
|
outputs=stats_display |
|
) |
|
|
|
|
|
if model_ready: |
|
chest_btn.click( |
|
fn=set_chest_example, |
|
outputs=[clinical_question, patient_history] |
|
) |
|
|
|
pathology_btn.click( |
|
fn=set_pathology_example, |
|
outputs=[clinical_question, patient_history] |
|
) |
|
|
|
general_btn.click( |
|
fn=set_general_example, |
|
outputs=[clinical_question, patient_history] |
|
) |
|
|
|
|
|
gr.Markdown(f""" |
|
--- |
|
## π **Enhanced Medical AI Features** |
|
|
|
### **Advanced Vision Models:** |
|
- **Microsoft GIT**: Enhanced image-to-text capabilities |
|
- **BLIP2**: Advanced vision-language understanding |
|
- **Multi-Model Fallback**: Automatic best model selection |
|
- **Better Descriptions**: More detailed and accurate analysis |
|
|
|
### **Medical Framework Integration:** |
|
- **Systematic Analysis**: Structured medical image interpretation |
|
- **Clinical Correlation**: Integration of symptoms with imaging |
|
- **Educational Content**: Teaching points and learning objectives |
|
- **Professional Guidelines**: Follows medical education standards |
|
|
|
**Current Model**: {current_model_name if current_model_name else 'Loading...'} | **Purpose**: Enhanced Medical Education |
|
""") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
share=False |
|
) |