haritsahm commited on
Commit
a1f1417
1 Parent(s): 18ab064

Update codes to enable device selection

Browse files
Files changed (2) hide show
  1. app.py +29 -5
  2. configs/inference.json +1 -1
app.py CHANGED
@@ -1,11 +1,24 @@
 
 
1
  from pathlib import Path
2
- import torch
3
- from monai.bundle import ConfigParser
4
  import gradio as gr
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  parser = ConfigParser()
8
- parser.read_config(f="configs/inference.json")
9
  parser.read_meta(f="configs/metadata.json")
10
 
11
  inference = parser.get_parsed_content("inferer")
@@ -14,9 +27,17 @@ network = parser.get_parsed_content("network_def")
14
  preprocess = parser.get_parsed_content("preprocessing")
15
  postprocess = parser.get_parsed_content("postprocessing")
16
 
 
 
17
  state_dict = torch.load("models/model.pt")
18
  network.load_state_dict(state_dict, strict=True)
19
 
 
 
 
 
 
 
20
  label2color = {0: (0, 0, 0),
21
  1: (225, 24, 69), # RED
22
  2: (135, 233, 17), # GREEN
@@ -38,8 +59,11 @@ def visualize_instance_seg_mask(mask):
38
  def query_image(img):
39
  data = {"image": img}
40
  batch = preprocess(data)
 
 
 
 
41
 
42
- network.eval()
43
  with torch.no_grad():
44
  pred = inference(batch['image'].unsqueeze(dim=0), network)
45
 
@@ -65,7 +89,7 @@ with open('Description.md','r') as file:
65
  markdown_content = file.read()
66
 
67
  demo = gr.Interface(
68
- query_image,
69
  inputs=[gr.Image(type="filepath")],
70
  outputs="image",
71
  title="Medical Image Classification with MONAI - Pathology Nuclei Segmentation Classification",
 
1
+ import json
2
+ import os
3
  from pathlib import Path
4
+
 
5
  import gradio as gr
6
  import numpy as np
7
+ import torch
8
+ from monai.bundle import ConfigParser
9
+
10
+ with open("configs/inference.json") as f:
11
+ inference_config = json.load(f)
12
+
13
+ device = torch.device('cpu')
14
+ if torch.cuda.is_available():
15
+ device = torch.device('cuda:0')
16
+
17
+ # * NOTE: device must be hardcoded, config file won't affect the device selection
18
+ inference_config["device"] = device
19
 
20
  parser = ConfigParser()
21
+ parser.read_config(f=inference_config)
22
  parser.read_meta(f="configs/metadata.json")
23
 
24
  inference = parser.get_parsed_content("inferer")
 
27
  preprocess = parser.get_parsed_content("preprocessing")
28
  postprocess = parser.get_parsed_content("postprocessing")
29
 
30
+ use_fp16 = os.environ.get('USE_FP16', False)
31
+
32
  state_dict = torch.load("models/model.pt")
33
  network.load_state_dict(state_dict, strict=True)
34
 
35
+ network = network.to(device)
36
+ network.eval()
37
+
38
+ if use_fp16 and torch.cuda.is_available():
39
+ network = network.half()
40
+
41
  label2color = {0: (0, 0, 0),
42
  1: (225, 24, 69), # RED
43
  2: (135, 233, 17), # GREEN
 
59
  def query_image(img):
60
  data = {"image": img}
61
  batch = preprocess(data)
62
+ batch['image'] = batch['image'].to(device)
63
+
64
+ if use_fp16 and torch.cuda.is_available():
65
+ batch['image'] = batch['image'].half()
66
 
 
67
  with torch.no_grad():
68
  pred = inference(batch['image'].unsqueeze(dim=0), network)
69
 
 
89
  markdown_content = file.read()
90
 
91
  demo = gr.Interface(
92
+ query_image,
93
  inputs=[gr.Image(type="filepath")],
94
  outputs="image",
95
  title="Medical Image Classification with MONAI - Pathology Nuclei Segmentation Classification",
configs/inference.json CHANGED
@@ -12,7 +12,7 @@
12
  "hovernet_mode": "fast",
13
  "patch_size": 256,
14
  "out_size": 164,
15
- "device": "cpu",
16
  "network_def": {
17
  "_target_": "HoVerNet",
18
  "mode": "@hovernet_mode",
 
12
  "hovernet_mode": "fast",
13
  "patch_size": 256,
14
  "out_size": 164,
15
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
16
  "network_def": {
17
  "_target_": "HoVerNet",
18
  "mode": "@hovernet_mode",