|
from transformers.modeling_utils import PretrainedConfig |
|
|
|
|
|
class CaSEDConfig(PretrainedConfig): |
|
"""Configuration class for CaSED. |
|
|
|
Args: |
|
index_name (str, optional): Name of the index. Defaults to "cc12m". |
|
alpha (float, optional): Weight of the vision loss. Defaults to 0.5. |
|
retrieval_num_results (int, optional): Number of results to return. Defaults to 10. |
|
""" |
|
|
|
model_type = "cased" |
|
is_composition = True |
|
|
|
def __init__( |
|
self, |
|
index_name: str = "cc12m", |
|
alpha: float = 0.5, |
|
retrieval_num_results: int = 10, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.index_name = index_name |
|
self.alpha = alpha |
|
self.retrieval_num_results = retrieval_num_results |
|
|