retinal-api / retinal.py
eyeqkfs's picture
Update retinal.py
da9637e verified
raw
history blame
1.94 kB
import gradio as gr
import torch
from torchvision import transforms, models
from PIL import Image
import numpy as np
# Recreate the model architecture (change to convnext_base or convnext_large as needed)
model = models.convnext_base(pretrained=False) # or use convnext_large(pretrained=False) if that was used
model.classifier[2] = torch.nn.Linear(model.classifier[2].in_features, 8) # 8 classes
# Load the saved weights
state_dict = torch.load("convnext_oct_model.pth", map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval()
# Class labels
CLASS_LABELS = {
0: "AMD",
1: "CNV",
2: "CSR",
3: "DME",
4: "DR",
5: "DRUSEN",
6: "MH",
7: "NORMAL"
}
# Preprocessing function
def preprocess_image(image):
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.fromarray(image.astype('uint8')).convert("RGB")
image = preprocess(image)
return image.unsqueeze(0)
# Prediction function
def predict_image(image):
image_tensor = preprocess_image(image)
with torch.no_grad():
outputs = model(image_tensor)
predicted = torch.max(outputs, 1)[1]
class_idx = predicted.item()
confidence = torch.nn.functional.softmax(outputs, dim=1)[0][class_idx].item()
return CLASS_LABELS.get(class_idx, "Unknown"), round(confidence, 4)
# Gradio Interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(),
outputs=[gr.Text(label="Prediction"), gr.Text(label="Confidence")],
title="Retinal Image Classification",
description="Upload a retinal image to predict the condition using a ConvNeXt model.",
allow_flagging="never"
)
interface.api_name = "/predict"
# Launch the interface with a public link
if __name__ == "__main__":
interface.launch(share=True)