from dataclasses import dataclass, asdict import yaml from pprint import pprint from .utils import is_main import os from pathlib import Path import uuid def load_config(path, config_dir=Path("configs")): if not path.endswith(".yml"): path += ".yml" if not os.path.exists(path): path = config_dir / path with open(path, "r") as stream: config = yaml.safe_load(stream) return config @dataclass class MultimodalConfig: # Training: # ------------------------------------------------------------ batch_size: int train_steps: int optimizer_name: str = "AdamW" lr: float = 8.0e-4 image_enc_lr: float = None min_lr: float = 0.0 lr_decay_iters: int = None gradient_accumulation_steps: int = 1 image_size: int = 256 eval_every: int = 250 eval_steps: int = 25 zero_stage: int = 2 gradient_clipping: float = 1.0 warmup_num_steps: int = 100 weight_decay: float = 0.00 run_blind: bool = False fine_tune: bool = False load_optimizer: bool = True # Checkpointing: # ------------------------------------------------------------ save_every: int = 2500 save: str = None load: str = None # Data: # ------------------------------------------------------------ train_dataset_name: str = "conceptual_captions" eval_dataset_name: str = "/data/conceptual_captions" train_dataset_dir: str = "/data/coco_data" eval_dataset_dir: str = "/data/coco_data" eval_dataset_pct: float = 0.1 # Model architecture: # ------------------------------------------------------------ encoder_name: str = "clip" tokenizer_name: str = "gpt2" lm_name: str = "EleutherAI/gpt-j-6B" image_seq_len: int = 2 pretrained_img_encoder: bool = False seq_len: int = None # Layer Freezing settings: # ------------------------------------------------------------ freeze_lm: bool = True freeze_img_encoder: bool = True image_embed_dropout_prob: float = 0.0 use_image_embed_layernorm: bool = False # Adapter settings: # ------------------------------------------------------------ adapter_config: dict = None # Classification Finetuning settings: # ------------------------------------------------------------ class_dict: dict = None # {num_classes: .., ckpt_path: .., classifier_type:, .., interface_type: .., interface_position: .., freeze_model: ..} # Logging settings: # ------------------------------------------------------------ name: str = None # name, just used for wandb logging log_every: int = 1 wandb_project: str = "magma" def print(self): if is_main(): print("-" * 100) pprint(self.__dict__, indent=4) print("-" * 100) def __post_init__(self): self.is_classifier = self.class_dict is not None if self.adapter_config is None: self.adapter_config = {} # Deepspeed Settings: # ------------------------------------------------------------ if self.lr_decay_iters is None: self.lr_scheduler = "WarmupLR" self.scheduler_dict = { "type": self.lr_scheduler, "params": { "warmup_min_lr": self.min_lr, "warmup_max_lr": self.lr, "warmup_num_steps": self.warmup_num_steps, }, } else: self.lr_scheduler = "WarmupDecayLR" self.scheduler_dict = { "type": self.lr_scheduler, "params": { "total_num_steps": self.lr_decay_iters, "warmup_min_lr": self.min_lr, "warmup_max_lr": self.lr, "warmup_num_steps": self.warmup_num_steps, }, } self.deepspeed_config_params = { "train_batch_size": self.batch_size, "gradient_accumulation_steps": self.gradient_accumulation_steps, "gradient_clipping": self.gradient_clipping, "fp16": {"enabled": True, "loss_scale_window": 250}, "scheduler": self.scheduler_dict, "zero_optimization": { "stage": self.zero_stage, "load_from_fp32_weights": False, }, } if self.name is None: self.name = str(uuid.uuid4())[:8] @classmethod def from_yml(cls, path): return cls(**load_config(path)) def to_dict(self): return asdict(self)