"""PyTorch VQModel.""" from huggingface_hub import hf_hub_download from ldm.util import instantiate_from_config from omegaconf import OmegaConf from transformers import PreTrainedModel from transformers.utils import logging from .configuration_vqmodel import VQModelConfig logger = logging.get_logger(__name__) class VQModel(PreTrainedModel): # type: ignore config_class = VQModelConfig def __init__(self, config: VQModelConfig) -> None: logger.info(f"VQModel config: {config}") super().__init__(config) yaml_path = ( config.yaml_path if config.repo_id is None else hf_hub_download(config.repo_id, config.yaml_path) ) # Load vq_cfg vq_cfg = OmegaConf.load(yaml_path) vq_cfg.model.params.lossconfig = "__is_first_stage__" self.vq_cfg = vq_cfg # Initialize model self.model = instantiate_from_config(vq_cfg.model) # Remove loss attribute if exists try: delattr(self.model, "loss") except AttributeError: pass