Spaces:
Runtime error
Runtime error
| import os | |
| import spaces | |
| import torch | |
| import numpy as np | |
| import argparse | |
| from peft import LoraConfig | |
| from pipeline_dedit_sdxl import DEditSDXLPipeline | |
| from pipeline_dedit_sd import DEditSDPipeline | |
| from utils import load_image, load_mask, load_mask_edit | |
| from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch | |
| from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys | |
| def run_main( | |
| name="example_tmp", | |
| name_2=None, | |
| mask_np_list=None, | |
| mask_label_list=None, | |
| image_gt=None, | |
| dpm="sd", | |
| resolution=512, | |
| seed=42, | |
| embedding_learning_rate=1e-4, | |
| max_emb_train_steps=200, | |
| diffusion_model_learning_rate=5e-5, | |
| max_diffusion_train_steps=200, | |
| train_batch_size=1, | |
| gradient_accumulation_steps=1, | |
| num_tokens=1, | |
| load_trained=False , | |
| num_sampling_steps=50, | |
| guidance_scale= 3 , | |
| strength=0.8, | |
| train_full_lora=False , | |
| lora_rank=4, | |
| lora_alpha=4, | |
| prompt_auxin_list = None, | |
| prompt_auxin_idx_list= None, | |
| load_edited_mask=False, | |
| load_edited_processed_mask=False, | |
| edge_thickness=20, | |
| num_imgs= 1 , | |
| active_mask_list = None, | |
| tgt_index=None, | |
| recon=False , | |
| recon_an_item=False, | |
| recon_prompt=None, | |
| text=False, | |
| tgt_prompt=None, | |
| image=False , | |
| src_index=None, | |
| tgt_name=None, | |
| move_resize=False , | |
| tgt_indices_list=None, | |
| delta_x_list=None, | |
| delta_y_list=None, | |
| priority_list=None, | |
| force_mask_remain=None, | |
| resize_list=None, | |
| remove=False, | |
| load_edited_removemask=False | |
| ): | |
| torch.cuda.manual_seed_all(seed) | |
| torch.manual_seed(seed) | |
| base_input_folder = "." | |
| base_output_folder = "." | |
| input_folder = os.path.join(base_input_folder, name) | |
| mask_list = [] | |
| for mask_np in mask_np_list: | |
| mask = torch.from_numpy(mask_np.astype(np.uint8)) | |
| mask_list.append(mask) | |
| #mask_list, mask_label_list = load_mask(input_folder) | |
| assert mask_list[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution) | |
| #try: | |
| # image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(resolution) ), size = resolution) | |
| #except: | |
| # image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(resolution) ), size = resolution) | |
| if image: | |
| input_folder_2 = os.path.join(base_input_folder, name_2) | |
| mask_list_2, mask_label_list_2 = load_mask(input_folder_2) | |
| assert mask_list_2[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution) | |
| try: | |
| image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(resolution) ), size = resolution) | |
| except: | |
| image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(resolution) ), size = resolution) | |
| output_dir = os.path.join(base_output_folder, name + "_" + name_2) | |
| os.makedirs(output_dir, exist_ok = True) | |
| else: | |
| output_dir = os.path.join(base_output_folder, name) | |
| os.makedirs(output_dir, exist_ok = True) | |
| if dpm == "sd": | |
| if image: | |
| pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens) | |
| else: | |
| pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens) | |
| elif dpm == "sdxl": | |
| if image: | |
| pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens) | |
| else: | |
| pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens) | |
| else: | |
| raise NotImplementedError | |
| set_string_list = pipe.set_string_list | |
| if prompt_auxin_list is not None: | |
| for auxin_idx, auxin_prompt in zip(prompt_auxin_idx_list, prompt_auxin_list): | |
| set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] ) | |
| print(set_string_list) | |
| if image: | |
| set_string_list_2 = pipe.set_string_list_2 | |
| print(set_string_list_2) | |
| if load_trained: | |
| unet_save_path = os.path.join(output_dir, "unet.pt") | |
| unet_state_dict = torch.load(unet_save_path) | |
| text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt") | |
| text_encoder1_state_dict = torch.load(text_encoder1_save_path) | |
| if dpm == "sdxl": | |
| text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt") | |
| text_encoder2_state_dict = torch.load(text_encoder2_save_path) | |
| if 'lora' in ''.join(unet_state_dict.keys()): | |
| unet_lora_config = LoraConfig( | |
| r=lora_rank, | |
| lora_alpha=lora_alpha, | |
| init_lora_weights="gaussian", | |
| target_modules=["to_k", "to_q", "to_v", "to_out.0"], | |
| ) | |
| pipe.unet.add_adapter(unet_lora_config) | |
| pipe.unet.load_state_dict(unet_state_dict) | |
| pipe.text_encoder.load_state_dict(text_encoder1_state_dict) | |
| if dpm == "sdxl": | |
| pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict) | |
| else: | |
| if image: | |
| pipe.mask_list = [m.cuda() for m in pipe.mask_list] | |
| pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2] | |
| pipe.train_emb_2imgs( | |
| image_gt, | |
| image_gt_2, | |
| set_string_list, | |
| set_string_list_2, | |
| gradient_accumulation_steps = gradient_accumulation_steps, | |
| embedding_learning_rate = embedding_learning_rate, | |
| max_emb_train_steps = max_emb_train_steps, | |
| train_batch_size = train_batch_size, | |
| ) | |
| pipe.train_model_2imgs( | |
| image_gt, | |
| image_gt_2, | |
| set_string_list, | |
| set_string_list_2, | |
| gradient_accumulation_steps = gradient_accumulation_steps, | |
| max_diffusion_train_steps = max_diffusion_train_steps, | |
| diffusion_model_learning_rate = diffusion_model_learning_rate , | |
| train_batch_size =train_batch_size, | |
| train_full_lora = train_full_lora, | |
| lora_rank = lora_rank, | |
| lora_alpha = lora_alpha | |
| ) | |
| else: | |
| pipe.mask_list = [m.cuda() for m in pipe.mask_list] | |
| pipe.train_emb( | |
| image_gt, | |
| set_string_list, | |
| gradient_accumulation_steps = gradient_accumulation_steps, | |
| embedding_learning_rate = embedding_learning_rate, | |
| max_emb_train_steps = max_emb_train_steps, | |
| train_batch_size = train_batch_size, | |
| ) | |
| pipe.train_model( | |
| image_gt, | |
| set_string_list, | |
| gradient_accumulation_steps = gradient_accumulation_steps, | |
| max_diffusion_train_steps = max_diffusion_train_steps, | |
| diffusion_model_learning_rate = diffusion_model_learning_rate , | |
| train_batch_size = train_batch_size, | |
| train_full_lora = train_full_lora, | |
| lora_rank = lora_rank, | |
| lora_alpha = lora_alpha | |
| ) | |
| unet_save_path = os.path.join(output_dir, "unet.pt") | |
| torch.save(pipe.unet.state_dict(),unet_save_path ) | |
| text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt") | |
| torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path) | |
| if dpm == "sdxl": | |
| text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt") | |
| torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path ) | |
| if recon: | |
| output_dir = os.path.join(output_dir, "recon") | |
| os.makedirs(output_dir, exist_ok = True) | |
| if recon_an_item: | |
| mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))] | |
| tgt_string = set_string_list[tgt_index] | |
| tgt_string = recon_prompt.replace("*", tgt_string) | |
| set_string_list = [tgt_string] | |
| print(set_string_list) | |
| save_path = os.path.join(output_dir, "out_recon.png") | |
| x_np = pipe.inference_with_mask( | |
| save_path, | |
| guidance_scale = guidance_scale, | |
| num_sampling_steps = num_sampling_steps, | |
| seed = seed, | |
| num_imgs = num_imgs, | |
| set_string_list = set_string_list, | |
| mask_list = mask_list | |
| ) | |
| if text: | |
| print("*** Text-guided editing ") | |
| output_dir = os.path.join(output_dir, "text") | |
| os.makedirs(output_dir, exist_ok = True) | |
| save_path = os.path.join(output_dir, "out_text.png") | |
| set_string_list[tgt_index] = tgt_prompt | |
| mask_active = torch.zeros_like(mask_list[0]) | |
| mask_active = mask_union_torch(mask_active, mask_list[tgt_index]) | |
| if active_mask_list is not None: | |
| for midx in active_mask_list: | |
| mask_active = mask_union_torch(mask_active, mask_list[midx]) | |
| if load_edited_mask: | |
| mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder) | |
| mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) | |
| mask_active = mask_union_torch(mask_active, mask_diff) | |
| mask_list = mask_list_edited | |
| save_path = os.path.join(output_dir, "out_textEdited.png") | |
| mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active) | |
| mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness) | |
| mask_hard = mask_substract_torch(mask_hard, mask_soft) | |
| pipe.inference_with_mask( | |
| save_path, | |
| orig_image = image_gt, | |
| set_string_list = set_string_list, | |
| guidance_scale = guidance_scale, | |
| strength = strength, | |
| num_imgs = num_imgs, | |
| mask_hard= mask_hard, | |
| mask_soft = mask_soft, | |
| mask_list = mask_list, | |
| seed = seed, | |
| num_sampling_steps = num_sampling_steps | |
| ) | |
| if remove: | |
| output_dir = os.path.join(output_dir, "remove") | |
| save_path = os.path.join(output_dir, "out_remove.png") | |
| os.makedirs(output_dir, exist_ok = True) | |
| mask_active = torch.zeros_like(mask_list[0]) | |
| if load_edited_mask: | |
| mask_list_edited, _ = load_mask_edit(input_folder) | |
| mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) | |
| mask_active = mask_union_torch(mask_active, mask_diff) | |
| mask_list = mask_list_edited | |
| if load_edited_processed_mask: | |
| # manually edit or draw masks after removing one index, then load | |
| mask_list_processed, _ = load_mask_edit(output_dir) | |
| mask_remain = get_mask_difference_torch(mask_list_processed, mask_list) | |
| else: | |
| # generate masks after removing one index, using nearest neighbor algorithm | |
| mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, tgt_index) | |
| save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask") | |
| visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png")) | |
| check_cover_all_torch(*mask_list_processed) | |
| mask_active = mask_union_torch(mask_active, mask_remain) | |
| if active_mask_list is not None: | |
| for midx in active_mask_list: | |
| mask_active = mask_union_torch(mask_active, mask_list[midx]) | |
| mask_hard = 1 - mask_active | |
| mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = edge_thickness) | |
| mask_hard = mask_substract_torch(mask_hard, mask_soft) | |
| pipe.inference_with_mask( | |
| save_path, | |
| orig_image = image_gt, | |
| guidance_scale = guidance_scale, | |
| strength = strength, | |
| num_imgs = num_imgs, | |
| mask_hard= mask_hard, | |
| mask_soft = mask_soft, | |
| mask_list = mask_list_processed, | |
| seed = seed, | |
| num_sampling_steps = num_sampling_steps | |
| ) | |
| if image: | |
| output_dir = os.path.join(output_dir, "image") | |
| save_path = os.path.join(output_dir, "out_image.png") | |
| os.makedirs(output_dir, exist_ok = True) | |
| mask_active = torch.zeros_like(mask_list[0]) | |
| if None not in (tgt_name, src_index, tgt_index): | |
| if tgt_name == name: | |
| set_string_list_tgt = set_string_list | |
| set_string_list_src = set_string_list_2 | |
| image_tgt = image_gt | |
| if load_edited_mask: | |
| mask_list_edited, _ = load_mask_edit(input_folder) | |
| mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) | |
| mask_active = mask_union_torch(mask_active, mask_diff) | |
| mask_list = mask_list_edited | |
| save_path = os.path.join(output_dir, "out_imageEdited.png") | |
| mask_list_tgt = mask_list | |
| elif tgt_name == name_2: | |
| set_string_list_tgt = set_string_list_2 | |
| set_string_list_src = set_string_list | |
| image_tgt = image_gt_2 | |
| if load_edited_mask: | |
| mask_list_2_edited, _ = load_mask_edit(input_folder_2) | |
| mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2) | |
| mask_active = mask_union_torch(mask_active, mask_diff) | |
| mask_list_2 = mask_list_2_edited | |
| save_path = os.path.join(output_dir, "out_imageEdited.png") | |
| mask_list_tgt = mask_list_2 | |
| else: | |
| exit("tgt_name should be either name or name_2") | |
| set_string_list_tgt[tgt_index] = set_string_list_src[src_index] | |
| mask_active = mask_list_tgt[tgt_index] | |
| mask_frozen = (1-mask_active.float()).to(mask_active.device) | |
| mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = edge_thickness) | |
| mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu()) | |
| mask_list_tgt = [m.cuda() for m in mask_list_tgt] | |
| pipe.inference_with_mask( | |
| save_path, | |
| set_string_list = set_string_list_tgt, | |
| mask_list = mask_list_tgt, | |
| guidance_scale = guidance_scale, | |
| num_sampling_steps = num_sampling_steps, | |
| mask_hard = mask_hard.cuda(), | |
| mask_soft = mask_soft.cuda(), | |
| num_imgs = num_imgs, | |
| orig_image = image_tgt, | |
| strength = strength, | |
| ) | |
| if move_resize: | |
| output_dir = os.path.join(output_dir, "move_resize") | |
| os.makedirs(output_dir, exist_ok = True) | |
| save_path = os.path.join(output_dir, "out_moveresize.png") | |
| mask_active = torch.zeros_like(mask_list[0]) | |
| if load_edited_mask: | |
| mask_list_edited, _ = load_mask_edit(input_folder) | |
| mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) | |
| mask_active = mask_union_torch(mask_active, mask_diff) | |
| mask_list = mask_list_edited | |
| # save_path = os.path.join(output_dir, "out_moveresizeEdited.png") | |
| if load_edited_processed_mask: | |
| mask_list_processed, _ = load_mask_edit(output_dir) | |
| mask_remain = get_mask_difference_torch(mask_list_processed, mask_list) | |
| else: | |
| mask_list_processed, mask_remain = process_mask_move_torch( | |
| mask_list, | |
| tgt_indices_list, | |
| delta_x_list, | |
| delta_y_list, priority_list, | |
| force_mask_remain = force_mask_remain, | |
| resize_list = resize_list | |
| ) | |
| save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask") | |
| visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png")) | |
| active_idxs = tgt_indices_list | |
| mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs]) | |
| mask_active = mask_union_torch(mask_remain, mask_active) | |
| if active_mask_list is not None: | |
| for midx in active_mask_list: | |
| mask_active = mask_union_torch(mask_active, mask_list_processed[midx]) | |
| mask_frozen =(1 - mask_active.float()) | |
| mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness) | |
| mask_hard = mask_substract_torch(mask_frozen, mask_soft) | |
| check_mask_overlap_torch(mask_hard, mask_soft) | |
| pipe.inference_with_mask( | |
| save_path, | |
| strength = strength, | |
| orig_image = image_gt, | |
| guidance_scale = guidance_scale, | |
| num_sampling_steps = num_sampling_steps, | |
| num_imgs = num_imgs, | |
| mask_hard= mask_hard, | |
| mask_soft = mask_soft, | |
| mask_list = mask_list_processed, | |
| seed = seed | |
| ) | |