ptschandl's picture
Update app.py
206b8d6
raw
history blame
No virus
2.91 kB
#!/usr/bin/env python3
import gradio as gr
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import itertools
import numpy as np
import os
args = {
"model_path": "model_last_epoch_34_torchvision0_3_state.ptw",
"device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
"dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"]
}
# No specific normalization was performed during training
def normtransform(x):
return x
# Load model
model = torchvision.models.resnet34()
model.fc = torch.nn.Linear(model.fc.in_features, len(args.get("dxlabels")))
model.load_state_dict(torch.load(args.get("model_path")))
model.eval()
model.to(device=args.get("device"))
torch.set_grad_enabled(False)
def predict(image: str)->dict:
global model, args, normtransform
prediction_tensor = torch.zeros([1, len(args.get("dxlabels"))]).to(device=args.get("device"))
# Test-time augmentations
available_sizes = [224]
target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8]
aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)]
# Load image
img = Image.open(image)
img = img.convert('RGB')
# Predict with Test-time augmentation
for (target_size, hflip, rotation, crop) in aug_combos:
tfm = transforms.Compose([
transforms.Resize(int(target_size // crop)),
transforms.CenterCrop(target_size),
transforms.RandomHorizontalFlip(hflip),
transforms.RandomRotation([rotation, rotation]),
transforms.ToTensor(),
# shades_of_grey_torch,
normtransform
])
test_data = tfm(img).unsqueeze(0).to(device=args.get("device"))
running_preds = torch.FloatTensor().to(device=args.get("device"))
outputs = model(test_data)
prediction_tensor += running_preds
prediction_tensor /= len(aug_combos)
predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().numpy()
return {args.get("dxlabels")[enu]: p for enu, p in enumerate(predictions)}
description = '''Research artifact for multi-class predictions of common
dermatologic tumors. This is the model used in the publication
[Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0).
For education and research use only.
**DO NOT use this to obtain medical advice!**
If you have a skin change in question, seek contact to your physician.'''
gr.Interface(
predict,
inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
outputs=gr.Label(num_top_classes=len(args.get("dxlabels"))),
title="Dermatoscopic classification",
description=description,
allow_flagging="manual",
examples=[os.path.join(os.path.dirname(__file__), "ISIC_0024306.jpg")]
).launch()