Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from torchvision import transforms, models | |
from PIL import Image | |
import numpy as np | |
import pandas as pd | |
from collections import defaultdict | |
import os | |
from datasets import load_dataset | |
# Title | |
st.markdown("<h2 style='color: #2E86C1;'>π· Upload & Predict</h2>", unsafe_allow_html=True) | |
st.markdown(""" | |
### π About This Feature: Upload & Predict | |
This section of the **DR Assistive Tool** allows users to upload retinal images and get an AI-based prediction of the **Diabetic Retinopathy stage**. It uses a fine-tuned **DenseNet-121** model trained specifically for detecting DR severity levels. | |
The model classifies the uploaded image into one of the five classes: | |
- **No DR** | |
- **Mild** | |
- **Moderate** | |
- **Severe** | |
- **Proliferative DR** | |
This is especially helpful for: | |
- Students learning about AI in healthcare | |
- Researchers testing model robustness | |
- Clinicians exploring AI-assisted screening tools | |
The tool also shows **sample images from the test set** for each class. You can use these images to test the modelβs performance and understand what different DR stages look like. | |
--- | |
### π§ How to Use: | |
1. π **View sample images** from the test set grouped by DR stage. | |
- Click the **"π Predict"** button under a sample image to test how the model classifies it. | |
2. π **Upload your own retinal image** (in JPG or PNG format) using the file uploader. | |
3. π§ Click the **"Predict"** button after uploading. | |
- The model will analyze the image and display: | |
- π― **Predicted DR Stage** | |
- π **Model confidence score (in %)** | |
β οΈ *Make sure your image is a clear, centered fundus photograph for best results.* | |
--- | |
### π Behind the Scenes: | |
- β Model: Pretrained **DenseNet-121** | |
- πΌ Input size: Images are resized to 224Γ224 pixels | |
- π Normalization: Matches ImageNet pretraining stats | |
- π¦ Output: Highest probability class from 5 DR categories using **softmax** | |
*This tool is for educational and research purposes only β not for clinical use.* | |
""", unsafe_allow_html=True) | |
# DR class names | |
class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'] | |
def load_sample_images_from_csv(): | |
csv_url = "https://huggingface.co/datasets/Ci-Dave/DDR_dataset_train_test/raw/main/splits/test_labels.csv" | |
df = pd.read_csv(csv_url) | |
samples = defaultdict(list) | |
for i in range(5): | |
class_name = class_names[i] | |
class_samples = df[df['label'] == i].head(5) | |
for _, row in class_samples.iterrows(): | |
img_path = row['new_path'] | |
if os.path.exists(img_path): # works only if images are local | |
samples[class_name].append(img_path) | |
return samples | |
# Load pretrained model | |
def load_model(): | |
model = models.densenet121(pretrained=False) | |
model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names)) | |
model.load_state_dict(torch.load("./Model/Pretrained_Densenet-121.pth", map_location='cpu')) | |
model.eval() | |
return model | |
# Image transform function | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Prediction function | |
def predict_image(model, image): | |
img_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
_, pred = torch.max(outputs, 1) | |
prob = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item() * 100 | |
return class_names[pred.item()], prob | |
# Create two tabs for better separation of features | |
tab1, tab2 = st.tabs(["π§ͺ Sample Images", "π€ Upload & Predict"]) | |
with tab1: | |
st.markdown("### π§ͺ Sample Images from Test Set") | |
st.markdown(""" | |
#### π About This Feature: Sample Images | |
In this tab, you can explore sample retinal images from the test set, grouped by their **Diabetic Retinopathy (DR)** stage. This helps you: | |
- Understand the **visual differences** between DR stages | |
- Test the modelβs performance on known data | |
- Get familiar with the modelβs prediction behavior | |
#### π§ How to Use: | |
1. Browse the sample images under each DR class. | |
2. Click **π Predict** under an image to let the AI model analyze it. | |
3. The result will show: | |
- π― **Predicted DR stage** | |
- π **Confidence score** | |
> *Ideal for researchers and students testing the model with known data.* | |
""", unsafe_allow_html=True) | |
sample_images = load_sample_images_from_csv() | |
for class_name in class_names: | |
if class_name in sample_images and sample_images[class_name]: | |
cols = st.columns(5) | |
for i, img_path in enumerate(sample_images[class_name]): | |
with cols[i]: | |
st.image(img_path, use_container_width=True) | |
if st.button("π Predict", key=f"predict_{img_path}_{i}"): | |
image = Image.open(img_path).convert('RGB') | |
model = load_model() | |
pred_class, prob = predict_image(model, image) | |
st.success(f"π― Prediction: **{pred_class}** ({prob:.2f}% confidence)") | |
else: | |
st.warning(f"β οΈ No images found for **{class_name}**") | |
with tab2: | |
st.markdown("### π€ Upload & Predict") | |
st.markdown(""" | |
#### π About This Feature: Upload & Predict | |
This tool allows you to upload a **retinal image** and get an **AI-based prediction** of the DR stage using a fine-tuned **DenseNet-121** model. | |
The model classifies the image into one of: | |
- No DR | |
- Mild | |
- Moderate | |
- Severe | |
- Proliferative DR | |
#### π§ How to Use: | |
1. π Upload a **clear fundus image** (JPG or PNG). | |
2. π§ Click **Predict** to let the model analyze it. | |
3. β You'll see: | |
- π― The predicted DR stage | |
- π Confidence level (in percentage) | |
""", unsafe_allow_html=True) | |
uploaded_file = st.file_uploader("π Upload Retinal Image", type=["jpg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert('RGB') | |
st.image(image, caption='πΌ Uploaded Image', use_container_width=True) | |
if st.button("π§ Predict"): | |
with st.spinner('Analyzing image...'): | |
model = load_model() | |
pred_class, prob = predict_image(model, image) | |
st.success(f"π― Prediction: **{pred_class}** ({prob:.2f}% confidence)") | |