| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import List, Optional, Dict | |
| from constants import VALIDATION_PROMPTS | |
| from utils.types import PESigmas | |
| class LogConfig: | |
| """ Parameters for logging and saving """ | |
| # Name of experiment. This will be the name of the output folder | |
| exp_name: str | |
| # The output directory where the model predictions and checkpoints will be written | |
| exp_dir: Path = Path("./outputs") | |
| # Save interval | |
| save_steps: int = 250 | |
| # [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to | |
| # `output_dir/runs/**CURRENT_DATETIME_HOSTNAME` | |
| logging_dir: Path = Path("logs") | |
| # The integration to report the results to. Supported platforms are "tensorboard" ' | |
| # (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
| report_to: str = "tensorboard" | |
| # Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` | |
| checkpoints_total_limit: Optional[int] = None | |
| class DataConfig: | |
| """ Parameters for data """ | |
| # A folder containing the training data | |
| train_data_dir: Path | |
| # A token to use as a placeholder for the concept | |
| placeholder_token: str | |
| # Super category token to use for normalizing the mapper output | |
| super_category_token: Optional[str] = "object" | |
| # Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process | |
| dataloader_num_workers: int = 8 | |
| # Choose between 'object' and 'style' - used for selecting the prompts for training | |
| learnable_property: str = "object" | |
| # How many times to repeat the training data | |
| repeats: int = 100 | |
| # The resolution for input images, all the images in the train/validation dataset will be resized to this resolution | |
| resolution: int = 512 | |
| # Whether to center crop images before resizing to resolution | |
| center_crop: bool = False | |
| class ModelConfig: | |
| """ Parameters for defining all models """ | |
| # Path to pretrained model or model identifier from huggingface.co/models | |
| pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4" | |
| # Whether to use our Nested Dropout technique | |
| use_nested_dropout: bool = True | |
| # Probability to apply nested dropout during training | |
| nested_dropout_prob: float = 0.5 | |
| # Whether to normalize the norm of the mapper's output vector | |
| normalize_mapper_output: bool = True | |
| # Target norm for the mapper's output vector | |
| target_norm: Optional[float] = None | |
| # Whether to use positional encoding over the input to the mapper | |
| use_positional_encoding: bool = True | |
| # Sigmas used for computing positional encoding | |
| pe_sigmas: Dict[str, float] = field(default_factory=lambda: {'sigma_t': 0.03, 'sigma_l': 2.0}) | |
| # Number of time anchors for computing our positional encodings | |
| num_pe_time_anchors: int = 10 | |
| # Whether to output the textual bypass vector | |
| output_bypass: bool = True | |
| # Revision of pretrained model identifier from huggingface.co/models | |
| revision: Optional[str] = None | |
| # Whether training should be resumed from a previous checkpoint. | |
| mapper_checkpoint_path: Optional[Path] = None | |
| def __post_init__(self): | |
| if self.pe_sigmas is not None: | |
| assert len(self.pe_sigmas) == 2, "Should provide exactly two sigma values: one for two and one for layers!" | |
| self.pe_sigmas = PESigmas(sigma_t=self.pe_sigmas['sigma_t'], sigma_l=self.pe_sigmas['sigma_l']) | |
| class EvalConfig: | |
| """ Parameters for validation """ | |
| # A list of prompts that will be used during validation to verify that the model is learning | |
| validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS) | |
| # Number of images that should be generated during validation with `validation_prompt` | |
| num_validation_images: int = 4 | |
| # Seeds to use for generating the validation images | |
| validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456]) | |
| # Run validation every X steps. | |
| validation_steps: int = 100 | |
| # Number of denoising steps | |
| num_denoising_steps: int = 50 | |
| def __post_init__(self): | |
| if self.validation_seeds is None: | |
| self.validation_seeds = list(range(self.num_validation_images)) | |
| assert len(self.validation_seeds) == self.num_validation_images, \ | |
| "Length of validation_seeds should equal num_validation_images" | |
| class OptimConfig: | |
| """ Parameters for the optimization process """ | |
| # Total number of training steps to perform. | |
| max_train_steps: Optional[int] = 1_000 | |
| # Learning rate | |
| learning_rate: float = 1e-3 | |
| # Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size | |
| scale_lr: bool = True | |
| # Batch size (per device) for the training dataloader | |
| train_batch_size: int = 2 | |
| # Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass | |
| gradient_checkpointing: bool = False | |
| # Number of updates steps to accumulate before performing a backward/update pass | |
| gradient_accumulation_steps: int = 4 | |
| # A seed for reproducible training | |
| seed: Optional[int] = None | |
| # The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", | |
| # "constant", "constant_with_warmup"] | |
| lr_scheduler: str = "constant" | |
| # Number of steps for the warmup in the lr scheduler | |
| lr_warmup_steps: int = 0 | |
| # The beta1 parameter for the Adam optimizer | |
| adam_beta1: float = 0.9 | |
| # The beta2 parameter for the Adam optimizer | |
| adam_beta2: float = 0.999 | |
| # Weight decay to use | |
| adam_weight_decay: float = 1e-2 | |
| # Epsilon value for the Adam optimizer | |
| adam_epsilon: float = 1e-08 | |
| # Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10. | |
| # and an Nvidia Ampere GPU. | |
| mixed_precision: str = "no" | |
| # Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see | |
| # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
| allow_tf32: bool = False | |
| class RunConfig: | |
| """ The main configuration for the coach trainer """ | |
| log: LogConfig = field(default_factory=LogConfig) | |
| data: DataConfig = field(default_factory=DataConfig) | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| eval: EvalConfig = field(default_factory=EvalConfig) | |
| optim: OptimConfig = field(default_factory=OptimConfig) | |