Spaces:
Build error
Build error
File size: 5,537 Bytes
3d5e231 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
from typing import Optional, List
from dataclasses import dataclass, field
from omegaconf import OmegaConf
@dataclass
class DataConfig:
dataset: Optional[str] = None
tokenizer_type: str = 'CharBPE'
context_length: int = 64
image_resolution: int = 256
transforms: str = 'dalle-vqvae'
bpe_pdrop: Optional[float] = None
@dataclass
class Stage1Hparams:
double_z: bool = False
z_channels: int = 256
resolution: int = 256
in_channels: int = 3
out_ch: int = 3
ch: int = 128
ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
num_res_blocks: int = 2
attn_resolutions: List[int] = field(default_factory=lambda: [16])
pdrop: float = 0.0
@dataclass
class Stage2Hparams:
embed_dim: int = 1536
n_layers: int = 42
n_heads: int = 24
n_dense_layers: int = 42
ctx_len_img: int = 256
ctx_len_txt: int = 64
embd_pdrop: float = 0.0
resid_pdrop: float = 0.0
attn_pdrop: float = 0.0
mlp_bias: bool = True
attn_bias: bool = True
gelu_use_approx: bool = False
use_head_txt: bool = True
n_classes: Optional[int] = None
@dataclass
class Stage1Config:
type: str = 'vqgan'
embed_dim: int = 256
n_embed: int = 16384
hparams: Stage1Hparams = Stage1Hparams()
@dataclass
class Stage2Config:
type: str = 'transformer1d'
vocab_size_txt: int = 16384
vocab_size_img: int = 16384
use_cls_cond: Optional[bool] = None
hparams: Stage2Hparams = Stage2Hparams()
@dataclass
class WarmupConfig:
epoch: int = 1
multiplier: int = 1
buffer_epoch: int = 0
min_lr: float = 0.0
mode: str = 'fix'
peak_lr: float = 1e-4
start_from_zero: bool = True
@dataclass
class OptConfig:
opt_type: str = 'adamW'
learning_rate: float = 5e-5
weight_decay: float = 1e-4
betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
grad_clip_norm: float = 1.0
sched_type: str = 'cosine'
max_steps: int = 0
min_lr: float = 1e-6
@dataclass
class ExpConfig:
per_gpu_train_batch_size: int = 4
per_gpu_eval_batch_size: int = 32
num_train_epochs: int = 10
save_ckpt_freq: int = 1
test_freq: int = 10
use_amp: bool = True
@dataclass
class PrefixModelConfig:
model_name_or_path: Optional[str] = ''
prefix_model_name_or_path: str = ''
prefix_mode: str = 'activation'
tuning_mode: str = 'finetune'
top_k_layers: int = 2
parameterize_mode: str = 'mlp'
optim_prefix: bool = False
preseqlen: int = 10
prefix_dropout: float = 0.1
init_random: bool = False
hidden_dim_prefix: int = 512
lowdata: bool = False
lowdata_token: str = ''
init_shallow: bool = False
init_shallow_word: bool = False
teacher_dropout: float = 0.1
gumbel: bool = False
replay_buffer: bool = False
@dataclass
class PromptModelConfig:
model_name_or_path: Optional[str] = ''
prefix_model_name_or_path: str = ''
tuning_mode: str = 'prompt'
preseqlen: int = 10
prefix_dropout: float = 0.1
@dataclass
class StoryModelConfig:
model_name_or_path: Optional[str] = ''
prefix_model_name_or_path: str = ''
tuning_mode: str = 'story'
preseqlen: int = 10
prefix_dropout: float = 0.1
prompt: bool = False
story_len: int = 4
sent_embed: int = 256
condition: bool = False
clip_embed: bool = False
@dataclass
class DefaultConfig:
dataset: DataConfig = DataConfig()
stage1: Stage1Config = Stage1Config()
stage2: Stage2Config = Stage2Config()
@dataclass
class FineTuningConfig:
dataset: DataConfig = DataConfig()
stage1: Stage1Config = Stage1Config()
stage2: Stage2Config = Stage2Config()
optimizer: OptConfig = OptConfig()
experiment: ExpConfig = ExpConfig()
@dataclass
class PrefixTuningConfig:
dataset: DataConfig = DataConfig()
stage1: Stage1Config = Stage1Config()
stage2: Stage2Config = Stage2Config()
prefix: PrefixModelConfig = PrefixModelConfig()
optimizer: OptConfig = OptConfig()
experiment: ExpConfig = ExpConfig()
@dataclass
class PromptTuningConfig:
dataset: DataConfig = DataConfig()
stage1: Stage1Config = Stage1Config()
stage2: Stage2Config = Stage2Config()
prompt: PromptModelConfig = PromptModelConfig()
optimizer: OptConfig = OptConfig()
experiment: ExpConfig = ExpConfig()
@dataclass
class StoryConfig:
dataset: DataConfig = DataConfig()
stage1: Stage1Config = Stage1Config()
stage2: Stage2Config = Stage2Config()
story: StoryModelConfig = StoryModelConfig()
optimizer: OptConfig = OptConfig()
experiment: ExpConfig = ExpConfig()
def get_base_config(mode):
if mode == 'default':
return OmegaConf.structured(DefaultConfig)
elif mode == 'finetuning':
return OmegaConf.structured(FineTuningConfig)
elif mode == 'prefixtuning':
return OmegaConf.structured(PrefixTuningConfig)
elif mode == 'prompt_tuning':
return OmegaConf.structured(PromptTuningConfig)
elif mode == 'story':
return OmegaConf.structured(StoryConfig)
else:
raise ValueError
# return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
|