yuxin commited on
Commit
84e7143
1 Parent(s): cc3ec16

add processor

Browse files
Files changed (2) hide show
  1. config_segvol.py +4 -1
  2. model_segvol_single.py +106 -0
config_segvol.py CHANGED
@@ -5,9 +5,12 @@ class SegVolConfig(PretrainedConfig):
5
 
6
  def __init__(
7
  self,
 
 
8
  **kwargs,
9
  ):
10
  self.spatial_size = [32, 256, 256]
11
  self.patch_size = [4, 16, 16]
12
- self.test_mode = True
 
13
  super().__init__(**kwargs)
 
5
 
6
  def __init__(
7
  self,
8
+ test_mode=True,
9
+ test_w_zoom=False,
10
  **kwargs,
11
  ):
12
  self.spatial_size = [32, 256, 256]
13
  self.patch_size = [4, 16, 16]
14
+ self.test_mode = test_mode
15
+ self.test_w_zoom = test_w_zoom
16
  super().__init__(**kwargs)
model_segvol_single.py CHANGED
@@ -1,5 +1,7 @@
1
  from transformers import PreTrainedModel
2
  from .config_segvol import SegVolConfig
 
 
3
 
4
  class SegVolModel(PreTrainedModel):
5
  config_class = SegVolConfig
@@ -21,10 +23,114 @@ class SegVolModel(PreTrainedModel):
21
  patch_size=self.config.patch_size,
22
  test_mode=self.config.test_mode,
23
  )
 
 
24
 
25
  def forward(self, image, text=None, boxes=None, points=None, **kwargs):
26
  return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # SegVol
29
  import torch
30
  import torch.nn as nn
 
1
  from transformers import PreTrainedModel
2
  from .config_segvol import SegVolConfig
3
+ import numpy as np
4
+ import monai.transforms as transforms
5
 
6
  class SegVolModel(PreTrainedModel):
7
  config_class = SegVolConfig
 
23
  patch_size=self.config.patch_size,
24
  test_mode=self.config.test_mode,
25
  )
26
+
27
+ self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
28
 
29
  def forward(self, image, text=None, boxes=None, points=None, **kwargs):
30
  return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs)
31
 
32
+ # processor
33
+ class SegVolProcessor():
34
+ def __init__(self, spatial_size) -> None:
35
+ self.img_loader = transforms.LoadImage()
36
+ self.transform = transforms.Compose(
37
+ [
38
+ ForegroundNormalization(keys=["image"]),
39
+ DimTranspose(keys=["image", "label"]),
40
+ MinMaxNormalization(),
41
+ transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
42
+ transforms.ToTensord(keys=["image", "label"]),
43
+ ]
44
+ )
45
+ self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
46
+
47
+ # ct_path is path for a ct scan file with nii.gz format
48
+ # gt_path is path for a ground truth file with nii.gz format
49
+ def preprocess_ct_gt(self, ct_path, gt_path, category):
50
+ item = {}
51
+ # generate ct_voxel_ndarray
52
+ ct_voxel_ndarray, _ = self.img_loader(ct_path)
53
+ ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
54
+ ct_shape = ct_voxel_ndarray.shape
55
+ ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
56
+ item['image'] = ct_voxel_ndarray
57
+
58
+ # generate gt_voxel_ndarray
59
+ gt_voxel_ndarray, _ = self.img_loader(gt_path)
60
+ gt_voxel_ndarray = np.array(gt_voxel_ndarray)
61
+ present_categories = np.unique(gt_voxel_ndarray)
62
+ gt_masks = []
63
+ for cls_idx in range(len(category)):
64
+ # ignore background
65
+ cls = cls_idx + 1
66
+ if cls not in present_categories:
67
+ gt_voxel_ndarray_category = np.zeros(ct_shape)
68
+ gt_masks.append(gt_voxel_ndarray_category)
69
+ else:
70
+ gt_voxel_ndarray_category = gt_voxel_ndarray.copy()
71
+ gt_voxel_ndarray_category[gt_voxel_ndarray != cls] = 0
72
+ gt_voxel_ndarray_category[gt_voxel_ndarray == cls] = 1
73
+ gt_masks.append(gt_voxel_ndarray_category)
74
+ gt_voxel_ndarray = np.stack(gt_masks, axis=0)
75
+ assert gt_voxel_ndarray.shape[0] == len(category) and gt_voxel_ndarray.shape[1:] == ct_voxel_ndarray.shape[1:]
76
+ item['label'] = gt_voxel_ndarray.astype(np.int32)
77
+
78
+ # transform
79
+ item = self.transform(item)
80
+ print('ready for zoom out')
81
+ item_zoom_out = self.zoom_out_transform(item)
82
+ item['zoom_out_image'] = item_zoom_out['image']
83
+ item['zoom_out_label'] = item_zoom_out['label']
84
+ print( 'Zoom_in image shape: ', item['image'].shape,
85
+ '\nZoom_in label shape: ', item['label'].shape,
86
+ '\nZoom_out image shape: ', item['zoom_out_image'].shape,
87
+ '\nZoom_out label shape: ', item['zoom_out_label'].shape,
88
+ )
89
+ return item
90
+
91
+ class MinMaxNormalization(transforms.Transform):
92
+ def __call__(self, data):
93
+ d = dict(data)
94
+ k = "image"
95
+ d[k] = d[k] - d[k].min()
96
+ d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
97
+ return d
98
+
99
+ class DimTranspose(transforms.Transform):
100
+ def __init__(self, keys):
101
+ self.keys = keys
102
+
103
+ def __call__(self, data):
104
+ d = dict(data)
105
+ for key in self.keys:
106
+ d[key] = np.swapaxes(d[key], -1, -3)
107
+ return d
108
+
109
+ class ForegroundNormalization(transforms.Transform):
110
+ def __init__(self, keys):
111
+ self.keys = keys
112
+
113
+ def __call__(self, data):
114
+ d = dict(data)
115
+
116
+ for key in self.keys:
117
+ d[key] = self.normalize(d[key])
118
+ return d
119
+
120
+ def normalize(self, ct_narray):
121
+ ct_voxel_ndarray = ct_narray.copy()
122
+ ct_voxel_ndarray = ct_voxel_ndarray.flatten()
123
+ thred = np.mean(ct_voxel_ndarray)
124
+ voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
125
+ upper_bound = np.percentile(voxel_filtered, 99.95)
126
+ lower_bound = np.percentile(voxel_filtered, 00.05)
127
+ mean = np.mean(voxel_filtered)
128
+ std = np.std(voxel_filtered)
129
+ ### transform ###
130
+ ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
131
+ ct_narray = (ct_narray - mean) / max(std, 1e-8)
132
+ return ct_narray
133
+
134
  # SegVol
135
  import torch
136
  import torch.nn as nn