|
"""PyTorch VQModel.""" |
|
|
|
from huggingface_hub import hf_hub_download |
|
from ldm.util import instantiate_from_config |
|
from omegaconf import OmegaConf |
|
from transformers import PreTrainedModel |
|
from transformers.utils import logging |
|
|
|
from .configuration_vqmodel import VQModelConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class VQModel(PreTrainedModel): |
|
config_class = VQModelConfig |
|
|
|
def __init__(self, config: VQModelConfig) -> None: |
|
logger.info(f"VQModel config: {config}") |
|
super().__init__(config) |
|
|
|
yaml_path = ( |
|
config.yaml_path |
|
if config.repo_id is None |
|
else hf_hub_download(config.repo_id, config.yaml_path) |
|
) |
|
|
|
|
|
vq_cfg = OmegaConf.load(yaml_path) |
|
vq_cfg.model.params.lossconfig = "__is_first_stage__" |
|
self.vq_cfg = vq_cfg |
|
|
|
|
|
self.model = instantiate_from_config(vq_cfg.model) |
|
|
|
|
|
try: |
|
delattr(self.model, "loss") |
|
except AttributeError: |
|
pass |
|
|