File size: 552 Bytes
1c389fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
"""VQModel configuration"""
from transformers import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class VQModelConfig(PretrainedConfig): # type: ignore
model_type = "vqmodel"
def __init__(
self,
repo_id: str | None = None,
yaml_path: str | None = None,
**kwargs: dict,
) -> None:
if repo_id is not None:
yaml_path = "config.yaml"
self.repo_id = repo_id
self.yaml_path = yaml_path
super().__init__(**kwargs)
|