pcuenq HF staff commited on
Commit
ed82520
1 Parent(s): 4ce86a1

Initial version

Browse files
Files changed (4) hide show
  1. README.md +5 -5
  2. app.py +142 -0
  3. pedro-512.jpg +0 -0
  4. requirements.txt +9 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Uncanny Faces
3
- emoji: 🏢
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
  ---
1
  ---
2
+ title: ControlNet Openpose
3
+ emoji: 😻
4
+ colorFrom: green
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import dlib
4
+ import numpy as np
5
+ import PIL
6
+
7
+ # Only used to convert to gray, could do it differently and remove this big dependency
8
+ import cv2
9
+
10
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
11
+ from diffusers import UniPCMultistepScheduler
12
+
13
+ from spiga.inference.config import ModelConfig
14
+ from spiga.inference.framework import SPIGAFramework
15
+
16
+ import matplotlib.pyplot as plt
17
+ from matplotlib.path import Path
18
+ import matplotlib.patches as patches
19
+
20
+ # Bounding boxes
21
+ face_detector = dlib.get_frontal_face_detector()
22
+
23
+ # Landmark extraction
24
+ spiga_extractor = SPIGAFramework(ModelConfig("300wpublic"))
25
+
26
+ uncanny_controlnet = ControlNetModel.from_pretrained(
27
+ "multimodalart/uncannyfaces_25K", torch_dtype=torch.float16
28
+ )
29
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
30
+ "stabilityai/stable-diffusion-2-1-base", controlnet=uncanny_controlnet, safety_checker=None, torch_dtype=torch.float16
31
+ )
32
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
33
+ pipe = pipe.to("cuda")
34
+
35
+ # Generator seed,
36
+ generator = torch.manual_seed(0)
37
+
38
+ def get_bounding_box(image):
39
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
40
+ face = face_detector(gray)[0]
41
+ bbox = [face.left(), face.top(), face.width(), face.height()]
42
+ return bbox
43
+
44
+ def get_landmarks(image, bbox):
45
+ features = spiga_extractor.inference(image, [bbox])
46
+ return features['landmarks'][0]
47
+
48
+ def get_patch(landmarks, color='lime', closed=False):
49
+ contour = landmarks
50
+ ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1)
51
+ facecolor = (0, 0, 0, 0) # Transparent fill color, if open
52
+ if closed:
53
+ contour.append(contour[0])
54
+ ops.append(Path.CLOSEPOLY)
55
+ facecolor = color
56
+ path = Path(contour, ops)
57
+ return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4)
58
+
59
+ def conditioning_from_landmarks(landmarks, size=512):
60
+ # Precisely control output image size
61
+ dpi = 72
62
+ fig, ax = plt.subplots(1, figsize=[size/dpi, size/dpi], tight_layout={'pad':0})
63
+ fig.set_dpi(dpi)
64
+
65
+ black = np.zeros((size, size, 3))
66
+ ax.imshow(black)
67
+
68
+ face_patch = get_patch(landmarks[0:17])
69
+ l_eyebrow = get_patch(landmarks[17:22], color='yellow')
70
+ r_eyebrow = get_patch(landmarks[22:27], color='yellow')
71
+ nose_v = get_patch(landmarks[27:31], color='orange')
72
+ nose_h = get_patch(landmarks[31:36], color='orange')
73
+ l_eye = get_patch(landmarks[36:42], color='magenta', closed=True)
74
+ r_eye = get_patch(landmarks[42:48], color='magenta', closed=True)
75
+ outer_lips = get_patch(landmarks[48:60], color='cyan', closed=True)
76
+ inner_lips = get_patch(landmarks[60:68], color='blue', closed=True)
77
+
78
+ ax.add_patch(face_patch)
79
+ ax.add_patch(l_eyebrow)
80
+ ax.add_patch(r_eyebrow)
81
+ ax.add_patch(nose_v)
82
+ ax.add_patch(nose_h)
83
+ ax.add_patch(l_eye)
84
+ ax.add_patch(r_eye)
85
+ ax.add_patch(outer_lips)
86
+ ax.add_patch(inner_lips)
87
+
88
+ plt.axis('off')
89
+
90
+ fig.canvas.draw()
91
+ buffer, (width, height) = fig.canvas.print_to_buffer()
92
+ assert width == height
93
+ assert width == size
94
+
95
+ buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4))
96
+ buffer = buffer[:, :, 0:3]
97
+ plt.close(fig)
98
+ return PIL.Image.fromarray(buffer)
99
+
100
+ def get_conditioning(image):
101
+ # Steps: convert to BGR and then:
102
+ # - Retrieve bounding box using `dlib`
103
+ # - Obtain landmarks using `spiga`
104
+ # - Create conditioning image with custom `matplotlib` code
105
+ # TODO: error if bbox is too small
106
+ image.thumbnail((512, 512))
107
+ image = np.array(image)
108
+ image = image[:, :, ::-1]
109
+ bbox = get_bounding_box(image)
110
+ landmarks = get_landmarks(image, bbox)
111
+ spiga_seg = conditioning_from_landmarks(landmarks)
112
+ return spiga_seg
113
+
114
+ def generate_images(image, prompt):
115
+ conditioning = get_conditioning(image)
116
+ output = pipe(
117
+ prompt,
118
+ conditioning,
119
+ generator=generator,
120
+ num_images_per_prompt=3,
121
+ num_inference_steps=20,
122
+ )
123
+ return [conditioning] + output.images
124
+
125
+
126
+ gr.Interface(
127
+ generate_images,
128
+ inputs=[
129
+ gr.Image(type="pil"),
130
+ gr.Textbox(
131
+ label="Enter your prompt",
132
+ max_lines=1,
133
+ placeholder="best quality, extremely detailed",
134
+ ),
135
+ ],
136
+ outputs=gr.Gallery().style(grid=[2], height="auto"),
137
+ title="Generate controlled outputs with ControlNet and Stable Diffusion. ",
138
+ description="This Space uses pose estimated lines as the additional conditioning.",
139
+ # "happy zombie" instead of "young woman" works great too :)
140
+ examples=[["pedro-512.jpg", "Highly detailed photograph of young woman smiling, with palm trees in the background"]],
141
+ allow_flagging=False,
142
+ ).launch(enable_queue=True)
pedro-512.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ accelerate
4
+ torch
5
+ git+https://github.com/andresprados/SPIGA
6
+ dlib
7
+ opencv-python
8
+ matplotlib
9
+ Pillow