Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import dlib | |
import numpy as np | |
import PIL | |
# Only used to convert to gray, could do it differently and remove this big dependency | |
import cv2 | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
from diffusers import UniPCMultistepScheduler | |
from spiga.inference.config import ModelConfig | |
from spiga.inference.framework import SPIGAFramework | |
import matplotlib.pyplot as plt | |
from matplotlib.path import Path | |
import matplotlib.patches as patches | |
# Bounding boxes | |
face_detector = dlib.get_frontal_face_detector() | |
# Landmark extraction | |
spiga_extractor = SPIGAFramework(ModelConfig("300wpublic")) | |
uncanny_controlnet = ControlNetModel.from_pretrained( | |
"multimodalart/uncannyfaces_25K", torch_dtype=torch.float16 | |
) | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1-base", controlnet=uncanny_controlnet, safety_checker=None, torch_dtype=torch.float16 | |
) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to("cuda") | |
# Generator seed, | |
generator = torch.manual_seed(0) | |
def get_bounding_box(image): | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
face = face_detector(gray)[0] | |
bbox = [face.left(), face.top(), face.width(), face.height()] | |
return bbox | |
def get_landmarks(image, bbox): | |
features = spiga_extractor.inference(image, [bbox]) | |
return features['landmarks'][0] | |
def get_patch(landmarks, color='lime', closed=False): | |
contour = landmarks | |
ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1) | |
facecolor = (0, 0, 0, 0) # Transparent fill color, if open | |
if closed: | |
contour.append(contour[0]) | |
ops.append(Path.CLOSEPOLY) | |
facecolor = color | |
path = Path(contour, ops) | |
return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4) | |
def conditioning_from_landmarks(landmarks, size=512): | |
# Precisely control output image size | |
dpi = 72 | |
fig, ax = plt.subplots(1, figsize=[size/dpi, size/dpi], tight_layout={'pad':0}) | |
fig.set_dpi(dpi) | |
black = np.zeros((size, size, 3)) | |
ax.imshow(black) | |
face_patch = get_patch(landmarks[0:17]) | |
l_eyebrow = get_patch(landmarks[17:22], color='yellow') | |
r_eyebrow = get_patch(landmarks[22:27], color='yellow') | |
nose_v = get_patch(landmarks[27:31], color='orange') | |
nose_h = get_patch(landmarks[31:36], color='orange') | |
l_eye = get_patch(landmarks[36:42], color='magenta', closed=True) | |
r_eye = get_patch(landmarks[42:48], color='magenta', closed=True) | |
outer_lips = get_patch(landmarks[48:60], color='cyan', closed=True) | |
inner_lips = get_patch(landmarks[60:68], color='blue', closed=True) | |
ax.add_patch(face_patch) | |
ax.add_patch(l_eyebrow) | |
ax.add_patch(r_eyebrow) | |
ax.add_patch(nose_v) | |
ax.add_patch(nose_h) | |
ax.add_patch(l_eye) | |
ax.add_patch(r_eye) | |
ax.add_patch(outer_lips) | |
ax.add_patch(inner_lips) | |
plt.axis('off') | |
fig.canvas.draw() | |
buffer, (width, height) = fig.canvas.print_to_buffer() | |
assert width == height | |
assert width == size | |
buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4)) | |
buffer = buffer[:, :, 0:3] | |
plt.close(fig) | |
return PIL.Image.fromarray(buffer) | |
def get_conditioning(image): | |
# Steps: convert to BGR and then: | |
# - Retrieve bounding box using `dlib` | |
# - Obtain landmarks using `spiga` | |
# - Create conditioning image with custom `matplotlib` code | |
# TODO: error if bbox is too small | |
image.thumbnail((512, 512)) | |
image = np.array(image) | |
image = image[:, :, ::-1] | |
bbox = get_bounding_box(image) | |
landmarks = get_landmarks(image, bbox) | |
spiga_seg = conditioning_from_landmarks(landmarks) | |
return spiga_seg | |
def generate_images(image, prompt): | |
conditioning = get_conditioning(image) | |
output = pipe( | |
prompt, | |
conditioning, | |
generator=generator, | |
num_images_per_prompt=3, | |
num_inference_steps=20, | |
) | |
return [conditioning] + output.images | |
gr.Interface( | |
generate_images, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Textbox( | |
label="Enter your prompt", | |
max_lines=1, | |
placeholder="best quality, extremely detailed", | |
), | |
], | |
outputs=gr.Gallery().style(grid=[2], height="auto"), | |
title="Generate controlled outputs with ControlNet and Stable Diffusion. ", | |
description="This Space uses pose estimated lines as the additional conditioning.", | |
# "happy zombie" instead of "young woman" works great too :) | |
examples=[["pedro-512.jpg", "Highly detailed photograph of young woman smiling, with palm trees in the background"]], | |
allow_flagging=False, | |
).launch(enable_queue=True) | |