Enable live conditioning via webcam

#3
by radames - opened
Files changed (1) hide show
  1. app.py +102 -26
app.py CHANGED
@@ -22,19 +22,49 @@ else:
22
  device = torch.device("cpu")
23
 
24
  model = create_model('./models/cldm_v15.yaml').cpu()
25
- model.load_state_dict(load_state_dict('./models/control_sd15_landmarks.pth', location='cpu'))
 
26
  model = model.to(device)
27
  ddim_sampler = DDIMSampler(model)
28
 
29
  detector = dlib.get_frontal_face_detector()
30
  predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def draw_landmarks(image, landmarks, color="white", radius=2.5):
33
  draw = ImageDraw.Draw(image)
34
  for dot in landmarks:
35
  x, y = dot
36
  draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=color)
37
 
 
38
  def get_68landmarks_img(img):
39
  gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
40
  faces = detector(gray)
@@ -50,9 +80,14 @@ def get_68landmarks_img(img):
50
  con_img = np.array(con_img)
51
  return con_img
52
 
53
- def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta):
 
 
 
54
  input_image = np.flip(input_image, axis=2)
55
- num_samples = min(num_samples, 2) # Limit the number of samples to 2 for Spaces only
 
 
56
  with torch.no_grad():
57
  img = resize_image(HWC3(input_image), image_resolution)
58
  H, W, C = img.shape
@@ -63,7 +98,8 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
63
  detected_map = get_68landmarks_img(img)
64
  detected_map = HWC3(detected_map)
65
 
66
- control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0
 
67
  control = torch.stack([control for _ in range(num_samples)], dim=0)
68
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
69
 
@@ -74,14 +110,17 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
74
  if config.save_memory:
75
  model.low_vram_shift(is_diffusing=False)
76
 
