"""VQModel configuration""" | |
from transformers import PretrainedConfig | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
class VQModelConfig(PretrainedConfig): # type: ignore | |
model_type = "vqmodel" | |
def __init__( | |
self, | |
repo_id: str | None = None, | |
yaml_path: str | None = None, | |
**kwargs: dict, | |
) -> None: | |
if repo_id is not None: | |
yaml_path = "config.yaml" | |
self.repo_id = repo_id | |
self.yaml_path = yaml_path | |
super().__init__(**kwargs) | |