fadindashfr commited on
Commit
fbe0f24
·
1 Parent(s): 4c5329d

fix RuntimeError 'cpu'

Browse files
Files changed (1) hide show
  1. app.py +20 -17
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from monai.bundle import ConfigParser
3
  import gradio as gr
 
4
 
5
  parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
6
  parser.read_config(f="configs/inference.json") # read the config from specified JSON file
@@ -9,8 +10,9 @@ parser.read_meta(f="configs/metadata.json") # read the metadata from specified J
9
  inference = parser.get_parsed_content("inferer")
10
  network = parser.get_parsed_content("network_def")
11
  preprocess = parser.get_parsed_content("preprocessing")
12
- state_dict = torch.load("models/model.pt")
13
  network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
 
14
  class_names = {
15
  0: "Other",
16
  1: "Inflammatory",
@@ -21,6 +23,7 @@ class_names = {
21
  def classify_image(image_file, label_file):
22
  data = {"image":image_file, "label":label_file}
23
  batch = preprocess(data)
 
24
  network.eval()
25
  with torch.no_grad():
26
  pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
@@ -29,25 +32,25 @@ def classify_image(image_file, label_file):
29
  return confidences
30
 
31
  example_files1 = [
32
- [r'sample_data\Images\test_11_2_0628.png',
33
- r'sample_data\Labels\test_11_2_0628.png'],
34
- [r'sample_data\Images\test_9_4_0149.png',
35
- r'sample_data\Labels\test_9_4_0149.png'],
36
- [r'sample_data\Images\test_12_3_0292.png',
37
- r'sample_data\Labels\test_12_3_0292.png'],
38
- [r'sample_data\Images\test_9_4_0019.png',
39
- r'sample_data\Labels\test_9_4_0019.png']
40
  ]
41
 
42
  example_files2 = [
43
- [r'sample_data\Images\test_14_3_0433.png',
44
- r'sample_data\Labels\test_14_3_0433.png'],
45
- [r'sample_data\Images\test_14_4_0544.png',
46
- r'sample_data\Labels\test_14_4_0544.png'],
47
- [r'sample_data\Images\train_1_1_0095.png',
48
- r'sample_data\Labels\train_1_1_0095.png'],
49
- [r'sample_data\Images\train_1_3_0020.png',
50
- r'sample_data\Labels\train_1_3_0020.png'],
51
  ]
52
 
53
  with open('Description.md','r') as file:
 
1
  import torch
2
  from monai.bundle import ConfigParser
3
  import gradio as gr
4
+ import json
5
 
6
  parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
7
  parser.read_config(f="configs/inference.json") # read the config from specified JSON file
 
10
  inference = parser.get_parsed_content("inferer")
11
  network = parser.get_parsed_content("network_def")
12
  preprocess = parser.get_parsed_content("preprocessing")
13
+ state_dict = torch.load("models/model.pt", map_location=torch.device('cpu'))
14
  network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
15
+
16
  class_names = {
17
  0: "Other",
18
  1: "Inflammatory",
 
23
  def classify_image(image_file, label_file):
24
  data = {"image":image_file, "label":label_file}
25
  batch = preprocess(data)
26
+ batch['image'] = batch['image']
27
  network.eval()
28
  with torch.no_grad():
29
  pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
 
32
  return confidences
33
 
34
  example_files1 = [
35
+ ['sample_data/Images/test_11_2_0628.png',
36
+ 'sample_data/Labels/test_11_2_0628.png'],
37
+ ['sample_data/Images/test_9_4_0149.png',
38
+ 'sample_data/Labels/test_9_4_0149.png'],
39
+ ['sample_data/Images/test_12_3_0292.png',
40
+ 'sample_data/Labels/test_12_3_0292.png'],
41
+ ['sample_data/Images/test_9_4_0019.png',
42
+ 'sample_data/Labels/test_9_4_0019.png']
43
  ]
44
 
45
  example_files2 = [
46
+ ['sample_data/Images/test_14_3_0433.png',
47
+ 'sample_data/Labels/test_14_3_0433.png'],
48
+ ['sample_data/Images/test_14_4_0544.png',
49
+ 'sample_data/Labels/test_14_4_0544.png'],
50
+ ['sample_data/Images/train_1_1_0095.png',
51
+ 'sample_data/Labels/train_1_1_0095.png'],
52
+ ['sample_data/Images/train_1_3_0020.png',
53
+ 'sample_data/Labels/train_1_3_0020.png'],
54
  ]
55
 
56
  with open('Description.md','r') as file: