DR_Classification / pages /Upload_and_Predict.py
3v324v23's picture
fixed
7fc8c61
import streamlit as st
import torch
from torchvision import transforms, models
from PIL import Image
import numpy as np
import pandas as pd
from collections import defaultdict
import os
from datasets import load_dataset
# Title
st.markdown("<h2 style='color: #2E86C1;'>πŸ“· Upload & Predict</h2>", unsafe_allow_html=True)
st.markdown("""
### πŸ“– About This Feature: Upload & Predict
This section of the **DR Assistive Tool** allows users to upload retinal images and get an AI-based prediction of the **Diabetic Retinopathy stage**. It uses a fine-tuned **DenseNet-121** model trained specifically for detecting DR severity levels.
The model classifies the uploaded image into one of the five classes:
- **No DR**
- **Mild**
- **Moderate**
- **Severe**
- **Proliferative DR**
This is especially helpful for:
- Students learning about AI in healthcare
- Researchers testing model robustness
- Clinicians exploring AI-assisted screening tools
The tool also shows **sample images from the test set** for each class. You can use these images to test the model’s performance and understand what different DR stages look like.
---
### 🧭 How to Use:
1. πŸ” **View sample images** from the test set grouped by DR stage.
- Click the **"πŸ” Predict"** button under a sample image to test how the model classifies it.
2. πŸ“ **Upload your own retinal image** (in JPG or PNG format) using the file uploader.
3. 🧠 Click the **"Predict"** button after uploading.
- The model will analyze the image and display:
- 🎯 **Predicted DR Stage**
- πŸ“Š **Model confidence score (in %)**
⚠️ *Make sure your image is a clear, centered fundus photograph for best results.*
---
### πŸ›  Behind the Scenes:
- βœ… Model: Pretrained **DenseNet-121**
- πŸ–Ό Input size: Images are resized to 224Γ—224 pixels
- πŸ”„ Normalization: Matches ImageNet pretraining stats
- πŸ“¦ Output: Highest probability class from 5 DR categories using **softmax**
*This tool is for educational and research purposes only β€” not for clinical use.*
""", unsafe_allow_html=True)
# DR class names
class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']
@st.cache_data
def load_sample_images_from_csv():
csv_url = "https://huggingface.co/datasets/Ci-Dave/DDR_dataset_train_test/raw/main/splits/test_labels.csv"
df = pd.read_csv(csv_url)
samples = defaultdict(list)
for i in range(5):
class_name = class_names[i]
class_samples = df[df['label'] == i].head(5)
for _, row in class_samples.iterrows():
img_path = row['new_path']
if os.path.exists(img_path): # works only if images are local
samples[class_name].append(img_path)
return samples
# Load pretrained model
@st.cache_resource
def load_model():
model = models.densenet121(pretrained=False)
model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names))
model.load_state_dict(torch.load("./Model/Pretrained_Densenet-121.pth", map_location='cpu'))
model.eval()
return model
# Image transform function
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction function
def predict_image(model, image):
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, pred = torch.max(outputs, 1)
prob = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item() * 100
return class_names[pred.item()], prob
# Create two tabs for better separation of features
tab1, tab2 = st.tabs(["πŸ§ͺ Sample Images", "πŸ“€ Upload & Predict"])
with tab1:
st.markdown("### πŸ§ͺ Sample Images from Test Set")
st.markdown("""
#### πŸ“– About This Feature: Sample Images
In this tab, you can explore sample retinal images from the test set, grouped by their **Diabetic Retinopathy (DR)** stage. This helps you:
- Understand the **visual differences** between DR stages
- Test the model’s performance on known data
- Get familiar with the model’s prediction behavior
#### 🧭 How to Use:
1. Browse the sample images under each DR class.
2. Click **πŸ” Predict** under an image to let the AI model analyze it.
3. The result will show:
- 🎯 **Predicted DR stage**
- πŸ“Š **Confidence score**
> *Ideal for researchers and students testing the model with known data.*
""", unsafe_allow_html=True)
sample_images = load_sample_images_from_csv()
for class_name in class_names:
if class_name in sample_images and sample_images[class_name]:
cols = st.columns(5)
for i, img_path in enumerate(sample_images[class_name]):
with cols[i]:
st.image(img_path, use_container_width=True)
if st.button("πŸ” Predict", key=f"predict_{img_path}_{i}"):
image = Image.open(img_path).convert('RGB')
model = load_model()
pred_class, prob = predict_image(model, image)
st.success(f"🎯 Prediction: **{pred_class}** ({prob:.2f}% confidence)")
else:
st.warning(f"⚠️ No images found for **{class_name}**")
with tab2:
st.markdown("### πŸ“€ Upload & Predict")
st.markdown("""
#### πŸ“– About This Feature: Upload & Predict
This tool allows you to upload a **retinal image** and get an **AI-based prediction** of the DR stage using a fine-tuned **DenseNet-121** model.
The model classifies the image into one of:
- No DR
- Mild
- Moderate
- Severe
- Proliferative DR
#### 🧭 How to Use:
1. πŸ“ Upload a **clear fundus image** (JPG or PNG).
2. 🧠 Click **Predict** to let the model analyze it.
3. βœ… You'll see:
- 🎯 The predicted DR stage
- πŸ“Š Confidence level (in percentage)
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader("πŸ“ Upload Retinal Image", type=["jpg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='πŸ–Ό Uploaded Image', use_container_width=True)
if st.button("🧠 Predict"):
with st.spinner('Analyzing image...'):
model = load_model()
pred_class, prob = predict_image(model, image)
st.success(f"🎯 Prediction: **{pred_class}** ({prob:.2f}% confidence)")