ptschandl commited on
Commit
e1d994c
1 Parent(s): 6bfeebe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import gradio as gr
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms, models
8
+ from PIL import Image
9
+ import itertools
10
+ import numpy as np
11
+
12
+
13
+ args = {
14
+ "model_path": "model_last_epoch_34_torchvision0_3_state.ptw"
15
+ }
16
+ args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+ # Get classes
19
+ dxlabels = ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"]
20
+
21
+ # No specific normalization was performed during training
22
+ def normtransform(x):
23
+ return x
24
+
25
+ # Load model
26
+ model = torchvision.models.resnet34()
27
+ model.fc = torch.nn.Linear(model.fc.in_features, len(dxlabels))
28
+ model.load_state_dict(torch.load(args.model_path))
29
+ model.eval()
30
+ model.to(device=args.device)
31
+ torch.set_grad_enabled(False)
32
+
33
+
34
+ def predict(image)->dict:
35
+ global model, dxlabels, args, normtransform
36
+
37
+ prediction_tensor = torch.zeros([1, len(dxlabels)]).to(device=args.device)
38
+
39
+ # Test-time augmentations
40
+ available_sizes = [224]
41
+ target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8]
42
+ aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)]
43
+
44
+ # Load image
45
+ img = Image.open(image)
46
+ img = img.convert('RGB')
47
+
48
+ # Predict with Test-time augmentation
49
+ for (target_size, hflip, rotation, crop) in tqdm(aug_combos, leave=True):
50
+ tfm = transforms.Compose([
51
+ transforms.Resize(int(target_size // crop)),
52
+ transforms.CenterCrop(target_size),
53
+ transforms.RandomHorizontalFlip(hflip),
54
+ transforms.RandomRotation([rotation, rotation]),
55
+ transforms.ToTensor(),
56
+ # shades_of_grey_torch,
57
+ normtransform
58
+ ])
59
+ test_data = tfm(img).unsqueeze(0).to(device=args.device)
60
+ running_preds = torch.FloatTensor().to(device=args.device)
61
+ outputs = model(test_data)
62
+ prediction_tensor += running_preds
63
+
64
+ prediction_tensor /= len(aug_combos)
65
+ predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().numpy()
66
+ return {dxlabels[enu]: p for enu, p in enumerate(predictions)}
67
+
68
+ description = """Research artifact for multi-class predictions of common
69
+ dermatologic tumors. This is the model used in the publication
70
+ [Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0).
71
+
72
+ For education and research use only.
73
+ **DO NOT use this to obtain medical advice!**
74
+ If you have a skin change in question, seek contact to your physician.
75
+ """
76
+
77
+ gr.Interface(
78
+ predict,
79
+ inputs=gr.inputs.Image(label="Upload a dermatoscopic image", type="filepath"),
80
+ outputs=gr.outputs.Label(num_top_classes=len(dxlabels)),
81
+ title="Dermatoscopic evaluation",
82
+ description=description,
83
+ allow_flagging="manual"
84
+ ).launch()