yuxin commited on
Commit
0672fb5
1 Parent(s): e49ada4
Files changed (1) hide show
  1. model_segvol_single.py +121 -80
model_segvol_single.py CHANGED
@@ -123,7 +123,7 @@ class SegVolModel(PreTrainedModel):
123
  class SegVolProcessor():
124
  def __init__(self, spatial_size) -> None:
125
  self.img_loader = transforms.LoadImage()
126
- self.transform = transforms.Compose(
127
  [
128
  ForegroundNormalization(keys=["image"]),
129
  DimTranspose(keys=["image", "label"]),
@@ -134,6 +134,36 @@ class SegVolProcessor():
134
  )
135
  self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
136
  self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  # ct_path is path for a ct scan file with nii.gz format
139
  # gt_path is path for a ground truth file with nii.gz format
@@ -174,7 +204,7 @@ class SegVolProcessor():
174
  'image': ct_npy,
175
  'label': gt_npy
176
  }
177
- item = self.transform(item)
178
  item_zoom_out = self.zoom_out_transform(item)
179
  item['zoom_out_image'] = item_zoom_out['image']
180
  item['zoom_out_label'] = item_zoom_out['label']
@@ -223,6 +253,17 @@ class SegVolProcessor():
223
  preds_nii = nib.Nifti1Image(preds_save, affine=ct.affine, header=ct.header)
224
  nib.save(preds_nii, save_path)
225
 
 
 
 
 
 
 
 
 
 
 
 
226
  class MinMaxNormalization(transforms.Transform):
227
  def __call__(self, data):
228
  d = dict(data)
@@ -409,8 +450,8 @@ class SegVol(nn.Module):
409
  ## sl
410
  sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
411
  ## ssl
412
- ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
413
- return sl_loss, ssl_loss
414
 
415
  def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
416
  with torch.no_grad():
@@ -456,20 +497,20 @@ class SegVol(nn.Module):
456
  sl_loss += sl_loss_dice + sl_loss_bce
457
  return sl_loss
458
 
459
- def unsupervised_forward(self, image, image_embedding, pseudo_seg_cleaned, img_shape):
460
- sll_loss = 0
461
- for iter in range(self.decoder_iter):
462
- if iter % 2 == 0:
463
- pseudo_labels, pseudo_points_prompt = self.build_pseudo_point_prompt_label(image.shape, pseudo_seg_cleaned)
464
- logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=None, points=pseudo_points_prompt)
465
- else:
466
- pseudo_labels, pseudo_bboxes_prompt = self.build_pseudo_box_prompt_label(image.shape, pseudo_seg_cleaned)
467
- logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=pseudo_bboxes_prompt, points=None)
468
- # cal loss
469
- sll_loss_dice = self.dice_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
470
- sll_loss_bce = self.bce_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
471
- sll_loss += sll_loss_dice + sll_loss_bce
472
- return sll_loss
473
 
474
  def build_prompt_label(self, bs, training_organs, train_labels):
475
  # generate prompt & label
@@ -501,68 +542,68 @@ class SegVol(nn.Module):
501
  iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
502
  return iter_points, iter_bboxes, iter_organs
503
 
