import gradio as gr import torch import torch.nn.functional as F import torchvision import torchvision.transforms import torchxrayvision as xrv def classify_image(img, model_name): model = xrv.models.get_model(model_name, from_hf_hub=True) img = xrv.datasets.normalize(img, 255) # Check that images are 2D arrays if len(img.shape) > 2: img = img[:, :, 0] if len(img.shape) < 2: print("error, dimension lower than 2 for image") # Add color channel img = img[None, :, :] transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()]) img = transform(img) with torch.no_grad(): img = torch.from_numpy(img).unsqueeze(0) preds = model(img).cpu() output = { k: float(v) for k, v in zip(xrv.datasets.default_pathologies, preds[0].detach().numpy()) } return output gr.Interface( fn=classify_image, inputs=[ gr.Image(shape=(224, 224), image_mode="L"), gr.Dropdown( [ "densenet121-res224-all", "densenet121-res224-nih", "densenet121-res224-pc", "densenet121-res224-chex", "densenet121-res224-rsna", "densenet121-res224-mimic_nb", "densenet121-res224-mimic_ch", "resnet50-res512-all", ], value="densenet121-res224-all", type="value", label="Pre-trained model", ), ], outputs=gr.outputs.Label(), title="Classify chest x-ray image", examples=[ ["16747_3_1.jpg", "densenet121-res224-all"], ["00000001_000.png", "resnet50-res512-all"], ], ).launch()