File size: 975 Bytes
cd16641
ce179d5
cd16641
7ff77f3
 
 
 
 
 
 
cd16641
 
 
 
7ff77f3
 
 
 
 
 
 
 
 
 
ce179d5
7ff77f3
 
 
 
 
 
ce179d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
from typing import Optional

from transformers.modeling_utils import PretrainedConfig


class CaSEDConfig(PretrainedConfig):
    """Configuration class for CaSED.

    Args:
        index_name (str): Name of the index. Defaults to "cc12m".
        alpha (float): Weight of the vision loss. Defaults to 0.5.
        retrieval_num_results (int): Number of results to return. Defaults to 10.
        cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased".
    """

    model_type = "cased"
    is_composition = True

    def __init__(
        self,
        index_name: str = "cc12m",
        alpha: float = 0.5,
        retrieval_num_results: int = 10,
        cache_dir: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.index_name = index_name
        self.alpha = alpha
        self.retrieval_num_results = retrieval_num_results
        self.cache_dir = cache_dir or os.path.expanduser("~/.cache/cased")