File size: 3,086 Bytes
6056cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

from typing import Literal
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.mistral import MistralConfig

NVEMBED_TYPE = "nvembed"
LATENT_ATTENTION_TYPE = "latent_attention"
BIDIR_MISTRAL_TYPE = "bidir_mistral"

class NVEmbedConfig(PretrainedConfig):
    model_type = "nvembed"
    is_composition = False

    def __init__(
        self,
        hidden_size=4096,
        latent_attention_config=None,
        text_config=None,
        padding_side: Literal["right", "left"]="right",
        add_pad_token: bool=True,
        is_mask_instruction: bool = True,
        add_eos: bool=True,
        mask_type: str="b",
        **kwargs,
    ):
        if isinstance(latent_attention_config, dict):
            latent_attention_config["model_type"] = (
                latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE
            )
            latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config)
        elif latent_attention_config is None:
            latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]()

        self.latent_attention_config = latent_attention_config

        if isinstance(text_config, dict):
            text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
            text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
        elif text_config is None:
            text_config = None

        self.hidden_size = hidden_size
        self.text_config = text_config
        self.padding_side = padding_side
        self.is_mask_instruction = is_mask_instruction
        self.add_pad_token = add_pad_token
        self.add_eos = add_eos
        self.mask_type = mask_type

        super().__init__(**kwargs)


class LatentAttentionConfig(PretrainedConfig):
    model_type = LATENT_ATTENTION_TYPE
    is_composition = False
    _name_or_path = "latent_attention"

    def __init__(
        self,
        num_latents_value: int=512,
        num_cross_heads: int=8,
        output_normalize: bool=True,
        hidden_dim: int=4096,
        latent_dim: int=4096,
        cross_dim_head: int=4096,
        **kwargs,
    ):
        self.num_latents_value = num_latents_value
        self.num_cross_heads = num_cross_heads
        self.output_normalize = output_normalize
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.cross_dim_head = cross_dim_head


class BidirectionalMistralConfig(MistralConfig):
    model_type = BIDIR_MISTRAL_TYPE
    keys_to_ignore_at_inference = ["past_key_values"]

AutoConfig.register(NVEMBED_TYPE, NVEmbedConfig)
AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig)
AutoConfig.register(BIDIR_MISTRAL_TYPE, BidirectionalMistralConfig)

NVEmbedConfig.register_for_auto_class()
LatentAttentionConfig.register_for_auto_class()
BidirectionalMistralConfig.register_for_auto_class()