Anupam251272's picture
Create app.py
cb18ff3 verified
import torch
import torchvision
import torchvision.transforms as transforms
import transformers
import gradio as gr
import pandas as pd
import numpy as np
import cv2
from PIL import Image
import io
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
import networkx as nx
import logging
import mlflow
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class VisionRAGHealthcare:
def __init__(self, device=None):
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
self.initialize_models()
def initialize_models(self):
"""Initialize vision and language models"""
try:
# Initialize DenseNet121 pre-trained on ImageNet
self.vision_model = torchvision.models.densenet121(weights='IMAGENET1K_V1')
self.vision_model.eval()
self.vision_model = self.vision_model.to(self.device)
# Define image transformations
self.image_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Initialize BERT for text processing
self.text_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.text_model = BertModel.from_pretrained('bert-base-uncased').to(self.device)
logger.info("Models initialized successfully")
except Exception as e:
logger.error(f"Error initializing models: {str(e)}")
raise
def preprocess_image(self, image):
"""Preprocess medical images"""
try:
# Handle different image input types
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
elif isinstance(image, str):
image = Image.open(image).convert('RGB')
elif isinstance(image, Image.Image):
image = image.convert('RGB')
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Apply preprocessing steps
image_tensor = self.image_transforms(image).unsqueeze(0)
return image_tensor.to(self.device)
except Exception as e:
logger.error(f"Error preprocessing image: {str(e)}")
raise
def preprocess_text(self, text):
"""Preprocess medical text"""
try:
inputs = self.text_tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
return inputs
except Exception as e:
logger.error(f"Error preprocessing text: {str(e)}")
raise
def extract_features(self, image_tensor, text_inputs):
"""Extract features from both image and text"""
try:
with torch.no_grad():
# Extract image features
image_features = self.vision_model.features(image_tensor)
image_features = torch.nn.functional.adaptive_avg_pool2d(image_features, (1, 1))
image_features = torch.flatten(image_features, 1)
# Extract text features
text_features = self.text_model(**text_inputs).last_hidden_state[:, 0, :]
return image_features, text_features
except Exception as e:
logger.error(f"Error extracting features: {str(e)}")
raise
class HealthcareGradioApp:
def __init__(self):
self.vision_rag = VisionRAGHealthcare()
def analyze_symptoms(self, clinical_notes):
"""Analyze clinical symptoms and provide relevant insights"""
symptoms = {
'cough': ['persistent cough', 'coughing', 'dry cough', 'wet cough'],
'fever': ['fever', 'high temperature', 'febrile'],
'breathing': ['shortness of breath', 'difficulty breathing', 'dyspnea'],
'chest_pain': ['chest pain', 'chest discomfort', 'chest tightness'],
'duration': ['days', 'weeks', 'months', 'chronic'],
'severity': ['mild', 'moderate', 'severe', 'intense']
}
findings = []
clinical_notes = clinical_notes.lower()
# Analyze symptoms
if any(term in clinical_notes for term in symptoms['cough']):
findings.append("Cough detected in symptoms")
if any(term in clinical_notes for term in symptoms['fever']):
findings.append("Fever reported")
if any(term in clinical_notes for term in symptoms['breathing']):
findings.append("Breathing difficulties noted")
if any(term in clinical_notes for term in symptoms['chest_pain']):
findings.append("Chest discomfort/pain reported")
# Analyze duration
if any(term in clinical_notes for term in symptoms['duration']):
findings.append("Chronic/ongoing symptoms noted")
# Analyze severity
if any(term in clinical_notes for term in symptoms['severity']):
findings.append("Moderate to severe symptoms indicated")
return findings
def analyze_image_features(self, image_features):
"""Analyze image features for potential medical conditions"""
feature_array = image_features.cpu().numpy()
feature_mean = np.mean(feature_array)
feature_std = np.std(feature_array)
image_findings = []
if feature_std > 0.5:
image_findings.append("Notable variations detected in lung tissue density")
if feature_mean > 0:
image_findings.append("Areas of increased opacity observed")
if feature_mean < -0.5:
image_findings.append("Possible areas of decreased density noted")
return image_findings
def suggest_recommendations(self, symptom_findings, image_findings):
"""Generate medical recommendations based on findings"""
recommendations = []
# Symptom-based recommendations
if "Cough detected in symptoms" in symptom_findings:
recommendations.append("- Monitor cough characteristics and duration")
recommendations.append("- Consider pulmonary function testing")
if "Fever reported" in symptom_findings:
recommendations.append("- Regular temperature monitoring recommended")
recommendations.append("- Consider complete blood count and inflammatory markers")
if "Breathing difficulties noted" in symptom_findings:
recommendations.append("- Pulse oximetry monitoring advised")
recommendations.append("- Consider spirometry assessment")
# Image-based recommendations
if "Notable variations detected in lung tissue density" in image_findings:
recommendations.append("- Follow-up imaging recommended in 2-4 weeks")
if "Areas of increased opacity observed" in image_findings:
recommendations.append("- Consider high-resolution CT scan")
# General recommendations
recommendations.append("- Regular monitoring of vital signs")
recommendations.append("- Follow-up with primary care physician")
return recommendations
def process_medical_case(self, image, clinical_notes):
"""Process medical case with both image and text data"""
try:
if image is None:
return "Error: Please upload a medical image for analysis."
if not clinical_notes or not clinical_notes.strip():
return "Error: Please provide clinical notes for analysis."
# Process inputs
image_tensor = self.vision_rag.preprocess_image(image)
text_inputs = self.vision_rag.preprocess_text(clinical_notes)
# Extract and analyze features
image_features, text_features = self.vision_rag.extract_features(
image_tensor, text_inputs
)
# Analyze components
symptom_findings = self.analyze_symptoms(clinical_notes)
image_findings = self.analyze_image_features(image_features)
recommendations = self.suggest_recommendations(symptom_findings, image_findings)
# Generate report
report = """
πŸ₯ Medical Case Analysis Report
============================
πŸ“‹ Clinical Symptoms Analysis:
---------------------------
{}
πŸ” Imaging Analysis Findings:
-------------------------
{}
πŸ’‘ Recommendations:
----------------
{}
⚠️ Important Notice:
-----------------
This is an AI-assisted analysis and should be reviewed by a healthcare professional.
Always consult with a medical doctor for proper diagnosis and treatment.
""".format(
"\n".join(f"β€’ {finding}" for finding in symptom_findings),
"\n".join(f"β€’ {finding}" for finding in image_findings),
"\n".join(recommendations)
)
return report
except Exception as e:
error_msg = f"Error in analysis: {str(e)}"
logger.error(error_msg)
return error_msg
def launch_gradio_interface(self):
"""Launch Gradio interface"""
# Custom CSS for better styling
custom_css = """
.gradio-container {max-width: 900px !important}
.image-upload {min-height: 400px !important}
"""
iface = gr.Interface(
fn=self.process_medical_case,
inputs=[
gr.Image(
type="numpy",
label="Upload Medical Image"
# Removed 'info' parameter from gr.Image
),
gr.Textbox(
lines=5,
label="Clinical Notes",
placeholder="Enter detailed patient symptoms, medical history, and current condition...",
info="Include: symptoms, duration, severity, and relevant medical history"
)
],
outputs=gr.Textbox(
label="Medical Analysis Report",
lines=15
),
title="πŸ₯ Healthcare Vision-RAG AI System",
description="""
Upload medical images and clinical notes for AI-powered analysis.
This system combines visual and textual analysis for comprehensive medical insights.
""",
article="""
### How to use:
1. Upload a medical image (X-ray, CT scan)
2. Enter detailed clinical notes
3. Review the AI-generated analysis report
### Note:
This is an AI assistance tool and should be used in conjunction with professional medical judgment.
""",
theme="default",
css=custom_css,
allow_flagging="never"
)
return iface
def main():
# MLflow tracking
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("vision-rag-healthcare")
try:
print("πŸš€ Initializing Healthcare Vision-RAG System...")
app = HealthcareGradioApp()
print("πŸ“Š Launching Gradio interface...")
iface = app.launch_gradio_interface()
iface.launch(
share=True,
debug=True,
# Removed 'enable_queue' parameter as it's deprecated
show_error=True
)
except Exception as e:
logger.error(f"Error launching application: {str(e)}")
raise
if __name__ == "__main__":
main()