|
|
|
import gradio as gr |
|
import torch |
|
import torchvision |
|
import torch.nn.functional as F |
|
from torchvision import transforms, models |
|
from PIL import Image |
|
import itertools |
|
import numpy as np |
|
import os |
|
import logging |
|
|
|
|
|
args = { |
|
"model_path": "model_last_epoch_34_torchvision0_3_state.ptw", |
|
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
|
"dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv", "vasc"], |
|
} |
|
logging.warning(args) |
|
|
|
|
|
|
|
def normtransform(x): |
|
return x |
|
|
|
|
|
|
|
logging.warning("Loading model...") |
|
model = torchvision.models.resnet34() |
|
model.fc = torch.nn.Linear(model.fc.in_features, len(args["dxlabels"])) |
|
model.load_state_dict(torch.load(args["model_path"])) |
|
model.eval() |
|
model.to(device=args["device"]) |
|
torch.set_grad_enabled(False) |
|
logging.warning("Model loaded.") |
|
|
|
|
|
def predict(image: str) -> dict: |
|
global model, args |
|
|
|
logging.warning(f"Starting predict('{image}') ...") |
|
prediction_tensor = torch.zeros([1, len(args["dxlabels"])]).to( |
|
device=args["device"] |
|
) |
|
|
|
|
|
available_sizes = [224] |
|
target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8] |
|
aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)] |
|
|
|
|
|
img = Image.open(image) |
|
img = img.convert("RGB") |
|
|
|
|
|
for target_size, hflip, rotation, crop in aug_combos: |
|
tfm = transforms.Compose( |
|
[ |
|
transforms.Resize(int(target_size // crop)), |
|
transforms.CenterCrop(target_size), |
|
transforms.RandomHorizontalFlip(hflip), |
|
transforms.RandomRotation([rotation, rotation]), |
|
transforms.ToTensor(), |
|
normtransform, |
|
] |
|
) |
|
test_data = tfm(img).unsqueeze(0).to(device=args["device"]) |
|
with torch.no_grad(): |
|
outputs = model(test_data) |
|
prediction_tensor += outputs |
|
|
|
prediction_tensor /= len(aug_combos) |
|
predictions = F.softmax(prediction_tensor, dim=1)[0].detach().cpu().tolist() |
|
logging.warning(f"Returning {predictions=}") |
|
return {args["dxlabels"][enu]: p for enu, p in enumerate(predictions)} |
|
|
|
|
|
description = """ |
|
Research image classification model for multi-class predictions of common dermatologic tumors, the model was trained on the [HAM10000 dataset](https://www.nature.com/articles/sdata2018161). |
|
|
|
This is the model used in the publication [Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0) where human-computer interaction of such a system was analyzed. |
|
|
|
Instructions for uploading: Ensure the image is not blurred, the lesion is centered and in focus, and no black/white vignette is in the surrounding. The image should depict the whole lesion, and not a zoomed-in part. |
|
|
|
For education and research use only. **DO NOT use this to obtain medical advice!** |
|
If you have a skin change in question, seek contact to a health care professional.""" |
|
|
|
logging.warning("Starting Gradio interface...") |
|
gr.Interface( |
|
predict, |
|
inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"), |
|
outputs=gr.Label(num_top_classes=len(args["dxlabels"])), |
|
title="Dermatoscopic classification", |
|
description=description, |
|
allow_flagging="never", |
|
examples=[ |
|
os.path.join(os.path.dirname(__file__), "images", x) |
|
for x in ["ISIC_0024306.jpg", "ISIC_0024315.jpg", "ISIC_0024318.jpg"] |
|
], |
|
).launch() |
|
|