VikramSingh178's picture
refactor: Update SDXL-LoRA inference pipeline to load multiple adapter weights
ebbf256
raw
history blame
2.39 kB
MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
ADAPTER_NAME = "VikramSingh178/sdxl-lora-finetune-product-caption"
ADAPTER_NAME_2 = "VikramSingh178/Products10k-SDXL-Lora"
VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
DATASET_NAME= "hahminlew/kream-product-blip-captions"
PROJECT_NAME = "Product Photography"
PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
class Config:
def __init__(self):
self.pretrained_model_name_or_path = MODEL_NAME
self.pretrained_vae_model_name_or_path = VAE_NAME
self.revision = None
self.variant = None
self.dataset_name = PRODUCTS_10k_DATASET
self.dataset_config_name = None
self.train_data_dir = None
self.image_column = 'image'
self.caption_column = 'text'
self.validation_prompt = None
self.num_validation_images = 4
self.validation_epochs = 1
self.max_train_samples = 7
self.output_dir = "output"
self.cache_dir = None
self.seed = 42
self.resolution = 512
self.center_crop = True
self.random_flip = True
self.train_text_encoder = False
self.train_batch_size = 64
self.num_train_epochs = 400
self.max_train_steps = None
self.checkpointing_steps = 500
self.checkpoints_total_limit = None
self.resume_from_checkpoint = None
self.gradient_accumulation_steps = 1
self.gradient_checkpointing = False
self.learning_rate = 1e-4
self.scale_lr = False
self.lr_scheduler = "constant"
self.lr_warmup_steps = 500
self.snr_gamma = None
self.allow_tf32 = True
self.dataloader_num_workers = 0
self.use_8bit_adam = True
self.adam_beta1 = 0.9
self.adam_beta2 = 0.999
self.adam_weight_decay = 1e-2
self.adam_epsilon = 1e-08
self.max_grad_norm = 1.0
self.push_to_hub = True
self.hub_token = None
self.prediction_type = None
self.hub_model_id = None
self.logging_dir = "logs"
self.report_to = "wandb"
self.mixed_precision = 'fp16'
self.local_rank = -1
self.enable_xformers_memory_efficient_attention = False
self.noise_offset = 0
self.rank = 4
self.debug_loss = False