from transformers import PreTrainedModel from config_segvol import SegVolConfig from network.model import SegVol from segment_anything_volumetric import sam_model_registry class SegVolModel(PreTrainedModel): config_class = SegVolConfig def __init__(self, config): super().__init__(config) sam_model = sam_model_registry['vit'](self.config.patch_size, self.config.spatial_size) self.model = SegVol( image_encoder=sam_model.image_encoder, mask_decoder=sam_model.mask_decoder, prompt_encoder=sam_model.prompt_encoder, roi_size=self.config.spatial_size, patch_size=self.config.patch_size, test_mode=self.config.test_mode, ) def forward(self, image, text=None, boxes=None, points=None, **kwargs): return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs)