yuxin commited on
Commit
d099612
1 Parent(s): 13e6989
Files changed (1) hide show
  1. model_segvol_single.py +7 -0
model_segvol_single.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers import PreTrainedModel, PretrainedConfig
2
  import numpy as np
3
  import monai.transforms as transforms
 
4
 
5
  class SegVolConfig(PretrainedConfig):
6
  model_type = "segvol"
@@ -209,6 +210,12 @@ class SegVolProcessor():
209
  target = target.cpu()
210
  return dice
211
 
 
 
 
 
 
 
212
  class MinMaxNormalization(transforms.Transform):
213
  def __call__(self, data):
214
  d = dict(data)
 
1
  from transformers import PreTrainedModel, PretrainedConfig
2
  import numpy as np
3
  import monai.transforms as transforms
4
+ import nibabel as nib
5
 
6
  class SegVolConfig(PretrainedConfig):
7
  model_type = "segvol"
 
210
  target = target.cpu()
211
  return dice
212
 
213
+ def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord):
214
+ ct = nib.load(ct_path)
215
+ print(ct.header, ct.affine, ct.shape)
216
+ print(start_coord, end_coord)
217
+ return
218
+
219
  class MinMaxNormalization(transforms.Transform):
220
  def __call__(self, data):
221
  d = dict(data)