File size: 908 Bytes
2fbf9d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import PreTrainedModel
from config_segvol import SegVolConfig
from network.model import SegVol
from segment_anything_volumetric import sam_model_registry


class SegVolModel(PreTrainedModel):
    config_class = SegVolConfig

    def __init__(self, config):
        super().__init__(config)
        sam_model = sam_model_registry['vit'](self.config.patch_size, self.config.spatial_size)
        self.model = SegVol(
            image_encoder=sam_model.image_encoder, 
            mask_decoder=sam_model.mask_decoder,
            prompt_encoder=sam_model.prompt_encoder,
            roi_size=self.config.spatial_size,
            patch_size=self.config.patch_size,
            test_mode=self.config.test_mode,
            )

    def forward(self, image, text=None, boxes=None, points=None, **kwargs):
        return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs)