File size: 2,573 Bytes
a244e91
 
ac161f7
a244e91
 
 
ac161f7
a244e91
 
 
 
 
 
 
b04e4c6
a244e91
ac161f7
b04e4c6
54ece9e
 
 
 
a244e91
ac161f7
 
 
a244e91
ac161f7
 
 
a244e91
ac161f7
 
a244e91
b04e4c6
 
9aceda3
 
 
 
 
 
 
b04e4c6
 
 
54ece9e
 
 
 
 
 
 
 
 
 
 
 
 
a244e91
b31314b
ac161f7
 
a244e91
 
 
ac161f7
 
a244e91
ac161f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import copy

from transformers import ViTConfig, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)


class ViTGPT2Config(PretrainedConfig):

    model_type = "vit-gpt2"
    is_composition = True
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
        super().__init__(
            vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs
        )

        project_encoder = kwargs.pop("project_encoder", None)

        if vision_config_dict is None:
            vision_config_dict = {}
            logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")

        if text_config_dict is None:
            text_config_dict = {}
            logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")

        self.vision_config = ViTConfig(**vision_config_dict)
        self.text_config = GPT2Config(**text_config_dict)

        self.is_encoder_decoder = True

        # Required in `generate()`.
        self.bos_token_id = self.text_config.bos_token_id
        self.eos_token_id = self.text_config.eos_token_id

        assert hasattr(self.text_config, 'pad_token_id')
        self.pad_token_id = self.text_config.pad_token_id

        self.decoder_start_token_id = self.text_config.bos_token_id
        self.forced_eos_token_id = self.text_config.eos_token_id

        _project_encoder = getattr(self.text_config, "project_encoder", None)
        if project_encoder is not None and _project_encoder is not None:
            assert project_encoder == _project_encoder
        elif project_encoder:
            _project_encoder = project_encoder
        elif _project_encoder:
            project_encoder = _project_encoder
        else:
            project_encoder = False

        self.config.project_encoder = project_encoder
        self.text_config.project_encoder = project_encoder

    @classmethod
    def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):

        return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        output["vision_config"] = self.vision_config.to_dict()
        output["text_config"] = self.text_config.to_dict()
        output["model_type"] = self.__class__.model_type
        return output