77
- cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
78
- un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
 
 
79
  shape = (4, H // 8, W // 8)
80
 
81
  if config.save_memory:
82
  model.low_vram_shift(is_diffusing=True)
83
 
84
- model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
 
85
  samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
86
  shape, cond, verbose=False, eta=eta,
87
  unconditional_guidance_scale=scale,
@@ -91,45 +130,82 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
91
  model.low_vram_shift(is_diffusing=False)
92
 
93
  x_samples = model.decode_first_stage(samples)
94
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
 
95
 
96
  results = [x_samples[i] for i in range(num_samples)]
 
97
  return [255 - detected_map] + results
98
 
99
 
 
 
 
 
 
 
 
100
  block = gr.Blocks().queue()
101
  with block:
 
102
  with gr.Row():
103
  gr.Markdown("## Control Stable Diffusion with Face Landmarks")
104
  with gr.Row():
105
  with gr.Column():
106
- input_image = gr.Image(source='upload', type="numpy")
 
 
 
 
 
 
 
 
 
107
  prompt = gr.Textbox(label="Prompt")
108
  run_button = gr.Button(label="Run")
109
  with gr.Accordion("Advanced options", open=False):
110
- num_samples = gr.Slider(label="Images", minimum=1, maximum=2, value=1, step=1)
111
- image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
112
- strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
 
 
 
113
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
114
- landmark_direct_mode = gr.Checkbox(label='Input Landmark Directly', value=False)
115
- ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
116
- scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
117
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
 
 
 
 
118
  eta = gr.Number(label="eta (DDIM)", value=0.0)
119
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
 
120
  n_prompt = gr.Textbox(label="Negative Prompt",
121
  value='cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
122
  with gr.Column():
123
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
124
- ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta]
 
 
 
125
  gr.Examples(fn=process, examples=[
126
- ["examples/image0.jpg", "a silly clown face", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
127
- ["examples/image1.png", "a photo of a woman wearing glasses", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
128
- ["examples/image2.png", "a silly portrait of man with head tilted and a beautiful hair on the side", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
129
- ["examples/image3.png", "portrait handsome men", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
130
- ["examples/image4.jpg", "a beautiful woman looking at the sky", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
131
- ],inputs=ips, outputs=[result_gallery], cache_examples=True)
132
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
 
 
 
 
 
 
 
133
 
134
 
135
  block.launch()
 
22
  device = torch.device("cpu")
23
 
24
  model = create_model('./models/cldm_v15.yaml').cpu()
25
+ model.load_state_dict(load_state_dict(
26
+ './models/control_sd15_landmarks.pth', location='cpu'))
27
  model = model.to(device)
28
  ddim_sampler = DDIMSampler(model)
29
 
30
  detector = dlib.get_frontal_face_detector()
31
  predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
32
 
33
+
34
+ canvas_html = "<face-canvas id='canvas-root' data-mode='points' style='display:flex;max-width: 500px;margin: 0 auto;'></face-canvas>"
35
+ load_js = """
36
+ async () => {
37
+ const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/face-canvas.js"
38
+ fetch(url)
39
+ .then(res => res.text())
40
+ .then(text => {
41
+ const script = document.createElement('script');
42
+ script.type = "module"
43
+ script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
44
+ document.head.appendChild(script);
45
+ });
46
+ }
47
+ """
48
+ get_js_image = """
49
+ async (input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta, image_file_live_opt) => {
50
+ const canvasEl = document.getElementById("canvas-root");
51
+ const imageData = canvasEl? canvasEl._data : null;
52
+ if(image_file_live_opt === 'webcam'){
53
+ input_image = imageData['image']
54
+ landmark_direct_mode = true
55
+ }
56
+ return [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta, image_file_live_opt]
57
+ }
58
+ """
59
+
60
+
61
  def draw_landmarks(image, landmarks, color="white", radius=2.5):
62
  draw = ImageDraw.Draw(image)
63
  for dot in landmarks:
64
  x, y = dot
65
  draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=color)
66
 
67
+
68
  def get_68landmarks_img(img):
69
  gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
70
  faces = detector(gray)
 
80
  con_img = np.array(con_img)
81
  return con_img
82
 
83
+
84
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta, image_file_live_opt="file"):
85
+ input_image = input_image.convert('RGB')
86
+ input_image = np.array(input_image)
87
  input_image = np.flip(input_image, axis=2)
88
+ print('input_image.shape', input_image.shape)
89
+ # Limit the number of samples to 2 for Spaces only
90
+ num_samples = min(num_samples, 2)
91
  with torch.no_grad():
92
  img = resize_image(HWC3(input_image), image_resolution)
93
  H, W, C = img.shape
 
98
  detected_map = get_68landmarks_img(img)
99
  detected_map = HWC3(detected_map)
100
 
101
+ control = torch.from_numpy(
102
+ detected_map.copy()).float().to(device) / 255.0
103
  control = torch.stack([control for _ in range(num_samples)], dim=0)
104
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
105
 
 
110
  if config.save_memory:
111
  model.low_vram_shift(is_diffusing=False)
112
 
113
+ cond = {"c_concat": [control], "c_crossattn": [
114
+ model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
115
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [
116
+ model.get_learned_conditioning([n_prompt] * num_samples)]}
117
  shape = (4, H // 8, W // 8)
118
 
119
  if config.save_memory:
120
  model.low_vram_shift(is_diffusing=True)
121
 
122
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
123
+ [strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
124
  samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
125
  shape, cond, verbose=False, eta=eta,
126
  unconditional_guidance_scale=scale,
 
130
  model.low_vram_shift(is_diffusing=False)
131
 
132
  x_samples = model.decode_first_stage(samples)
133
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')
134
+ * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
135
 
136
  results = [x_samples[i] for i in range(num_samples)]
137
+
138
  return [255 - detected_map] + results
139
 
140
 
141
+ def toggle(choice):
142
+ if choice == "file":
143
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
144
+ elif choice == "webcam":
145
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
146
+
147
+
148
  block = gr.Blocks().queue()
149
  with block:
150
+ live_conditioning = gr.JSON(value={}, visible=False)
151
  with gr.Row():
152
  gr.Markdown("## Control Stable Diffusion with Face Landmarks")
153
  with gr.Row():
154
  with gr.Column():
155
+ image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
156
+ label="How would you like to upload your image?")
157
+ input_image = gr.Image(source="upload", visible=True, type="pil")
158
+ canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
159
+
160
+ image_file_live_opt.change(fn=toggle,
161
+ inputs=[image_file_live_opt],
162
+ outputs=[input_image, canvas],
163
+ queue=False)
164
+
165
  prompt = gr.Textbox(label="Prompt")
166
  run_button = gr.Button(label="Run")
167
  with gr.Accordion("Advanced options", open=False):
168
+ num_samples = gr.Slider(
169
+ label="Images", minimum=1, maximum=2, value=1, step=1)
170
+ image_resolution = gr.Slider(
171
+ label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
172
+ strength = gr.Slider(
173
+ label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
174
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
175
+ landmark_direct_mode = gr.Checkbox(
176
+ label='Input Landmark Directly', value=False)
177
+ ddim_steps = gr.Slider(
178
+ label="Steps", minimum=1, maximum=100, value=20, step=1)
179
+ scale = gr.Slider(label="Guidance Scale",
180
+ minimum=0.1, maximum=30.0, value=9.0, step=0.1)
181
+ seed = gr.Slider(label="Seed", minimum=-1,
182
+ maximum=2147483647, step=1, randomize=True)
183
  eta = gr.Number(label="eta (DDIM)", value=0.0)
184
+ a_prompt = gr.Textbox(
185
+ label="Added Prompt", value='best quality, extremely detailed')
186
  n_prompt = gr.Textbox(label="Negative Prompt",
187
  value='cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
188
  with gr.Column():
189
+ result_gallery = gr.Gallery(
190
+ label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
191
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution,
192
+ ddim_steps, guess_mode, landmark_direct_mode, strength, scale, seed, eta]
193
+
194
  gr.Examples(fn=process, examples=[
195
+ ["examples/image0.jpg", "a silly clown face", "best quality, extremely detailed",
196
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
197
+ ["examples/image1.png", "a photo of a woman wearing glasses", "best quality, extremely detailed",
198
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
199
+ ["examples/image2.png", "a silly portrait of man with head tilted and a beautiful hair on the side", "best quality, extremely detailed",
200
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
201
+ ["examples/image3.png", "portrait handsome men", "best quality, extremely detailed",
202
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
203
+ ["examples/image4.jpg", "a beautiful woman looking at the sky", "best quality, extremely detailed",
204
+ "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 1, 512, 20, False, False, 1.0, 9.0, -1, 0.0],
205
+ ], inputs=ips, outputs=[result_gallery], cache_examples=True)
206
+ run_button.click(fn=process, inputs=ips + [image_file_live_opt],
207
+ outputs=[result_gallery], _js=get_js_image)
208
+ block.load(None, None, None, _js=load_js)
209
 
210
 
211
  block.launch()