import os from typing import Optional from transformers.modeling_utils import PretrainedConfig class CaSEDConfig(PretrainedConfig): """Configuration class for CaSED. Args: index_name (str): Name of the index. Defaults to "cc12m". alpha (float): Weight of the vision loss. Defaults to 0.5. retrieval_num_results (int): Number of results to return. Defaults to 10. cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". """ model_type = "cased" is_composition = True def __init__( self, index_name: str = "cc12m", alpha: float = 0.5, retrieval_num_results: int = 10, cache_dir: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) self.index_name = index_name self.alpha = alpha self.retrieval_num_results = retrieval_num_results self.cache_dir = cache_dir or os.path.expanduser("~/.cache/cased")