from diffusers import UnCLIPPipeline, DiffusionPipeline import torch import os from lora_diffusion.cli_lora_pti import * from lora_diffusion.lora import * from PIL import Image import numpy as np import json from lora_dataset import PivotalTuningDatasetCapation as PVD UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE def save_all( unet, text_encoder, save_path, placeholder_token_ids=None, placeholder_tokens=None, save_lora=True, save_ti=True, target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, target_replace_module_unet=DEFAULT_TARGET_REPLACE, safe_form=True, ): if not safe_form: # save ti if save_ti: ti_path = ti_lora_path(save_path) learned_embeds_dict = {} for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] print( f"Current Learned Embeddings for {tok}:, id {tok_id} ", learned_embeds[:4], ) learned_embeds_dict[tok] = learned_embeds.detach().cpu() torch.save(learned_embeds_dict, ti_path) print("Ti saved to ", ti_path) # save text encoder if save_lora: save_lora_weight( unet, save_path, target_replace_module=target_replace_module_unet ) print("Unet saved to ", save_path) save_lora_weight( text_encoder, _text_lora_path(save_path), target_replace_module=target_replace_module_text, ) print("Text Encoder saved to ", _text_lora_path(save_path)) else: assert save_path.endswith( ".safetensors" ), f"Save path : {save_path} should end with .safetensors" loras = {} embeds = {} if save_lora: loras["unet"] = (unet, target_replace_module_unet) loras["text_encoder"] = (text_encoder, target_replace_module_text) if save_ti: for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] print( f"Current Learned Embeddings for {tok}:, id {tok_id} ", learned_embeds[:4], ) embeds[tok] = learned_embeds.detach().cpu() return save_safeloras_with_embeds(loras, embeds, save_path) def save_safeloras_with_embeds( modelmap = {}, embeds = {}, outpath="./lora.safetensors", ): """ Saves the Lora from multiple modules in a single safetensor file. modelmap is a dictionary of { "module name": (module, target_replace_module) } """ weights = {} metadata = {} for name, (model, target_replace_module) in modelmap.items(): metadata[name] = json.dumps(list(target_replace_module)) for i, (_up, _down) in enumerate( extract_lora_as_tensor(model, target_replace_module) ): rank = _down.shape[0] metadata[f"{name}:{i}:rank"] = str(rank) weights[f"{name}:{i}:up"] = _up weights[f"{name}:{i}:down"] = _down for token, tensor in embeds.items(): metadata[token] = EMBED_FLAG weights[token] = tensor sorted_dict = {key: value for key, value in sorted(weights.items())} state={} state['weights']=sorted_dict state['metadata'] = metadata # print(sorted_dict.keys()) # # print('meta', metadata) # print(f"Saving weights to {outpath}") # safe_save(weights, outpath, metadata) return state def perform_tuning( unet, vae, text_encoder, dataloader, num_steps, scheduler, optimizer, save_steps: int, placeholder_token_ids, placeholder_tokens, save_path, lr_scheduler_lora, lora_unet_target_modules, lora_clip_target_modules, mask_temperature, out_name: str, tokenizer, test_image_path: str, cached_latents: bool, log_wandb: bool = False, wandb_log_prompt_cnt: int = 10, class_token: str = "person", train_inpainting: bool = False, ): progress_bar = tqdm(range(num_steps)) progress_bar.set_description("Steps") global_step = 0 weight_dtype = torch.float16 unet.train() text_encoder.train() if log_wandb: preped_clip = prepare_clip_model_sets() loss_sum = 0.0 for epoch in range(math.ceil(num_steps / len(dataloader))): for batch in dataloader: lr_scheduler_lora.step() optimizer.zero_grad() loss = loss_step( batch, unet, vae, text_encoder, scheduler, train_inpainting=train_inpainting, t_mutliplier=0.8, mixed_precision=True, mask_temperature=mask_temperature, cached_latents=cached_latents, ) loss_sum += loss.detach().item() loss.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 ) optimizer.step() progress_bar.update(1) logs = { "loss": loss.detach().item(), "lr": lr_scheduler_lora.get_last_lr()[0], } progress_bar.set_postfix(**logs) global_step += 1 if global_step % save_steps == 0: save_all( unet, text_encoder, placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens, save_path=os.path.join( save_path, f"step_{global_step}.safetensors" ), target_replace_module_text=lora_clip_target_modules, target_replace_module_unet=lora_unet_target_modules, ) moved = ( torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) .mean() .item() ) print("LORA Unet Moved", moved) moved = ( torch.tensor( list(itertools.chain(*inspect_lora(text_encoder).values())) ) .mean() .item() ) print("LORA CLIP Moved", moved) if log_wandb: with torch.no_grad(): pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=None, ) # open all images in test_image_path images = [] for file in os.listdir(test_image_path): if file.endswith(".png") or file.endswith(".jpg"): images.append( Image.open(os.path.join(test_image_path, file)) ) wandb.log({"loss": loss_sum / save_steps}) loss_sum = 0.0 wandb.log( evaluate_pipe( pipe, target_images=images, class_token=class_token, learnt_token="".join(placeholder_tokens), n_test=wandb_log_prompt_cnt, n_step=50, clip_model_sets=preped_clip, ) ) if global_step >= num_steps: break return save_all( unet, text_encoder, placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens, save_path=os.path.join(save_path, f"{out_name}.safetensors"), target_replace_module_text=lora_clip_target_modules, target_replace_module_unet=lora_unet_target_modules, ) def train( images, caption, pretrained_model_name_or_path: str, train_text_encoder: bool = True, pretrained_vae_name_or_path: str = None, revision: Optional[str] = None, perform_inversion: bool = True, use_template: Literal[None, "object", "style"] = None, train_inpainting: bool = False, placeholder_tokens: str = "", placeholder_token_at_data: Optional[str] = None, initializer_tokens: Optional[str] = None, seed: int = 42, resolution: int = 512, color_jitter: bool = True, train_batch_size: int = 1, sample_batch_size: int = 1, max_train_steps_tuning: int = 1000, max_train_steps_ti: int = 1000, save_steps: int = 100, gradient_accumulation_steps: int = 4, gradient_checkpointing: bool = False, lora_rank: int = 4, lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, lora_clip_target_modules={"CLIPAttention"}, lora_dropout_p: float = 0.0, lora_scale: float = 1.0, use_extended_lora: bool = False, clip_ti_decay: bool = True, learning_rate_unet: float = 1e-4, learning_rate_text: float = 1e-5, learning_rate_ti: float = 5e-4, continue_inversion: bool = False, continue_inversion_lr: Optional[float] = None, use_face_segmentation_condition: bool = False, cached_latents: bool = True, use_mask_captioned_data: bool = False, mask_temperature: float = 1.0, scale_lr: bool = False, lr_scheduler: str = "linear", lr_warmup_steps: int = 0, lr_scheduler_lora: str = "linear", lr_warmup_steps_lora: int = 0, weight_decay_ti: float = 0.00, weight_decay_lora: float = 0.001, use_8bit_adam: bool = False, device="cuda:0", extra_args: Optional[dict] = None, log_wandb: bool = False, wandb_log_prompt_cnt: int = 10, wandb_project_name: str = "new_pti_project", wandb_entity: str = "new_pti_entity", proxy_token: str = "person", enable_xformers_memory_efficient_attention: bool = False, out_name: str = "final_lora", ): torch.manual_seed(seed) # print(placeholder_tokens, initializer_tokens) if len(placeholder_tokens) == 0: placeholder_tokens = [] print("PTI : Placeholder Tokens not given, using null token") else: placeholder_tokens = placeholder_tokens.split("|") assert ( sorted(placeholder_tokens) == placeholder_tokens ), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'" if initializer_tokens is None: print("PTI : Initializer Tokens not given, doing random inits") initializer_tokens = [""] * len(placeholder_tokens) else: initializer_tokens = initializer_tokens.split("|") assert len(initializer_tokens) == len( placeholder_tokens ), "Unequal Initializer token for Placeholder tokens." if proxy_token is not None: class_token = proxy_token class_token = "".join(initializer_tokens) if placeholder_token_at_data is not None: tok, pat = placeholder_token_at_data.split("|") token_map = {tok: pat} else: token_map = {"DUMMY": "".join(placeholder_tokens)} print("PTI : Placeholder Tokens", placeholder_tokens) print("PTI : Initializer Tokens", initializer_tokens) # get the models text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models( pretrained_model_name_or_path, pretrained_vae_name_or_path, revision, placeholder_tokens, initializer_tokens, device=device, ) noise_scheduler = DDPMScheduler.from_config( pretrained_model_name_or_path, subfolder="scheduler" ) if gradient_checkpointing: unet.enable_gradient_checkpointing() if enable_xformers_memory_efficient_attention: from diffusers.utils.import_utils import is_xformers_available if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: raise ValueError( "xformers is not available. Make sure it is installed correctly" ) if scale_lr: unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size text_encoder_lr = ( learning_rate_text * gradient_accumulation_steps * train_batch_size ) ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size else: unet_lr = learning_rate_unet text_encoder_lr = learning_rate_text ti_lr = learning_rate_ti train_dataset = PVD( images=images, caption=caption, token_map=token_map, use_template=use_template, tokenizer=tokenizer, size=resolution, color_jitter=color_jitter, use_face_segmentation_condition=use_face_segmentation_condition, use_mask_captioned_data=use_mask_captioned_data, train_inpainting=train_inpainting, ) train_dataset.blur_amount = 200 if train_inpainting: assert not cached_latents, "Cached latents not supported for inpainting" train_dataloader = inpainting_dataloader( train_dataset, train_batch_size, tokenizer, vae, text_encoder ) else: print(cached_latents) train_dataloader = text2img_dataloader( train_dataset, train_batch_size, tokenizer, vae, text_encoder, cached_latents=cached_latents, ) index_no_updates = torch.arange(len(tokenizer)) != -1 for tok_id in placeholder_token_ids: index_no_updates[tok_id] = False unet.requires_grad_(False) vae.requires_grad_(False) params_to_freeze = itertools.chain( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), ) for param in params_to_freeze: param.requires_grad = False if cached_latents: vae = None # STEP 1 : Perform Inversion if perform_inversion: ti_optimizer = optim.AdamW( text_encoder.get_input_embeddings().parameters(), lr=ti_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay_ti, ) lr_scheduler = get_scheduler( lr_scheduler, optimizer=ti_optimizer, num_warmup_steps=lr_warmup_steps, num_training_steps=max_train_steps_ti, ) train_inversion( unet, vae, text_encoder, train_dataloader, max_train_steps_ti, cached_latents=cached_latents, accum_iter=gradient_accumulation_steps, scheduler=noise_scheduler, index_no_updates=index_no_updates, optimizer=ti_optimizer, lr_scheduler=lr_scheduler, save_steps=save_steps, placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, save_path="./tmps", test_image_path="./tmps", log_wandb=log_wandb, wandb_log_prompt_cnt=wandb_log_prompt_cnt, class_token=class_token, train_inpainting=train_inpainting, mixed_precision=False, tokenizer=tokenizer, clip_ti_decay=clip_ti_decay, ) del ti_optimizer # Next perform Tuning with LoRA: if not use_extended_lora: unet_lora_params, _ = inject_trainable_lora( unet, r=lora_rank, target_replace_module=lora_unet_target_modules, dropout_p=lora_dropout_p, scale=lora_scale, ) else: print("PTI : USING EXTENDED UNET!!!") lora_unet_target_modules = ( lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE ) print("PTI : Will replace modules: ", lora_unet_target_modules) unet_lora_params, _ = inject_trainable_lora_extended( unet, r=lora_rank, target_replace_module=lora_unet_target_modules ) print(f"PTI : has {len(unet_lora_params)} lora") print("PTI : Before training:") inspect_lora(unet) params_to_optimize = [ {"params": itertools.chain(*unet_lora_params), "lr": unet_lr}, ] text_encoder.requires_grad_(False) if continue_inversion: params_to_optimize += [ { "params": text_encoder.get_input_embeddings().parameters(), "lr": continue_inversion_lr if continue_inversion_lr is not None else ti_lr, } ] text_encoder.requires_grad_(True) params_to_freeze = itertools.chain( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), ) for param in params_to_freeze: param.requires_grad = False else: text_encoder.requires_grad_(False) if train_text_encoder: text_encoder_lora_params, _ = inject_trainable_lora( text_encoder, target_replace_module=lora_clip_target_modules, r=lora_rank, ) params_to_optimize += [ { "params": itertools.chain(*text_encoder_lora_params), "lr": text_encoder_lr, } ] inspect_lora(text_encoder) lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora) unet.train() if train_text_encoder: text_encoder.train() train_dataset.blur_amount = 70 lr_scheduler_lora = get_scheduler( lr_scheduler_lora, optimizer=lora_optimizers, num_warmup_steps=lr_warmup_steps_lora, num_training_steps=max_train_steps_tuning, ) return perform_tuning( unet, vae, text_encoder, train_dataloader, max_train_steps_tuning, cached_latents=cached_latents, scheduler=noise_scheduler, optimizer=lora_optimizers, save_steps=save_steps, placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, save_path="./tmps", lr_scheduler_lora=lr_scheduler_lora, lora_unet_target_modules=lora_unet_target_modules, lora_clip_target_modules=lora_clip_target_modules, mask_temperature=mask_temperature, tokenizer=tokenizer, out_name=out_name, test_image_path="./tmps", log_wandb=log_wandb, wandb_log_prompt_cnt=wandb_log_prompt_cnt, class_token=class_token, train_inpainting=train_inpainting, ) def semantic_karlo(prompt, output_dir, num_initial_image, bg_preprocess=False): pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16) pipe = pipe.to('cuda') view_prompt=["front view of ","overhead view of ","side view of ", "back view of "] if bg_preprocess: # Please refer to the code at https://github.com/Ir1d/image-background-remove-tool. import cv2 from carvekit.api.high import HiInterface interface = HiInterface(object_type="object", batch_size_seg=5, batch_size_matting=1, device='cuda' if torch.cuda.is_available() else 'cpu', seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=False) for i in range(num_initial_image): t=", white background" if bg_preprocess else ", white background" if i==0: prompt_ = f"{view_prompt[i%4]}{prompt}{t}" else: prompt_ = f"{view_prompt[i%4]}{prompt}" image = pipe(prompt_).images[0] fn=f"instance{i}.png" os.makedirs(output_dir,exist_ok=True) if bg_preprocess: # motivated by NeuralLift-360 (removing bg), and Zero-1-to-3 (removing bg and object-centering) # NOTE: This option was added during the code orgranization process. # The results reported in the paper were obtained with [bg_preprocess: False] setting. img_without_background = interface([image]) mask = np.array(img_without_background[0]) > 127 image = np.array(image) image[~mask] = [255., 255., 255.] # x, y, w, h = cv2.boundingRect(mask.astype(np.uint8)) # image = image[y:y+h, x:x+w, :] image = Image.fromarray(np.array(image)) image.save(os.path.join(output_dir,fn)) def semantic_sd(prompt, output_dir, num_initial_image, bg_preprocess=False): pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to('cuda') view_prompt=["front view of ","overhead view of ","side view of ", "back view of "] if bg_preprocess: # Please refer to the code at https://github.com/Ir1d/image-background-remove-tool. import cv2 from carvekit.api.high import HiInterface interface = HiInterface(object_type="object", batch_size_seg=5, batch_size_matting=1, device='cuda' if torch.cuda.is_available() else 'cpu', seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=False) for i in range(num_initial_image): t=", white background" if bg_preprocess else ", white background" if i==0: prompt_ = f"{view_prompt[i%4]}{prompt}{t}" else: prompt_ = f"{view_prompt[i%4]}{prompt}" image = pipe(prompt_).images[0] fn=f"instance{i}.png" os.makedirs(output_dir,exist_ok=True) if bg_preprocess: # motivated by NeuralLift-360 (removing bg), and Zero-1-to-3 (removing bg and object-centering) # NOTE: This option was added during the code orgranization process. # The results reported in the paper were obtained with [bg_preprocess: False] setting. img_without_background = interface([image]) mask = np.array(img_without_background[0]) > 127 image = np.array(image) image[~mask] = [255., 255., 255.] # x, y, w, h = cv2.boundingRect(mask.astype(np.uint8)) # image = image[y:y+h, x:x+w, :] image = Image.fromarray(np.array(image)) image.save(os.path.join(output_dir,fn)) def semantic_coding(images, cfgs,sd,initial): ti_step=cfgs.pop('ti_step') pt_step=cfgs.pop('pt_step') # semantic_model=cfgs.pop('semantic_model') prompt=cfgs['sd']['prompt'] # instance_dir=os.path.join(exp_dir,'initial_image') # weight_dir=os.path.join(exp_dir,'lora') if initial=="": initial=None state=train(images=images, caption=initial, pretrained_model_name_or_path='runwayml/stable-diffusion-v1-5',\ gradient_checkpointing=True,\ scale_lr=True,lora_rank=1,cached_latents=False,save_steps=max(ti_step,pt_step)+1,\ max_train_steps_ti=ti_step,max_train_steps_tuning=pt_step, use_template="object",\ lr_warmup_steps=0, lr_warmup_steps_lora=100, placeholder_tokens="<0>", initializer_tokens=initial,\ continue_inversion=True, continue_inversion_lr=1e-4,device="cuda:0", ) if initial is not None: sd.prompt=prompt.replace(initial,'<0>') else: sd.prompt="a <0>" return state