import json import os from dataclasses import dataclass, field from typing import List import pyrallis import torch from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from torchvision.utils import save_image from tqdm import tqdm from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \ generate_original_image from src.null_text_inversion import invert_image from src.prompt_mixing import PromptMixing from src.prompt_to_prompt_controllers import AttentionStore, AttentionReplace from src.prompt_utils import get_proxy_prompts def save_args_dict(args, similar_words): exp_path = os.path.join(args.exp_dir, args.prompt.replace(' ', '-'), f"seed={args.seed}_{args.exp_name}") os.makedirs(exp_path, exist_ok=True) args_dict = vars(args) args_dict['similar_words'] = similar_words with open(os.path.join(exp_path, "opt.json"), 'w') as fp: json.dump(args_dict, fp, sort_keys=True, indent=4) return exp_path def setup(args): ldm_stable = get_stable_diffusion_model(args) ldm_stable_config = get_stable_diffusion_config(args) return ldm_stable, ldm_stable_config def main(ldm_stable, ldm_stable_config, args): similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable) exp_path = save_args_dict(args, similar_words) images = [] x_t = None uncond_embeddings = None if args.real_image_path != "": ldm_stable, ldm_stable_config = setup(args) x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path) image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings) save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg") save_image(torch.from_numpy(orig_mask).float(), f"{exp_path}/{similar_words[0]}_mask.jpg") images.append(image[0]) object_of_interest_index = args.prompt.split().index('{word}') + 1 pm = PromptMixing(args, object_of_interest_index, average_attention) do_other_obj_self_attn_masking = len(args.objects_to_preserve) > 0 and args.end_preserved_obj_self_attn_masking > 0 do_self_or_cross_attn_inject = args.cross_attn_inject_steps != 0.0 or args.self_attn_inject_steps != 0.0 if do_other_obj_self_attn_masking: print("Do self attn other obj masking") if do_self_or_cross_attn_inject: print(f'Do self attn inject for {args.self_attn_inject_steps} steps') print(f'Do cross attn inject for {args.cross_attn_inject_steps} steps') another_prompts_dataloader = DataLoader(another_prompts[1:], batch_size=args.batch_size, shuffle=False) for another_prompt_batch in tqdm(another_prompts_dataloader): batch_size = len(another_prompt_batch["word"]) batch_prompts = prompts * batch_size batch_another_prompt = another_prompt_batch["prompt"] if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking: batch_prompts.append(prompts[0]) batch_another_prompt.insert(0, prompts[0]) if do_self_or_cross_attn_inject: controller = AttentionReplace(batch_another_prompt, ldm_stable.tokenizer, ldm_stable.device, ldm_stable_config["low_resource"], ldm_stable_config["num_diffusion_steps"], cross_replace_steps=args.cross_attn_inject_steps, self_replace_steps=args.self_attn_inject_steps) else: controller = AttentionStore(ldm_stable_config["low_resource"]) diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, prompt_mixing=pm) with torch.no_grad(): image, x_t, _, mask = diffusion_model_wrapper.forward(batch_prompts, latent=x_t, other_prompt=batch_another_prompt, post_background=args.background_post_process, orig_all_latents=orig_all_latents, orig_mask=orig_mask, uncond_embeddings=uncond_embeddings) for i in range(batch_size): image_index = i + 1 if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking else i save_image(ToTensor()(image[image_index]), f"{exp_path}/{another_prompt_batch['word'][i]}.jpg") if mask is not None: save_image(torch.from_numpy(mask).float(), f"{exp_path}/{another_prompt_batch['word'][i]}_mask.jpg") images.append(image[image_index]) images = [ToTensor()(image) for image in images] save_image(images, f"{exp_path}/grid.jpg", nrow=min(max([i for i in range(2, 8) if len(images) % i == 0]), 8)) return images, similar_words @dataclass class LPMConfig: # general config seed: int = 10 batch_size: int = 1 exp_dir: str = "results" exp_name: str = "" display_images: bool = False gpu_id: int = 0 # Stable Diffusion config auth_token: str = "" low_resource: bool = True num_diffusion_steps: int = 50 guidance_scale: float = 7.5 max_num_words: int = 77 # prompt-mixing prompt: str = "a {word} in the field eats an apple" object_of_interest: str = "snake" # The object for which we generate variations proxy_words: List[str] = field(default_factory=lambda :[]) # Leave empty for automatic proxy words number_of_variations: int = 20 start_prompt_range: int = 7 # Number of steps to begin prompt-mixing end_prompt_range: int = 17 # Number of steps to finish prompt-mixing # attention based shape localization objects_to_preserve: List[str] = field(default_factory=lambda :[]) # Objects for which apply attention based shape localization remove_obj_from_self_mask: bool = True # If set to True, removes the object of interest from the self-attention mask obj_pixels_injection_threshold: float = 0.05 end_preserved_obj_self_attn_masking: int = 40 # real image real_image_path: str = "" # controllable background preservation background_post_process: bool = True background_nouns: List[str] = field(default_factory=lambda :[]) # Objects to take from the original image in addition to the background num_segments: int = 5 # Number of clusters for the segmentation background_segment_threshold: float = 0.3 # Threshold for the segments labeling background_blend_timestep: int = 35 # Number of steps before background blending # other cross_attn_inject_steps: float = 0.0 self_attn_inject_steps: float = 0.0 if __name__ == '__main__': args = pyrallis.parse(config_class=LPMConfig) print(args) stable, stable_config = setup(args) main(stable, stable_config, args)