File size: 2,573 Bytes
a244e91 ac161f7 a244e91 ac161f7 a244e91 b04e4c6 a244e91 ac161f7 b04e4c6 54ece9e a244e91 ac161f7 a244e91 ac161f7 a244e91 ac161f7 a244e91 b04e4c6 9aceda3 b04e4c6 54ece9e a244e91 b31314b ac161f7 a244e91 ac161f7 a244e91 ac161f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import copy
from transformers import ViTConfig, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class ViTGPT2Config(PretrainedConfig):
model_type = "vit-gpt2"
is_composition = True
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
super().__init__(
vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs
)
project_encoder = kwargs.pop("project_encoder", None)
if vision_config_dict is None:
vision_config_dict = {}
logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")
self.vision_config = ViTConfig(**vision_config_dict)
self.text_config = GPT2Config(**text_config_dict)
self.is_encoder_decoder = True
# Required in `generate()`.
self.bos_token_id = self.text_config.bos_token_id
self.eos_token_id = self.text_config.eos_token_id
assert hasattr(self.text_config, 'pad_token_id')
self.pad_token_id = self.text_config.pad_token_id
self.decoder_start_token_id = self.text_config.bos_token_id
self.forced_eos_token_id = self.text_config.eos_token_id
_project_encoder = getattr(self.text_config, "project_encoder", None)
if project_encoder is not None and _project_encoder is not None:
assert project_encoder == _project_encoder
elif project_encoder:
_project_encoder = project_encoder
elif _project_encoder:
project_encoder = _project_encoder
else:
project_encoder = False
self.config.project_encoder = project_encoder
self.text_config.project_encoder = project_encoder
@classmethod
def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
|