Spaces:
Runtime error
Runtime error
import os | |
import json | |
import re | |
from turtle import width | |
import torch | |
import random | |
import numpy as np | |
import gradio as gr | |
from glob import glob | |
from omegaconf import OmegaConf | |
from datetime import datetime | |
from safetensors import safe_open | |
from diffusers import AutoencoderKL,UNet2DConditionModel,StableDiffusionPipeline | |
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler | |
from diffusers.utils.import_utils import is_xformers_available | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint | |
from utils.convert_lora_safetensor_to_diffusers import convert_lora | |
import torch.nn.functional as F | |
from PIL import Image | |
from utils.diffuser_utils import MasaCtrlPipeline | |
from utils.masactrl_utils import (AttentionBase, | |
regiter_attention_editor_diffusers) | |
from utils.free_lunch_utils import register_upblock2d,register_crossattn_upblock2d,register_free_upblock2d, register_free_crossattn_upblock2d | |
from utils.style_attn_control import MaskPromptedStyleAttentionControl | |
from torchvision.utils import save_image | |
from diffusers.models.attention_processor import AttnProcessor2_0 | |
css = """ | |
.toolbutton { | |
margin-buttom: 0em 0em 0em 0em; | |
max-width: 2.5em; | |
min-width: 2.5em !important; | |
height: 2.5em; | |
} | |
""" | |
class GlobalText: | |
def __init__(self): | |
# config dirs | |
self.basedir = os.getcwd() | |
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") | |
self.personalized_model_dir = './models/Stable-diffusion' | |
self.lora_model_dir = './models/Lora' | |
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) | |
self.savedir_sample = os.path.join(self.savedir, "sample") | |
self.savedir_mask = os.path.join(self.savedir, "mask") | |
self.stable_diffusion_list = ["runwayml/stable-diffusion-v1-5", | |
"stabilityai/stable-diffusion-2-1"] | |
self.personalized_model_list = [] | |
self.lora_model_list = [] | |
# config models | |
self.tokenizer = None | |
self.text_encoder = None | |
self.vae = None | |
self.unet = None | |
self.pipeline = None | |
self.lora_loaded = None | |
self.personal_model_loaded = None | |
self.lora_model_state_dict = {} | |
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
self.refresh_stable_diffusion() | |
self.refresh_personalized_model() | |
self.reset_start_code() | |
def load_base_pipeline(self, model_path): | |
print(f'loading {model_path} model') | |
scheduler = DDIMScheduler.from_pretrained(model_path,subfolder="scheduler") | |
self.pipeline = MasaCtrlPipeline.from_pretrained(model_path, | |
scheduler=scheduler).to(self.device) | |
def refresh_stable_diffusion(self): | |
self.load_base_pipeline(self.stable_diffusion_list[0]) | |
self.lora_loaded = None | |
self.personal_model_loaded = None | |
return self.stable_diffusion_list[0] | |
def refresh_personalized_model(self): | |
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "**/*.safetensors"), recursive=True) | |
self.personalized_model_list = {os.path.basename(file): file for file in personalized_model_list} | |
lora_model_list = glob(os.path.join(self.lora_model_dir, "**/*.safetensors"), recursive=True) | |
self.lora_model_list = {os.path.basename(file): file for file in lora_model_list} | |
def update_stable_diffusion(self, stable_diffusion_dropdown): | |
self.load_base_pipeline(stable_diffusion_dropdown) | |
self.lora_loaded = None | |
self.personal_model_loaded = None | |
return gr.Dropdown.update() | |
def update_base_model(self, base_model_dropdown): | |
if self.pipeline is None: | |
gr.Info(f"Please select a pretrained model path.") | |
return None | |
else: | |
base_model = self.personalized_model_list[base_model_dropdown] | |
mid_model = StableDiffusionPipeline.from_single_file(base_model) | |
self.pipeline.vae = mid_model.vae | |
self.pipeline.unet = mid_model.unet | |
self.pipeline.text_encoder = mid_model.text_encoder | |
self.pipeline.to(self.device) | |
self.personal_model_loaded = base_model_dropdown.split('.')[0] | |
print(f'load {base_model_dropdown} model success!') | |
return gr.Dropdown() | |
def update_lora_model(self, lora_model_dropdown,lora_alpha_slider): | |
if self.pipeline is None: | |
gr.Info(f"Please select a pretrained model path.") | |
return None | |
else: | |
if lora_model_dropdown == "none": | |
self.pipeline.unfuse_lora() | |
self.pipeline.unload_lora_weights() | |
self.lora_loaded = None | |
# self.personal_model_loaded = None | |
print("Restore lora.") | |
else: | |
lora_model_path = self.lora_model_list[lora_model_dropdown]#os.path.join(self.lora_model_dir, lora_model_dropdown) | |
# self.lora_model_state_dict = {} | |
# if lora_model_dropdown == "none": pass | |
# else: | |
# with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: | |
# for key in f.keys(): | |
# self.lora_model_state_dict[key] = f.get_tensor(key) | |
# convert_lora(self.pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) | |
self.pipeline.unfuse_lora() | |
self.pipeline.unload_lora_weights() | |
self.pipeline.load_lora_weights(lora_model_path) | |
self.pipeline.fuse_lora(lora_alpha_slider) | |
self.lora_loaded = lora_model_dropdown.split('.')[0] | |
print(f'load {lora_model_dropdown} model success!') | |
return gr.Dropdown() | |
def generate(self, source, style, source_mask, style_mask, | |
start_step, start_layer, Style_attn_step, | |
Method, Style_Guidance, ddim_steps, scale, seed, de_bug, | |
target_prompt, negative_prompt_textbox, | |
inter_latents, | |
freeu, b1, b2, s1, s2, | |
width_slider,height_slider, | |
): | |
os.makedirs(self.savedir, exist_ok=True) | |
os.makedirs(self.savedir_sample, exist_ok=True) | |
os.makedirs(self.savedir_mask, exist_ok=True) | |
model = self.pipeline | |
if seed != -1 and seed != "": torch.manual_seed(int(seed)) | |
else: torch.seed() | |
seed = torch.initial_seed() | |
sample_count = len(os.listdir(self.savedir_sample)) | |
os.makedirs(os.path.join(self.savedir_mask, f"results_{sample_count}"), exist_ok=True) | |
# ref_prompt = [source_prompt, target_prompt] | |
# prompts = ref_prompt+[''] | |
ref_prompt = [target_prompt, target_prompt] | |
prompts = ref_prompt+[target_prompt] | |
source_image,style_image,source_mask,style_mask = load_mask_images(source,style,source_mask,style_mask,self.device,width_slider,height_slider,out_dir=os.path.join(self.savedir_mask, f"results_{sample_count}")) | |
# global START_CODE, LATENTS_LIST | |
with torch.no_grad(): | |
#import pdb;pdb.set_trace() | |
#prev_source | |
if self.start_code is None and self.latents_list is None: | |
content_style = torch.cat([style_image, source_image], dim=0) | |
editor = AttentionBase() | |
regiter_attention_editor_diffusers(model, editor) | |
st_code, latents_list = model.invert(content_style, | |
ref_prompt, | |
guidance_scale=scale, | |
num_inference_steps=ddim_steps, | |
return_intermediates=True) | |
start_code = torch.cat([st_code, st_code[1:]], dim=0) | |
self.start_code = start_code | |
self.latents_list = latents_list | |
else: | |
start_code = self.start_code | |
latents_list = self.latents_list | |
print('------------------------------------------ Use previous latents ------------------------------------------ ') | |
#["Without mask", "Only masked region", "Seperate Background Foreground"] | |
if Method == "Without mask": | |
style_mask = None | |
source_mask = None | |
only_masked_region = False | |
elif Method == "Only masked region": | |
assert style_mask is not None and source_mask is not None | |
only_masked_region = True | |
else: | |
assert style_mask is not None and source_mask is not None | |
only_masked_region = False | |
controller = MaskPromptedStyleAttentionControl(start_step, start_layer, | |
style_attn_step=Style_attn_step, | |
style_guidance=Style_Guidance, | |
style_mask=style_mask, | |
source_mask=source_mask, | |
only_masked_region=only_masked_region, | |
guidance=scale, | |
de_bug=de_bug, | |
) | |
if freeu: | |
print(f'++++++++++++++++++ Run with FreeU {b1}_{b2}_{s1}_{s2} ++++++++++++++++') | |
if Method != "Without mask": | |
register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask) | |
register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask) | |
else: | |
register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None) | |
register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None) | |
else: | |
print(f'++++++++++++++++++ Run without FreeU ++++++++++++++++') | |
register_upblock2d(model) | |
register_crossattn_upblock2d(model) | |
regiter_attention_editor_diffusers(model, controller) | |
regiter_attention_editor_diffusers(model, controller) | |
# inference the synthesized image | |
generate_image= model(prompts, | |
width=width_slider, | |
height=height_slider, | |
latents=start_code, | |
guidance_scale=scale, | |
num_inference_steps=ddim_steps, | |
ref_intermediate_latents=latents_list if inter_latents else None, | |
neg_prompt=negative_prompt_textbox, | |
return_intermediates=False,) | |
# os.makedirs(os.path.join(output_dir, f"results_{sample_count}")) | |
save_file_name = f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}.jpg" | |
if self.lora_loaded != None: | |
save_file_name = f"lora_{self.lora_loaded}_" + save_file_name | |
if self.personal_model_loaded != None: | |
save_file_name = f"personal_{self.personal_model_loaded}_" + save_file_name | |
#f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}_lora_{self.lora_loaded}.jpg" | |
save_file_path = os.path.join(self.savedir_sample, save_file_name) | |
#save_file_name = os.path.join(output_dir, f"results_style_{style_name}", f"{content_name}.jpg") | |
save_image(torch.cat([source_image/2 + 0.5, style_image/2 + 0.5, generate_image[2:]], dim=0), save_file_path, nrow=3, padding=0) | |
# global OUTPUT_RESULT | |
# OUTPUT_RESULT = save_file_name | |
generate_image = generate_image.cpu().permute(0, 2, 3, 1).numpy() | |
#save_gif(latents_list, os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif')) | |
# import pdb;pdb.set_trace() | |
#gif_dir = os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif') | |
return [ | |
generate_image[0], | |
generate_image[1], | |
generate_image[2], | |
] | |
def reset_start_code(self,): | |
self.start_code = None | |
self.latents_list = None | |
global_text = GlobalText() | |
def load_mask_images(source,style,source_mask,style_mask,device,width,height,out_dir=None): | |
# invert the image into noise map | |
if isinstance(source['image'], np.ndarray): | |
source_image = torch.from_numpy(source['image']).to(device) / 127.5 - 1. | |
else: | |
source_image = torch.from_numpy(np.array(source['image'])).to(device) / 127.5 - 1. | |
source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2) | |
source_image = F.interpolate(source_image, (height,width )) | |
if out_dir is not None and source_mask is None: | |
source['mask'].save(os.path.join(out_dir,'source_mask.jpg')) | |
else: | |
Image.fromarray(source_mask).save(os.path.join(out_dir,'source_mask.jpg')) | |
if out_dir is not None and style_mask is None: | |
style['mask'].save(os.path.join(out_dir,'style_mask.jpg')) | |
else: | |
Image.fromarray(style_mask).save(os.path.join(out_dir,'style_mask.jpg')) | |
# save source['mask'] | |
# import pdb;pdb.set_trace() | |
source_mask = torch.from_numpy(np.array(source['mask']) if source_mask is None else source_mask).to(device) / 255. | |
source_mask = source_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1] | |
source_mask = F.interpolate(source_mask, (height//8,width//8)) | |
if isinstance(source['image'], np.ndarray): | |
style_image = torch.from_numpy(style['image']).to(device) / 127.5 - 1. | |
else: | |
style_image = torch.from_numpy(np.array(style['image'])).to(device) / 127.5 - 1. | |
style_image = style_image.unsqueeze(0).permute(0, 3, 1, 2) | |
style_image = F.interpolate(style_image, (height,width)) | |
style_mask = torch.from_numpy(np.array(style['mask']) if style_mask is None else style_mask ).to(device) / 255. | |
style_mask = style_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1] | |
style_mask = F.interpolate(style_mask, (height//8,width//8)) | |
return source_image,style_image,source_mask,style_mask | |
def ui(): | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# [Portrait Diffusion: Training-free Face Stylization with Chain-of-Painting](https://arxiv.org/abs/00000) | |
Jin Liu, Huaibo Huang, Chao Jin, Ran He* (*Corresponding Author)<br> | |
[Arxiv Report](https://arxiv.org/abs/0000) | [Project Page](https://www.github.io/) | [Github](https://github.com/) | |
""" | |
) | |
with gr.Column(variant="panel"): | |
gr.Markdown( | |
""" | |
### 1. Select a pretrained model. | |
""" | |
) | |
with gr.Row(): | |
stable_diffusion_dropdown = gr.Dropdown( | |
label="Pretrained Model Path", | |
choices=global_text.stable_diffusion_list, | |
interactive=True, | |
allow_custom_value=True | |
) | |
stable_diffusion_dropdown.change(fn=global_text.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) | |
stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
def update_stable_diffusion(): | |
global_text.refresh_stable_diffusion() | |
stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[]) | |
base_model_dropdown = gr.Dropdown( | |
label="Select a ckpt model (optional)", | |
choices=sorted(list(global_text.personalized_model_list.keys())), | |
interactive=True, | |
allow_custom_value=True, | |
) | |
base_model_dropdown.change(fn=global_text.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown]) | |
lora_model_dropdown = gr.Dropdown( | |
label="Select a LoRA model (optional)", | |
choices=["none"] + sorted(list(global_text.lora_model_list.keys())), | |
value="none", | |
interactive=True, | |
allow_custom_value=True, | |
) | |
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) | |
lora_model_dropdown.change(fn=global_text.update_lora_model, inputs=[lora_model_dropdown,lora_alpha_slider], outputs=[lora_model_dropdown]) | |
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
def update_personalized_model(): | |
global_text.refresh_personalized_model() | |
return [ | |
gr.Dropdown(choices=sorted(list(global_text.personalized_model_list.keys()))), | |
gr.Dropdown(choices=["none"] + sorted(list(global_text.lora_model_list.keys()))) | |
] | |
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) | |
with gr.Column(variant="panel"): | |
gr.Markdown( | |
""" | |
### 2. Configs for PortraitDiff. | |
""" | |
) | |
with gr.Tab("Configs"): | |
with gr.Row(): | |
source_image = gr.Image(label="Source Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512) | |
style_image = gr.Image(label="Style Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512) | |
with gr.Row(): | |
prompt_textbox = gr.Textbox(label="Prompt", value='head', lines=1) | |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1) | |
# output_dir = gr.Textbox(label="output_dir", value='./results/') | |
with gr.Row().style(equal_height=False): | |
with gr.Column(): | |
width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) | |
height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) | |
Method = gr.Dropdown( | |
["Without mask", "Only masked region", "Seperate Background Foreground"], | |
value="Without mask", | |
label="Mask", info="Select how to use masks") | |
with gr.Tab('Base Configs'): | |
with gr.Row(): | |
# sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) | |
ddim_steps = gr.Slider(label="DDIM Steps", value=50, minimum=10, maximum=100, step=1) | |
Style_attn_step = gr.Slider(label="Step of Style Attention Control", | |
minimum=0, | |
maximum=50, | |
value=35, | |
step=1) | |
start_step = gr.Slider(label="Step of Attention Control", | |
minimum=0, | |
maximum=150, | |
value=0, | |
step=1) | |
start_layer = gr.Slider(label="Layer of Style Attention Control", | |
minimum=0, | |
maximum=16, | |
value=10, | |
step=1) | |
Style_Guidance = gr.Slider(label="Style Guidance Scale", | |
minimum=0, | |
maximum=4, | |
value=1.2, | |
step=0.05) | |
cfg_scale_slider = gr.Slider(label="CFG Scale", value=0, minimum=0, maximum=20) | |
with gr.Tab('FreeU'): | |
with gr.Row(): | |
freeu = gr.Checkbox(label="Free Upblock", value=False) | |
de_bug = gr.Checkbox(value=False,label='DeBug') | |
inter_latents = gr.Checkbox(value=True,label='Use intermediate latents') | |
with gr.Row(): | |
b1 = gr.Slider(label='b1:', | |
minimum=-1, | |
maximum=2, | |
step=0.01, | |
value=1.3) | |
b2 = gr.Slider(label='b2:', | |
minimum=-1, | |
maximum=2, | |
step=0.01, | |
value=1.5) | |
with gr.Row(): | |
s1 = gr.Slider(label='s1: ', | |
minimum=0, | |
maximum=2, | |
step=0.1, | |
value=1.0) | |
s2 = gr.Slider(label='s2:', | |
minimum=0, | |
maximum=2, | |
step=0.1, | |
value=1.0) | |
with gr.Row(): | |
seed_textbox = gr.Textbox(label="Seed", value=-1) | |
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") | |
seed_button.click(fn=lambda: random.randint(1, 1e8), inputs=[], outputs=[seed_textbox]) | |
with gr.Column(): | |
generate_button = gr.Button(value="Generate", variant='primary') | |
generate_image = gr.Image(label="Image with PortraitDiff", interactive=False, type='numpy', height=512,) | |
with gr.Row(): | |
recons_content = gr.Image(label="reconstructed content", type="pil", image_mode="RGB", height=256) | |
recons_style = gr.Image(label="reconstructed style", type="pil", image_mode="RGB", height=256) | |
with gr.Tab("SAM"): | |
with gr.Column(): | |
add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)") | |
with gr.Row(): | |
sam_source_btn = gr.Button(value="SAM Source") | |
send_source_btn = gr.Button(value="Send Source") | |
sam_style_btn = gr.Button(value="SAM Style") | |
send_style_btn = gr.Button(value="Send Style") | |
with gr.Row(): | |
source_image_sam = gr.Image(label="Source Image SAM", elem_id="SourceimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512) | |
style_image_sam = gr.Image(label="Style Image SAM", elem_id="StyleimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512) | |
with gr.Row(): | |
source_image_with_points = gr.Image(label="source Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256) | |
source_mask = gr.Image(label="Source Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256) | |
style_image_with_points = gr.Image(label="Style Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256) | |
style_mask = gr.Image(label="Style Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256) | |
gr.Examples( | |
[[os.path.join(os.path.dirname(__file__), "images/content/1.jpg"), | |
os.path.join(os.path.dirname(__file__), "images/style/1.jpg")], | |
], | |
[source_image, style_image] | |
) | |
inputs = [ | |
source_image, style_image, source_mask, style_mask, | |
start_step, start_layer, Style_attn_step, | |
Method, Style_Guidance,ddim_steps, cfg_scale_slider, seed_textbox, de_bug, | |
prompt_textbox, negative_prompt_textbox, inter_latents, | |
freeu, b1, b2, s1, s2, | |
width_slider,height_slider, | |
] | |
generate_button.click( | |
fn=global_text.generate, | |
inputs=inputs, | |
outputs=[recons_style,recons_content,generate_image] | |
) | |
source_image.upload(global_text.reset_start_code, inputs=[], outputs=[]) | |
style_image.upload(global_text.reset_start_code, inputs=[], outputs=[]) | |
ddim_steps.change(fn=global_text.reset_start_code, inputs=[], outputs=[]) | |
return demo | |
if __name__ == "__main__": | |
demo = ui() | |
demo.launch(server_name="172.18.32.44") |