osbm commited on
Commit
c7f853f
1 Parent(s): 9ddc211

add it to compose

Browse files
Files changed (1) hide show
  1. inference.py +7 -5
inference.py CHANGED
@@ -55,7 +55,8 @@ transforms = Compose(
55
  EnsureTyped(keys=keys),
56
  ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0),
57
  ConcatItemsd(keys=("t2_anatomy_reader1"), name=CommonKeys.LABEL, dim=0),
58
- ]
 
59
  )
60
 
61
 
@@ -68,7 +69,8 @@ postprocessing = Compose(
68
  keys=CommonKeys.PRED,
69
  applied_labels=list(range(1, 3))
70
  ),
71
- ]
 
72
  )
73
  inferer = monai.inferers.SlidingWindowInferer(
74
  roi_size=(96, 96, 96),
@@ -133,19 +135,19 @@ def make_inference(data_dict:list) -> str:
133
  test_ds = Dataset(
134
  data=data_dict,
135
  transform=transforms,
136
- allow_missing_keys=True,
137
  )
138
  model.eval()
139
  with torch.no_grad():
140
  example = test_ds[0]
141
- label = example["t2_anatomy_reader1"]
142
  input_tensor = example["t2"].unsqueeze(0)
143
  input_tensor = input_tensor.to(device)
144
  output_tensor = inferer(input_tensor, model)
145
  output_tensor = output_tensor.argmax(dim=1, keepdim=False)
146
  output_tensor = output_tensor.squeeze(0).to(torch.device("cpu"))
147
 
148
- output_tensor = postprocessing({"pred": output_tensor, "label": label})["pred"]
 
149
  output_tensor = output_tensor.numpy().astype(np.uint8)
150
  target_shape = example["t2_meta_dict"]["spatial_shape"]
151
  output_tensor = resize_image(output_tensor, target_shape)
 
55
  EnsureTyped(keys=keys),
56
  ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0),
57
  ConcatItemsd(keys=("t2_anatomy_reader1"), name=CommonKeys.LABEL, dim=0),
58
+ ],
59
+ allow_missing_keys=True,
60
  )
61
 
62
 
 
69
  keys=CommonKeys.PRED,
70
  applied_labels=list(range(1, 3))
71
  ),
72
+ ],
73
+ allow_missing_keys=True,
74
  )
75
  inferer = monai.inferers.SlidingWindowInferer(
76
  roi_size=(96, 96, 96),
 
135
  test_ds = Dataset(
136
  data=data_dict,
137
  transform=transforms,
 
138
  )
139
  model.eval()
140
  with torch.no_grad():
141
  example = test_ds[0]
142
+ # label = example["t2_anatomy_reader1"]
143
  input_tensor = example["t2"].unsqueeze(0)
144
  input_tensor = input_tensor.to(device)
145
  output_tensor = inferer(input_tensor, model)
146
  output_tensor = output_tensor.argmax(dim=1, keepdim=False)
147
  output_tensor = output_tensor.squeeze(0).to(torch.device("cpu"))
148
 
149
+ # output_tensor = postprocessing({"pred": output_tensor, "label": label})["pred"]
150
+ output_tensor = postprocessing({"pred": output_tensor})["pred"]
151
  output_tensor = output_tensor.numpy().astype(np.uint8)
152
  target_shape = example["t2_meta_dict"]["spatial_shape"]
153
  output_tensor = resize_image(output_tensor, target_shape)