from huggingface_hub import PyTorchModelHubMixin from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download from unifiedmodel import RRUM import os import torch class YoutubeVideoSimilarityModel(RRUM, PyTorchModelHubMixin): """ Hugging Face `PyTorchModelHubMixin` wrapper for RegretsReporter `RRUM` model. This allows loading, using, and saving the model from Hugging Face model hub with default Hugging Face methods `from_pretrained` and `save_pretrained`. """ @classmethod def _from_pretrained( cls, model_id, revision, cache_dir, force_download, proxies, resume_download, local_files_only, use_auth_token, map_location="cpu", strict=False, **model_kwargs, ): map_location = torch.device(map_location) if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) else: model_file = hf_hub_download( repo_id=model_id, filename=PYTORCH_WEIGHTS_NAME, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, use_auth_token=use_auth_token, local_files_only=local_files_only, ) # convert Huggingface config to RRUM acceptable input parameters if "config" in model_kwargs: model_kwargs = {**model_kwargs["config"], **model_kwargs} del model_kwargs["config"] model = cls(**model_kwargs) state_dict = torch.load(model_file, map_location=map_location) model.load_state_dict(state_dict, strict=strict) model.eval() return model