File size: 1,279 Bytes
08efd84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c8621b
 
 
 
 
08efd84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os

import monai.networks.nets
import torch
from transformers import AutoConfig, AutoModel, PreTrainedModel
from vista3d_config import VISTA3DConfig


class VISTA3DModel(PreTrainedModel):
    """VISTA3D model for hugging face"""

    config_class = VISTA3DConfig

    def __init__(self, config):
        super().__init__(config)
        if config.model_type == "VISTA3D":
            self.network = monai.networks.nets.vista3d132(
                encoder_embed_dim=config.encoder_embed_dim,
                in_channels=config.input_channels,
            )

    def forward(self, input):
        return self.network(input)


def register_my_model():
    """Utility function to register VISTA3D model so that it can be instantiate by the AutoModel function."""
    AutoConfig.register("VISTA3D", VISTA3DConfig)
    AutoModel.register(VISTA3DConfig, VISTA3DModel)


if __name__ == "__main__":
    FILE_PATH = os.path.dirname(__file__)
    MODEL_WEIGHT_PATH = os.path.join(FILE_PATH, "models/model.pt")
    MODEL_PATH = os.path.join(FILE_PATH, "vista3d_pretrained_model")
    config = VISTA3DConfig()
    hugging_face_model = VISTA3DModel(config)
    hugging_face_model.network.load_state_dict(torch.load(MODEL_WEIGHT_PATH))
    hugging_face_model.save_pretrained(MODEL_PATH)