Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| import torchvision.transforms | |
| import torchxrayvision as xrv | |
| def classify_image(img, model_name): | |
| model = xrv.models.get_model(model_name, from_hf_hub=True) | |
| img = xrv.datasets.normalize(img, 255) | |
| # Check that images are 2D arrays | |
| if len(img.shape) > 2: | |
| img = img[:, :, 0] | |
| if len(img.shape) < 2: | |
| print("error, dimension lower than 2 for image") | |
| # Add color channel | |
| img = img[None, :, :] | |
| transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()]) | |
| img = transform(img) | |
| with torch.no_grad(): | |
| img = torch.from_numpy(img).unsqueeze(0) | |
| preds = model(img).cpu() | |
| output = { | |
| k: float(v) | |
| for k, v in zip(xrv.datasets.default_pathologies, preds[0].detach().numpy()) | |
| } | |
| return output | |
| gr.Interface( | |
| fn=classify_image, | |
| inputs=[ | |
| gr.Image(shape=(224, 224), image_mode="L"), | |
| gr.Dropdown( | |
| [ | |
| "densenet121-res224-all", | |
| "densenet121-res224-nih", | |
| "densenet121-res224-pc", | |
| "densenet121-res224-chex", | |
| "densenet121-res224-rsna", | |
| "densenet121-res224-mimic_nb", | |
| "densenet121-res224-mimic_ch", | |
| "resnet50-res512-all", | |
| ], | |
| value="densenet121-res224-all", | |
| type="value", | |
| label="Pre-trained model", | |
| ), | |
| ], | |
| outputs=gr.outputs.Label(), | |
| title="Classify chest x-ray image", | |
| examples=[ | |
| ["16747_3_1.jpg", "densenet121-res224-all"], | |
| ["00000001_000.png", "resnet50-res512-all"], | |
| ], | |
| ).launch() | |