yuxin
commited on
Commit
•
d099612
1
Parent(s):
13e6989
add model
Browse files- model_segvol_single.py +7 -0
model_segvol_single.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from transformers import PreTrainedModel, PretrainedConfig
|
2 |
import numpy as np
|
3 |
import monai.transforms as transforms
|
|
|
4 |
|
5 |
class SegVolConfig(PretrainedConfig):
|
6 |
model_type = "segvol"
|
@@ -209,6 +210,12 @@ class SegVolProcessor():
|
|
209 |
target = target.cpu()
|
210 |
return dice
|
211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
class MinMaxNormalization(transforms.Transform):
|
213 |
def __call__(self, data):
|
214 |
d = dict(data)
|
|
|
1 |
from transformers import PreTrainedModel, PretrainedConfig
|
2 |
import numpy as np
|
3 |
import monai.transforms as transforms
|
4 |
+
import nibabel as nib
|
5 |
|
6 |
class SegVolConfig(PretrainedConfig):
|
7 |
model_type = "segvol"
|
|
|
210 |
target = target.cpu()
|
211 |
return dice
|
212 |
|
213 |
+
def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord):
|
214 |
+
ct = nib.load(ct_path)
|
215 |
+
print(ct.header, ct.affine, ct.shape)
|
216 |
+
print(start_coord, end_coord)
|
217 |
+
return
|
218 |
+
|
219 |
class MinMaxNormalization(transforms.Transform):
|
220 |
def __call__(self, data):
|
221 |
d = dict(data)
|