File size: 3,602 Bytes
e1d994c
 
 
 
 
 
 
 
 
e846bea
a135607
e1d994c
 
 
4765d7e
64f0bb0
 
e1d994c
a135607
e1d994c
64f0bb0
e1d994c
 
 
 
64f0bb0
e1d994c
a135607
e1d994c
64f0bb0
 
e1d994c
64f0bb0
e1d994c
a135607
e1d994c
64f0bb0
 
a135607
 
 
64f0bb0
 
 
e1d994c
 
 
 
 
64f0bb0
e1d994c
206b8d6
64f0bb0
e1d994c
 
64f0bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1d994c
64f0bb0
a135607
64f0bb0
 
 
 
dc5064f
64f0bb0
dc5064f
e1d994c
dc5064f
e1d994c
64f0bb0
dc5064f
e1d994c
a135607
e1d994c
 
206b8d6
64f0bb0
60c1b24
e1d994c
64f0bb0
 
 
 
 
e1d994c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
import gradio as gr
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import itertools
import numpy as np
import os
import logging


args = {
    "model_path": "model_last_epoch_34_torchvision0_3_state.ptw",
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv", "vasc"],
}
logging.warning(args)


# No specific normalization was performed during training
def normtransform(x):
    return x


# Load model
logging.warning("Loading model...")
model = torchvision.models.resnet34()
model.fc = torch.nn.Linear(model.fc.in_features, len(args["dxlabels"]))
model.load_state_dict(torch.load(args["model_path"]))
model.eval()
model.to(device=args["device"])
torch.set_grad_enabled(False)
logging.warning("Model loaded.")


def predict(image: str) -> dict:
    global model, args

    logging.warning(f"Starting predict('{image}') ...")
    prediction_tensor = torch.zeros([1, len(args["dxlabels"])]).to(
        device=args["device"]
    )

    # Test-time augmentations
    available_sizes = [224]
    target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8]
    aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)]

    # Load image
    img = Image.open(image)
    img = img.convert("RGB")

    # Predict with Test-time augmentation
    for target_size, hflip, rotation, crop in aug_combos:
        tfm = transforms.Compose(
            [
                transforms.Resize(int(target_size // crop)),
                transforms.CenterCrop(target_size),
                transforms.RandomHorizontalFlip(hflip),
                transforms.RandomRotation([rotation, rotation]),
                transforms.ToTensor(),
                normtransform,
            ]
        )
        test_data = tfm(img).unsqueeze(0).to(device=args["device"])
        with torch.no_grad():
            outputs = model(test_data)
        prediction_tensor += outputs

    prediction_tensor /= len(aug_combos)
    predictions = F.softmax(prediction_tensor, dim=1)[0].detach().cpu().tolist()
    logging.warning(f"Returning {predictions=}")
    return {args["dxlabels"][enu]: p for enu, p in enumerate(predictions)}


description = """
Research image classification model for multi-class predictions of common dermatologic tumors, the model was trained on the [HAM10000 dataset](https://www.nature.com/articles/sdata2018161).

This is the model used in the publication [Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0) where human-computer interaction of such a system was analyzed.

Instructions for uploading: Ensure the image is not blurred, the lesion is centered and in focus, and no black/white vignette is in the surrounding. The image should depict the whole lesion, and not a zoomed-in part.

For education and research use only. **DO NOT use this to obtain medical advice!**
If you have a skin change in question, seek contact to a health care professional."""

logging.warning("Starting Gradio interface...")
gr.Interface(
    predict,
    inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
    outputs=gr.Label(num_top_classes=len(args["dxlabels"])),
    title="Dermatoscopic classification",
    description=description,
    allow_flagging="never",
    examples=[
        os.path.join(os.path.dirname(__file__), "images", x)
        for x in ["ISIC_0024306.jpg", "ISIC_0024315.jpg", "ISIC_0024318.jpg"]
    ],
).launch()