yuxin
commited on
Commit
·
13e6989
1
Parent(s):
72b250b
add model
Browse files- config.json +1 -0
- model_segvol_single.py +8 -4
config.json
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
8 |
},
|
|
|
9 |
"model_type": "segvol",
|
10 |
"patch_size": [
|
11 |
4,
|
|
|
6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
8 |
},
|
9 |
+
"clip_model": "openai/clip-vit-base-patch32",
|
10 |
"model_type": "segvol",
|
11 |
"patch_size": [
|
12 |
4,
|
model_segvol_single.py
CHANGED
@@ -8,11 +8,13 @@ class SegVolConfig(PretrainedConfig):
|
|
8 |
def __init__(
|
9 |
self,
|
10 |
test_mode=True,
|
|
|
11 |
**kwargs,
|
12 |
):
|
13 |
self.spatial_size = [32, 256, 256]
|
14 |
self.patch_size = [4, 16, 16]
|
15 |
self.test_mode = test_mode
|
|
|
16 |
super().__init__(**kwargs)
|
17 |
|
18 |
class SegVolModel(PreTrainedModel):
|
@@ -33,6 +35,7 @@ class SegVolModel(PreTrainedModel):
|
|
33 |
prompt_encoder=sam_model.prompt_encoder,
|
34 |
roi_size=self.config.spatial_size,
|
35 |
patch_size=self.config.patch_size,
|
|
|
36 |
test_mode=self.config.test_mode,
|
37 |
)
|
38 |
|
@@ -48,7 +51,7 @@ class SegVolModel(PreTrainedModel):
|
|
48 |
point_prompt_group=None,
|
49 |
use_zoom=True):
|
50 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
51 |
-
assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None)
|
52 |
bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
|
53 |
if bbox_prompt_group is not None:
|
54 |
bbox_prompt, bbox_prompt_map = bbox_prompt_group
|
@@ -363,6 +366,7 @@ class SegVol(nn.Module):
|
|
363 |
prompt_encoder,
|
364 |
roi_size,
|
365 |
patch_size,
|
|
|
366 |
test_mode=False,
|
367 |
):
|
368 |
super().__init__()
|
@@ -370,7 +374,7 @@ class SegVol(nn.Module):
|
|
370 |
self.image_encoder = image_encoder
|
371 |
self.mask_decoder = mask_decoder
|
372 |
self.prompt_encoder = prompt_encoder
|
373 |
-
self.text_encoder = TextEncoder()
|
374 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
375 |
self.test_mode = test_mode
|
376 |
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
@@ -547,12 +551,12 @@ class SegVol(nn.Module):
|
|
547 |
return pseudo_labels, bboxes
|
548 |
|
549 |
class TextEncoder(nn.Module):
|
550 |
-
def __init__(self):
|
551 |
super().__init__()
|
552 |
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
553 |
config = CLIPTextConfig()
|
554 |
self.clip_text_model = CLIPTextModel(config)
|
555 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
556 |
self.dim_align = nn.Linear(512, 768)
|
557 |
# freeze text encoder
|
558 |
for param in self.clip_text_model.parameters():
|
|
|
8 |
def __init__(
|
9 |
self,
|
10 |
test_mode=True,
|
11 |
+
clip_model='openai/clip-vit-base-patch32',
|
12 |
**kwargs,
|
13 |
):
|
14 |
self.spatial_size = [32, 256, 256]
|
15 |
self.patch_size = [4, 16, 16]
|
16 |
self.test_mode = test_mode
|
17 |
+
self.clip_model = clip_model
|
18 |
super().__init__(**kwargs)
|
19 |
|
20 |
class SegVolModel(PreTrainedModel):
|
|
|
35 |
prompt_encoder=sam_model.prompt_encoder,
|
36 |
roi_size=self.config.spatial_size,
|
37 |
patch_size=self.config.patch_size,
|
38 |
+
clip_model=self.config.clip_model,
|
39 |
test_mode=self.config.test_mode,
|
40 |
)
|
41 |
|
|
|
51 |
point_prompt_group=None,
|
52 |
use_zoom=True):
|
53 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
54 |
+
assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
|
55 |
bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
|
56 |
if bbox_prompt_group is not None:
|
57 |
bbox_prompt, bbox_prompt_map = bbox_prompt_group
|
|
|
366 |
prompt_encoder,
|
367 |
roi_size,
|
368 |
patch_size,
|
369 |
+
clip_model,
|
370 |
test_mode=False,
|
371 |
):
|
372 |
super().__init__()
|
|
|
374 |
self.image_encoder = image_encoder
|
375 |
self.mask_decoder = mask_decoder
|
376 |
self.prompt_encoder = prompt_encoder
|
377 |
+
self.text_encoder = TextEncoder(clip_model)
|
378 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
379 |
self.test_mode = test_mode
|
380 |
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
|
|
551 |
return pseudo_labels, bboxes
|
552 |
|
553 |
class TextEncoder(nn.Module):
|
554 |
+
def __init__(self, clip_model):
|
555 |
super().__init__()
|
556 |
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
557 |
config = CLIPTextConfig()
|
558 |
self.clip_text_model = CLIPTextModel(config)
|
559 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_model)
|
560 |
self.dim_align = nn.Linear(512, 768)
|
561 |
# freeze text encoder
|
562 |
for param in self.clip_text_model.parameters():
|