from transformers import PreTrainedModel from .config import MonoSceneConfig from monoscene.monoscene import MonoScene class MonoSceneModel(PreTrainedModel): config_class = MonoSceneConfig def __init__(self, config): super().__init__(config) self.model = MonoScene( dataset=config.dataset, n_classes=config.n_classes, feature=config.feature, project_res=['1', '2', '4', '8'], project_scale=config.project_scale, full_scene_size=config.full_scene_size ) def forward(self, tensor): return self.model.forward(tensor)