ydshieh
commited on
Commit
•
ac161f7
1
Parent(s):
58a121d
clean up config script
Browse files
vit_gpt2/configuration_vit_gpt2.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import copy
|
2 |
|
3 |
-
from transformers import
|
4 |
from transformers.configuration_utils import PretrainedConfig
|
5 |
from transformers.utils import logging
|
6 |
|
|
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
9 |
|
@@ -12,34 +13,28 @@ class ViTGPT2Config(PretrainedConfig):
|
|
12 |
model_type = "vit-gpt2"
|
13 |
is_composition = True
|
14 |
|
15 |
-
def __init__(self, **kwargs):
|
16 |
-
super().__init__(**kwargs)
|
17 |
-
|
18 |
-
if "vit_config" not in kwargs:
|
19 |
-
raise ValueError("`vit_config` can not be `None`.")
|
20 |
|
21 |
-
if
|
22 |
-
|
|
|
23 |
|
24 |
-
|
25 |
-
|
|
|
26 |
|
27 |
-
self.
|
28 |
-
self.
|
29 |
|
30 |
@classmethod
|
31 |
-
def from_vit_gpt2_configs(
|
32 |
-
|
33 |
-
|
34 |
-
return cls(
|
35 |
-
vit_config=vit_config.to_dict(),
|
36 |
-
gpt2_config=gpt2_config.to_dict(),
|
37 |
-
**kwargs
|
38 |
-
)
|
39 |
|
40 |
def to_dict(self):
|
41 |
output = copy.deepcopy(self.__dict__)
|
42 |
-
output["
|
43 |
-
output["
|
44 |
output["model_type"] = self.__class__.model_type
|
45 |
-
return output
|
|
|
1 |
import copy
|
2 |
|
3 |
+
from transformers import ViTConfig, GPT2Config
|
4 |
from transformers.configuration_utils import PretrainedConfig
|
5 |
from transformers.utils import logging
|
6 |
|
7 |
+
|
8 |
logger = logging.get_logger(__name__)
|
9 |
|
10 |
|
|
|
13 |
model_type = "vit-gpt2"
|
14 |
is_composition = True
|
15 |
|
16 |
+
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
|
17 |
+
super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
|
|
|
|
|
|
|
18 |
|
19 |
+
if vision_config_dict is None:
|
20 |
+
vision_config_dict = {}
|
21 |
+
logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")
|
22 |
|
23 |
+
if text_config_dict is None:
|
24 |
+
text_config_dict = {}
|
25 |
+
logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")
|
26 |
|
27 |
+
self.vision_config = ViTConfig(**vision_config_dict)
|
28 |
+
self.text_config = GPT2Config(**text_config_dict)
|
29 |
|
30 |
@classmethod
|
31 |
+
def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
|
32 |
+
|
33 |
+
return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def to_dict(self):
|
36 |
output = copy.deepcopy(self.__dict__)
|
37 |
+
output["vision_config"] = self.vision_config.to_dict()
|
38 |
+
output["text_config"] = self.text_config.to_dict()
|
39 |
output["model_type"] = self.__class__.model_type
|
40 |
+
return output
|