|
import sys
|
|
sys.path.append('./')
|
|
|
|
from adaface.adaface_wrapper import AdaFaceWrapper
|
|
import torch
|
|
from insightface.app import FaceAnalysis
|
|
from PIL import Image
|
|
import numpy as np
|
|
import random
|
|
|
|
import gradio as gr
|
|
import spaces
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--adaface_ckpt_path', type=str,
|
|
default='models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt')
|
|
parser.add_argument('--gpu', type=int, default=None)
|
|
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
|
args = parser.parse_args()
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max
|
|
if torch.cuda.is_available():
|
|
device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
|
|
else:
|
|
device = "cpu"
|
|
dtype = torch.float16
|
|
|
|
|
|
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path="models/sar/sar.safetensors",
|
|
adaface_ckpt_path=args.adaface_ckpt_path, device=device)
|
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
if randomize_seed:
|
|
seed = random.randint(0, MAX_SEED)
|
|
return seed
|
|
|
|
def swap_to_gallery(images):
|
|
|
|
|
|
|
|
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
|
|
|
|
def remove_back_to_files():
|
|
|
|
|
|
|
|
return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True)
|
|
|
|
def update_out_gallery(images):
|
|
|
|
return gr.update(height=600)
|
|
|
|
@spaces.GPU
|
|
def generate_image(image_paths, guidance_scale, adaface_id_cfg_scale,
|
|
num_images, prompt, negative_prompt, seed, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
if image_paths is None or len(image_paths) == 0:
|
|
raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
|
|
|
|
if prompt is None:
|
|
prompt = ""
|
|
|
|
adaface_subj_embs = \
|
|
adaface.generate_adaface_embeddings(image_folder=None, image_paths=image_paths,
|
|
out_id_embs_scale=adaface_id_cfg_scale, update_text_encoder=True)
|
|
|
|
if adaface_subj_embs is None:
|
|
raise gr.Error(f"Failed to detect any faces! Please try with other images")
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed)
|
|
print(f"Manual seed: {seed}")
|
|
|
|
noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator)
|
|
|
|
|
|
samples = adaface(noise, prompt, negative_prompt, guidance_scale=guidance_scale, out_image_count=num_images, generator=generator, verbose=True)
|
|
return samples
|
|
|
|
|
|
title = r"""
|
|
<h1>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</h1>
|
|
"""
|
|
|
|
description = r"""
|
|
<b>Official demo</b> for our NeurIPS 2024 submission <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br>
|
|
|
|
❗️**Tips**❗️
|
|
1. Upload one or more images of a person. If multiple faces are detected, we use the largest one.
|
|
2. Increase <b>AdaFace CFG Scale</b> (preferred) or <b>Guidance scale</b> and/or to highlight fine facial features.
|
|
3. AdaFace Text-to-Video: <a href="https://huggingface.co/spaces/adaface-neurips/adaface-animate" style="display: inline-flex; align-items: center;">
|
|
AdaFace-Animate
|
|
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
|
|
</a>
|
|
|
|
**TODO**
|
|
- ControlNet integration.
|
|
"""
|
|
|
|
css = '''
|
|
.gradio-container {width: 85% !important}
|
|
'''
|
|
with gr.Blocks(css=css) as demo:
|
|
|
|
|
|
gr.Markdown(title)
|
|
gr.Markdown(description)
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
|
|
|
|
|
|
img_files = gr.File(
|
|
label="Drag / Select 1 or more photos of a person's face",
|
|
file_types=["image"],
|
|
file_count="multiple"
|
|
)
|
|
uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=1, height=300)
|
|
with gr.Column(visible=False) as clear_button_column:
|
|
remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=img_files, size="sm")
|
|
|
|
prompt = gr.Dropdown(label="Prompt",
|
|
info="Try something like 'man/woman walking on the beach'. If the face is not in focus, try adding 'face portrait of' at the beginning.",
|
|
value=None,
|
|
allow_custom_value=True,
|
|
filterable=False,
|
|
choices=[
|
|
"woman ((best quality)), ((masterpiece)), ((realistic)), long highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin",
|
|
"woman walking on the beach, sunset, orange sky",
|
|
"woman in a white apron and chef hat, garnishing a gourmet dish, full body view, long shot",
|
|
"woman dancing pose among folks in a park, waving hands",
|
|
"woman in iron man costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot",
|
|
"woman jedi wielding a lightsaber, star wars, full body view, eye level shot",
|
|
"woman playing guitar on a boat, ocean waves",
|
|
"woman with a passion for reading, curled up with a book in a cozy nook near a window",
|
|
"woman running pose in a park, eye level shot",
|
|
"woman in superman costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot"
|
|
])
|
|
|
|
submit = gr.Button("Submit", variant="primary")
|
|
|
|
negative_prompt = gr.Textbox(
|
|
label="Negative Prompt",
|
|
value="flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, nude, naked, nsfw, topless, bare breasts",
|
|
)
|
|
|
|
adaface_id_cfg_scale = gr.Slider(
|
|
label="AdaFace CFG Scale",
|
|
info="The CFG scale of the AdaFace ID embeddings (influencing fine facial features)",
|
|
minimum=0.5,
|
|
maximum=8.0,
|
|
step=0.5,
|
|
value=4.0,
|
|
)
|
|
|
|
guidance_scale = gr.Slider(
|
|
label="Guidance scale",
|
|
minimum=0.5,
|
|
maximum=8.0,
|
|
step=0.5,
|
|
value=4.0,
|
|
)
|
|
|
|
num_images = gr.Slider(
|
|
label="Number of output images",
|
|
minimum=1,
|
|
maximum=6,
|
|
step=1,
|
|
value=4,
|
|
)
|
|
seed = gr.Slider(
|
|
label="Seed",
|
|
minimum=0,
|
|
maximum=MAX_SEED,
|
|
step=1,
|
|
value=0,
|
|
)
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True, info="Uncheck for reproducible results")
|
|
|
|
with gr.Column():
|
|
out_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2, height=600)
|
|
|
|
img_files.upload(fn=swap_to_gallery, inputs=img_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
|
|
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
|
|
|
|
submit.click(
|
|
fn=randomize_seed_fn,
|
|
inputs=[seed, randomize_seed],
|
|
outputs=seed,
|
|
queue=False,
|
|
api_name=False,
|
|
).then(
|
|
fn=generate_image,
|
|
inputs=[img_files, guidance_scale, adaface_id_cfg_scale, num_images, prompt, negative_prompt, seed],
|
|
outputs=[out_gallery]
|
|
).then(
|
|
fn=update_out_gallery,
|
|
inputs=[out_gallery],
|
|
outputs=[out_gallery]
|
|
)
|
|
|
|
demo.launch(share=True, server_name=args.ip, ssl_verify=False) |