radames commited on
Commit
0dc4cd8
1 Parent(s): 9c37edb

covert to blocks, enable webcam

Browse files

raise errors if face is not detected
add extra examples

Files changed (4) hide show
  1. app.py +80 -35
  2. examples/image0.jpg +0 -0
  3. examples/image1.jpg +0 -0
  4. examples/pedro-512.jpg +0 -0
app.py CHANGED
@@ -35,16 +35,22 @@ pipe = pipe.to("cuda")
35
  # Generator seed,
36
  generator = torch.manual_seed(0)
37
 
 
38
  def get_bounding_box(image):
39
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
40
- face = face_detector(gray)[0]
 
 
 
41
  bbox = [face.left(), face.top(), face.width(), face.height()]
42
  return bbox
43
 
 
44
  def get_landmarks(image, bbox):
45
  features = spiga_extractor.inference(image, [bbox])
46
  return features['landmarks'][0]
47
 
 
48
  def get_patch(landmarks, color='lime', closed=False):
49
  contour = landmarks
50
  ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1)
@@ -56,10 +62,12 @@ def get_patch(landmarks, color='lime', closed=False):
56
  path = Path(contour, ops)
57
  return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4)
58
 
 
59
  def conditioning_from_landmarks(landmarks, size=512):
60
  # Precisely control output image size
61
  dpi = 72
62
- fig, ax = plt.subplots(1, figsize=[size/dpi, size/dpi], tight_layout={'pad':0})
 
63
  fig.set_dpi(dpi)
64
 
65
  black = np.zeros((size, size, 3))
@@ -86,17 +94,16 @@ def conditioning_from_landmarks(landmarks, size=512):
86
  ax.add_patch(inner_lips)
87
 
88
  plt.axis('off')
89
-
90
  fig.canvas.draw()
91
  buffer, (width, height) = fig.canvas.print_to_buffer()
92
  assert width == height
93
  assert width == size
94
-
95
  buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4))
96
  buffer = buffer[:, :, 0:3]
97
  plt.close(fig)
98
  return PIL.Image.fromarray(buffer)
99
 
 
100
  def get_conditioning(image):
101
  # Steps: convert to BGR and then:
102
  # - Retrieve bounding box using `dlib`
@@ -109,34 +116,72 @@ def get_conditioning(image):
109
  bbox = get_bounding_box(image)
110
  landmarks = get_landmarks(image, bbox)
111
  spiga_seg = conditioning_from_landmarks(landmarks)
112
- return spiga_seg
113
-
114
- def generate_images(image, prompt):
115
- conditioning = get_conditioning(image)
116
- output = pipe(
117
- prompt,
118
- conditioning,
119
- generator=generator,
120
- num_images_per_prompt=3,
121
- num_inference_steps=20,
122
- )
123
- return [conditioning] + output.images
124
-
125
-
126
- gr.Interface(
127
- generate_images,
128
- inputs=[
129
- gr.Image(type="pil"),
130
- gr.Textbox(
131
- label="Enter your prompt",
132
- max_lines=1,
133
- placeholder="best quality, extremely detailed",
134
- ),
135
- ],
136
- outputs=gr.Gallery().style(grid=[2], height="auto"),
137
- title="Generate controlled outputs with ControlNet and Stable Diffusion. ",
138
- description="This Space uses a custom visualization based on SPIGA face landmarks for conditioning.",
139
- # "happy zombie" instead of "young woman" works great too :)
140
- examples=[["pedro-512.jpg", "Highly detailed photograph of young woman smiling, with palm trees in the background"]],
141
- allow_flagging=False,
142
- ).launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Generator seed,
36
  generator = torch.manual_seed(0)
37
 
38
+
39
  def get_bounding_box(image):
40
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
41
+ faces = face_detector(gray)
42
+ if len(faces) == 0:
43
+ raise Exception("No face detected in image")
44
+ face = faces[0]
45
  bbox = [face.left(), face.top(), face.width(), face.height()]
46
  return bbox
47
 
48
+
49
  def get_landmarks(image, bbox):
50
  features = spiga_extractor.inference(image, [bbox])
