yuxin
commited on
Commit
•
e49ada4
1
Parent(s):
d099612
add model
Browse files- model_segvol_single.py +10 -3
model_segvol_single.py
CHANGED
@@ -212,9 +212,16 @@ class SegVolProcessor():
|
|
212 |
|
213 |
def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord):
|
214 |
ct = nib.load(ct_path)
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
class MinMaxNormalization(transforms.Transform):
|
220 |
def __call__(self, data):
|
|
|
212 |
|
213 |
def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord):
|
214 |
ct = nib.load(ct_path)
|
215 |
+
logits_mask = logits_mask.transpose(-1, -3)
|
216 |
+
start_coord[-1], start_coord[-3] = start_coord[-3], start_coord[-1]
|
217 |
+
end_coord[-1], end_coord[-3] = end_coord[-3], end_coord[-1]
|
218 |
+
preds_save = torch.zeros(ct.shape)
|
219 |
+
preds_save[start_coord[0]:end_coord[0],
|
220 |
+
start_coord[1]:end_coord[1],
|
221 |
+
start_coord[2]:end_coord[2]] = torch.sigmoid(logits_mask)
|
222 |
+
preds_save = torch.where(preds_save > 0.5, 1., 0.).numpy()
|
223 |
+
preds_nii = nib.Nifti1Image(preds_save, affine=ct.affine, header=ct.header)
|
224 |
+
nib.save(preds_nii, save_path)
|
225 |
|
226 |
class MinMaxNormalization(transforms.Transform):
|
227 |
def __call__(self, data):
|