from transformers import pipeline from vista3d_config import VISTA3DConfig from vista3d_model import VISTA3DModel, register_my_model from vista3d_pipeline import VISTA3DPipeline, register_simple_pipeline class HuggingFacePipelineHelper: def __init__(self, pipeline_name: str = "vista3d"): self.pipeline_name = pipeline_name def __model_register(self): register_my_model() def __pipeline_register(self): register_simple_pipeline() def get_pipeline(self): self.__model_register() self.__pipeline_register() return pipeline(self.pipeline_name) def _update_config(self, config, config_dict): if config_dict: for key in config_dict: if hasattr(config, key) and getattr(config, key) != config_dict[key]: setattr(config, key, config_dict[key]) return config def init_pipeline(self, pretrained_model_name_or_path: str, **kwargs): config = VISTA3DConfig() config_dict = kwargs.pop("config_dict", None) self._update_config(config, config_dict) model = VISTA3DModel(config) model.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path ) return VISTA3DPipeline(model, **kwargs)