from transformers import PretrainedConfig | |
from typing import Dict, Optional | |
ARLOW_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { | |
"yuchenxie/arlow-gpt": "https://huggingface.co/yuchenxie/arlow-gpt/resolve/main/config.json", | |
} | |
class ArlowGPTConfig(PretrainedConfig): | |
model_type = "arlow_gpt" | |
def __init__( | |
self, | |
clip_model_name: str = "yuchenxie/CLiP", | |
gpt2_model_name: str = "yuchenxie/GPT-2", | |
clip_config: Optional[Dict] = None, | |
gpt2_config: Optional[Dict] = None, | |
projection_dim: int = 768, | |
vocab_size: int = 50257, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.clip_model_name = clip_model_name | |
self.gpt2_model_name = gpt2_model_name | |
self.clip_config = clip_config | |
self.gpt2_config = gpt2_config | |
self.projection_dim = projection_dim | |
self.vocab_size = vocab_size | |