yuxin commited on
Commit
e49ada4
1 Parent(s): d099612
Files changed (1) hide show
  1. model_segvol_single.py +10 -3
model_segvol_single.py CHANGED
@@ -212,9 +212,16 @@ class SegVolProcessor():
212
 
213
  def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord):
214
  ct = nib.load(ct_path)
215
- print(ct.header, ct.affine, ct.shape)
216
- print(start_coord, end_coord)
217
- return
 
 
 
 
 
 
 
218
 
219
  class MinMaxNormalization(transforms.Transform):
220
  def __call__(self, data):
 
212
 
213
  def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord):
214
  ct = nib.load(ct_path)
215
+ logits_mask = logits_mask.transpose(-1, -3)
216
+ start_coord[-1], start_coord[-3] = start_coord[-3], start_coord[-1]
217
+ end_coord[-1], end_coord[-3] = end_coord[-3], end_coord[-1]
218
+ preds_save = torch.zeros(ct.shape)
219
+ preds_save[start_coord[0]:end_coord[0],
220
+ start_coord[1]:end_coord[1],
221
+ start_coord[2]:end_coord[2]] = torch.sigmoid(logits_mask)
222
+ preds_save = torch.where(preds_save > 0.5, 1., 0.).numpy()
223
+ preds_nii = nib.Nifti1Image(preds_save, affine=ct.affine, header=ct.header)
224
+ nib.save(preds_nii, save_path)
225
 
226
  class MinMaxNormalization(transforms.Transform):
227
  def __call__(self, data):