GPT-2V / configuration_arlow_gpt.py
yuchenxie's picture
Update configuration_arlow_gpt.py
232b8ae verified
raw
history blame
896 Bytes
from transformers import PretrainedConfig
from typing import Dict, Optional
ARLOW_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"yuchenxie/arlow-gpt": "https://huggingface.co/yuchenxie/arlow-gpt/resolve/main/config.json",
}
class ArlowGPTConfig(PretrainedConfig):
model_type = "arlow_gpt"
def __init__(
self,
clip_model_name: str = "yuchenxie/CLiP",
gpt2_model_name: str = "yuchenxie/GPT-2",
clip_config: Optional[Dict] = None,
gpt2_config: Optional[Dict] = None,
projection_dim: int = 768,
vocab_size: int = 50257,
**kwargs
):
super().__init__(**kwargs)
self.clip_model_name = clip_model_name
self.gpt2_model_name = gpt2_model_name
self.clip_config = clip_config
self.gpt2_config = gpt2_config
self.projection_dim = projection_dim
self.vocab_size = vocab_size