Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms | |
import torchxrayvision as xrv | |
def classify_image(img, model_name): | |
model = xrv.models.get_model(model_name, from_hf_hub=True) | |
img = xrv.datasets.normalize(img, 255) | |
# Check that images are 2D arrays | |
if len(img.shape) > 2: | |
img = img[:, :, 0] | |
if len(img.shape) < 2: | |
print("error, dimension lower than 2 for image") | |
# Add color channel | |
img = img[None, :, :] | |
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()]) | |
img = transform(img) | |
with torch.no_grad(): | |
img = torch.from_numpy(img).unsqueeze(0) | |
preds = model(img).cpu() | |
output = { | |
k: float(v) | |
for k, v in zip(xrv.datasets.default_pathologies, preds[0].detach().numpy()) | |
} | |
return output | |
gr.Interface( | |
fn=classify_image, | |
inputs=[ | |
gr.Image(shape=(224, 224), image_mode="L"), | |
gr.Dropdown( | |
[ | |
"densenet121-res224-all", | |
"densenet121-res224-nih", | |
"densenet121-res224-pc", | |
"densenet121-res224-chex", | |
"densenet121-res224-rsna", | |
"densenet121-res224-mimic_nb", | |
"densenet121-res224-mimic_ch", | |
"resnet50-res512-all", | |
], | |
value="densenet121-res224-all", | |
type="value", | |
label="Pre-trained model", | |
), | |
], | |
outputs=gr.outputs.Label(), | |
title="Classify chest x-ray image", | |
examples=[ | |
["16747_3_1.jpg", "densenet121-res224-all"], | |
["00000001_000.png", "resnet50-res512-all"], | |
], | |
).launch() | |