504
- def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
505
- pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
506
- # generate points
507
- points = []
508
- point_labels = []
509
- for batch_idx in range(input_shape[0]):
510
- # generate pseudo label
511
- unique_ids = torch.unique(seg_labels[batch_idx])
512
- unique_ids = unique_ids[unique_ids != -1]
513
- region_id = random.choice(unique_ids).item()
514
- pseudo_labels[batch_idx][seg_labels[batch_idx]==region_id] = 1
515
- # generate point prompt
516
- num_positive_extra_max, num_negative_extra_max = 10, 10
517
- num_positive_extra = random.randint(4, num_positive_extra_max)
518
- num_negative_extra = random.randint(0, num_negative_extra_max)
519
- assert len(pseudo_labels[batch_idx][0].shape) == 3
520
- point, point_label = select_points(
521
- pseudo_labels[batch_idx][0],
522
- num_positive_extra=num_positive_extra,
523
- num_negative_extra=num_negative_extra,
524
- fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
525
- points.append(point)
526
- point_labels.append(point_label)
527
- points = torch.stack(points, dim=0).to(self.custom_device)
528
- point_labels = torch.stack(point_labels, dim=0).to(self.custom_device)
529
- pseudo_points_prompt = (points, point_labels)
530
- return pseudo_labels, pseudo_points_prompt
531
-
532
- def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
533
- pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
534
- iter_bboxes = []
535
- # generate boxes
536
- for batch_idx in range(input_shape[0]):
537
- # generate ori pseudo label
538
- unique_ids = torch.unique(seg_labels_cleaned[batch_idx])
539
- unique_ids = unique_ids[unique_ids != -1]
540
- region_id = random.choice(unique_ids).item()
541
- pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==region_id] = 1
542
- # generate box prompt
543
- box = generate_box(pseudo_labels[batch_idx][0])
544
- iter_bboxes.append(box)
545
- # refine pseudo label
546
- x_min, y_min, z_min, x_max, y_max, z_max = box
547
- binary_cube = torch.zeros_like(pseudo_labels[batch_idx][0]).int()
548
- binary_cube[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1] = 1
549
- # cal iou
550
- mask_label = seg_labels_cleaned[batch_idx][0]
551
- assert binary_cube.shape == mask_label.shape, str(binary_cube.shape) + ' ' + str(mask_label.shape)
552
- mask_values_in_binary_cube = mask_label[binary_cube == 1]
553
- unique_mask_values = torch.unique(mask_values_in_binary_cube)
554
- # print('unique_mask_values ', unique_mask_values)
555
- for value in unique_mask_values:
556
- if value == -1: continue
557
- mask_area = (mask_label == value)
558
- intersection = (binary_cube & mask_area)
559
- iou = intersection.float().sum() / mask_area.float().sum()
560
- if iou > 0.90:
561
- # print(f"Mask value {value} has IOU > 0.90 in binary cube.")
562
- pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
563
-
564
- bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
565
- return pseudo_labels, bboxes
566
 
567
  class TextEncoder(nn.Module):
568
  def __init__(self, clip_model):
 
123
  class SegVolProcessor():
124
  def __init__(self, spatial_size) -> None:
125
  self.img_loader = transforms.LoadImage()
126
+ self.transform4test = transforms.Compose(
127
  [
128
  ForegroundNormalization(keys=["image"]),
129
  DimTranspose(keys=["image", "label"]),
 
134
  )
135
  self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
136
  self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
137
+ self.transform4train = transforms.Compose(
138
+ [
139
+ transforms.AddChanneld(keys=["image"]),
140
+ DimTranspose(keys=["image", "label"]),
141
+ MinMaxNormalization(),
142
+ transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
143
+ transforms.SpatialPadd(keys=["image", "label"], spatial_size=spatial_size, mode='constant'),
144
+ transforms.OneOf(transforms=[
145
+ transforms.Resized(keys=["image", "label"],spatial_size=spatial_size),
146
+ transforms.RandCropByPosNegLabeld(
147
+ keys=["image", "label"],
148
+ label_key="label",
149
+ spatial_size=spatial_size,
150
+ pos=5,
151
+ neg=1,
152
+ num_samples=1,
153
+ image_key="image",
154
+ image_threshold=0,
155
+ ),
156
+ ],
157
+ weights=[1, 3]
158
+ ),
159
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0),
160
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1),
161
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2),
162
+ transforms.RandScaleIntensityd(keys="image", factors=0.2, prob=0.2),
163
+ transforms.RandShiftIntensityd(keys="image", offsets=0.2, prob=0.2),
164
+ transforms.ToTensord(keys=["image", "label"]),
165
+ ]
166
+ )
167
 
168
  # ct_path is path for a ct scan file with nii.gz format
169
  # gt_path is path for a ground truth file with nii.gz format
 
204
  'image': ct_npy,
205
  'label': gt_npy
206
  }
207
+ item = self.transform4test(item)
208
  item_zoom_out = self.zoom_out_transform(item)
209
  item['zoom_out_image'] = item_zoom_out['image']
210
  item['zoom_out_label'] = item_zoom_out['label']
 
253
  preds_nii = nib.Nifti1Image(preds_save, affine=ct.affine, header=ct.header)
254
  nib.save(preds_nii, save_path)
255
 
256
+ def train_transform(self, ct_npy, gt_npy):
257
+ item = {
258
+ 'image': ct_npy,
259
+ 'label': gt_npy
260
+ }
261
+ item = self.transform4train(item)
262
+ if type(item) is list:
263
+ assert len(item) == 1
264
+ item = item[0]
265
+ return item
266
+
267
  class MinMaxNormalization(transforms.Transform):
268
  def __call__(self, data):
269
  d = dict(data)
 
450
  ## sl
451
  sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
452
  ## ssl
453
+ # ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
454
+ return sl_loss
455
 
456
  def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
457
  with torch.no_grad():
 
497
  sl_loss += sl_loss_dice + sl_loss_bce
