Spaces:
Sleeping
Sleeping
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
import transformers | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import io | |
from transformers import BertTokenizer, BertModel | |
from torch.utils.data import Dataset, DataLoader | |
import networkx as nx | |
import logging | |
import mlflow | |
import os | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class VisionRAGHealthcare: | |
def __init__(self, device=None): | |
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') | |
logger.info(f"Using device: {self.device}") | |
self.initialize_models() | |
def initialize_models(self): | |
"""Initialize vision and language models""" | |
try: | |
# Initialize DenseNet121 pre-trained on ImageNet | |
self.vision_model = torchvision.models.densenet121(weights='IMAGENET1K_V1') | |
self.vision_model.eval() | |
self.vision_model = self.vision_model.to(self.device) | |
# Define image transformations | |
self.image_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
# Initialize BERT for text processing | |
self.text_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
self.text_model = BertModel.from_pretrained('bert-base-uncased').to(self.device) | |
logger.info("Models initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing models: {str(e)}") | |
raise | |
def preprocess_image(self, image): | |
"""Preprocess medical images""" | |
try: | |
# Handle different image input types | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image).convert('RGB') | |
elif isinstance(image, str): | |
image = Image.open(image).convert('RGB') | |
elif isinstance(image, Image.Image): | |
image = image.convert('RGB') | |
else: | |
raise ValueError(f"Unsupported image type: {type(image)}") | |
# Apply preprocessing steps | |
image_tensor = self.image_transforms(image).unsqueeze(0) | |
return image_tensor.to(self.device) | |
except Exception as e: | |
logger.error(f"Error preprocessing image: {str(e)}") | |
raise | |
def preprocess_text(self, text): | |
"""Preprocess medical text""" | |
try: | |
inputs = self.text_tokenizer( | |
text, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
return inputs | |
except Exception as e: | |
logger.error(f"Error preprocessing text: {str(e)}") | |
raise | |
def extract_features(self, image_tensor, text_inputs): | |
"""Extract features from both image and text""" | |
try: | |
with torch.no_grad(): | |
# Extract image features | |
image_features = self.vision_model.features(image_tensor) | |
image_features = torch.nn.functional.adaptive_avg_pool2d(image_features, (1, 1)) | |
image_features = torch.flatten(image_features, 1) | |
# Extract text features | |
text_features = self.text_model(**text_inputs).last_hidden_state[:, 0, :] | |
return image_features, text_features | |
except Exception as e: | |
logger.error(f"Error extracting features: {str(e)}") | |
raise | |
class HealthcareGradioApp: | |
def __init__(self): | |
self.vision_rag = VisionRAGHealthcare() | |
def analyze_symptoms(self, clinical_notes): | |
"""Analyze clinical symptoms and provide relevant insights""" | |
symptoms = { | |
'cough': ['persistent cough', 'coughing', 'dry cough', 'wet cough'], | |
'fever': ['fever', 'high temperature', 'febrile'], | |
'breathing': ['shortness of breath', 'difficulty breathing', 'dyspnea'], | |
'chest_pain': ['chest pain', 'chest discomfort', 'chest tightness'], | |
'duration': ['days', 'weeks', 'months', 'chronic'], | |
'severity': ['mild', 'moderate', 'severe', 'intense'] | |
} | |
findings = [] | |
clinical_notes = clinical_notes.lower() | |
# Analyze symptoms | |
if any(term in clinical_notes for term in symptoms['cough']): | |
findings.append("Cough detected in symptoms") | |
if any(term in clinical_notes for term in symptoms['fever']): | |
findings.append("Fever reported") | |
if any(term in clinical_notes for term in symptoms['breathing']): | |
findings.append("Breathing difficulties noted") | |
if any(term in clinical_notes for term in symptoms['chest_pain']): | |
findings.append("Chest discomfort/pain reported") | |
# Analyze duration | |
if any(term in clinical_notes for term in symptoms['duration']): | |
findings.append("Chronic/ongoing symptoms noted") | |
# Analyze severity | |
if any(term in clinical_notes for term in symptoms['severity']): | |
findings.append("Moderate to severe symptoms indicated") | |
return findings | |
def analyze_image_features(self, image_features): | |
"""Analyze image features for potential medical conditions""" | |
feature_array = image_features.cpu().numpy() | |
feature_mean = np.mean(feature_array) | |
feature_std = np.std(feature_array) | |
image_findings = [] | |
if feature_std > 0.5: | |
image_findings.append("Notable variations detected in lung tissue density") | |
if feature_mean > 0: | |
image_findings.append("Areas of increased opacity observed") | |
if feature_mean < -0.5: | |
image_findings.append("Possible areas of decreased density noted") | |
return image_findings | |
def suggest_recommendations(self, symptom_findings, image_findings): | |
"""Generate medical recommendations based on findings""" | |
recommendations = [] | |
# Symptom-based recommendations | |
if "Cough detected in symptoms" in symptom_findings: | |
recommendations.append("- Monitor cough characteristics and duration") | |
recommendations.append("- Consider pulmonary function testing") | |
if "Fever reported" in symptom_findings: | |
recommendations.append("- Regular temperature monitoring recommended") | |
recommendations.append("- Consider complete blood count and inflammatory markers") | |
if "Breathing difficulties noted" in symptom_findings: | |
recommendations.append("- Pulse oximetry monitoring advised") | |
recommendations.append("- Consider spirometry assessment") | |
# Image-based recommendations | |
if "Notable variations detected in lung tissue density" in image_findings: | |
recommendations.append("- Follow-up imaging recommended in 2-4 weeks") | |
if "Areas of increased opacity observed" in image_findings: | |
recommendations.append("- Consider high-resolution CT scan") | |
# General recommendations | |
recommendations.append("- Regular monitoring of vital signs") | |
recommendations.append("- Follow-up with primary care physician") | |
return recommendations | |
def process_medical_case(self, image, clinical_notes): | |
"""Process medical case with both image and text data""" | |
try: | |
if image is None: | |
return "Error: Please upload a medical image for analysis." | |
if not clinical_notes or not clinical_notes.strip(): | |
return "Error: Please provide clinical notes for analysis." | |
# Process inputs | |
image_tensor = self.vision_rag.preprocess_image(image) | |
text_inputs = self.vision_rag.preprocess_text(clinical_notes) | |
# Extract and analyze features | |
image_features, text_features = self.vision_rag.extract_features( | |
image_tensor, text_inputs | |
) | |
# Analyze components | |
symptom_findings = self.analyze_symptoms(clinical_notes) | |
image_findings = self.analyze_image_features(image_features) | |
recommendations = self.suggest_recommendations(symptom_findings, image_findings) | |
# Generate report | |
report = """ | |
π₯ Medical Case Analysis Report | |
============================ | |
π Clinical Symptoms Analysis: | |
--------------------------- | |
{} | |
π Imaging Analysis Findings: | |
------------------------- | |
{} | |
π‘ Recommendations: | |
---------------- | |
{} | |
β οΈ Important Notice: | |
----------------- | |
This is an AI-assisted analysis and should be reviewed by a healthcare professional. | |
Always consult with a medical doctor for proper diagnosis and treatment. | |
""".format( | |
"\n".join(f"β’ {finding}" for finding in symptom_findings), | |
"\n".join(f"β’ {finding}" for finding in image_findings), | |
"\n".join(recommendations) | |
) | |
return report | |
except Exception as e: | |
error_msg = f"Error in analysis: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
def launch_gradio_interface(self): | |
"""Launch Gradio interface""" | |
# Custom CSS for better styling | |
custom_css = """ | |
.gradio-container {max-width: 900px !important} | |
.image-upload {min-height: 400px !important} | |
""" | |
iface = gr.Interface( | |
fn=self.process_medical_case, | |
inputs=[ | |
gr.Image( | |
type="numpy", | |
label="Upload Medical Image" | |
# Removed 'info' parameter from gr.Image | |
), | |
gr.Textbox( | |
lines=5, | |
label="Clinical Notes", | |
placeholder="Enter detailed patient symptoms, medical history, and current condition...", | |
info="Include: symptoms, duration, severity, and relevant medical history" | |
) | |
], | |
outputs=gr.Textbox( | |
label="Medical Analysis Report", | |
lines=15 | |
), | |
title="π₯ Healthcare Vision-RAG AI System", | |
description=""" | |
Upload medical images and clinical notes for AI-powered analysis. | |
This system combines visual and textual analysis for comprehensive medical insights. | |
""", | |
article=""" | |
### How to use: | |
1. Upload a medical image (X-ray, CT scan) | |
2. Enter detailed clinical notes | |
3. Review the AI-generated analysis report | |
### Note: | |
This is an AI assistance tool and should be used in conjunction with professional medical judgment. | |
""", | |
theme="default", | |
css=custom_css, | |
allow_flagging="never" | |
) | |
return iface | |
def main(): | |
# MLflow tracking | |
mlflow.set_tracking_uri("sqlite:///mlflow.db") | |
mlflow.set_experiment("vision-rag-healthcare") | |
try: | |
print("π Initializing Healthcare Vision-RAG System...") | |
app = HealthcareGradioApp() | |
print("π Launching Gradio interface...") | |
iface = app.launch_gradio_interface() | |
iface.launch( | |
share=True, | |
debug=True, | |
# Removed 'enable_queue' parameter as it's deprecated | |
show_error=True | |
) | |
except Exception as e: | |
logger.error(f"Error launching application: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
main() |