Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import torch | |
import random | |
import numpy as np | |
import gradio as gr | |
from glob import glob | |
from datetime import datetime | |
from diffusers import StableDiffusionPipeline,AutoencoderKL | |
from diffusers import DDIMScheduler, LCMScheduler, EulerDiscreteScheduler | |
import torch.nn.functional as F | |
from PIL import Image,ImageDraw | |
from utils.pipeline import ZePoPipeline | |
from utils.attn_control import AttentionStyle | |
from torchvision.utils import save_image | |
import utils.ptp_utils as ptp_utils | |
import torchvision.transforms as transforms | |
try: | |
import xformers | |
is_xformers = True | |
except ImportError: | |
is_xformers = False | |
css = """ | |
.toolbutton { | |
margin-buttom: 0em 0em 0em 0em; | |
max-width: 2.5em; | |
min-width: 2.5em !important; | |
height: 2.5em; | |
} | |
""" | |
# import sys | |
# sys.setrecursionlimit(100000) | |
class GlobalText: | |
def __init__(self): | |
# config dirs | |
self.basedir = os.getcwd() | |
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") | |
self.personalized_model_dir = './models/Stable-diffusion' | |
self.lora_model_dir = './models/Lora' | |
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) | |
self.savedir_sample = os.path.join(self.savedir, "sample") | |
# self.savedir_mask = os.path.join(self.savedir, "mask") | |
self.stable_diffusion_list = ["SimianLuo/LCM_Dreamshaper_v7"] | |
self.personalized_model_list = [] | |
self.lora_model_list = [] | |
self.tokenizer = None | |
self.text_encoder = None | |
self.vae = None | |
self.unet = None | |
self.pipeline = None | |
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
self.lora_model_state_dict = {} | |
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
self.nsfw_image = Image.open('./data/nsfw.jpg') # to float in [0,1] | |
def init_source_image_path(self, source_path): | |
self.source_paths = sorted(glob(os.path.join(source_path, '*'))) | |
self.max_source_index = len(self.source_paths) // 12 | |
return self.source_paths[0:12] | |
def init_style_image_path(self, style_path): | |
self.style_paths = sorted(glob(os.path.join(style_path, '*'))) | |
self.max_style_index = len(self.style_paths) // 12 | |
return self.style_paths[0:12] | |
def init_results_image_path(self): | |
results_paths = [os.path.join(self.savedir_sample, file) for file in os.listdir(self.savedir_sample)] | |
self.results_paths = sorted(results_paths, key=os.path.getctime, reverse=True) | |
self.max_results_index = len(self.results_paths) // 12 | |
return self.results_paths[0:12] | |
def load_base_pipeline(self, model_path): | |
time_start = datetime.now() | |
self.scheduler = 'LCM' | |
scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler") | |
self.pipeline = ZePoPipeline.from_pretrained(model_path,scheduler=scheduler,torch_dtype=torch.float16,).to('cuda') | |
if is_xformers: | |
self.pipeline.enable_xformers_memory_efficient_attention() | |
time_end = datetime.now() | |
print(f'Load {model_path} successful in {time_end-time_start}') | |
return gr.Dropdown() | |
def refresh_stable_diffusion(self,model_path): | |
self.load_base_pipeline(model_path) | |
return self.stable_diffusion_list[0] | |
def update_base_model(self, base_model_dropdown): | |
if self.pipeline is None: | |
gr.Info(f"Please select a pretrained model path.") | |
return None | |
else: | |
base_model = self.personalized_model_list[base_model_dropdown] | |
mid_model = StableDiffusionPipeline.from_single_file(base_model) | |
self.pipeline.vae = mid_model.vae | |
self.pipeline.unet = mid_model.unet | |
self.pipeline.text_encoder = mid_model.text_encoder | |
self.pipeline.to(self.device) | |
self.personal_model_loaded = base_model_dropdown.split('.')[0] | |
print(f'load {base_model_dropdown} model success!') | |
return gr.Dropdown() | |
def generate(self, source, style, | |
num_steps, co_feat_step,strength, | |
start_ac_layer, end_ac_layer, | |
sty_guidance,cfg_scale, mix_q_scale, | |
Scheduler, save_intermediate, seed, de_bug, | |
target_prompt, negative_prompt_textbox, | |
width_slider,height_slider, | |
tome_sx, tome_sy, tome_ratio,tome, | |
): | |
os.makedirs(self.savedir, exist_ok=True) | |
os.makedirs(self.savedir_sample, exist_ok=True) | |
if self.pipeline == None: | |
self.refresh_stable_diffusion(self.stable_diffusion_list[-1]) | |
model = self.pipeline | |
if Scheduler == 'DDIM': | |
model.scheduler = DDIMScheduler.from_config(model.scheduler.config) | |
print(f"Successful adoption of DDIM scheduler") | |
if Scheduler == 'LCM': | |
model.scheduler = LCMScheduler.from_config(model.scheduler.config) | |
print(f"Successful adoption of LCM scheduler") | |
if Scheduler == 'EulerDiscrete': | |
model.scheduler = EulerDiscreteScheduler.from_config(model.scheduler.config) | |
if seed != '-1' and seed != "": torch.manual_seed(int(seed)) | |
else: torch.seed() | |
seed = torch.initial_seed() | |
print(f"Seed: {seed}") | |
self.sample_count = len(os.listdir(self.savedir_sample)) | |
prompts = [target_prompt] * 3 | |
source = source.resize((width_slider, height_slider)) | |
style = style.resize((width_slider, height_slider)) | |
with torch.no_grad(): | |
controller = AttentionStyle(num_steps, | |
start_ac_layer, | |
end_ac_layer, | |
style_guidance=sty_guidance, | |
mix_q_scale=mix_q_scale, | |
de_bug=de_bug, | |
) | |
ptp_utils.register_attention_control(model, controller, | |
tome, | |
sx=tome_sx, | |
sy=tome_sy, | |
ratio=tome_ratio, | |
de_bug=de_bug,) | |
time_begin = datetime.now() | |
results = model(prompt=prompts, | |
negative_prompt=negative_prompt_textbox, | |
image=source, | |
style=style, | |
num_inference_steps=num_steps, | |
eta=0.0, | |
guidance_scale=cfg_scale, | |
strength=strength, | |
save_intermediate=save_intermediate, | |
fix_step_index=co_feat_step, | |
de_bug = de_bug, | |
callback = None | |
) | |
generate_image = results.images | |
for idx, has_nsfw_concept in enumerate(results.nsfw_content_detected): | |
if has_nsfw_concept: | |
generate_image[idx] = np.array(self.nsfw_image.resize((height_slider,width_slider))).astype(np.float32) / 255.0 | |
time_end = datetime.now() | |
print('generate one image with time {}'.format(time_end-time_begin)) | |
save_file_name = f"{self.sample_count}_step{num_steps}_sl{start_ac_layer}_el{end_ac_layer}_ST{strength}_CF{co_feat_step}_STG{sty_guidance}_MQ{mix_q_scale}_CFG{cfg_scale}_seed{seed}.jpg" | |
save_file_path = os.path.join(self.savedir, save_file_name) | |
save_image(torch.tensor(generate_image).permute(0, 3, 1, 2), save_file_path, nrow=3, padding=0) | |
save_image(torch.tensor(generate_image[2:]).permute(0, 3, 1, 2), os.path.join(self.savedir_sample, save_file_name), nrow=3, padding=0) | |
self.init_results_image_path() | |
return [ | |
generate_image[0], | |
generate_image[1], | |
generate_image[2], | |
] | |
global_text = GlobalText() | |
def ui(): | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# [ZePo: Zero-Shot Portrait Stylization with Faster Sampling](https://arxiv.org/abs/2408.05492) | |
Jin Liu, Huaibo Huang, Jie Cao, Ran He<br> | |
[Arxiv](https://arxiv.org/abs/2408.05492) | [Github](https://github.com/liujin112/ZePo) | |
""" | |
) | |
with gr.Column(variant="panel"): | |
gr.Markdown( | |
""" | |
### 1. Select a pretrained model. | |
""" | |
) | |
with gr.Row(): | |
stable_diffusion_dropdown = gr.Dropdown( | |
label="Pretrained Model Path", | |
choices=global_text.stable_diffusion_list, | |
interactive=True, | |
allow_custom_value=True | |
) | |
stable_diffusion_dropdown.change(fn=global_text.load_base_pipeline, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) | |
stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
def update_stable_diffusion(stable_diffusion_dropdown): | |
global_text.refresh_stable_diffusion(stable_diffusion_dropdown) | |
stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) | |
with gr.Column(variant="panel"): | |
gr.Markdown( | |
""" | |
### 2. Configs for ZePo. | |
""" | |
) | |
with gr.Tab("Configs"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
source_image = gr.Image(label="Source Image", elem_id="img2maskimg", sources="upload", type="pil",image_mode="RGB", height=256) | |
style_image = gr.Image(label="Style Image", elem_id="img2maskimg", sources="upload", type="pil", image_mode="RGB", height=256) | |
generate_image = gr.Image(label="Image with PortraitDiff", type="pil", interactive=True, image_mode="RGB", height=512) | |
with gr.Row(): | |
recons_content = gr.Image(label="reconstructed content", type="pil", image_mode="RGB", height=256) | |
recons_style = gr.Image(label="reconstructed style", type="pil", image_mode="RGB", height=256) | |
prompt_textbox = gr.Textbox(label="Prompt", value='head', lines=1) | |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1) | |
with gr.Row(equal_height=False): | |
with gr.Column(): | |
with gr.Tab("Resolution"): | |
width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) | |
height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) | |
Scheduler = gr.Dropdown( | |
["DDIM", "LCM", "EulerDiscrete"], | |
value="LCM", | |
label="Scheduler", info="Select a Scheduler") | |
with gr.Tab("Content Gallery"): | |
with gr.Row(): | |
source_path = gr.Textbox(value='./data/content', label="Source Path") | |
refresh_source_list_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
source_gallery_index = gr.Slider(label="Index", value=0, minimum=0, maximum=50, step=1) | |
num_gallery_images = 12 | |
source_image_gallery = gr.Gallery(value=[], columns=4, label="Source Image List") | |
refresh_source_list_button.click(fn=global_text.init_source_image_path, inputs=[source_path], outputs=[source_image_gallery]) | |
def update_source_list(index): | |
if int(index) < 0: | |
index = 0 | |
if int(index) > global_text.max_source_index: | |
index = global_text.max_source_index | |
return global_text.source_paths[int(index)*num_gallery_images:(int(index)+1)*num_gallery_images] | |
source_gallery_index.change(fn=update_source_list, inputs=[source_gallery_index], outputs=[source_image_gallery]) | |
with gr.Tab("Style Gallery"): | |
with gr.Row(): | |
style_path = gr.Textbox(value='./data/style', label="style Path") | |
refresh_style_list_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
style_gallery_index = gr.Slider(label="Index", value=0, minimum=0, maximum=50, step=1) | |
num_gallery_images = 12 | |
style_image_gallery = gr.Gallery(value=[], columns=4, label="style Image List") | |
refresh_style_list_button.click(fn=global_text.init_style_image_path, inputs=[style_path], outputs=[style_image_gallery]) | |
def update_style_list(index): | |
if int(index) < 0: | |
index = 0 | |
if int(index) > global_text.max_style_index: | |
index = global_text.max_style_index | |
return global_text.style_paths[int(index)*num_gallery_images:(int(index)+1)*num_gallery_images] | |
style_gallery_index.change(fn=update_style_list, inputs=[style_gallery_index], outputs=[style_image_gallery]) | |
# with gr.Tab("Results Gallery"): | |
# with gr.Row(): | |
# refresh_results_list_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") | |
# results_gallery_index = gr.Slider(label="Index", value=0, minimum=0, maximum=50, step=1) | |
# num_gallery_images = 12 | |
# results_image_gallery = gr.Gallery(value=[], columns=4, label="style Image List") | |
# refresh_results_list_button.click(fn=global_text.init_results_image_path, inputs=[], outputs=[results_image_gallery]) | |
# def update_results_list(index): | |
# if int(index) < 0: | |
# index = 0 | |
# if int(index) > global_text.max_results_index: | |
# index = global_text.max_results_index | |
# return global_text.results_paths[int(index)*num_gallery_images:(int(index)+1)*num_gallery_images] | |
# results_gallery_index.change(fn=update_results_list, inputs=[results_gallery_index], outputs=[style_image_gallery]) | |
with gr.Row(): | |
generate_button = gr.Button(value="Generate", variant='primary') | |
with gr.Tab('Base Configs'): | |
num_steps = gr.Slider(label="Total Steps", value=4, minimum=0, maximum=25, step=1) | |
strength = gr.Slider(label="Noisy Ratio", value=0.5, minimum=0, maximum=1, step=0.01,info="How much noise applied to souce image, 50% for better balance.") | |
co_feat_step = gr.Slider(label="Consistency Feature Extract Step", value=99, minimum=0, maximum=999, step=1) | |
with gr.Row(): | |
start_ac_layer = gr.Slider(label="Start Layer of AC", | |
minimum=0, | |
maximum=16, | |
value=8, | |
step=1) | |
end_ac_layer = gr.Slider(label="End Layer of AC", | |
minimum=0, | |
maximum=16, | |
value=16, | |
step=1) | |
with gr.Row(): | |
Style_Guidance = gr.Slider(label="Style Guidance Scale", | |
minimum=-1, | |
maximum=3, | |
value=1.2, | |
step=0.01, | |
) | |
mix_q_scale = gr.Slider(label='Query Mix Ratio', | |
minimum=0, | |
maximum=2, | |
step=0.05, | |
value=1.0, | |
) | |
cfg_scale_slider = gr.Slider(label="CFG Scale", value=2.5, minimum=0, maximum=20, info="Classifier-free guidance scale.") | |
with gr.Row(): | |
save_intermediate = gr.Checkbox(label="save_intermediate", value=False) | |
de_bug = gr.Checkbox(value=False,label='DeBug') | |
with gr.Tab('ToMe'): | |
with gr.Row(): | |
tome = gr.Checkbox(label="Token Merge", value=True) | |
tome_ratio = gr.Slider(label='ratio: ', | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.5) | |
with gr.Row(): | |
tome_sx = gr.Slider(label='sx:', | |
minimum=0, | |
maximum=64, | |
step=2, | |
value=2) | |
tome_sy = gr.Slider(label='sy:', | |
minimum=0, | |
maximum=64, | |
step=2, | |
value=2) | |
with gr.Row(): | |
seed_textbox = gr.Textbox(label="Seed", value=-1) | |
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") | |
seed_button.click(fn=lambda: random.randint(1, 1e16), inputs=[], outputs=[seed_textbox]) | |
inputs = [ | |
source_image, style_image, | |
num_steps,co_feat_step,strength, | |
start_ac_layer, end_ac_layer, | |
Style_Guidance,cfg_scale_slider,mix_q_scale, | |
Scheduler, save_intermediate, seed_textbox, de_bug, | |
prompt_textbox, negative_prompt_textbox, | |
width_slider,height_slider, | |
tome_sx, tome_sy, tome_ratio, tome, | |
] | |
generate_button.click( | |
fn=global_text.generate, | |
inputs=inputs, | |
outputs=[recons_style,recons_content,generate_image] | |
) | |
ex = gr.Examples( | |
[ | |
["./data/content/27032.jpg","./data/style/27.jpg",4,0.8,0.5,8,8427921159605868845], | |
["./data/content/29812.jpg","./data/style/47.jpg",4,0.5,0.65,11,8119359809263726691], | |
], | |
[source_image, style_image, num_steps,strength, mix_q_scale, start_ac_layer, seed_textbox], | |
[ | |
"Example 1", | |
],) | |
return demo | |
if __name__ == "__main__": | |
demo = ui() | |
demo.launch(show_error=True) |