File size: 1,279 Bytes
08efd84 1c8621b 08efd84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import os
import monai.networks.nets
import torch
from transformers import AutoConfig, AutoModel, PreTrainedModel
from vista3d_config import VISTA3DConfig
class VISTA3DModel(PreTrainedModel):
"""VISTA3D model for hugging face"""
config_class = VISTA3DConfig
def __init__(self, config):
super().__init__(config)
if config.model_type == "VISTA3D":
self.network = monai.networks.nets.vista3d132(
encoder_embed_dim=config.encoder_embed_dim,
in_channels=config.input_channels,
)
def forward(self, input):
return self.network(input)
def register_my_model():
"""Utility function to register VISTA3D model so that it can be instantiate by the AutoModel function."""
AutoConfig.register("VISTA3D", VISTA3DConfig)
AutoModel.register(VISTA3DConfig, VISTA3DModel)
if __name__ == "__main__":
FILE_PATH = os.path.dirname(__file__)
MODEL_WEIGHT_PATH = os.path.join(FILE_PATH, "models/model.pt")
MODEL_PATH = os.path.join(FILE_PATH, "vista3d_pretrained_model")
config = VISTA3DConfig()
hugging_face_model = VISTA3DModel(config)
hugging_face_model.network.load_state_dict(torch.load(MODEL_WEIGHT_PATH))
hugging_face_model.save_pretrained(MODEL_PATH)
|