formatting, more examples
Browse files- app.py +51 -35
- ISIC_0024306.jpg β images/ISIC_0024306.jpg +0 -0
- images/ISIC_0024315.jpg +0 -0
- images/ISIC_0024318.jpg +0 -0
app.py
CHANGED
@@ -13,76 +13,92 @@ import logging
|
|
13 |
|
14 |
args = {
|
15 |
"model_path": "model_last_epoch_34_torchvision0_3_state.ptw",
|
16 |
-
"device": torch.device(
|
17 |
-
"dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"]
|
18 |
}
|
19 |
logging.warning(args)
|
20 |
|
|
|
21 |
# No specific normalization was performed during training
|
22 |
def normtransform(x):
|
23 |
return x
|
24 |
|
|
|
25 |
# Load model
|
26 |
logging.warning("Loading model...")
|
27 |
model = torchvision.models.resnet34()
|
28 |
-
model.fc = torch.nn.Linear(model.fc.in_features, len(args
|
29 |
-
model.load_state_dict(torch.load(args
|
30 |
model.eval()
|
31 |
-
model.to(device=args
|
32 |
torch.set_grad_enabled(False)
|
33 |
logging.warning("Model loaded.")
|
34 |
|
35 |
-
|
|
|
36 |
global model, args
|
37 |
|
38 |
logging.warning(f"Starting predict('{image}') ...")
|
39 |
-
prediction_tensor = torch.zeros([1, len(args
|
|
|
|
|
40 |
|
41 |
# Test-time augmentations
|
42 |
available_sizes = [224]
|
43 |
target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8]
|
44 |
aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)]
|
45 |
-
|
46 |
# Load image
|
47 |
img = Image.open(image)
|
48 |
-
img = img.convert(
|
49 |
|
50 |
# Predict with Test-time augmentation
|
51 |
-
for
|
52 |
-
tfm = transforms.Compose(
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
prediction_tensor /= len(aug_combos)
|
67 |
-
predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().
|
68 |
logging.warning(f"Returning {predictions=}")
|
69 |
-
return {args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
[Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0).
|
74 |
|
75 |
-
For education and research use only.
|
76 |
-
|
77 |
-
If you have a skin change in question, seek contact to your physician.'''
|
78 |
|
79 |
logging.warning("Starting Gradio interface...")
|
80 |
gr.Interface(
|
81 |
predict,
|
82 |
inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
|
83 |
-
outputs=gr.Label(num_top_classes=len(args
|
84 |
title="Dermatoscopic classification",
|
85 |
description=description,
|
86 |
-
allow_flagging="
|
87 |
-
examples=[
|
|
|
|
|
|
|
88 |
).launch()
|
|
|
13 |
|
14 |
args = {
|
15 |
"model_path": "model_last_epoch_34_torchvision0_3_state.ptw",
|
16 |
+
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
17 |
+
"dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv", "vasc"],
|
18 |
}
|
19 |
logging.warning(args)
|
20 |
|
21 |
+
|
22 |
# No specific normalization was performed during training
|
23 |
def normtransform(x):
|
24 |
return x
|
25 |
|
26 |
+
|
27 |
# Load model
|
28 |
logging.warning("Loading model...")
|
29 |
model = torchvision.models.resnet34()
|
30 |
+
model.fc = torch.nn.Linear(model.fc.in_features, len(args["dxlabels"]))
|
31 |
+
model.load_state_dict(torch.load(args["model_path"]))
|
32 |
model.eval()
|
33 |
+
model.to(device=args["device"])
|
34 |
torch.set_grad_enabled(False)
|
35 |
logging.warning("Model loaded.")
|
36 |
|
37 |
+
|
38 |
+
def predict(image: str) -> dict:
|
39 |
global model, args
|
40 |
|
41 |
logging.warning(f"Starting predict('{image}') ...")
|
42 |
+
prediction_tensor = torch.zeros([1, len(args["dxlabels"])]).to(
|
43 |
+
device=args["device"]
|
44 |
+
)
|
45 |
|
46 |
# Test-time augmentations
|
47 |
available_sizes = [224]
|
48 |
target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8]
|
49 |
aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)]
|
50 |
+
|
51 |
# Load image
|
52 |
img = Image.open(image)
|
53 |
+
img = img.convert("RGB")
|
54 |
|
55 |
# Predict with Test-time augmentation
|
56 |
+
for target_size, hflip, rotation, crop in aug_combos:
|
57 |
+
tfm = transforms.Compose(
|
58 |
+
[
|
59 |
+
transforms.Resize(int(target_size // crop)),
|
60 |
+
transforms.CenterCrop(target_size),
|
61 |
+
transforms.RandomHorizontalFlip(hflip),
|
62 |
+
transforms.RandomRotation([rotation, rotation]),
|
63 |
+
transforms.ToTensor(),
|
64 |
+
normtransform,
|
65 |
+
]
|
66 |
+
)
|
67 |
+
test_data = tfm(img).unsqueeze(0).to(device=args["device"])
|
68 |
+
with torch.no_grad():
|
69 |
+
outputs = model(test_data)
|
70 |
+
prediction_tensor += outputs
|
71 |
+
|
72 |
prediction_tensor /= len(aug_combos)
|
73 |
+
predictions = F.softmax(prediction_tensor, dim=1)[0].detach().cpu().tolist()
|
74 |
logging.warning(f"Returning {predictions=}")
|
75 |
+
return {args["dxlabels"][enu]: p for enu, p in enumerate(predictions)}
|
76 |
+
|
77 |
+
|
78 |
+
description = """
|
79 |
+
Research image classification model for multi-class predictions of common dermatologic tumors, the model was trained
|
80 |
+
on the [HAM10000 dataset](https://www.nature.com/articles/sdata2018161).
|
81 |
+
|
82 |
+
This is the model used in the publication
|
83 |
+
[Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0)
|
84 |
+
where human-computer interaction of such a system was analyzed.
|
85 |
|
86 |
+
Instructions for uploading: Ensure the image is not blurred, the lesion is centered and in focus, and no black/white
|
87 |
+
vignette is in the surrounding. The image should depict the whole lesion, and not a zoomed-in part.
|
|
|
88 |
|
89 |
+
For education and research use only. **DO NOT use this to obtain medical advice!**
|
90 |
+
If you have a skin change in question, seek contact to your physician."""
|
|
|
91 |
|
92 |
logging.warning("Starting Gradio interface...")
|
93 |
gr.Interface(
|
94 |
predict,
|
95 |
inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
|
96 |
+
outputs=gr.Label(num_top_classes=len(args["dxlabels"])),
|
97 |
title="Dermatoscopic classification",
|
98 |
description=description,
|
99 |
+
allow_flagging="never",
|
100 |
+
examples=[
|
101 |
+
os.path.join(os.path.dirname(__file__), "images", x)
|
102 |
+
for x in ["ISIC_0024306.jpg", "ISIC_0024315.jpg", "ISIC_0024318.jpg"]
|
103 |
+
],
|
104 |
).launch()
|
ISIC_0024306.jpg β images/ISIC_0024306.jpg
RENAMED
File without changes
|
images/ISIC_0024315.jpg
ADDED
![]() |
images/ISIC_0024318.jpg
ADDED
![]() |