|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import PIL |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
import retinaface |
|
|
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel |
|
from diffusers import UniPCMultistepScheduler |
|
|
|
from spiga.inference.config import ModelConfig |
|
from spiga.inference.framework import SPIGAFramework |
|
import spiga.demo.analyze.track.retinasort.config as cfg |
|
|
|
import matplotlib.pyplot as plt |
|
from matplotlib.path import Path |
|
import matplotlib.patches as patches |
|
|
|
|
|
config = cfg.cfg_retinasort |
|
face_detector = retinaface.RetinaFaceDetector(model=config['retina']['model_name'], |
|
device='cuda' if torch.cuda.is_available() else 'cpu', |
|
extra_features=config['retina']['extra_features'], |
|
cfg_postreat=config['retina']['postreat']) |
|
|
|
spiga_extractor = SPIGAFramework(ModelConfig("300wpublic")) |
|
|
|
uncanny_controlnet = ControlNetModel.from_pretrained( |
|
"multimodalart/uncannyfaces_25K", torch_dtype=torch.float16 |
|
) |
|
pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1-base", controlnet=uncanny_controlnet, safety_checker=None, torch_dtype=torch.float16 |
|
) |
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
pipe = pipe.to("cuda") |
|
|
|
|
|
generator = torch.manual_seed(0) |
|
|
|
canvas_html = "<face-canvas id='canvas-root' style='display:flex;max-width: 500px;margin: 0 auto;'></face-canvas>" |
|
load_js = """ |
|
async () => { |
|
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/face-canvas.js" |
|
fetch(url) |
|
.then(res => res.text()) |
|
.then(text => { |
|
const script = document.createElement('script'); |
|
script.type = "module" |
|
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); |
|
document.head.appendChild(script); |
|
}); |
|
} |
|
""" |
|
get_js_image = """ |
|
async (image_in_img, prompt, image_file_live_opt, live_conditioning) => { |
|
const canvasEl = document.getElementById("canvas-root"); |
|
const imageData = canvasEl? canvasEl._data : null; |
|
return [image_in_img, prompt, image_file_live_opt, imageData] |
|
} |
|
""" |
|
|
|
|
|
def get_bounding_box(image): |
|
pil_image = Image.fromarray(image) |
|
face_detector.set_input_shape(pil_image.size[1], pil_image.size[0]) |
|
features = face_detector.inference(pil_image) |
|
|
|
if (features is None) and (len(features['bbox']) <= 0): |
|
raise Exception("No face detected") |
|
|
|
bbox = features['bbox'][0] |
|
x1, y1, x2, y2 = bbox[:4] |
|
bbox_wh = [x1, y1, x2-x1, y2-y1] |
|
return bbox_wh |
|
|
|
|
|
def get_landmarks(image, bbox): |
|
features = spiga_extractor.inference(image, [bbox]) |
|
return features['landmarks'][0] |
|
|
|
|
|
def get_patch(landmarks, color='lime', closed=False): |
|
contour = landmarks |
|
ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1) |
|
facecolor = (0, 0, 0, 0) |
|
if closed: |
|
contour.append(contour[0]) |
|
ops.append(Path.CLOSEPOLY) |
|
facecolor = color |
|
path = Path(contour, ops) |
|
return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4) |
|
|
|
|
|
def conditioning_from_landmarks(landmarks, size=512): |
|
|
|
dpi = 72 |
|
fig, ax = plt.subplots( |
|
1, figsize=[size/dpi, size/dpi], tight_layout={'pad': 0}) |
|
fig.set_dpi(dpi) |
|
|
|
black = np.zeros((size, size, 3)) |
|
ax.imshow(black) |
|
|
|
face_patch = get_patch(landmarks[0:17]) |
|
l_eyebrow = get_patch(landmarks[17:22], color='yellow') |
|
r_eyebrow = get_patch(landmarks[22:27], color='yellow') |
|
nose_v = get_patch(landmarks[27:31], color='orange') |
|
nose_h = get_patch(landmarks[31:36], color='orange') |
|
l_eye = get_patch(landmarks[36:42], color='magenta', closed=True) |
|
r_eye = get_patch(landmarks[42:48], color='magenta', closed=True) |
|
outer_lips = get_patch(landmarks[48:60], color='cyan', closed=True) |
|
inner_lips = get_patch(landmarks[60:68], color='blue', closed=True) |
|
|
|
ax.add_patch(face_patch) |
|
ax.add_patch(l_eyebrow) |
|
ax.add_patch(r_eyebrow) |
|
ax.add_patch(nose_v) |
|
ax.add_patch(nose_h) |
|
ax.add_patch(l_eye) |
|
ax.add_patch(r_eye) |
|
ax.add_patch(outer_lips) |
|
ax.add_patch(inner_lips) |
|
|
|
plt.axis('off') |
|
fig.canvas.draw() |
|
buffer, (width, height) = fig.canvas.print_to_buffer() |
|
assert width == height |
|
assert width == size |
|
buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4)) |
|
buffer = buffer[:, :, 0:3] |
|
plt.close(fig) |
|
return PIL.Image.fromarray(buffer) |
|
|
|
|
|
def get_conditioning(image): |
|
|
|
|
|
|
|
|
|
|
|
image.thumbnail((512, 512)) |
|
image = np.array(image) |
|
image = image[:, :, ::-1] |
|
bbox = get_bounding_box(image) |
|
landmarks = get_landmarks(image, bbox) |
|
spiga_seg = conditioning_from_landmarks(landmarks) |
|
return spiga_seg |
|
|
|
|
|
def generate_images(image_in_img, prompt, image_file_live_opt='file', live_conditioning=None): |
|
if image_in_img is None and 'image' not in live_conditioning: |
|
raise gr.Error("Please provide an image") |
|
try: |
|
if image_file_live_opt == 'file': |
|
conditioning = get_conditioning(image_in_img) |
|
elif image_file_live_opt == 'webcam': |
|
base64_img = live_conditioning['image'] |
|
image_data = base64.b64decode(base64_img.split(',')[1]) |
|
conditioning = Image.open(BytesIO(image_data)).convert( |
|
'RGB').resize((512, 512)) |
|
|
|
output = pipe( |
|
prompt, |
|
conditioning, |
|
generator=generator, |
|
num_images_per_prompt=3, |
|
num_inference_steps=20, |
|
) |
|
return [conditioning] + output.images |
|
except Exception as e: |
|
raise gr.Error(str(e)) |
|
|
|
|
|
def toggle(choice): |
|
if choice == "file": |
|
return gr.update(visible=True, value=None), gr.update(visible=False, value=None) |
|
elif choice == "webcam": |
|
return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html) |
|
|
|
|
|
with gr.Blocks() as blocks: |
|
gr.Markdown(""" |
|
## Generate Uncanny Faces with ControlNet Stable Diffusion |
|
[Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet) |
|
""") |
|
with gr.Row(): |
|
live_conditioning = gr.JSON(value={}, visible=False) |
|
with gr.Column(): |
|
image_file_live_opt = gr.Radio(["file", "webcam"], value="file", |
|
label="How would you like to upload your image?") |
|
image_in_img = gr.Image(source="upload", visible=True, type="pil") |
|
canvas = gr.HTML(None, elem_id="canvas_html", visible=False) |
|
|
|
image_file_live_opt.change(fn=toggle, |
|
inputs=[image_file_live_opt], |
|
outputs=[image_in_img, canvas], |
|
queue=False) |
|
prompt = gr.Textbox( |
|
label="Enter your prompt", |
|
max_lines=1, |
|
placeholder="best quality, extremely detailed", |
|
) |
|
run_button = gr.Button("Generate") |
|
with gr.Column(): |
|
gallery = gr.Gallery().style(grid=[2], height="auto") |
|
run_button.click(fn=generate_images, |
|
inputs=[image_in_img, prompt, |
|
image_file_live_opt, live_conditioning], |
|
outputs=[gallery], |
|
_js=get_js_image) |
|
blocks.load(None, None, None, _js=load_js) |
|
gr.Examples(fn=generate_images, |
|
examples=[ |
|
["./examples/pedro-512.jpg", |
|
"Highly detailed photograph of young woman smiling, with palm trees in the background"], |
|
["./examples/image1.jpg", |
|
"Highly detailed photograph of a scary clown"], |
|
["./examples/image0.jpg", |
|
"Highly detailed photograph of Madonna"], |
|
], |
|
inputs=[image_in_img, prompt], |
|
outputs=[gallery], |
|
cache_examples=True) |
|
gr.Markdown(''' |
|
This Space was trained on synthetic 3D faces to learn how to keep a pose - however it also learned that all faces are synthetic 3D faces, [learn more on our blog](https://huggingface.co/blog/train-your-controlnet), it uses a custom visualization based on SPIGA face landmarks for conditioning. |
|
''') |
|
blocks.launch() |
|
|