yuxin commited on
Commit
13e6989
·
1 Parent(s): 72b250b
Files changed (2) hide show
  1. config.json +1 -0
  2. model_segvol_single.py +8 -4
config.json CHANGED
@@ -6,6 +6,7 @@
6
  "AutoConfig": "model_segvol_single.SegVolConfig",
7
  "AutoModel": "model_segvol_single.SegVolModel"
8
  },
 
9
  "model_type": "segvol",
10
  "patch_size": [
11
  4,
 
6
  "AutoConfig": "model_segvol_single.SegVolConfig",
7
  "AutoModel": "model_segvol_single.SegVolModel"
8
  },
9
+ "clip_model": "openai/clip-vit-base-patch32",
10
  "model_type": "segvol",
11
  "patch_size": [
12
  4,
model_segvol_single.py CHANGED
@@ -8,11 +8,13 @@ class SegVolConfig(PretrainedConfig):
8
  def __init__(
9
  self,
10
  test_mode=True,
 
11
  **kwargs,
12
  ):
13
  self.spatial_size = [32, 256, 256]
14
  self.patch_size = [4, 16, 16]
15
  self.test_mode = test_mode
 
16
  super().__init__(**kwargs)
17
 
18
  class SegVolModel(PreTrainedModel):
@@ -33,6 +35,7 @@ class SegVolModel(PreTrainedModel):
33
  prompt_encoder=sam_model.prompt_encoder,
34
  roi_size=self.config.spatial_size,
35
  patch_size=self.config.patch_size,
 
36
  test_mode=self.config.test_mode,
37
  )
38
 
@@ -48,7 +51,7 @@ class SegVolModel(PreTrainedModel):
48
  point_prompt_group=None,
49
  use_zoom=True):
50
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
51
- assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None)
52
  bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
53
  if bbox_prompt_group is not None:
54
  bbox_prompt, bbox_prompt_map = bbox_prompt_group
@@ -363,6 +366,7 @@ class SegVol(nn.Module):
363
  prompt_encoder,
364
  roi_size,
365
  patch_size,
 
366
  test_mode=False,
367
  ):
368
  super().__init__()
@@ -370,7 +374,7 @@ class SegVol(nn.Module):
370
  self.image_encoder = image_encoder
371
  self.mask_decoder = mask_decoder
372
  self.prompt_encoder = prompt_encoder
373
- self.text_encoder = TextEncoder()
374
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
375
  self.test_mode = test_mode
376
  self.dice_loss = BinaryDiceLoss().to(self.custom_device)
@@ -547,12 +551,12 @@ class SegVol(nn.Module):
547
  return pseudo_labels, bboxes
548
 
549
  class TextEncoder(nn.Module):
550
- def __init__(self):
551
  super().__init__()
552
  self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
553
  config = CLIPTextConfig()
554
  self.clip_text_model = CLIPTextModel(config)
555
- self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
556
  self.dim_align = nn.Linear(512, 768)
557
  # freeze text encoder
558
  for param in self.clip_text_model.parameters():
 
8
  def __init__(
9
  self,
10
  test_mode=True,
11
+ clip_model='openai/clip-vit-base-patch32',
12
  **kwargs,
13
  ):
14
  self.spatial_size = [32, 256, 256]
15
  self.patch_size = [4, 16, 16]
16
  self.test_mode = test_mode
17
+ self.clip_model = clip_model
18
  super().__init__(**kwargs)
19
 
20
  class SegVolModel(PreTrainedModel):
 
35
  prompt_encoder=sam_model.prompt_encoder,
36
  roi_size=self.config.spatial_size,
37
  patch_size=self.config.patch_size,
38
+ clip_model=self.config.clip_model,
39
  test_mode=self.config.test_mode,
40
  )
41
 
 
51
  point_prompt_group=None,
52
  use_zoom=True):
53
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
54
+ assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
55
  bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
56
  if bbox_prompt_group is not None:
57
  bbox_prompt, bbox_prompt_map = bbox_prompt_group
 
366
  prompt_encoder,
367
  roi_size,
368
  patch_size,
369
+ clip_model,
370
  test_mode=False,
371
  ):
372
  super().__init__()
 
374
  self.image_encoder = image_encoder
375
  self.mask_decoder = mask_decoder
376
  self.prompt_encoder = prompt_encoder
377
+ self.text_encoder = TextEncoder(clip_model)
378
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
379
  self.test_mode = test_mode
380
  self.dice_loss = BinaryDiceLoss().to(self.custom_device)
 
551
  return pseudo_labels, bboxes
552
 
553
  class TextEncoder(nn.Module):
554
+ def __init__(self, clip_model):
555
  super().__init__()
556
  self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
557
  config = CLIPTextConfig()
558
  self.clip_text_model = CLIPTextModel(config)
559
+ self.tokenizer = AutoTokenizer.from_pretrained(clip_model)
560
  self.dim_align = nn.Linear(512, 768)
561
  # freeze text encoder
562
  for param in self.clip_text_model.parameters():