Spaces:
Runtime error
Runtime error
import json | |
import os | |
import os.path as osp | |
import random | |
from argparse import ArgumentParser | |
from datetime import datetime | |
import gradio as gr | |
import numpy as np | |
import openxlab | |
import torch | |
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler | |
from omegaconf import OmegaConf | |
from openxlab.model import download | |
from PIL import Image | |
from animatediff.pipelines import I2VPipeline | |
from animatediff.utils.util import RANGE_LIST, save_videos_grid | |
sample_idx = 0 | |
scheduler_dict = { | |
"DDIM": DDIMScheduler, | |
"Euler": EulerDiscreteScheduler, | |
"PNDM": PNDMScheduler, | |
} | |
css = """ | |
.toolbutton { | |
margin-buttom: 0em 0em 0em 0em; | |
max-width: 2.5em; | |
min-width: 2.5em !important; | |
height: 2.5em; | |
} | |
""" | |
parser = ArgumentParser() | |
parser.add_argument('--config', type=str, default='example/config/base.yaml') | |
parser.add_argument('--server-name', type=str, default='0.0.0.0') | |
parser.add_argument('--port', type=int, default=7860) | |
parser.add_argument('--share', action='store_true') | |
parser.add_argument('--local-debug', action='store_true') | |
parser.add_argument('--save-path', default='samples') | |
args = parser.parse_args() | |
LOCAL_DEBUG = args.local_debug | |
BASE_CONFIG = 'example/config/base.yaml' | |
STYLE_CONFIG_LIST = { | |
'anime': './example/openxlab/2-animation.yaml', | |
} | |
# download models | |
PIA_PATH = './models/PIA' | |
VAE_PATH = './models/VAE' | |
DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA' | |
if not LOCAL_DEBUG: | |
CACHE_PATH = '/home/xlab-app-center/.cache/model' | |
PIA_PATH = osp.join(CACHE_PATH, 'PIA') | |
VAE_PATH = osp.join(CACHE_PATH, 'VAE') | |
DreamBooth_LoRA_PATH = osp.join(CACHE_PATH, 'DreamBooth_LoRA') | |
STABLE_DIFFUSION_PATH = osp.join(CACHE_PATH, 'StableDiffusion') | |
IP_ADAPTER_PATH = osp.join(CACHE_PATH, 'IP_Adapter') | |
os.makedirs(PIA_PATH, exist_ok=True) | |
os.makedirs(VAE_PATH, exist_ok=True) | |
os.makedirs(DreamBooth_LoRA_PATH, exist_ok=True) | |
os.makedirs(STABLE_DIFFUSION_PATH, exist_ok=True) | |
openxlab.login(os.environ['OPENXLAB_AK'], os.environ['OPENXLAB_SK']) | |
download(model_repo='zhangyiming/PIA-pruned', model_name='PIA', output=PIA_PATH) | |
download(model_repo='zhangyiming/Counterfeit-V3.0', | |
model_name='Counterfeit-V3.0_fp32_pruned', output=DreamBooth_LoRA_PATH) | |
download(model_repo='zhangyiming/kl-f8-anime2_VAE', | |
model_name='kl-f8-anime2', output=VAE_PATH) | |
# ip_adapter | |
download(model_repo='zhangyiming/IP-Adapter', | |
model_name='clip_encoder', output=osp.join(IP_ADAPTER_PATH, 'image_encoder')) | |
download(model_repo='zhangyiming/IP-Adapter', | |
model_name='config', output=osp.join(IP_ADAPTER_PATH, 'image_encoder')) | |
download(model_repo='zhangyiming/IP-Adapter', | |
model_name='ip_adapter_sd15', output=IP_ADAPTER_PATH) | |
# unet | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Unet', | |
model_name='unet', output=osp.join(STABLE_DIFFUSION_PATH, 'unet')) | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Unet', | |
model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'unet')) | |
# vae | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_VAE', | |
model_name='vae', output=osp.join(STABLE_DIFFUSION_PATH, 'vae')) | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_VAE', | |
model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'vae')) | |
# text encoder | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_TextEncod', | |
model_name='text_encoder', output=osp.join(STABLE_DIFFUSION_PATH, 'text_encoder')) | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_TextEncod', | |
model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'text_encoder')) | |
# tokenizer | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer', | |
model_name='merge', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer')) | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer', | |
model_name='special_tokens_map', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer')) | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer', | |
model_name='tokenizer_config', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer')) | |
download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer', | |
model_name='vocab', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer')) | |
# scheduler | |
scheduler_dict = { | |
"_class_name": "PNDMScheduler", | |
"_diffusers_version": "0.6.0", | |
"beta_end": 0.012, | |
"beta_schedule": "scaled_linear", | |
"beta_start": 0.00085, | |
"num_train_timesteps": 1000, | |
"set_alpha_to_one": False, | |
"skip_prk_steps": True, | |
"steps_offset": 1, | |
"trained_betas": None, | |
"clip_sample": False | |
} | |
os.makedirs(osp.join(STABLE_DIFFUSION_PATH, 'scheduler'), exist_ok=True) | |
with open(osp.join(STABLE_DIFFUSION_PATH, 'scheduler', 'scheduler_config.json'), 'w') as file: | |
json.dump(scheduler_dict, file) | |
# model index | |
model_index_dict = { | |
"_class_name": "StableDiffusionPipeline", | |
"_diffusers_version": "0.6.0", | |
"feature_extractor": [ | |
"transformers", | |
"CLIPImageProcessor" | |
], | |
"safety_checker": [ | |
"stable_diffusion", | |
"StableDiffusionSafetyChecker" | |
], | |
"scheduler": [ | |
"diffusers", | |
"PNDMScheduler" | |
], | |
"text_encoder": [ | |
"transformers", | |
"CLIPTextModel" | |
], | |
"tokenizer": [ | |
"transformers", | |
"CLIPTokenizer" | |
], | |
"unet": [ | |
"diffusers", | |
"UNet2DConditionModel" | |
], | |
"vae": [ | |
"diffusers", | |
"AutoencoderKL" | |
] | |
} | |
with open(osp.join(STABLE_DIFFUSION_PATH, 'model_index.json'), 'w') as file: | |
json.dump(model_index_dict, file) | |
else: | |
PIA_PATH = './models/PIA' | |
VAE_PATH = './models/VAE' | |
DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA' | |
STABLE_DIFFUSION_PATH = './models/StableDiffusion/sd15' | |
def preprocess_img(img_np, max_size: int = 512): | |
ori_image = Image.fromarray(img_np).convert('RGB') | |
width, height = ori_image.size | |
short_edge = max(width, height) | |
if short_edge > max_size: | |
scale_factor = max_size / short_edge | |
else: | |
scale_factor = 1 | |
width = int(width * scale_factor) | |
height = int(height * scale_factor) | |
ori_image = ori_image.resize((width, height)) | |
if (width % 8 != 0) or (height % 8 != 0): | |
in_width = (width // 8) * 8 | |
in_height = (height // 8) * 8 | |
else: | |
in_width = width | |
in_height = height | |
in_image = ori_image | |
in_image = ori_image.resize((in_width, in_height)) | |
in_image_np = np.array(in_image) | |
return in_image_np, in_height, in_width | |
class AnimateController: | |
def __init__(self): | |
# config dirs | |
self.basedir = os.getcwd() | |
self.savedir = os.path.join( | |
self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) | |
self.savedir_sample = os.path.join(self.savedir, "sample") | |
os.makedirs(self.savedir, exist_ok=True) | |
self.inference_config = OmegaConf.load(args.config) | |
self.style_configs = {k: OmegaConf.load( | |
v) for k, v in STYLE_CONFIG_LIST.items()} | |
self.pipeline_dict = self.load_model_list() | |
def load_model_list(self): | |
pipeline_dict = dict() | |
for style, cfg in self.style_configs.items(): | |
dreambooth_path = cfg.get('dreambooth', 'none') | |
if dreambooth_path and dreambooth_path.upper() != 'NONE': | |
dreambooth_path = osp.join( | |
DreamBooth_LoRA_PATH, dreambooth_path) | |
lora_path = cfg.get('lora', None) | |
if lora_path is not None: | |
lora_path = osp.join(DreamBooth_LoRA_PATH, lora_path) | |
lora_alpha = cfg.get('lora_alpha', 0.0) | |
vae_path = cfg.get('vae', None) | |
if vae_path is not None: | |
vae_path = osp.join(VAE_PATH, vae_path) | |
pipeline_dict[style] = I2VPipeline.build_pipeline( | |
self.inference_config, | |
STABLE_DIFFUSION_PATH, | |
unet_path=osp.join(PIA_PATH, 'pia.ckpt'), | |
dreambooth_path=dreambooth_path, | |
lora_path=lora_path, | |
lora_alpha=lora_alpha, | |
vae_path=vae_path, | |
ip_adapter_path='h94/IP-Adapter', | |
ip_adapter_scale=0.1) | |
return pipeline_dict | |
def fetch_default_n_prompt(self, style: str): | |
cfg = self.style_configs[style] | |
n_prompt = cfg.get('n_prompt', '') | |
ip_adapter_scale = cfg.get('real_ip_adapter_scale', 0) | |
gr.Info('Set default negative prompt and ip_adapter_scale.') | |
print('Set default negative prompt and ip_adapter_scale.') | |
return n_prompt, ip_adapter_scale | |
def animate( | |
self, | |
init_img, | |
motion_scale, | |
prompt_textbox, | |
negative_prompt_textbox, | |
sampler_dropdown, | |
sample_step_slider, | |
cfg_scale_slider, | |
seed_textbox, | |
ip_adapter_scale, | |
style, | |
progress=gr.Progress(), | |
): | |
if seed_textbox != -1 and seed_textbox != "": | |
torch.manual_seed(int(seed_textbox)) | |
else: | |
torch.seed() | |
seed = torch.initial_seed() | |
pipeline = self.pipeline_dict[style] | |
init_img, h, w = preprocess_img(init_img) | |
sample = pipeline( | |
image=init_img, | |
prompt=prompt_textbox, | |
negative_prompt=negative_prompt_textbox, | |
num_inference_steps=sample_step_slider, | |
guidance_scale=cfg_scale_slider, | |
width=w, | |
height=h, | |
video_length=16, | |
mask_sim_template_idx=motion_scale - 1, | |
ip_adapter_scale=ip_adapter_scale, | |
progress_fn=progress, | |
).videos | |
save_sample_path = os.path.join( | |
self.savedir_sample, f"{sample_idx}.mp4") | |
save_videos_grid(sample, save_sample_path) | |
sample_config = { | |
"prompt": prompt_textbox, | |
"n_prompt": negative_prompt_textbox, | |
"sampler": sampler_dropdown, | |
"num_inference_steps": sample_step_slider, | |
"guidance_scale": cfg_scale_slider, | |
"width": w, | |
"height": h, | |
"seed": seed, | |
"motion": motion_scale, | |
} | |
json_str = json.dumps(sample_config, indent=4) | |
with open(os.path.join(self.savedir, "logs.json"), "a") as f: | |
f.write(json_str) | |
f.write("\n\n") | |
return save_sample_path | |
controller = AnimateController() | |
def ui(): | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
"<div align='center'><font size='7'> <img src=\"file/pia.png\" style=\"height: 72px;\"/ > Your Personalized Image Animator</font></div>" | |
"<div align='center'><font size='7'>via Plug-and-Play Modules in Text-to-Image Models </font></div>" | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"<div align='center'><font size='5'><a href='https://pi-animator.github.io/'>Project Page</a>  " # noqa | |
"<a href='https://arxiv.org/abs/2312.13964/'>Paper</a>  " | |
"<a href='https://github.com/open-mmlab/PIA'>Code</a>  " # noqa | |
# "Try More Style: <a href='https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia'>Click Here!</a> </font></div>" # noqa | |
"Try More Style: <a href='https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia'>Click here! </a></font></div>" # noqa | |
) | |
with gr.Row(equal_height=False): | |
with gr.Column(): | |
with gr.Row(): | |
init_img = gr.Image(label='Input Image') | |
style_dropdown = gr.Dropdown(label='Style', choices=list( | |
STYLE_CONFIG_LIST.keys()), value=list(STYLE_CONFIG_LIST.keys())[0]) | |
with gr.Row(): | |
prompt_textbox = gr.Textbox(label="Prompt", lines=1) | |
gift_button = gr.Button( | |
value='🎁', elem_classes='toolbutton' | |
) | |
def append_gift(prompt): | |
rand = random.randint(0, 2) | |
if rand == 1: | |
prompt = prompt + 'wearing santa hats' | |
elif rand == 2: | |
prompt = prompt + 'lift a Christmas gift' | |
else: | |
prompt = prompt + 'in Christmas suit, lift a Christmas gift' | |
gr.Info('Merry Christmas! Add magic to your prompt!') | |
return prompt | |
gift_button.click( | |
fn=append_gift, | |
inputs=[prompt_textbox], | |
outputs=[prompt_textbox], | |
) | |
prompt_textbox = gr.Textbox(label="Prompt", lines=1) | |
motion_scale_silder = gr.Slider( | |
label='Motion Scale (Larger value means larger motion but less identity consistency)', value=2, step=1, minimum=1, maximum=len(RANGE_LIST)) | |
ip_adapter_scale = gr.Slider( | |
label='IP-Apdater Scale', value=controller.fetch_default_n_prompt( | |
list(STYLE_CONFIG_LIST.keys())[0])[1], minimum=0, maximum=1) | |
with gr.Accordion('Advance Options', open=False): | |
negative_prompt_textbox = gr.Textbox( | |
value=controller.fetch_default_n_prompt( | |
list(STYLE_CONFIG_LIST.keys())[0])[0], | |
label="Negative prompt", lines=2) | |
with gr.Row(): | |
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list( | |
scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) | |
sample_step_slider = gr.Slider( | |
label="Sampling steps", value=20, minimum=10, maximum=100, step=1) | |
cfg_scale_slider = gr.Slider( | |
label="CFG Scale", value=7.5, minimum=0, maximum=20) | |
with gr.Row(): | |
seed_textbox = gr.Textbox(label="Seed", value=-1) | |
seed_button = gr.Button( | |
value="\U0001F3B2", elem_classes="toolbutton") | |
seed_button.click( | |
fn=lambda x: random.randint(1, 1e8), | |
outputs=[seed_textbox], | |
queue=False | |
) | |
generate_button = gr.Button( | |
value="Generate", variant='primary') | |
result_video = gr.Video( | |
label="Generated Animation", interactive=False) | |
style_dropdown.change(fn=controller.fetch_default_n_prompt, | |
inputs=[style_dropdown], | |
outputs=[negative_prompt_textbox, ip_adapter_scale], queue=False) | |
generate_button.click( | |
fn=controller.animate, | |
inputs=[ | |
init_img, | |
motion_scale_silder, | |
prompt_textbox, | |
negative_prompt_textbox, | |
sampler_dropdown, | |
sample_step_slider, | |
cfg_scale_slider, | |
seed_textbox, | |
ip_adapter_scale, | |
style_dropdown, | |
], | |
outputs=[result_video] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = ui() | |
demo.queue(max_size=10) | |
demo.launch(server_name=args.server_name, | |
server_port=args.port, share=args.share, | |
max_threads=10, | |
allowed_paths=['pia.png']) | |