ldm-vq-f16 / modeling_vqmodel.py
ktrk115's picture
Upload model
1c389fc verified
raw
history blame contribute delete
No virus
1.09 kB
"""PyTorch VQModel."""
from huggingface_hub import hf_hub_download
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from transformers import PreTrainedModel
from transformers.utils import logging
from .configuration_vqmodel import VQModelConfig
logger = logging.get_logger(__name__)
class VQModel(PreTrainedModel): # type: ignore
config_class = VQModelConfig
def __init__(self, config: VQModelConfig) -> None:
logger.info(f"VQModel config: {config}")
super().__init__(config)
yaml_path = (
config.yaml_path
if config.repo_id is None
else hf_hub_download(config.repo_id, config.yaml_path)
)
# Load vq_cfg
vq_cfg = OmegaConf.load(yaml_path)
vq_cfg.model.params.lossconfig = "__is_first_stage__"
self.vq_cfg = vq_cfg
# Initialize model
self.model = instantiate_from_config(vq_cfg.model)
# Remove loss attribute if exists
try:
delattr(self.model, "loss")
except AttributeError:
pass