RamAnanth1 commited on
Commit
fdead57
1 Parent(s): 02fa9c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -3
app.py CHANGED
@@ -20,6 +20,13 @@ canny_checkpoint = "models/control_sd15_canny.pth"
20
  scribble_checkpoint = "models/control_sd15_scribble.pth"
21
  pose_checkpoint = "models/control_sd15_openpose.pth"
22
 
 
 
 
 
 
 
 
23
  pose_model = create_model('./models/cldm_v15.yaml').cpu()
24
  pose_model.load_state_dict(load_state_dict(cached_download(
25
  hf_hub_url(REPO_ID, pose_checkpoint)
@@ -38,8 +45,38 @@ def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples,
38
  # TODO: Add other control tasks
39
  if input_control == "Scribble":
40
  return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
41
- else:
42
  return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, image_resolution, ddim_steps, scale, seed, eta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
45
  with torch.no_grad():
@@ -114,6 +151,7 @@ def create_canvas(w, h):
114
 
115
  block = gr.Blocks().queue()
116
  control_task_list = [
 
117
  "Scribble",
118
  "Pose"
119
  ]
@@ -135,6 +173,8 @@ with block:
135
  with gr.Accordion("Advanced options", open=False):
136
  num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
137
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
 
 
138
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
139
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
140
  seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
@@ -144,9 +184,25 @@ with block:
144
  value='longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality')
145
  with gr.Column():
146
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
147
- ips = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta]
148
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
149
  examples_list = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  [
152
  "turtle.png",
@@ -159,7 +215,25 @@ with block:
159
  20,
160
  9.0,
161
  123490213,
162
- 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  ]
165
  ]
 
20
  scribble_checkpoint = "models/control_sd15_scribble.pth"
21
  pose_checkpoint = "models/control_sd15_openpose.pth"
22
 
23
+ canny_model = create_model('./models/cldm_v15.yaml').cpu()
24
+ canny_model.load_state_dict(load_state_dict(cached_download(
25
+ hf_hub_url(REPO_ID, canny_checkpoint)
26
+ ), location='cuda'))
27
+ canny_model = canny_model.cuda()
28
+ ddim_sampler = DDIMSampler(canny_model)
29
+
30
  pose_model = create_model('./models/cldm_v15.yaml').cpu()
31
  pose_model.load_state_dict(load_state_dict(cached_download(
32
  hf_hub_url(REPO_ID, pose_checkpoint)
 
45
  # TODO: Add other control tasks
46
  if input_control == "Scribble":
47
  return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
48
+ elif input_control == "Pose":
49
  return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, image_resolution, ddim_steps, scale, seed, eta)
50
+
51
+ return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
52
+
53
+ def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
54
+ with torch.no_grad():
55
+ img = resize_image(HWC3(input_image), image_resolution)
56
+ H, W, C = img.shape
57
+
58
+ detected_map = apply_canny(img, low_threshold, high_threshold)
59
+ detected_map = HWC3(detected_map)
60
+
61
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
62
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
63
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
64
+
65
+ seed_everything(seed)
66
+
67
+ cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
68
+ un_cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([n_prompt] * num_samples)]}
69
+ shape = (4, H // 8, W // 8)
70
+
71
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
72
+ shape, cond, verbose=False, eta=eta,
73
+ unconditional_guidance_scale=scale,
74
+ unconditional_conditioning=un_cond)
75
+ x_samples = canny_model.decode_first_stage(samples)
76
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
77
+
78
+ results = [x_samples[i] for i in range(num_samples)]
79
+ return [255 - detected_map] + results
80
 
81
  def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
82
  with torch.no_grad():
 
151
 
152
  block = gr.Blocks().queue()
153
  control_task_list = [
154
+ "Canny Edge Map"
155
  "Scribble",
156
  "Pose"
157
  ]
 
173
  with gr.Accordion("Advanced options", open=False):
174
  num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
175
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
176
+ low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
177
+ high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
178
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
179
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
180
  seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
 
184
  value='longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality')
185
  with gr.Column():
186
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
187
+ ips = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold]
188
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
189
  examples_list = [
190
+ [
191
+ "bird.png",
192
+ "bird",
193
+ "Canny Edge Map",
194
+ "best quality, extremely detailed",
195
+ 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
196
+ 1,
197
+ 512,
198
+ 20,
199
+ 9.0,
200
+ 123490213,
201
+ 0.0,
202
+ 100,
203
+ 200
204
+
205
+ ],
206
 
207
  [
208
  "turtle.png",
 
215
  20,
216
  9.0,
217
  123490213,
218
+ 0.0,
219
+ 100,
220
+ 200
221
+
222
+ ]
223
+ [
224
+ "pose1.png",
225
+ "Chef in the Kitchen",
226
+ "Pose",
227
+ "best quality, extremely detailed",
228
+ 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
229
+ 1,
230
+ 512,
231
+ 20,
232
+ 9.0,
233
+ 123490213,
234
+ 0.0,
235
+ 100,
236
+ 200
237
 
238
  ]
239
  ]