ydshieh commited on
Commit
b04e4c6
1 Parent(s): ac161f7

update config script

Browse files
Files changed (1) hide show
  1. vit_gpt2/configuration_vit_gpt2.py +8 -1
vit_gpt2/configuration_vit_gpt2.py CHANGED
@@ -12,9 +12,11 @@ class ViTGPT2Config(PretrainedConfig):
12
 
13
  model_type = "vit-gpt2"
14
  is_composition = True
 
15
 
16
  def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
17
- super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
 
18
 
19
  if vision_config_dict is None:
20
  vision_config_dict = {}
@@ -27,6 +29,11 @@ class ViTGPT2Config(PretrainedConfig):
27
  self.vision_config = ViTConfig(**vision_config_dict)
28
  self.text_config = GPT2Config(**text_config_dict)
29
 
 
 
 
 
 
30
  @classmethod
31
  def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
32
 
 
12
 
13
  model_type = "vit-gpt2"
14
  is_composition = True
15
+ keys_to_ignore_at_inference = ["past_key_values"]
16
 
17
  def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
18
+ super().__init__(
19
+ text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
20
 
21
  if vision_config_dict is None:
22
  vision_config_dict = {}
 
29
  self.vision_config = ViTConfig(**vision_config_dict)
30
  self.text_config = GPT2Config(**text_config_dict)
31
 
32
+ self.is_encoder_decoder = True
33
+
34
+ self.decoder_start_token_id = self.text_config.bos_token_id
35
+ self.forced_eos_token_id = self.text_config.eos_token_id
36
+
37
  @classmethod
38
  def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
39