import os import json import re from turtle import width import torch import random import numpy as np import gradio as gr from glob import glob from omegaconf import OmegaConf from datetime import datetime from safetensors import safe_open from diffusers import AutoencoderKL,UNet2DConditionModel,StableDiffusionPipeline from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer from utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from utils.convert_lora_safetensor_to_diffusers import convert_lora import torch.nn.functional as F from PIL import Image from utils.diffuser_utils import MasaCtrlPipeline from utils.masactrl_utils import (AttentionBase, regiter_attention_editor_diffusers) from utils.free_lunch_utils import register_upblock2d,register_crossattn_upblock2d,register_free_upblock2d, register_free_crossattn_upblock2d from utils.style_attn_control import MaskPromptedStyleAttentionControl from torchvision.utils import save_image from diffusers.models.attention_processor import AttnProcessor2_0 css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ class GlobalText: def __init__(self): # config dirs self.basedir = os.getcwd() self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") self.personalized_model_dir = './models/Stable-diffusion' self.lora_model_dir = './models/Lora' self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) self.savedir_sample = os.path.join(self.savedir, "sample") self.savedir_mask = os.path.join(self.savedir, "mask") self.stable_diffusion_list = ["runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1"] self.personalized_model_list = [] self.lora_model_list = [] # config models self.tokenizer = None self.text_encoder = None self.vae = None self.unet = None self.pipeline = None self.lora_loaded = None self.personal_model_loaded = None self.lora_model_state_dict = {} self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.refresh_stable_diffusion() self.refresh_personalized_model() self.reset_start_code() def load_base_pipeline(self, model_path): print(f'loading {model_path} model') scheduler = DDIMScheduler.from_pretrained(model_path,subfolder="scheduler") self.pipeline = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler).to(self.device) def refresh_stable_diffusion(self): self.load_base_pipeline(self.stable_diffusion_list[0]) self.lora_loaded = None self.personal_model_loaded = None return self.stable_diffusion_list[0] def refresh_personalized_model(self): personalized_model_list = glob(os.path.join(self.personalized_model_dir, "**/*.safetensors"), recursive=True) self.personalized_model_list = {os.path.basename(file): file for file in personalized_model_list} lora_model_list = glob(os.path.join(self.lora_model_dir, "**/*.safetensors"), recursive=True) self.lora_model_list = {os.path.basename(file): file for file in lora_model_list} def update_stable_diffusion(self, stable_diffusion_dropdown): self.load_base_pipeline(stable_diffusion_dropdown) self.lora_loaded = None self.personal_model_loaded = None return gr.Dropdown.update() def update_base_model(self, base_model_dropdown): if self.pipeline is None: gr.Info(f"Please select a pretrained model path.") return None else: base_model = self.personalized_model_list[base_model_dropdown] mid_model = StableDiffusionPipeline.from_single_file(base_model) self.pipeline.vae = mid_model.vae self.pipeline.unet = mid_model.unet self.pipeline.text_encoder = mid_model.text_encoder self.pipeline.to(self.device) self.personal_model_loaded = base_model_dropdown.split('.')[0] print(f'load {base_model_dropdown} model success!') return gr.Dropdown() def update_lora_model(self, lora_model_dropdown,lora_alpha_slider): if self.pipeline is None: gr.Info(f"Please select a pretrained model path.") return None else: if lora_model_dropdown == "none": self.pipeline.unfuse_lora() self.pipeline.unload_lora_weights() self.lora_loaded = None # self.personal_model_loaded = None print("Restore lora.") else: lora_model_path = self.lora_model_list[lora_model_dropdown]#os.path.join(self.lora_model_dir, lora_model_dropdown) # self.lora_model_state_dict = {} # if lora_model_dropdown == "none": pass # else: # with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: # for key in f.keys(): # self.lora_model_state_dict[key] = f.get_tensor(key) # convert_lora(self.pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) self.pipeline.unfuse_lora() self.pipeline.unload_lora_weights() self.pipeline.load_lora_weights(lora_model_path) self.pipeline.fuse_lora(lora_alpha_slider) self.lora_loaded = lora_model_dropdown.split('.')[0] print(f'load {lora_model_dropdown} model success!') return gr.Dropdown() def generate(self, source, style, source_mask, style_mask, start_step, start_layer, Style_attn_step, Method, Style_Guidance, ddim_steps, scale, seed, de_bug, target_prompt, negative_prompt_textbox, inter_latents, freeu, b1, b2, s1, s2, width_slider,height_slider, ): os.makedirs(self.savedir, exist_ok=True) os.makedirs(self.savedir_sample, exist_ok=True) os.makedirs(self.savedir_mask, exist_ok=True) model = self.pipeline if seed != -1 and seed != "": torch.manual_seed(int(seed)) else: torch.seed() seed = torch.initial_seed() sample_count = len(os.listdir(self.savedir_sample)) os.makedirs(os.path.join(self.savedir_mask, f"results_{sample_count}"), exist_ok=True) # ref_prompt = [source_prompt, target_prompt] # prompts = ref_prompt+[''] ref_prompt = [target_prompt, target_prompt] prompts = ref_prompt+[target_prompt] source_image,style_image,source_mask,style_mask = load_mask_images(source,style,source_mask,style_mask,self.device,width_slider,height_slider,out_dir=os.path.join(self.savedir_mask, f"results_{sample_count}")) # global START_CODE, LATENTS_LIST with torch.no_grad(): #import pdb;pdb.set_trace() #prev_source if self.start_code is None and self.latents_list is None: content_style = torch.cat([style_image, source_image], dim=0) editor = AttentionBase() regiter_attention_editor_diffusers(model, editor) st_code, latents_list = model.invert(content_style, ref_prompt, guidance_scale=scale, num_inference_steps=ddim_steps, return_intermediates=True) start_code = torch.cat([st_code, st_code[1:]], dim=0) self.start_code = start_code self.latents_list = latents_list else: start_code = self.start_code latents_list = self.latents_list print('------------------------------------------ Use previous latents ------------------------------------------ ') #["Without mask", "Only masked region", "Seperate Background Foreground"] if Method == "Without mask": style_mask = None source_mask = None only_masked_region = False elif Method == "Only masked region": assert style_mask is not None and source_mask is not None only_masked_region = True else: assert style_mask is not None and source_mask is not None only_masked_region = False controller = MaskPromptedStyleAttentionControl(start_step, start_layer, style_attn_step=Style_attn_step, style_guidance=Style_Guidance, style_mask=style_mask, source_mask=source_mask, only_masked_region=only_masked_region, guidance=scale, de_bug=de_bug, ) if freeu: print(f'++++++++++++++++++ Run with FreeU {b1}_{b2}_{s1}_{s2} ++++++++++++++++') if Method != "Without mask": register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask) register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask) else: register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None) register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None) else: print(f'++++++++++++++++++ Run without FreeU ++++++++++++++++') register_upblock2d(model) register_crossattn_upblock2d(model) regiter_attention_editor_diffusers(model, controller) regiter_attention_editor_diffusers(model, controller) # inference the synthesized image generate_image= model(prompts, width=width_slider, height=height_slider, latents=start_code, guidance_scale=scale, num_inference_steps=ddim_steps, ref_intermediate_latents=latents_list if inter_latents else None, neg_prompt=negative_prompt_textbox, return_intermediates=False,) # os.makedirs(os.path.join(output_dir, f"results_{sample_count}")) save_file_name = f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}.jpg" if self.lora_loaded != None: save_file_name = f"lora_{self.lora_loaded}_" + save_file_name if self.personal_model_loaded != None: save_file_name = f"personal_{self.personal_model_loaded}_" + save_file_name #f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}_lora_{self.lora_loaded}.jpg" save_file_path = os.path.join(self.savedir_sample, save_file_name) #save_file_name = os.path.join(output_dir, f"results_style_{style_name}", f"{content_name}.jpg") save_image(torch.cat([source_image/2 + 0.5, style_image/2 + 0.5, generate_image[2:]], dim=0), save_file_path, nrow=3, padding=0) # global OUTPUT_RESULT # OUTPUT_RESULT = save_file_name generate_image = generate_image.cpu().permute(0, 2, 3, 1).numpy() #save_gif(latents_list, os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif')) # import pdb;pdb.set_trace() #gif_dir = os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif') return [ generate_image[0], generate_image[1], generate_image[2], ] def reset_start_code(self,): self.start_code = None self.latents_list = None global_text = GlobalText() def load_mask_images(source,style,source_mask,style_mask,device,width,height,out_dir=None): # invert the image into noise map if isinstance(source['image'], np.ndarray): source_image = torch.from_numpy(source['image']).to(device) / 127.5 - 1. else: source_image = torch.from_numpy(np.array(source['image'])).to(device) / 127.5 - 1. source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2) source_image = F.interpolate(source_image, (height,width )) if out_dir is not None and source_mask is None: source['mask'].save(os.path.join(out_dir,'source_mask.jpg')) else: Image.fromarray(source_mask).save(os.path.join(out_dir,'source_mask.jpg')) if out_dir is not None and style_mask is None: style['mask'].save(os.path.join(out_dir,'style_mask.jpg')) else: Image.fromarray(style_mask).save(os.path.join(out_dir,'style_mask.jpg')) # save source['mask'] # import pdb;pdb.set_trace() source_mask = torch.from_numpy(np.array(source['mask']) if source_mask is None else source_mask).to(device) / 255. source_mask = source_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1] source_mask = F.interpolate(source_mask, (height//8,width//8)) if isinstance(source['image'], np.ndarray): style_image = torch.from_numpy(style['image']).to(device) / 127.5 - 1. else: style_image = torch.from_numpy(np.array(style['image'])).to(device) / 127.5 - 1. style_image = style_image.unsqueeze(0).permute(0, 3, 1, 2) style_image = F.interpolate(style_image, (height,width)) style_mask = torch.from_numpy(np.array(style['mask']) if style_mask is None else style_mask ).to(device) / 255. style_mask = style_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1] style_mask = F.interpolate(style_mask, (height//8,width//8)) return source_image,style_image,source_mask,style_mask def ui(): with gr.Blocks(css=css) as demo: gr.Markdown( """ # [Portrait Diffusion: Training-free Face Stylization with Chain-of-Painting](https://arxiv.org/abs/00000) Jin Liu, Huaibo Huang, Chao Jin, Ran He* (*Corresponding Author)
[Arxiv Report](https://arxiv.org/abs/0000) | [Project Page](https://www.github.io/) | [Github](https://github.com/) """ ) with gr.Column(variant="panel"): gr.Markdown( """ ### 1. Select a pretrained model. """ ) with gr.Row(): stable_diffusion_dropdown = gr.Dropdown( label="Pretrained Model Path", choices=global_text.stable_diffusion_list, interactive=True, allow_custom_value=True ) stable_diffusion_dropdown.change(fn=global_text.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") def update_stable_diffusion(): global_text.refresh_stable_diffusion() stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[]) base_model_dropdown = gr.Dropdown( label="Select a ckpt model (optional)", choices=sorted(list(global_text.personalized_model_list.keys())), interactive=True, allow_custom_value=True, ) base_model_dropdown.change(fn=global_text.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown]) lora_model_dropdown = gr.Dropdown( label="Select a LoRA model (optional)", choices=["none"] + sorted(list(global_text.lora_model_list.keys())), value="none", interactive=True, allow_custom_value=True, ) lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) lora_model_dropdown.change(fn=global_text.update_lora_model, inputs=[lora_model_dropdown,lora_alpha_slider], outputs=[lora_model_dropdown]) personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") def update_personalized_model(): global_text.refresh_personalized_model() return [ gr.Dropdown(choices=sorted(list(global_text.personalized_model_list.keys()))), gr.Dropdown(choices=["none"] + sorted(list(global_text.lora_model_list.keys()))) ] personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) with gr.Column(variant="panel"): gr.Markdown( """ ### 2. Configs for PortraitDiff. """ ) with gr.Tab("Configs"): with gr.Row(): source_image = gr.Image(label="Source Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512) style_image = gr.Image(label="Style Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512) with gr.Row(): prompt_textbox = gr.Textbox(label="Prompt", value='head', lines=1) negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1) # output_dir = gr.Textbox(label="output_dir", value='./results/') with gr.Row().style(equal_height=False): with gr.Column(): width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) Method = gr.Dropdown( ["Without mask", "Only masked region", "Seperate Background Foreground"], value="Without mask", label="Mask", info="Select how to use masks") with gr.Tab('Base Configs'): with gr.Row(): # sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) ddim_steps = gr.Slider(label="DDIM Steps", value=50, minimum=10, maximum=100, step=1) Style_attn_step = gr.Slider(label="Step of Style Attention Control", minimum=0, maximum=50, value=35, step=1) start_step = gr.Slider(label="Step of Attention Control", minimum=0, maximum=150, value=0, step=1) start_layer = gr.Slider(label="Layer of Style Attention Control", minimum=0, maximum=16, value=10, step=1) Style_Guidance = gr.Slider(label="Style Guidance Scale", minimum=0, maximum=4, value=1.2, step=0.05) cfg_scale_slider = gr.Slider(label="CFG Scale", value=0, minimum=0, maximum=20) with gr.Tab('FreeU'): with gr.Row(): freeu = gr.Checkbox(label="Free Upblock", value=False) de_bug = gr.Checkbox(value=False,label='DeBug') inter_latents = gr.Checkbox(value=True,label='Use intermediate latents') with gr.Row(): b1 = gr.Slider(label='b1:', minimum=-1, maximum=2, step=0.01, value=1.3) b2 = gr.Slider(label='b2:', minimum=-1, maximum=2, step=0.01, value=1.5) with gr.Row(): s1 = gr.Slider(label='s1: ', minimum=0, maximum=2, step=0.1, value=1.0) s2 = gr.Slider(label='s2:', minimum=0, maximum=2, step=0.1, value=1.0) with gr.Row(): seed_textbox = gr.Textbox(label="Seed", value=-1) seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") seed_button.click(fn=lambda: random.randint(1, 1e8), inputs=[], outputs=[seed_textbox]) with gr.Column(): generate_button = gr.Button(value="Generate", variant='primary') generate_image = gr.Image(label="Image with PortraitDiff", interactive=False, type='numpy', height=512,) with gr.Row(): recons_content = gr.Image(label="reconstructed content", type="pil", image_mode="RGB", height=256) recons_style = gr.Image(label="reconstructed style", type="pil", image_mode="RGB", height=256) with gr.Tab("SAM"): with gr.Column(): add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)") with gr.Row(): sam_source_btn = gr.Button(value="SAM Source") send_source_btn = gr.Button(value="Send Source") sam_style_btn = gr.Button(value="SAM Style") send_style_btn = gr.Button(value="Send Style") with gr.Row(): source_image_sam = gr.Image(label="Source Image SAM", elem_id="SourceimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512) style_image_sam = gr.Image(label="Style Image SAM", elem_id="StyleimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512) with gr.Row(): source_image_with_points = gr.Image(label="source Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256) source_mask = gr.Image(label="Source Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256) style_image_with_points = gr.Image(label="Style Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256) style_mask = gr.Image(label="Style Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256) gr.Examples( [[os.path.join(os.path.dirname(__file__), "images/content/1.jpg"), os.path.join(os.path.dirname(__file__), "images/style/1.jpg")], ], [source_image, style_image] ) inputs = [ source_image, style_image, source_mask, style_mask, start_step, start_layer, Style_attn_step, Method, Style_Guidance,ddim_steps, cfg_scale_slider, seed_textbox, de_bug, prompt_textbox, negative_prompt_textbox, inter_latents, freeu, b1, b2, s1, s2, width_slider,height_slider, ] generate_button.click( fn=global_text.generate, inputs=inputs, outputs=[recons_style,recons_content,generate_image] ) source_image.upload(global_text.reset_start_code, inputs=[], outputs=[]) style_image.upload(global_text.reset_start_code, inputs=[], outputs=[]) ddim_steps.change(fn=global_text.reset_start_code, inputs=[], outputs=[]) return demo if __name__ == "__main__": demo = ui() demo.launch(server_name="172.18.32.44")