| 
							 | 
						from dataclasses import dataclass | 
					
					
						
						| 
							 | 
						from transformers import GPT2Config, CLIPVisionConfig | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						PREFIX_MAP = { | 
					
					
						
						| 
							 | 
						    "openai/clip-vit-base-patch32": 50, | 
					
					
						
						| 
							 | 
						    "openai/clip-vit-base-patch16": 197, | 
					
					
						
						| 
							 | 
						    "openai/clip-vit-large-patch14": 257, | 
					
					
						
						| 
							 | 
						    "openai/clip-vit-large-patch14-336": 577 | 
					
					
						
						| 
							 | 
						} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						TEXT_HIDDEN_SIZE_MAP = { | 
					
					
						
						| 
							 | 
						    "gpt2": 768, | 
					
					
						
						| 
							 | 
						    "gpt2-medium": 768, | 
					
					
						
						| 
							 | 
						    "gpt2-large": 1280, | 
					
					
						
						| 
							 | 
						    "gpt2-xl": 1600 | 
					
					
						
						| 
							 | 
						} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						IMAGE_HIDDEN_SIZE_MAP = { | 
					
					
						
						| 
							 | 
						    "openai/clip-vit-base-patch32": 768, | 
					
					
						
						| 
							 | 
						    "openai/clip-vit-base-patch16": 768, | 
					
					
						
						| 
							 | 
						    "openai/clip-vit-large-patch14": 768, | 
					
					
						
						| 
							 | 
						    "openai/clip-vit-large-patch14-336": 768 | 
					
					
						
						| 
							 | 
						} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@dataclass | 
					
					
						
						| 
							 | 
						class CLIPGPT2Config: | 
					
					
						
						| 
							 | 
						    image_model: str = "openai/clip-vit-base-patch32" | 
					
					
						
						| 
							 | 
						    freeze_image_model: bool = True | 
					
					
						
						| 
							 | 
						    text_model: str = "gpt2-large" | 
					
					
						
						| 
							 | 
						    freeze_text_model: bool = True | 
					
					
						
						| 
							 | 
						    linear_mapping_type: int = "linear" | 
					
					
						
						| 
							 | 
						    add_image_token: bool = True | 
					
					
						
						| 
							 | 
						    freeze_ln: bool = False | 
					
					
						
						| 
							 | 
						    image_from_pretrained: bool = True | 
					
					
						
						| 
							 | 
						    text_from_pretrained: bool = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __post_init__(self): | 
					
					
						
						| 
							 | 
						        self.prefix_length = PREFIX_MAP[self.image_model] | 
					
					
						
						| 
							 | 
						        self.image_hidden_size = IMAGE_HIDDEN_SIZE_MAP[self.image_model] | 
					
					
						
						| 
							 | 
						        self.text_hidden_size = TEXT_HIDDEN_SIZE_MAP[self.text_model] | 
					
					
						
						| 
							 | 
						        self.image_resize = 224 if "336" not in self.image_model else 336 | 
					
					
						
						| 
							 | 
						        self.text_config = GPT2Config.from_pretrained(self.text_model) | 
					
					
						
						| 
							 | 
						        self.image_config = CLIPVisionConfig.from_pretrained(self.image_model) | 
					
					
						
						| 
							 | 
						        self.vocab_size = self.text_config.vocab_size + self.add_image_token | 
					
					
						
						| 
							 | 
						
 |