Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
|
|
8 |
import itertools
|
9 |
import numpy as np
|
10 |
import os
|
|
|
11 |
|
12 |
|
13 |
args = {
|
@@ -15,23 +16,26 @@ args = {
|
|
15 |
"device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
|
16 |
"dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"]
|
17 |
}
|
|
|
18 |
|
19 |
# No specific normalization was performed during training
|
20 |
def normtransform(x):
|
21 |
return x
|
22 |
|
23 |
# Load model
|
|
|
24 |
model = torchvision.models.resnet34()
|
25 |
model.fc = torch.nn.Linear(model.fc.in_features, len(args.get("dxlabels")))
|
26 |
model.load_state_dict(torch.load(args.get("model_path")))
|
27 |
model.eval()
|
28 |
model.to(device=args.get("device"))
|
29 |
torch.set_grad_enabled(False)
|
30 |
-
|
31 |
|
32 |
def predict(image: str)->dict:
|
33 |
-
global model, args
|
34 |
-
|
|
|
35 |
prediction_tensor = torch.zeros([1, len(args.get("dxlabels"))]).to(device=args.get("device"))
|
36 |
|
37 |
# Test-time augmentations
|
@@ -61,6 +65,7 @@ def predict(image: str)->dict:
|
|
61 |
|
62 |
prediction_tensor /= len(aug_combos)
|
63 |
predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().numpy()
|
|
|
64 |
return {args.get("dxlabels")[enu]: p for enu, p in enumerate(predictions)}
|
65 |
|
66 |
description = '''Research artifact for multi-class predictions of common
|
@@ -71,6 +76,7 @@ For education and research use only.
|
|
71 |
**DO NOT use this to obtain medical advice!**
|
72 |
If you have a skin change in question, seek contact to your physician.'''
|
73 |
|
|
|
74 |
gr.Interface(
|
75 |
predict,
|
76 |
inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
|
|
|
8 |
import itertools
|
9 |
import numpy as np
|
10 |
import os
|
11 |
+
import logging
|
12 |
|
13 |
|
14 |
args = {
|
|
|
16 |
"device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
|
17 |
"dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"]
|
18 |
}
|
19 |
+
logging.warning(args)
|
20 |
|
21 |
# No specific normalization was performed during training
|
22 |
def normtransform(x):
|
23 |
return x
|
24 |
|
25 |
# Load model
|
26 |
+
logging.warning("Loading model...")
|
27 |
model = torchvision.models.resnet34()
|
28 |
model.fc = torch.nn.Linear(model.fc.in_features, len(args.get("dxlabels")))
|
29 |
model.load_state_dict(torch.load(args.get("model_path")))
|
30 |
model.eval()
|
31 |
model.to(device=args.get("device"))
|
32 |
torch.set_grad_enabled(False)
|
33 |
+
logging.warning("Model loaded.")
|
34 |
|
35 |
def predict(image: str)->dict:
|
36 |
+
global model, args
|
37 |
+
|
38 |
+
logging.warning(f"Starting predict('{image}') ...")
|
39 |
prediction_tensor = torch.zeros([1, len(args.get("dxlabels"))]).to(device=args.get("device"))
|
40 |
|
41 |
# Test-time augmentations
|
|
|
65 |
|
66 |
prediction_tensor /= len(aug_combos)
|
67 |
predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().numpy()
|
68 |
+
logging.warning(f"Returning {predictions=}")
|
69 |
return {args.get("dxlabels")[enu]: p for enu, p in enumerate(predictions)}
|
70 |
|
71 |
description = '''Research artifact for multi-class predictions of common
|
|
|
76 |
**DO NOT use this to obtain medical advice!**
|
77 |
If you have a skin change in question, seek contact to your physician.'''
|
78 |
|
79 |
+
logging.warning("Starting Gradio interface...")
|
80 |
gr.Interface(
|
81 |
predict,
|
82 |
inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
|