import gradio as gr import torch from torchvision import transforms, models from PIL import Image import numpy as np # Recreate the model architecture model = models.convnext_base(pretrained=False) model.classifier[2] = torch.nn.Linear(model.classifier[2].in_features, 8) # 8 classes # Load the saved weights state_dict = torch.load("convnext_oct_model.pth", map_location=torch.device("cpu")) model.load_state_dict(state_dict) model.eval() # Class labels (full names with abbreviation at the end) CLASS_LABELS = { 0: "Age-related Macular Degeneration (AMD)", 1: "Choroidal Neovascularization (CNV)", 2: "Central Serous Retinopathy (CSR)", 3: "Diabetic Macular Edema (DME)", 4: "Diabetic Retinopathy (DR)", 5: "Drusen (DRUSEN)", 6: "Macular Hole (MH)", 7: "Healthy Retina (NORMAL)" } # Preprocessing function def preprocess_image(image): preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image = Image.fromarray(image.astype('uint8')).convert("RGB") image = preprocess(image) return image.unsqueeze(0) # Prediction function def predict_image(image): image_tensor = preprocess_image(image) with torch.no_grad(): outputs = model(image_tensor) predicted = torch.max(outputs, 1)[1] class_idx = predicted.item() confidence = torch.nn.functional.softmax(outputs, dim=1)[0][class_idx].item() return CLASS_LABELS.get(class_idx, "Unknown"), round(confidence, 4) # Gradio Interface interface = gr.Interface( fn=predict_image, inputs=gr.Image(), outputs=[gr.Text(label="Prediction"), gr.Text(label="Confidence")], title="Retinal Image Classification", description="Upload a retinal image to predict the condition using a ConvNeXt model.", allow_flagging="never" ) interface.api_name = "/predict" # Launch the interface with a public link if __name__ == "__main__": interface.launch(share=True)