ptschandl commited on
Commit
a135607
1 Parent(s): 206b8d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  import itertools
9
  import numpy as np
10
  import os
 
11
 
12
 
13
  args = {
@@ -15,23 +16,26 @@ args = {
15
  "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
16
  "dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"]
17
  }
 
18
 
19
  # No specific normalization was performed during training
20
  def normtransform(x):
21
  return x
22
 
23
  # Load model
 
24
  model = torchvision.models.resnet34()
25
  model.fc = torch.nn.Linear(model.fc.in_features, len(args.get("dxlabels")))
26
  model.load_state_dict(torch.load(args.get("model_path")))
27
  model.eval()
28
  model.to(device=args.get("device"))
29
  torch.set_grad_enabled(False)
30
-
31
 
32
  def predict(image: str)->dict:
33
- global model, args, normtransform
34
-
 
35
  prediction_tensor = torch.zeros([1, len(args.get("dxlabels"))]).to(device=args.get("device"))
36
 
37
  # Test-time augmentations
@@ -61,6 +65,7 @@ def predict(image: str)->dict:
61
 
62
  prediction_tensor /= len(aug_combos)
63
  predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().numpy()
 
64
  return {args.get("dxlabels")[enu]: p for enu, p in enumerate(predictions)}
65
 
66
  description = '''Research artifact for multi-class predictions of common
@@ -71,6 +76,7 @@ For education and research use only.
71
  **DO NOT use this to obtain medical advice!**
72
  If you have a skin change in question, seek contact to your physician.'''
73
 
 
74
  gr.Interface(
75
  predict,
76
  inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
 
8
  import itertools
9
  import numpy as np
10
  import os
11
+ import logging
12
 
13
 
14
  args = {
 
16
  "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
17
  "dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"]
18
  }
19
+ logging.warning(args)
20
 
21
  # No specific normalization was performed during training
22
  def normtransform(x):
23
  return x
24
 
25
  # Load model
26
+ logging.warning("Loading model...")
27
  model = torchvision.models.resnet34()
28
  model.fc = torch.nn.Linear(model.fc.in_features, len(args.get("dxlabels")))
29
  model.load_state_dict(torch.load(args.get("model_path")))
30
  model.eval()
31
  model.to(device=args.get("device"))
32
  torch.set_grad_enabled(False)
33
+ logging.warning("Model loaded.")
34
 
35
  def predict(image: str)->dict:
36
+ global model, args
37
+
38
+ logging.warning(f"Starting predict('{image}') ...")
39
  prediction_tensor = torch.zeros([1, len(args.get("dxlabels"))]).to(device=args.get("device"))
40
 
41
  # Test-time augmentations
 
65
 
66
  prediction_tensor /= len(aug_combos)
67
  predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().numpy()
68
+ logging.warning(f"Returning {predictions=}")
69
  return {args.get("dxlabels")[enu]: p for enu, p in enumerate(predictions)}
70
 
71
  description = '''Research artifact for multi-class predictions of common
 
76
  **DO NOT use this to obtain medical advice!**
77
  If you have a skin change in question, seek contact to your physician.'''
78
 
79
+ logging.warning("Starting Gradio interface...")
80
  gr.Interface(
81
  predict,
82
  inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),