NeTI / config.py
neural-ti's picture
Upload 17 files
3eb1ce9
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Dict
from constants import VALIDATION_PROMPTS
from utils.types import PESigmas
@dataclass
class LogConfig:
""" Parameters for logging and saving """
# Name of experiment. This will be the name of the output folder
exp_name: str
# The output directory where the model predictions and checkpoints will be written
exp_dir: Path = Path("./outputs")
# Save interval
save_steps: int = 250
# [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
# `output_dir/runs/**CURRENT_DATETIME_HOSTNAME`
logging_dir: Path = Path("logs")
# The integration to report the results to. Supported platforms are "tensorboard" '
# (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
report_to: str = "tensorboard"
# Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator`
checkpoints_total_limit: Optional[int] = None
@dataclass
class DataConfig:
""" Parameters for data """
# A folder containing the training data
train_data_dir: Path
# A token to use as a placeholder for the concept
placeholder_token: str
# Super category token to use for normalizing the mapper output
super_category_token: Optional[str] = "object"
# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process
dataloader_num_workers: int = 8
# Choose between 'object' and 'style' - used for selecting the prompts for training
learnable_property: str = "object"
# How many times to repeat the training data
repeats: int = 100
# The resolution for input images, all the images in the train/validation dataset will be resized to this resolution
resolution: int = 512
# Whether to center crop images before resizing to resolution
center_crop: bool = False
@dataclass
class ModelConfig:
""" Parameters for defining all models """
# Path to pretrained model or model identifier from huggingface.co/models
pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4"
# Whether to use our Nested Dropout technique
use_nested_dropout: bool = True
# Probability to apply nested dropout during training
nested_dropout_prob: float = 0.5
# Whether to normalize the norm of the mapper's output vector
normalize_mapper_output: bool = True
# Target norm for the mapper's output vector
target_norm: Optional[float] = None
# Whether to use positional encoding over the input to the mapper
use_positional_encoding: bool = True
# Sigmas used for computing positional encoding
pe_sigmas: Dict[str, float] = field(default_factory=lambda: {'sigma_t': 0.03, 'sigma_l': 2.0})
# Number of time anchors for computing our positional encodings
num_pe_time_anchors: int = 10
# Whether to output the textual bypass vector
output_bypass: bool = True
# Revision of pretrained model identifier from huggingface.co/models
revision: Optional[str] = None
# Whether training should be resumed from a previous checkpoint.
mapper_checkpoint_path: Optional[Path] = None
def __post_init__(self):
if self.pe_sigmas is not None:
assert len(self.pe_sigmas) == 2, "Should provide exactly two sigma values: one for two and one for layers!"
self.pe_sigmas = PESigmas(sigma_t=self.pe_sigmas['sigma_t'], sigma_l=self.pe_sigmas['sigma_l'])
@dataclass
class EvalConfig:
""" Parameters for validation """
# A list of prompts that will be used during validation to verify that the model is learning
validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS)
# Number of images that should be generated during validation with `validation_prompt`
num_validation_images: int = 4
# Seeds to use for generating the validation images
validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456])
# Run validation every X steps.
validation_steps: int = 100
# Number of denoising steps
num_denoising_steps: int = 50
def __post_init__(self):
if self.validation_seeds is None:
self.validation_seeds = list(range(self.num_validation_images))
assert len(self.validation_seeds) == self.num_validation_images, \
"Length of validation_seeds should equal num_validation_images"
@dataclass
class OptimConfig:
""" Parameters for the optimization process """
# Total number of training steps to perform.
max_train_steps: Optional[int] = 1_000
# Learning rate
learning_rate: float = 1e-3
# Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size
scale_lr: bool = True
# Batch size (per device) for the training dataloader
train_batch_size: int = 2
# Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass
gradient_checkpointing: bool = False
# Number of updates steps to accumulate before performing a backward/update pass
gradient_accumulation_steps: int = 4
# A seed for reproducible training
seed: Optional[int] = None
# The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",
# "constant", "constant_with_warmup"]
lr_scheduler: str = "constant"
# Number of steps for the warmup in the lr scheduler
lr_warmup_steps: int = 0
# The beta1 parameter for the Adam optimizer
adam_beta1: float = 0.9
# The beta2 parameter for the Adam optimizer
adam_beta2: float = 0.999
# Weight decay to use
adam_weight_decay: float = 1e-2
# Epsilon value for the Adam optimizer
adam_epsilon: float = 1e-08
# Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.
# and an Nvidia Ampere GPU.
mixed_precision: str = "no"
# Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
allow_tf32: bool = False
@dataclass
class RunConfig:
""" The main configuration for the coach trainer """
log: LogConfig = field(default_factory=LogConfig)
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
eval: EvalConfig = field(default_factory=EvalConfig)
optim: OptimConfig = field(default_factory=OptimConfig)