Jinl's picture
update requirements
a33ec77
raw
history blame
27.5 kB
import os
import json
import re
from turtle import width
import torch
import random
import numpy as np
import gradio as gr
from glob import glob
from omegaconf import OmegaConf
from datetime import datetime
from safetensors import safe_open
from diffusers import AutoencoderKL,UNet2DConditionModel,StableDiffusionPipeline
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer
from utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from utils.convert_lora_safetensor_to_diffusers import convert_lora
import torch.nn.functional as F
from PIL import Image
from utils.diffuser_utils import MasaCtrlPipeline
from utils.masactrl_utils import (AttentionBase,
regiter_attention_editor_diffusers)
from utils.free_lunch_utils import register_upblock2d,register_crossattn_upblock2d,register_free_upblock2d, register_free_crossattn_upblock2d
from utils.style_attn_control import MaskPromptedStyleAttentionControl
from torchvision.utils import save_image
from diffusers.models.attention_processor import AttnProcessor2_0
css = """
.toolbutton {
margin-buttom: 0em 0em 0em 0em;
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.5em;
}
"""
class GlobalText:
def __init__(self):
# config dirs
self.basedir = os.getcwd()
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
self.personalized_model_dir = './models/Stable-diffusion'
self.lora_model_dir = './models/Lora'
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
self.savedir_sample = os.path.join(self.savedir, "sample")
self.savedir_mask = os.path.join(self.savedir, "mask")
self.stable_diffusion_list = ["runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1"]
self.personalized_model_list = []
self.lora_model_list = []
# config models
self.tokenizer = None
self.text_encoder = None
self.vae = None
self.unet = None
self.pipeline = None
self.lora_loaded = None
self.personal_model_loaded = None
self.lora_model_state_dict = {}
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.refresh_stable_diffusion()
self.refresh_personalized_model()
self.reset_start_code()
def load_base_pipeline(self, model_path):
print(f'loading {model_path} model')
scheduler = DDIMScheduler.from_pretrained(model_path,subfolder="scheduler")
self.pipeline = MasaCtrlPipeline.from_pretrained(model_path,
scheduler=scheduler).to(self.device)
def refresh_stable_diffusion(self):
self.load_base_pipeline(self.stable_diffusion_list[0])
self.lora_loaded = None
self.personal_model_loaded = None
return self.stable_diffusion_list[0]
def refresh_personalized_model(self):
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "**/*.safetensors"), recursive=True)
self.personalized_model_list = {os.path.basename(file): file for file in personalized_model_list}
lora_model_list = glob(os.path.join(self.lora_model_dir, "**/*.safetensors"), recursive=True)
self.lora_model_list = {os.path.basename(file): file for file in lora_model_list}
def update_stable_diffusion(self, stable_diffusion_dropdown):
self.load_base_pipeline(stable_diffusion_dropdown)
self.lora_loaded = None
self.personal_model_loaded = None
return gr.Dropdown.update()
def update_base_model(self, base_model_dropdown):
if self.pipeline is None:
gr.Info(f"Please select a pretrained model path.")
return None
else:
base_model = self.personalized_model_list[base_model_dropdown]
mid_model = StableDiffusionPipeline.from_single_file(base_model)
self.pipeline.vae = mid_model.vae
self.pipeline.unet = mid_model.unet
self.pipeline.text_encoder = mid_model.text_encoder
self.pipeline.to(self.device)
self.personal_model_loaded = base_model_dropdown.split('.')[0]
print(f'load {base_model_dropdown} model success!')
return gr.Dropdown()
def update_lora_model(self, lora_model_dropdown,lora_alpha_slider):
if self.pipeline is None:
gr.Info(f"Please select a pretrained model path.")
return None
else:
if lora_model_dropdown == "none":
self.pipeline.unfuse_lora()
self.pipeline.unload_lora_weights()
self.lora_loaded = None
# self.personal_model_loaded = None
print("Restore lora.")
else:
lora_model_path = self.lora_model_list[lora_model_dropdown]#os.path.join(self.lora_model_dir, lora_model_dropdown)
# self.lora_model_state_dict = {}
# if lora_model_dropdown == "none": pass
# else:
# with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
# for key in f.keys():
# self.lora_model_state_dict[key] = f.get_tensor(key)
# convert_lora(self.pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
self.pipeline.unfuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.load_lora_weights(lora_model_path)
self.pipeline.fuse_lora(lora_alpha_slider)
self.lora_loaded = lora_model_dropdown.split('.')[0]
print(f'load {lora_model_dropdown} model success!')
return gr.Dropdown()
def generate(self, source, style, source_mask, style_mask,
start_step, start_layer, Style_attn_step,
Method, Style_Guidance, ddim_steps, scale, seed, de_bug,
target_prompt, negative_prompt_textbox,
inter_latents,
freeu, b1, b2, s1, s2,
width_slider,height_slider,
):
os.makedirs(self.savedir, exist_ok=True)
os.makedirs(self.savedir_sample, exist_ok=True)
os.makedirs(self.savedir_mask, exist_ok=True)
model = self.pipeline
if seed != -1 and seed != "": torch.manual_seed(int(seed))
else: torch.seed()
seed = torch.initial_seed()
sample_count = len(os.listdir(self.savedir_sample))
os.makedirs(os.path.join(self.savedir_mask, f"results_{sample_count}"), exist_ok=True)
# ref_prompt = [source_prompt, target_prompt]
# prompts = ref_prompt+['']
ref_prompt = [target_prompt, target_prompt]
prompts = ref_prompt+[target_prompt]
source_image,style_image,source_mask,style_mask = load_mask_images(source,style,source_mask,style_mask,self.device,width_slider,height_slider,out_dir=os.path.join(self.savedir_mask, f"results_{sample_count}"))
# global START_CODE, LATENTS_LIST
with torch.no_grad():
#import pdb;pdb.set_trace()
#prev_source
if self.start_code is None and self.latents_list is None:
content_style = torch.cat([style_image, source_image], dim=0)
editor = AttentionBase()
regiter_attention_editor_diffusers(model, editor)
st_code, latents_list = model.invert(content_style,
ref_prompt,
guidance_scale=scale,
num_inference_steps=ddim_steps,
return_intermediates=True)
start_code = torch.cat([st_code, st_code[1:]], dim=0)
self.start_code = start_code
self.latents_list = latents_list
else:
start_code = self.start_code
latents_list = self.latents_list
print('------------------------------------------ Use previous latents ------------------------------------------ ')
#["Without mask", "Only masked region", "Seperate Background Foreground"]
if Method == "Without mask":
style_mask = None
source_mask = None
only_masked_region = False
elif Method == "Only masked region":
assert style_mask is not None and source_mask is not None
only_masked_region = True
else:
assert style_mask is not None and source_mask is not None
only_masked_region = False
controller = MaskPromptedStyleAttentionControl(start_step, start_layer,
style_attn_step=Style_attn_step,
style_guidance=Style_Guidance,
style_mask=style_mask,
source_mask=source_mask,
only_masked_region=only_masked_region,
guidance=scale,
de_bug=de_bug,
)
if freeu:
print(f'++++++++++++++++++ Run with FreeU {b1}_{b2}_{s1}_{s2} ++++++++++++++++')
if Method != "Without mask":
register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask)
register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask)
else:
register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None)
register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None)
else:
print(f'++++++++++++++++++ Run without FreeU ++++++++++++++++')
register_upblock2d(model)
register_crossattn_upblock2d(model)
regiter_attention_editor_diffusers(model, controller)
regiter_attention_editor_diffusers(model, controller)
# inference the synthesized image
generate_image= model(prompts,
width=width_slider,
height=height_slider,
latents=start_code,
guidance_scale=scale,
num_inference_steps=ddim_steps,
ref_intermediate_latents=latents_list if inter_latents else None,
neg_prompt=negative_prompt_textbox,
return_intermediates=False,)
# os.makedirs(os.path.join(output_dir, f"results_{sample_count}"))
save_file_name = f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}.jpg"
if self.lora_loaded != None:
save_file_name = f"lora_{self.lora_loaded}_" + save_file_name
if self.personal_model_loaded != None:
save_file_name = f"personal_{self.personal_model_loaded}_" + save_file_name
#f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}_lora_{self.lora_loaded}.jpg"
save_file_path = os.path.join(self.savedir_sample, save_file_name)
#save_file_name = os.path.join(output_dir, f"results_style_{style_name}", f"{content_name}.jpg")
save_image(torch.cat([source_image/2 + 0.5, style_image/2 + 0.5, generate_image[2:]], dim=0), save_file_path, nrow=3, padding=0)
# global OUTPUT_RESULT
# OUTPUT_RESULT = save_file_name
generate_image = generate_image.cpu().permute(0, 2, 3, 1).numpy()
#save_gif(latents_list, os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif'))
# import pdb;pdb.set_trace()
#gif_dir = os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif')
return [
generate_image[0],
generate_image[1],
generate_image[2],
]
def reset_start_code(self,):
self.start_code = None
self.latents_list = None
global_text = GlobalText()
def load_mask_images(source,style,source_mask,style_mask,device,width,height,out_dir=None):
# invert the image into noise map
if isinstance(source['image'], np.ndarray):
source_image = torch.from_numpy(source['image']).to(device) / 127.5 - 1.
else:
source_image = torch.from_numpy(np.array(source['image'])).to(device) / 127.5 - 1.
source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2)
source_image = F.interpolate(source_image, (height,width ))
if out_dir is not None and source_mask is None:
source['mask'].save(os.path.join(out_dir,'source_mask.jpg'))
else:
Image.fromarray(source_mask).save(os.path.join(out_dir,'source_mask.jpg'))
if out_dir is not None and style_mask is None:
style['mask'].save(os.path.join(out_dir,'style_mask.jpg'))
else:
Image.fromarray(style_mask).save(os.path.join(out_dir,'style_mask.jpg'))
# save source['mask']
# import pdb;pdb.set_trace()
source_mask = torch.from_numpy(np.array(source['mask']) if source_mask is None else source_mask).to(device) / 255.
source_mask = source_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1]
source_mask = F.interpolate(source_mask, (height//8,width//8))
if isinstance(source['image'], np.ndarray):
style_image = torch.from_numpy(style['image']).to(device) / 127.5 - 1.
else:
style_image = torch.from_numpy(np.array(style['image'])).to(device) / 127.5 - 1.
style_image = style_image.unsqueeze(0).permute(0, 3, 1, 2)
style_image = F.interpolate(style_image, (height,width))
style_mask = torch.from_numpy(np.array(style['mask']) if style_mask is None else style_mask ).to(device) / 255.
style_mask = style_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1]
style_mask = F.interpolate(style_mask, (height//8,width//8))
return source_image,style_image,source_mask,style_mask
def ui():
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# [Portrait Diffusion: Training-free Face Stylization with Chain-of-Painting](https://arxiv.org/abs/00000)
Jin Liu, Huaibo Huang, Chao Jin, Ran He* (*Corresponding Author)<br>
[Arxiv Report](https://arxiv.org/abs/0000) | [Project Page](https://www.github.io/) | [Github](https://github.com/)
"""
)
with gr.Column(variant="panel"):
gr.Markdown(
"""
### 1. Select a pretrained model.
"""
)
with gr.Row():
stable_diffusion_dropdown = gr.Dropdown(
label="Pretrained Model Path",
choices=global_text.stable_diffusion_list,
interactive=True,
allow_custom_value=True
)
stable_diffusion_dropdown.change(fn=global_text.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
def update_stable_diffusion():
global_text.refresh_stable_diffusion()
stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[])
base_model_dropdown = gr.Dropdown(
label="Select a ckpt model (optional)",
choices=sorted(list(global_text.personalized_model_list.keys())),
interactive=True,
allow_custom_value=True,
)
base_model_dropdown.change(fn=global_text.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
lora_model_dropdown = gr.Dropdown(
label="Select a LoRA model (optional)",
choices=["none"] + sorted(list(global_text.lora_model_list.keys())),
value="none",
interactive=True,
allow_custom_value=True,
)
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
lora_model_dropdown.change(fn=global_text.update_lora_model, inputs=[lora_model_dropdown,lora_alpha_slider], outputs=[lora_model_dropdown])
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
def update_personalized_model():
global_text.refresh_personalized_model()
return [
gr.Dropdown(choices=sorted(list(global_text.personalized_model_list.keys()))),
gr.Dropdown(choices=["none"] + sorted(list(global_text.lora_model_list.keys())))
]
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
with gr.Column(variant="panel"):
gr.Markdown(
"""
### 2. Configs for PortraitDiff.
"""
)
with gr.Tab("Configs"):
with gr.Row():
source_image = gr.Image(label="Source Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512)
style_image = gr.Image(label="Style Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512)
with gr.Row():
prompt_textbox = gr.Textbox(label="Prompt", value='head', lines=1)
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
# output_dir = gr.Textbox(label="output_dir", value='./results/')
with gr.Row().style(equal_height=False):
with gr.Column():
width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
Method = gr.Dropdown(
["Without mask", "Only masked region", "Seperate Background Foreground"],
value="Without mask",
label="Mask", info="Select how to use masks")
with gr.Tab('Base Configs'):
with gr.Row():
# sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
ddim_steps = gr.Slider(label="DDIM Steps", value=50, minimum=10, maximum=100, step=1)
Style_attn_step = gr.Slider(label="Step of Style Attention Control",
minimum=0,
maximum=50,
value=35,
step=1)
start_step = gr.Slider(label="Step of Attention Control",
minimum=0,
maximum=150,
value=0,
step=1)
start_layer = gr.Slider(label="Layer of Style Attention Control",
minimum=0,
maximum=16,
value=10,
step=1)
Style_Guidance = gr.Slider(label="Style Guidance Scale",
minimum=0,
maximum=4,
value=1.2,
step=0.05)
cfg_scale_slider = gr.Slider(label="CFG Scale", value=0, minimum=0, maximum=20)
with gr.Tab('FreeU'):
with gr.Row():
freeu = gr.Checkbox(label="Free Upblock", value=False)
de_bug = gr.Checkbox(value=False,label='DeBug')
inter_latents = gr.Checkbox(value=True,label='Use intermediate latents')
with gr.Row():
b1 = gr.Slider(label='b1:',
minimum=-1,
maximum=2,
step=0.01,
value=1.3)
b2 = gr.Slider(label='b2:',
minimum=-1,
maximum=2,
step=0.01,
value=1.5)
with gr.Row():
s1 = gr.Slider(label='s1: ',
minimum=0,
maximum=2,
step=0.1,
value=1.0)
s2 = gr.Slider(label='s2:',
minimum=0,
maximum=2,
step=0.1,
value=1.0)
with gr.Row():
seed_textbox = gr.Textbox(label="Seed", value=-1)
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
seed_button.click(fn=lambda: random.randint(1, 1e8), inputs=[], outputs=[seed_textbox])
with gr.Column():
generate_button = gr.Button(value="Generate", variant='primary')
generate_image = gr.Image(label="Image with PortraitDiff", interactive=False, type='numpy', height=512,)
with gr.Row():
recons_content = gr.Image(label="reconstructed content", type="pil", image_mode="RGB", height=256)
recons_style = gr.Image(label="reconstructed style", type="pil", image_mode="RGB", height=256)
with gr.Tab("SAM"):
with gr.Column():
add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
with gr.Row():
sam_source_btn = gr.Button(value="SAM Source")
send_source_btn = gr.Button(value="Send Source")
sam_style_btn = gr.Button(value="SAM Style")
send_style_btn = gr.Button(value="Send Style")
with gr.Row():
source_image_sam = gr.Image(label="Source Image SAM", elem_id="SourceimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
style_image_sam = gr.Image(label="Style Image SAM", elem_id="StyleimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
with gr.Row():
source_image_with_points = gr.Image(label="source Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256)
source_mask = gr.Image(label="Source Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256)
style_image_with_points = gr.Image(label="Style Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256)
style_mask = gr.Image(label="Style Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256)
gr.Examples(
[[os.path.join(os.path.dirname(__file__), "images/content/1.jpg"),
os.path.join(os.path.dirname(__file__), "images/style/1.jpg")],
],
[source_image, style_image]
)
inputs = [
source_image, style_image, source_mask, style_mask,
start_step, start_layer, Style_attn_step,
Method, Style_Guidance,ddim_steps, cfg_scale_slider, seed_textbox, de_bug,
prompt_textbox, negative_prompt_textbox, inter_latents,
freeu, b1, b2, s1, s2,
width_slider,height_slider,
]
generate_button.click(
fn=global_text.generate,
inputs=inputs,
outputs=[recons_style,recons_content,generate_image]
)
source_image.upload(global_text.reset_start_code, inputs=[], outputs=[])
style_image.upload(global_text.reset_start_code, inputs=[], outputs=[])
ddim_steps.change(fn=global_text.reset_start_code, inputs=[], outputs=[])
return demo
if __name__ == "__main__":
demo = ui()
demo.launch(server_name="172.18.32.44")