vit-gpt2 / vit_gpt2 /configuration_vit_gpt2.py
ydshieh
Fix project_encoder
b8c22f0
import copy
from transformers import ViTConfig, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class ViTGPT2Config(PretrainedConfig):
model_type = "vit-gpt2"
is_composition = True
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
super().__init__(
vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs
)
project_encoder = kwargs.pop("project_encoder", None)
if vision_config_dict is None:
vision_config_dict = {}
logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")
self.vision_config = ViTConfig(**vision_config_dict)
self.text_config = GPT2Config(**text_config_dict)
self.is_encoder_decoder = True
# Required in `generate()`.
self.bos_token_id = self.text_config.bos_token_id
self.eos_token_id = self.text_config.eos_token_id
assert hasattr(self.text_config, 'pad_token_id')
self.pad_token_id = self.text_config.pad_token_id
self.decoder_start_token_id = self.text_config.bos_token_id
self.forced_eos_token_id = self.text_config.eos_token_id
_project_encoder = getattr(self.text_config, "project_encoder", None)
if project_encoder is not None and _project_encoder is not None:
assert project_encoder == _project_encoder
elif project_encoder is not None:
_project_encoder = project_encoder
elif _project_encoder is not None:
project_encoder = _project_encoder
else:
project_encoder = False
self.project_encoder = project_encoder
self.text_config.project_encoder = project_encoder
@classmethod
def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output