vit-gpt2 / vit_gpt2 /configuration_vit_gpt2.py
ydshieh
update config script
b04e4c6
raw
history blame
1.7 kB
import copy
from transformers import ViTConfig, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class ViTGPT2Config(PretrainedConfig):
model_type = "vit-gpt2"
is_composition = True
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
super().__init__(
text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
if vision_config_dict is None:
vision_config_dict = {}
logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")
self.vision_config = ViTConfig(**vision_config_dict)
self.text_config = GPT2Config(**text_config_dict)
self.is_encoder_decoder = True
self.decoder_start_token_id = self.text_config.bos_token_id
self.forced_eos_token_id = self.text_config.eos_token_id
@classmethod
def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output