yuxin commited on
Commit
2fbf9d3
1 Parent(s): f205f50
Files changed (4) hide show
  1. config.json +23 -0
  2. config_segvol.py +13 -0
  3. model_segvol.py +24 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SegVolModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "config_segvol.SegVolConfig",
7
+ "AutoModel": "model_segvol.SegVolModel"
8
+ },
9
+ "model_type": "segvol",
10
+ "patch_size": [
11
+ 4,
12
+ 16,
13
+ 16
14
+ ],
15
+ "spatial_size": [
16
+ 32,
17
+ 256,
18
+ 256
19
+ ],
20
+ "test_mode": true,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.18.0"
23
+ }
config_segvol.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class SegVolConfig(PretrainedConfig):
4
+ model_type = "segvol"
5
+
6
+ def __init__(
7
+ self,
8
+ **kwargs,
9
+ ):
10
+ self.spatial_size = [32, 256, 256]
11
+ self.patch_size = [4, 16, 16]
12
+ self.test_mode = True
13
+ super().__init__(**kwargs)
model_segvol.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from config_segvol import SegVolConfig
3
+ from network.model import SegVol
4
+ from segment_anything_volumetric import sam_model_registry
5
+
6
+
7
+ class SegVolModel(PreTrainedModel):
8
+ config_class = SegVolConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ sam_model = sam_model_registry['vit'](self.config.patch_size, self.config.spatial_size)
13
+ self.model = SegVol(
14
+ image_encoder=sam_model.image_encoder,
15
+ mask_decoder=sam_model.mask_decoder,
16
+ prompt_encoder=sam_model.prompt_encoder,
17
+ roi_size=self.config.spatial_size,
18
+ patch_size=self.config.patch_size,
19
+ test_mode=self.config.test_mode,
20
+ )
21
+
22
+ def forward(self, image, text=None, boxes=None, points=None, **kwargs):
23
+ return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs)
24
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:500f2758a8f989339b2b2baf09a819169bc87549795193d3cfe505726ac0b399
3
+ size 723726667