Model card

Try our model here

Model description

This is an image categorization model that uses restnet-50 as the base model to classify diabetic retinopathy

Intended uses & limitations

Given an image taken using fundus photography this model will identify diabetic retinopathy on a scale of 0 to 4:

0 - No DR

1 - Mild

2 - Moderate

3 - Severe

4 - Proliferative DR

Training

  • We trained our model with retina images taken using fundus photography under a variety of imaging conditions.
  • The training data was gathered for a Kaggle completion by the Asia Pacific Tele-Ophthalmology Society (APTOS) in 2019
  • Training data
  • Training Process

Evaluation

Training accuracy - trained for 50 epochs, reaching 83% accuracy within our training data

Epoch Train Loss Valid Loss Accuracy Error Rate Time
0 1.271288 1.351223 0.665301 0.334699 03:47
1 1.013268 0.742499 0.741803 0.258197 04:12
2 0.806825 0.687152 0.754098 0.245902 03:42
0 0.631816 0.533298 0.789617 0.210383 04:22
1 0.537469 0.457713 0.829235 0.170765 04:23
2 0.498419 0.515875 0.810109 0.189891 04:20
3 0.478353 0.511856 0.815574 0.184426 04:13
4 0.459457 0.475843 0.801913 0.198087 04:17
...
48 0.024947 0.800241 0.840164 0.159836 03:21
49 0.027916 0.803851 0.838798 0.161202 03:26

confusion matrix

We submitted our model for validation to the APTOS 2019 Blindness Detection Competition, achieving a private score of 0.869345

Trying the model

Note: You can easily try our model here

This application uses a trained model to detect the severity of diabetic retinopathy from a given retina image taken using fundus photography. The severity levels are:

  • 0 - No DR
  • 1 - Mild
  • 2 - Moderate
  • 3 - Severe
  • 4 - Proliferative DR

How to Use the Model

To use the model, you need to provide an image of the retina taken using fundus photography. The model will then predict the severity of diabetic retinopathy and return a dictionary where the keys are the severity levels and the values are the corresponding probabilities.

Breakdown of the app.py File

Here's a breakdown of what the app.py file is doing:

  1. Import necessary libraries: The file starts by importing the necessary libraries. This includes gradio for creating the UI, fastai.vision.all for loading the trained model, and skimage for image processing.

  2. Define helper functions: The get_x and get_y functions are defined. These functions are used to get the x and y values from the input dictionary. In this case, the x value is the image and the y value is the diagnosis.

  3. Load the trained model: The trained model is loaded from the model.pkl file using the load_learner function from fastai.

  4. Define label descriptions: A dictionary is defined to map label numbers to descriptions. This is used to return descriptions instead of numbers in the prediction result.

  5. Define the prediction function: The predict function is defined. This function takes an image as input, makes a prediction using the trained model, and returns a dictionary where the keys are the severity levels and the values are the corresponding probabilities.

  6. Define title and description: The title and description of the application are defined. These will be displayed in the Gradio UI.

To run the application, you need to create a Gradio interface with the predict function as the prediction function, an image as the input, and a label as the output. You can then launch the interface to start the application.

from fastai.vision.all import *
import skimage

# Define the functions to get the x and y values from the input dictionary - in this case, the x value is the image and the y value is the diagnosis
# needed to load the model since we defined them during training
def get_x(r): return ""

def get_y(r): return r['diagnosis']

learn = load_learner('model.pkl')
labels = learn.dls.vocab

# Define the mapping from label numbers to descriptions
label_descriptions = {
    0: "No DR",
    1: "Mild",
    2: "Moderate",
    3: "Severe",
    4: "Proliferative DR"
}

def predict(img):
    img = PILImage.create(img)
    pred, pred_idx, probs = learn.predict(img)
    # Use the label_descriptions dictionary to return descriptions instead of numbers
    return {label_descriptions[labels[i]]: float(probs[i]) for i in range(len(labels))}

title = "Diabetic Retinopathy Detection"
description = """Detects severity of diabetic retinopathy from a given retina image taken using fundus photography -

    0 - No DR

    1 - Mild

    2 - Moderate

    3 - Severe

    4 - Proliferative DR
"""
article = "<p style='text-align: center'><a href='https://www.kaggle.com/code/josemauriciodelgado/proliferative-retinopathy' target='_blank'>Notebook</a></p>"

# Get a list of all image paths in the test folder
test_folder = "test"  # replace with the actual path to your test folder
image_paths = [os.path.join(test_folder, img) for img in os.listdir(test_folder) if img.endswith(('.png', '.jpg', '.jpeg'))]

gr.Interface(
    fn=predict,
    inputs=gr.Image(),
    outputs=gr.Label(num_top_classes=5),
    examples=image_paths,  # set the examples parameter to the list of image paths
    article=article,
    title=title,
    description=description,
).launch()

source code

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for jdelgado2002/diabetic_retinopathy_detection

Finetuned
(134)
this model

Space using jdelgado2002/diabetic_retinopathy_detection 1