ydshieh
commited on
Commit
•
b04e4c6
1
Parent(s):
ac161f7
update config script
Browse files
vit_gpt2/configuration_vit_gpt2.py
CHANGED
@@ -12,9 +12,11 @@ class ViTGPT2Config(PretrainedConfig):
|
|
12 |
|
13 |
model_type = "vit-gpt2"
|
14 |
is_composition = True
|
|
|
15 |
|
16 |
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
|
17 |
-
super().__init__(
|
|
|
18 |
|
19 |
if vision_config_dict is None:
|
20 |
vision_config_dict = {}
|
@@ -27,6 +29,11 @@ class ViTGPT2Config(PretrainedConfig):
|
|
27 |
self.vision_config = ViTConfig(**vision_config_dict)
|
28 |
self.text_config = GPT2Config(**text_config_dict)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
30 |
@classmethod
|
31 |
def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
|
32 |
|
|
|
12 |
|
13 |
model_type = "vit-gpt2"
|
14 |
is_composition = True
|
15 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
16 |
|
17 |
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
|
18 |
+
super().__init__(
|
19 |
+
text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
|
20 |
|
21 |
if vision_config_dict is None:
|
22 |
vision_config_dict = {}
|
|
|
29 |
self.vision_config = ViTConfig(**vision_config_dict)
|
30 |
self.text_config = GPT2Config(**text_config_dict)
|
31 |
|
32 |
+
self.is_encoder_decoder = True
|
33 |
+
|
34 |
+
self.decoder_start_token_id = self.text_config.bos_token_id
|
35 |
+
self.forced_eos_token_id = self.text_config.eos_token_id
|
36 |
+
|
37 |
@classmethod
|
38 |
def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
|
39 |
|