File size: 946 Bytes
cb3de11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Any, Dict, List, Tuple

from transformers import PretrainedConfig


class ILKTConfig(PretrainedConfig):

    model_type = "ILKT"

    def __init__(
        self,
        backbone_config: Dict[str, Any] = {},
        embedding_head_config: Dict[str, Any] = {},
        mlm_head_config: Dict[str, Any] = {},
        cls_head_config: Dict[str, Any] = {},
        cls_heads: List[Tuple[int, str]] = [],
        max_length: int = 512,
        **kwargs
    ):
        self.backbone_config = backbone_config
        self.embedding_head_config = embedding_head_config
        self.mlm_head_config = mlm_head_config
        self.cls_head_config = cls_head_config
        self.cls_heads = cls_heads
        self.max_length = False
        self.output_hidden_states = False

        # TODO:
        # make config a proper HF config, save max length ets, don't know how it works exactly in hf ecosystem

        super().__init__(**kwargs)