ydshieh commited on
Commit
ac161f7
1 Parent(s): 58a121d

clean up config script

Browse files
Files changed (1) hide show
  1. vit_gpt2/configuration_vit_gpt2.py +18 -23
vit_gpt2/configuration_vit_gpt2.py CHANGED
@@ -1,9 +1,10 @@
1
  import copy
2
 
3
- from transformers import GPT2Config, ViTConfig
4
  from transformers.configuration_utils import PretrainedConfig
5
  from transformers.utils import logging
6
 
 
7
  logger = logging.get_logger(__name__)
8
 
9
 
@@ -12,34 +13,28 @@ class ViTGPT2Config(PretrainedConfig):
12
  model_type = "vit-gpt2"
13
  is_composition = True
14
 
15
- def __init__(self, **kwargs):
16
- super().__init__(**kwargs)
17
-
18
- if "vit_config" not in kwargs:
19
- raise ValueError("`vit_config` can not be `None`.")
20
 
21
- if "gpt2_config" not in kwargs:
22
- raise ValueError("`gpt2_config` can not be `None`.")
 
23
 
24
- vit_config = kwargs.pop("vit_config")
25
- gpt2_config = kwargs.pop("gpt2_config")
 
26
 
27
- self.vit_config = ViTConfig(**vit_config)
28
- self.gpt2_config = GPT2Config(**gpt2_config)
29
 
30
  @classmethod
31
- def from_vit_gpt2_configs(
32
- cls, vit_config: PretrainedConfig, gpt2_config: PretrainedConfig, **kwargs
33
- ):
34
- return cls(
35
- vit_config=vit_config.to_dict(),
36
- gpt2_config=gpt2_config.to_dict(),
37
- **kwargs
38
- )
39
 
40
  def to_dict(self):
41
  output = copy.deepcopy(self.__dict__)
42
- output["vit_config"] = self.vit_config.to_dict()
43
- output["gpt2_config"] = self.gpt2_config.to_dict()
44
  output["model_type"] = self.__class__.model_type
45
- return output
 
1
  import copy
2
 
3
+ from transformers import ViTConfig, GPT2Config
4
  from transformers.configuration_utils import PretrainedConfig
5
  from transformers.utils import logging
6
 
7
+
8
  logger = logging.get_logger(__name__)
9
 
10
 
 
13
  model_type = "vit-gpt2"
14
  is_composition = True
15
 
16
+ def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
17
+ super().__init__(text_config_dict=text_config_dict, vision_config_dict=vision_config_dict, **kwargs)
 
 
 
18
 
19
+ if vision_config_dict is None:
20
+ vision_config_dict = {}
21
+ logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")
22
 
23
+ if text_config_dict is None:
24
+ text_config_dict = {}
25
+ logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")
26
 
27
+ self.vision_config = ViTConfig(**vision_config_dict)
28
+ self.text_config = GPT2Config(**text_config_dict)
29
 
30
  @classmethod
31
+ def from_vit_gpt2_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
32
+
33
+ return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)
 
 
 
 
 
34
 
35
  def to_dict(self):
36
  output = copy.deepcopy(self.__dict__)
37
+ output["vision_config"] = self.vision_config.to_dict()
38
+ output["text_config"] = self.text_config.to_dict()
39
  output["model_type"] = self.__class__.model_type
40
+ return output