File size: 770 Bytes
7ff77f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers.modeling_utils import PretrainedConfig


class CaSEDConfig(PretrainedConfig):
    """Configuration class for CaSED.

    Args:
        index_name (str, optional): Name of the index. Defaults to "cc12m".
        alpha (float, optional): Weight of the vision loss. Defaults to 0.5.
        retrieval_num_results (int, optional): Number of results to return. Defaults to 10.
    """

    model_type = "cased"
    is_composition = True

    def __init__(
        self,
        index_name: str = "cc12m",
        alpha: float = 0.5,
        retrieval_num_results: int = 10,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.index_name = index_name
        self.alpha = alpha
        self.retrieval_num_results = retrieval_num_results