File size: 3,066 Bytes
917ff2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------

from typing import Optional, List
from dataclasses import dataclass, field
from omegaconf import OmegaConf


@dataclass
class DataConfig:
    dataset: Optional[str] = None
    tokenizer_type: str = 'CharBPE'
    context_length: int = 64
    image_resolution: int = 256
    transforms: str = 'dalle-vqvae'
    bpe_pdrop: Optional[float] = None


@dataclass
class Stage1Hparams:
    double_z: bool = False
    z_channels: int = 256
    resolution: int = 256
    in_channels: int = 3
    out_ch: int = 3
    ch: int = 128
    ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
    num_res_blocks: int = 2
    attn_resolutions: List[int] = field(default_factory=lambda: [16])
    pdrop: float = 0.0


@dataclass
class Stage2Hparams:
    embed_dim: int = 1536
    n_layers: int = 42
    n_heads: int = 24
    n_dense_layers: int = 42
    ctx_len_img: int = 256
    ctx_len_txt: int = 64
    embd_pdrop: float = 0.0
    resid_pdrop: float = 0.0
    attn_pdrop: float = 0.0
    mlp_bias: bool = True
    attn_bias: bool = True
    gelu_use_approx: bool = False
    use_head_txt: bool = True
    n_classes: Optional[int] = None


@dataclass
class Stage1Config:
    type: str = 'vqgan'
    embed_dim: int = 256
    n_embed: int = 16384
    hparams: Stage1Hparams = Stage1Hparams()


@dataclass
class Stage2Config:
    type: str = 'transformer1d'
    vocab_size_txt: int = 16384
    vocab_size_img: int = 16384
    use_cls_cond: Optional[bool] = None
    hparams: Stage2Hparams = Stage2Hparams()


@dataclass
class WarmupConfig:
    epoch: int = 1
    multiplier: int = 1
    buffer_epoch: int = 0
    min_lr: float = 0.0
    mode: str = 'fix'
    peak_lr: float = 1e-4
    start_from_zero: bool = True


@dataclass
class OptConfig:
    opt_type: str = 'adamW'
    base_lr: float = 1e-4
    weight_decay: float = 1e-4
    betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
    grad_clip_norm: float = 1.0

    sched_type: str = 'cosine'
    max_steps: int = 0
    min_lr: float = 0.0


@dataclass
class ExpConfig:
    local_batch_size: int = 4
    total_batch_size: int = 512
    valid_batch_size: int = 32
    epochs: int = 10
    save_ckpt_freq: int = 2
    test_freq: int = 1
    use_amp: bool = True


@dataclass
class DefaultConfig:
    dataset: DataConfig = DataConfig()
    stage1: Stage1Config = Stage1Config()
    stage2: Stage2Config = Stage2Config()


@dataclass
class FineTuningConfig:
    dataset: DataConfig = DataConfig()
    stage1: Stage1Config = Stage1Config()
    stage2: Stage2Config = Stage2Config()
    optimizer: OptConfig = OptConfig()
    experiment: ExpConfig = ExpConfig()


def get_base_config(use_default=True):
    return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)