Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from functools import partial | |
| import Utils.Pneumonia_Utils as PU | |
| import Utils.CT_Scan_Utils as CSU | |
| import Utils.Covid19_Utils as C19U | |
| import Utils.DR_Utils as DRU | |
| # Constants for model paths | |
| CANCER_MODEL_PATH = 'cs_models/EfficientNet_CT_Scans.pth.tar' | |
| DIABETIC_RETINOPATHY_MODEL_PATH = 'cs_models/model_DR_9.pth.tar' | |
| PNEUMONIA_MODEL_PATH = 'cs_models/DenseNet_Pneumonia.pth.tar' | |
| COVID_MODEL_PATH = 'cs_models/DenseNet_Covid.pth.tar' | |
| # Constants for class labels | |
| CANCER_CLASS_LABELS = ['adenocarcinoma','large.cell.carcinoma','normal','squamous.cell.carcinoma'] | |
| DIABETIC_RETINOPATHY_CLASS_LABELS = ['No DR','Mild', 'Moderate', 'Severe', 'Proliferative DR'] | |
| PNEUMONIA_CLASS_LABELS = ['Normal', 'Pneumonia'] | |
| COVID_CLASS_LABELS = ['Normal','Covid19'] | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| def cancer_page(image, test_model): | |
| x_ray_image = CSU.transform_image(image, CSU.val_transform) | |
| x_ray_image = x_ray_image.to(device) | |
| grad_x_ray_image, pred_label, pred_conf = CSU.plot_grad_cam(test_model, | |
| x_ray_image, | |
| CANCER_CLASS_LABELS, | |
| normalized=True) | |
| grad_x_ray_image = np.clip(grad_x_ray_image, 0, 1) | |
| return grad_x_ray_image, pred_label, pred_conf | |
| def covid_page(image, test_model): | |
| x_ray_image = C19U.transform_image(image, C19U.val_transform) | |
| x_ray_image = x_ray_image.to(device) | |
| grad_x_ray_image, pred_label, pred_conf = C19U.plot_grad_cam(test_model, | |
| x_ray_image, | |
| COVID_CLASS_LABELS, | |
| normalized=True) | |
| grad_x_ray_image = np.clip(grad_x_ray_image, 0, 1) | |
| return grad_x_ray_image, pred_label, pred_conf | |
| def pneumonia_page(image, test_model): | |
| x_ray_image = PU.transform_image(image, PU.val_transform) | |
| x_ray_image = x_ray_image.to(device) | |
| grad_x_ray_image, pred_label, pred_conf = PU.plot_grad_cam(test_model, | |
| x_ray_image, | |
| PNEUMONIA_CLASS_LABELS, | |
| normalized=True) | |
| grad_x_ray_image = np.clip(grad_x_ray_image, 0, 1) | |
| return grad_x_ray_image, pred_label, pred_conf | |
| def diabetic_retinopathy_page(image_1, image_2, test_model): | |
| images = DRU.transform_image(image_1, image_2, DRU.val_transform) | |
| pred_label_1, pred_label_2 = DRU.Inf_predict_image(test_model, | |
| images, | |
| DIABETIC_RETINOPATHY_CLASS_LABELS) | |
| return pred_label_1, pred_label_2 | |
| CSU_model = CSU.Efficient().to(device) | |
| CSU_model.load_state_dict(torch.load(CANCER_MODEL_PATH,map_location=torch.device('cpu')),strict=False) | |
| CSU_test_model = CSU.ModelGradCam(CSU_model).to(device) | |
| CSU_images_dir = "TESTS/CHEST_CT_SCANS" | |
| all_images = os.listdir(CSU_images_dir) | |
| CSU_examples = [[os.path.join(CSU_images_dir,image)] for image in np.random.choice(all_images, size=4, replace=False)] | |
| C19U_model = C19U.DenseNet().to(device) | |
| C19U_model.load_state_dict(torch.load(COVID_MODEL_PATH,map_location=torch.device('cpu')),strict=False) | |
| C19U_test_model = C19U.ModelGradCam(C19U_model).to(device) | |
| C19U_C19_images_dir = [[os.path.join("TESTS/COVID19",image)] for image in np.random.choice(os.listdir("TESTS/COVID19"), size=2, replace=False)] | |
| NORM_images_dir = [[os.path.join("TESTS/NORMAL",image)] for image in np.random.choice(os.listdir("TESTS/NORMAL"), size=2, replace=False)] | |
| C19U_examples = C19U_C19_images_dir + NORM_images_dir | |
| PU_model = PU.DenseNet.to(device) | |
| PU_model.load_state_dict(torch.load(PNEUMONIA_MODEL_PATH,map_location=torch.device('cpu')),strict=False) | |
| PU_test_model = PU.ModelGradCam(PU_model).to(device) | |
| PU_images_dir = [[os.path.join("TESTS/PNEUMONIA",image)] for image in np.random.choice(os.listdir("TESTS/PNEUMONIA"), size=2, replace=False)] | |
| NORM_images_dir = [[os.path.join("TESTS/NORMAL",image)] for image in np.random.choice(os.listdir("TESTS/NORMAL"), size=2, replace=False)] | |
| PU_examples = PU_images_dir + NORM_images_dir | |
| DRU_cnn_model = DRU.ConvolutionNeuralNetwork().to(device) | |
| DRU_eff_b3 = DRU.Efficient().to(device) | |
| DRU_ensemble = DRU.EnsembleModel(DRU_cnn_model, DRU_eff_b3).to(device) | |
| DRU_ensemble.load_state_dict(torch.load(DIABETIC_RETINOPATHY_MODEL_PATH,map_location=torch.device('cpu'))["state_dict"], strict=False) | |
| DRU_test_model = DRU_ensemble | |
| DRU_examples = [['TESTS/DR_1/10030_left._aug_0._aug_6.jpeg','TESTS/DR_0/10031_right._aug_17.jpeg']] | |
| demo = gr.Blocks(title="X-RAY_CLASSIFIER") | |
| with demo: | |
| gr.Markdown( | |
| """ # WELCOME, Try Out the X-ray_Classifier Below | |
| Try out the following classification models below.""" | |
| ) | |
| with gr.Tab("Chest Cancer"): | |
| with gr.Row(): | |
| cancer_input = gr.Image(type="pil", label="Image") | |
| cancer_output1 = gr.Image(type="numpy", label="Heatmap Image") | |
| cancer_output2 = gr.Textbox(label="Labels Present") | |
| cancer_output3 = gr.Label(label="Probabilities", show_label=False) | |
| cancer_button = gr.Button("Predict") | |
| cancer_examples = gr.Examples(CSU_examples, inputs=[cancer_input]) | |
| with gr.Tab("Covid19"): | |
| with gr.Row(): | |
| covid_input = gr.Image(type="pil", label="Image") | |
| covid_output1 = gr.Image(type="numpy", label="Heatmap Image") | |
| covid_output2 = gr.Textbox(label="Labels Present") | |
| covid_output3 = gr.Label(label="Probabilities", show_label=False) | |
| covid_button = gr.Button("Predict") | |
| covid_examples = gr.Examples(C19U_examples, inputs=[covid_input]) | |
| with gr.Tab("Pneumonia"): | |
| with gr.Row(): | |
| pneumonia_input = gr.Image(type="pil", label="Image") | |
| pneumonia_output1 = gr.Image(type="numpy", label="Heatmap Image") | |
| pneumonia_output2 = gr.Textbox(label="Labels Present") | |
| pneumonia_output3 = gr.Label(label="Probabilities", show_label=False) | |
| pneumonia_button = gr.Button("Predict") | |
| pneumonia_examples = gr.Examples(PU_examples, inputs=[pneumonia_input]) | |
| with gr.Tab("Diabetic Retinopathy"): | |
| with gr.Row(): | |
| dr_input1 = gr.Image(type="pil", label="Image") | |
| dr_input2 = gr.Image(type="pil", label="Image") | |
| dr_output1 = gr.Textbox(label="Labels Present") | |
| dr_output2 = gr.Textbox(label="Labels Present") | |
| dr_button = gr.Button("Predict") | |
| dr_examples = gr.Examples(DRU_examples, inputs=[dr_input1, dr_input2]) | |
| cancer_button.click(partial(cancer_page, test_model=CSU_test_model), | |
| inputs=cancer_input, | |
| outputs=[cancer_output1, cancer_output2, cancer_output3]) | |
| covid_button.click(partial(covid_page, test_model=C19U_test_model), | |
| inputs=covid_input, | |
| outputs=[covid_output1, covid_output2, covid_output3]) | |
| pneumonia_button.click(partial(pneumonia_page, test_model=PU_test_model), | |
| inputs=pneumonia_input, | |
| outputs=[pneumonia_output1, pneumonia_output2, pneumonia_output3]) | |
| dr_button.click(partial(diabetic_retinopathy_page, | |
| test_model=DRU_test_model), | |
| inputs=[dr_input1, dr_input2], | |
| outputs=[dr_output1, dr_output2]) | |
| if __name__ == "__main__": | |
| demo.launch() | |