51
  return features['landmarks'][0]
52
 
53
+
54
  def get_patch(landmarks, color='lime', closed=False):
55
  contour = landmarks
56
  ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1)
 
62
  path = Path(contour, ops)
63
  return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4)
64
 
65
+
66
  def conditioning_from_landmarks(landmarks, size=512):
67
  # Precisely control output image size
68
  dpi = 72
69
+ fig, ax = plt.subplots(
70
+ 1, figsize=[size/dpi, size/dpi], tight_layout={'pad': 0})
71
  fig.set_dpi(dpi)
72
 
73
  black = np.zeros((size, size, 3))
 
94
  ax.add_patch(inner_lips)
95
 
96
  plt.axis('off')
 
97
  fig.canvas.draw()
98
  buffer, (width, height) = fig.canvas.print_to_buffer()
99
  assert width == height
100
  assert width == size
 
101
  buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4))
102
  buffer = buffer[:, :, 0:3]
103
  plt.close(fig)
104
  return PIL.Image.fromarray(buffer)
105
 
106
+
107
  def get_conditioning(image):
108
  # Steps: convert to BGR and then:
109
  # - Retrieve bounding box using `dlib`
 
116
  bbox = get_bounding_box(image)
117
  landmarks = get_landmarks(image, bbox)
118
  spiga_seg = conditioning_from_landmarks(landmarks)
119
+ return spiga_seg
120
+
121
+
122
+ def generate_images(image, prompt, image_video=None):
123
+ if image is None and image_video is None:
124
+ raise gr.Error("Please provide an image")
125
+ if image_video is not None:
126
+ image = image_video
127
+ try:
128
+ conditioning = get_conditioning(image)
129
+ output = pipe(
130
+ prompt,
131
+ conditioning,
132
+ generator=generator,
133
+ num_images_per_prompt=3,
134
+ num_inference_steps=20,
135
+ )
136
+ return [conditioning] + output.images
137
+ except Exception as e:
138
+ raise gr.Error(str(e))
139
+
140
+
141
+ def toggle(choice):
142
+ if choice == "webcam":
143
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
144
+ else:
145
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
146
+
147
+
148
+ with gr.Blocks() as blocks:
149
+ gr.Markdown("""
150
+ ## Generate controlled outputs with ControlNet and Stable Diffusion.
151
+ This Space uses a custom visualization based on SPIGA face landmarks for conditioning.
152
+ """)
153
+ with gr.Row():
154
+ with gr.Column():
155
+ image_or_file_opt = gr.Radio(["file", "webcam"], value="file",
156
+ label="How would you like to upload your image?")
157
+ image_in_video = gr.Image(
158
+ source="webcam", type="pil", visible=False)
159
+ image_in_img = gr.Image(
160
+ source="upload", visible=True, type="pil")
161
+ image_or_file_opt.change(fn=toggle, inputs=[image_or_file_opt],
162
+ outputs=[image_in_video, image_in_img], queue=False)
163
+ prompt = gr.Textbox(
164
+ label="Enter your prompt",
165
+ max_lines=1,
166
+ placeholder="best quality, extremely detailed",
167
+ )
168
+ run_button = gr.Button("Generate")
169
+ with gr.Column():
170
+ gallery = gr.Gallery().style(grid=[2], height="auto")
171
+ run_button.click(fn=generate_images,
172
+ inputs=[image_in_img, prompt, image_in_video],
173
+ outputs=[gallery])
174
+ gr.Examples(fn=generate_images,
175
+ examples=[
176
+ ["./examples/pedro-512.jpg",
177
+ "Highly detailed photograph of young woman smiling, with palm trees in the background"],
178
+ ["./examples/image1.jpg",
179
+ "Highly detailed photograph of a scary clown"],
180
+ ["./examples/image0.jpg",
181
+ "Highly detailed photograph of Barack Obama"],
182
+ ],
183
+ inputs=[image_in_img, prompt],
184
+ outputs=[gallery],
185
+ cache_examples=True)
186
+
187
+ blocks.launch()
examples/image0.jpg ADDED
examples/image1.jpg ADDED
examples/pedro-512.jpg ADDED