ptschandl commited on
Commit
64f0bb0
β€’
1 Parent(s): a135607

formatting, more examples

Browse files
app.py CHANGED
@@ -13,76 +13,92 @@ import logging
13
 
14
  args = {
15
  "model_path": "model_last_epoch_34_torchvision0_3_state.ptw",
16
- "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
17
- "dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv","vasc"]
18
  }
19
  logging.warning(args)
20
 
 
21
  # No specific normalization was performed during training
22
  def normtransform(x):
23
  return x
24
 
 
25
  # Load model
26
  logging.warning("Loading model...")
27
  model = torchvision.models.resnet34()
28
- model.fc = torch.nn.Linear(model.fc.in_features, len(args.get("dxlabels")))
29
- model.load_state_dict(torch.load(args.get("model_path")))
30
  model.eval()
31
- model.to(device=args.get("device"))
32
  torch.set_grad_enabled(False)
33
  logging.warning("Model loaded.")
34
 
35
- def predict(image: str)->dict:
 
36
  global model, args
37
 
38
  logging.warning(f"Starting predict('{image}') ...")
39
- prediction_tensor = torch.zeros([1, len(args.get("dxlabels"))]).to(device=args.get("device"))
 
 
40
 
41
  # Test-time augmentations
42
  available_sizes = [224]
43
  target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8]
44
  aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)]
45
-
46
  # Load image
47
  img = Image.open(image)
48
- img = img.convert('RGB')
49
 
50
  # Predict with Test-time augmentation
51
- for (target_size, hflip, rotation, crop) in aug_combos:
52
- tfm = transforms.Compose([
53
- transforms.Resize(int(target_size // crop)),
54
- transforms.CenterCrop(target_size),
55
- transforms.RandomHorizontalFlip(hflip),
56
- transforms.RandomRotation([rotation, rotation]),
57
- transforms.ToTensor(),
58
- # shades_of_grey_torch,
59
- normtransform
60
- ])
61
- test_data = tfm(img).unsqueeze(0).to(device=args.get("device"))
62
- running_preds = torch.FloatTensor().to(device=args.get("device"))
63
- outputs = model(test_data)
64
- prediction_tensor += running_preds
65
-
 
66
  prediction_tensor /= len(aug_combos)
67
- predictions = F.softmax(prediction_tensor, dim=1)[0].cpu().numpy()
68
  logging.warning(f"Returning {predictions=}")
69
- return {args.get("dxlabels")[enu]: p for enu, p in enumerate(predictions)}
 
 
 
 
 
 
 
 
 
70
 
71
- description = '''Research artifact for multi-class predictions of common
72
- dermatologic tumors. This is the model used in the publication
73
- [Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0).
74
 
75
- For education and research use only.
76
- **DO NOT use this to obtain medical advice!**
77
- If you have a skin change in question, seek contact to your physician.'''
78
 
79
  logging.warning("Starting Gradio interface...")
80
  gr.Interface(
81
  predict,
82
  inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
83
- outputs=gr.Label(num_top_classes=len(args.get("dxlabels"))),
84
  title="Dermatoscopic classification",
85
  description=description,
86
- allow_flagging="manual",
87
- examples=[os.path.join(os.path.dirname(__file__), "ISIC_0024306.jpg")]
 
 
 
88
  ).launch()
 
13
 
14
  args = {
15
  "model_path": "model_last_epoch_34_torchvision0_3_state.ptw",
16
+ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
17
+ "dxlabels": ["akiec", "bcc", "bkl", "df", "mel", "nv", "vasc"],
18
  }
19
  logging.warning(args)
20
 
21
+
22
  # No specific normalization was performed during training
23
  def normtransform(x):
24
  return x
25
 
26
+
27
  # Load model
28
  logging.warning("Loading model...")
29
  model = torchvision.models.resnet34()
30
+ model.fc = torch.nn.Linear(model.fc.in_features, len(args["dxlabels"]))
31
+ model.load_state_dict(torch.load(args["model_path"]))
32
  model.eval()
33
+ model.to(device=args["device"])
34
  torch.set_grad_enabled(False)
35
  logging.warning("Model loaded.")
36
 
37
+
38
+ def predict(image: str) -> dict:
39
  global model, args
40
 
41
  logging.warning(f"Starting predict('{image}') ...")
42
+ prediction_tensor = torch.zeros([1, len(args["dxlabels"])]).to(
43
+ device=args["device"]
44
+ )
45
 
46
  # Test-time augmentations
47
  available_sizes = [224]
48
  target_sizes, hflips, rotations, crops = available_sizes, [0, 1], [0, 90], [0.8]
49
  aug_combos = [x for x in itertools.product(target_sizes, hflips, rotations, crops)]
50
+
51
  # Load image
52
  img = Image.open(image)
53
+ img = img.convert("RGB")
54
 
55
  # Predict with Test-time augmentation
56
+ for target_size, hflip, rotation, crop in aug_combos:
57
+ tfm = transforms.Compose(
58
+ [
59
+ transforms.Resize(int(target_size // crop)),
60
+ transforms.CenterCrop(target_size),
61
+ transforms.RandomHorizontalFlip(hflip),
62
+ transforms.RandomRotation([rotation, rotation]),
63
+ transforms.ToTensor(),
64
+ normtransform,
65
+ ]
66
+ )
67
+ test_data = tfm(img).unsqueeze(0).to(device=args["device"])
68
+ with torch.no_grad():
69
+ outputs = model(test_data)
70
+ prediction_tensor += outputs
71
+
72
  prediction_tensor /= len(aug_combos)
73
+ predictions = F.softmax(prediction_tensor, dim=1)[0].detach().cpu().tolist()
74
  logging.warning(f"Returning {predictions=}")
75
+ return {args["dxlabels"][enu]: p for enu, p in enumerate(predictions)}
76
+
77
+
78
+ description = """
79
+ Research image classification model for multi-class predictions of common dermatologic tumors, the model was trained
80
+ on the [HAM10000 dataset](https://www.nature.com/articles/sdata2018161).
81
+
82
+ This is the model used in the publication
83
+ [Tschandl P. et al. Nature Medicine 2020](https://www.nature.com/articles/s41591-020-0942-0)
84
+ where human-computer interaction of such a system was analyzed.
85
 
86
+ Instructions for uploading: Ensure the image is not blurred, the lesion is centered and in focus, and no black/white
87
+ vignette is in the surrounding. The image should depict the whole lesion, and not a zoomed-in part.
 
88
 
89
+ For education and research use only. **DO NOT use this to obtain medical advice!**
90
+ If you have a skin change in question, seek contact to your physician."""
 
91
 
92
  logging.warning("Starting Gradio interface...")
93
  gr.Interface(
94
  predict,
95
  inputs=gr.Image(label="Upload a dermatoscopic image", type="filepath"),
96
+ outputs=gr.Label(num_top_classes=len(args["dxlabels"])),
97
  title="Dermatoscopic classification",
98
  description=description,
99
+ allow_flagging="never",
100
+ examples=[
101
+ os.path.join(os.path.dirname(__file__), "images", x)
102
+ for x in ["ISIC_0024306.jpg", "ISIC_0024315.jpg", "ISIC_0024318.jpg"]
103
+ ],
104
  ).launch()
ISIC_0024306.jpg β†’ images/ISIC_0024306.jpg RENAMED
File without changes
images/ISIC_0024315.jpg ADDED
images/ISIC_0024318.jpg ADDED