| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from model import DropoutNet | |
| is_cuda = torch.cuda.is_available() | |
| model = DropoutNet() | |
| model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu'))) | |
| model.eval() | |
| if is_cuda: | |
| print("Running on the GPU") | |
| model = model.to('cuda') | |
| else: | |
| print("Running on the CPU") | |
| def predict(image): | |
| image = cv2.resize(image, (28, 28)) | |
| gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| image = gray_image | |
| image = np.expand_dims(image, axis=0) | |
| image = image.reshape(1, 1, 28, 28) | |
| image = torch.from_numpy(image) | |
| image = image.float() | |
| output = model(image) | |
| out_value, out_index = torch.max(output, 1) | |
| labels = {'0': 'choroidal neovascularization', '1': 'diabetic macular edema', '2': 'drusen', '3': 'normal'} | |
| return labels[str(out_index.item())] | |
| description_html = """ | |
| <p>This model predicts the disease based on the retinal image.</p> | |
| """ | |
| article_html = """ | |
| <h3>How does it work?</h3> | |
| <p>The model is a Convolutional Neural Network (CNN) trained on the retinal images to predict the disease.</p> | |
| <p>Dataset used for training is MEDMNIST dataset which contains retinal images of 4 different diseases.</p> | |
| <p>It uses PyTorch framework for training and prediction.</p> | |
| <h3>How to use?</h3> | |
| <p>Upload an image of the retina and click on 'Submit' to get the prediction.</p> | |
| <p>It will show the predicted disease based on the input image.</p> | |
| <p>It can predict one of the following diseases:</p> | |
| <ul> | |
| <li>Choroidal Neovascularization</li> | |
| <li>Diabetic Macular Edema</li> | |
| <li>Drusen</li> | |
| <li>Normal</li> | |
| </ul> | |
| <h3>How accurate is it?</h3> | |
| <p>The model has an accuracy of 75 on the test dataset.</p> | |
| """ | |
| gr.Interface(fn=predict, inputs="image", outputs="label", title="Retinal Disease Prediction", description="This model predicts the disease based on the retinal image.", article=article_html).launch() | |