498
  return sl_loss
499
 
500
+ # def unsupervised_forward(self, image, image_embedding, pseudo_seg_cleaned, img_shape):
501
+ # sll_loss = 0
502
+ # for iter in range(self.decoder_iter):
503
+ # if iter % 2 == 0:
504
+ # pseudo_labels, pseudo_points_prompt = self.build_pseudo_point_prompt_label(image.shape, pseudo_seg_cleaned)
505
+ # logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=None, points=pseudo_points_prompt)
506
+ # else:
507
+ # pseudo_labels, pseudo_bboxes_prompt = self.build_pseudo_box_prompt_label(image.shape, pseudo_seg_cleaned)
508
+ # logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=pseudo_bboxes_prompt, points=None)
509
+ # # cal loss
510
+ # sll_loss_dice = self.dice_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
511
+ # sll_loss_bce = self.bce_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
512
+ # sll_loss += sll_loss_dice + sll_loss_bce
513
+ # return sll_loss
514
 
515
  def build_prompt_label(self, bs, training_organs, train_labels):
516
  # generate prompt & label
 
542
  iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
543
  return iter_points, iter_bboxes, iter_organs
544
 
545
+ # def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
546
+ # pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
547
+ # # generate points
548
+ # points = []
549
+ # point_labels = []
550
+ # for batch_idx in range(input_shape[0]):
551
+ # # generate pseudo label
552
+ # unique_ids = torch.unique(seg_labels[batch_idx])
553
+ # unique_ids = unique_ids[unique_ids != -1]
554
+ # region_id = random.choice(unique_ids).item()
555
+ # pseudo_labels[batch_idx][seg_labels[batch_idx]==region_id] = 1
556
+ # # generate point prompt
557
+ # num_positive_extra_max, num_negative_extra_max = 10, 10
558
+ # num_positive_extra = random.randint(4, num_positive_extra_max)
559
+ # num_negative_extra = random.randint(0, num_negative_extra_max)
560
+ # assert len(pseudo_labels[batch_idx][0].shape) == 3
561
+ # point, point_label = select_points(
562
+ # pseudo_labels[batch_idx][0],
563
+ # num_positive_extra=num_positive_extra,
564
+ # num_negative_extra=num_negative_extra,
565
+ # fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
566
+ # points.append(point)
567
+ # point_labels.append(point_label)
568
+ # points = torch.stack(points, dim=0).to(self.custom_device)
569
+ # point_labels = torch.stack(point_labels, dim=0).to(self.custom_device)
570
+ # pseudo_points_prompt = (points, point_labels)
571
+ # return pseudo_labels, pseudo_points_prompt
572
+
573
+ # def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
574
+ # pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
575
+ # iter_bboxes = []
576
+ # # generate boxes
577
+ # for batch_idx in range(input_shape[0]):
578
+ # # generate ori pseudo label
579
+ # unique_ids = torch.unique(seg_labels_cleaned[batch_idx])
580
+ # unique_ids = unique_ids[unique_ids != -1]
581
+ # region_id = random.choice(unique_ids).item()
582
+ # pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==region_id] = 1
583
+ # # generate box prompt
584
+ # box = generate_box(pseudo_labels[batch_idx][0])
585
+ # iter_bboxes.append(box)
586
+ # # refine pseudo label
587
+ # x_min, y_min, z_min, x_max, y_max, z_max = box
588
+ # binary_cube = torch.zeros_like(pseudo_labels[batch_idx][0]).int()
589
+ # binary_cube[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1] = 1
590
+ # # cal iou
591
+ # mask_label = seg_labels_cleaned[batch_idx][0]
592
+ # assert binary_cube.shape == mask_label.shape, str(binary_cube.shape) + ' ' + str(mask_label.shape)
593
+ # mask_values_in_binary_cube = mask_label[binary_cube == 1]
594
+ # unique_mask_values = torch.unique(mask_values_in_binary_cube)
595
+ # # print('unique_mask_values ', unique_mask_values)
596
+ # for value in unique_mask_values:
597
+ # if value == -1: continue
598
+ # mask_area = (mask_label == value)
599
+ # intersection = (binary_cube & mask_area)
600
+ # iou = intersection.float().sum() / mask_area.float().sum()
601
+ # if iou > 0.90:
602
+ # # print(f"Mask value {value} has IOU > 0.90 in binary cube.")
603
+ # pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
604
+
605
+ # bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
606
+ # return pseudo_labels, bboxes
607
 
608
  class TextEncoder(nn.Module):
609
  def __init__(self, clip_model):