File size: 1,090 Bytes
1c389fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""PyTorch VQModel."""

from huggingface_hub import hf_hub_download
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from transformers import PreTrainedModel
from transformers.utils import logging

from .configuration_vqmodel import VQModelConfig

logger = logging.get_logger(__name__)


class VQModel(PreTrainedModel):  # type: ignore
    config_class = VQModelConfig

    def __init__(self, config: VQModelConfig) -> None:
        logger.info(f"VQModel config: {config}")
        super().__init__(config)

        yaml_path = (
            config.yaml_path
            if config.repo_id is None
            else hf_hub_download(config.repo_id, config.yaml_path)
        )

        # Load vq_cfg
        vq_cfg = OmegaConf.load(yaml_path)
        vq_cfg.model.params.lossconfig = "__is_first_stage__"
        self.vq_cfg = vq_cfg

        # Initialize model
        self.model = instantiate_from_config(vq_cfg.model)

        # Remove loss attribute if exists
        try:
            delattr(self.model, "loss")
        except AttributeError:
            pass