RamAnanth1 commited on
Commit
02fa9c7
1 Parent(s): 6585503

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -62
app.py CHANGED
@@ -20,13 +20,6 @@ canny_checkpoint = "models/control_sd15_canny.pth"
20
  scribble_checkpoint = "models/control_sd15_scribble.pth"
21
  pose_checkpoint = "models/control_sd15_openpose.pth"
22
 
23
- canny_model = create_model('./models/cldm_v15.yaml').cpu()
24
- canny_model.load_state_dict(load_state_dict(cached_download(
25
- hf_hub_url(REPO_ID, canny_checkpoint)
26
- ), location='cuda'))
27
- canny_model = canny_model.cuda()
28
- ddim_sampler = DDIMSampler(canny_model)
29
-
30
  pose_model = create_model('./models/cldm_v15.yaml').cpu()
31
  pose_model.load_state_dict(load_state_dict(cached_download(
32
  hf_hub_url(REPO_ID, pose_checkpoint)
@@ -34,10 +27,10 @@ pose_model.load_state_dict(load_state_dict(cached_download(
34
  pose_model = pose_model.cuda()
35
  ddim_sampler_pose = DDIMSampler(pose_model)
36
 
37
- scribble_model = create_model('./models/cldm_v15.yaml')
38
  scribble_model.load_state_dict(load_state_dict(cached_download(
39
  hf_hub_url(REPO_ID, scribble_checkpoint)
40
- ), location='cpu'))
41
  scribble_model = canny_model.cuda()
42
  ddim_sampler_scribble = DDIMSampler(scribble_model)
43
 
@@ -45,38 +38,9 @@ def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples,
45
  # TODO: Add other control tasks
46
  if input_control == "Scribble":
47
  return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
48
- elif input_control == "Pose":
49
  return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, image_resolution, ddim_steps, scale, seed, eta)
50
- return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
51
 
52
- def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
53
- with torch.no_grad():
54
- img = resize_image(HWC3(input_image), image_resolution)
55
- H, W, C = img.shape
56
-
57
- detected_map = apply_canny(img, low_threshold, high_threshold)
58
- detected_map = HWC3(detected_map)
59
-
60
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
61
- control = torch.stack([control for _ in range(num_samples)], dim=0)
62
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
63
-
64
- seed_everything(seed)
65
-
66
- cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
67
- un_cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([n_prompt] * num_samples)]}
68
- shape = (4, H // 8, W // 8)
69
-
70
- samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
71
- shape, cond, verbose=False, eta=eta,
72
- unconditional_guidance_scale=scale,
73
- unconditional_conditioning=un_cond)
74
- x_samples = canny_model.decode_first_stage(samples)
75
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
76
-
77
- results = [x_samples[i] for i in range(num_samples)]
78
- return [255 - detected_map] + results
79
-
80
  def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
81
  with torch.no_grad():
82
  img = resize_image(HWC3(input_image), image_resolution)
@@ -150,7 +114,6 @@ def create_canvas(w, h):
150
 
151
  block = gr.Blocks().queue()
152
  control_task_list = [
153
- "Canny Edge Map",
154
  "Scribble",
155
  "Pose"
156
  ]
@@ -172,8 +135,6 @@ with block:
172
  with gr.Accordion("Advanced options", open=False):
173
  num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
174
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
175
- low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
176
- high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
177
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
178
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
179
  seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
@@ -183,25 +144,10 @@ with block:
183
  value='longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality')
184
  with gr.Column():
185
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
186
- ips = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold]
187
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
188
  examples_list = [
189
- [
190
- "bird.png",
191
- "bird",
192
- "Canny Edge Map",
193
- "best quality, extremely detailed",
194
- 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
195
- 1,
196
- 512,
197
- 20,
198
- 9.0,
199
- 123490213,
200
- 0.0,
201
- 100,
202
- 200
203
-
204
- ],
205
  [
206
  "turtle.png",
207
  "turtle",
@@ -213,9 +159,7 @@ with block:
213
  20,
214
  9.0,
215
  123490213,
216
- 0.0,
217
- 100,
218
- 200
219
 
220
  ]
221
  ]
 
20
  scribble_checkpoint = "models/control_sd15_scribble.pth"
21
  pose_checkpoint = "models/control_sd15_openpose.pth"
22
 
 
 
 
 
 
 
 
23
  pose_model = create_model('./models/cldm_v15.yaml').cpu()
24
  pose_model.load_state_dict(load_state_dict(cached_download(
25
  hf_hub_url(REPO_ID, pose_checkpoint)
 
27
  pose_model = pose_model.cuda()
28
  ddim_sampler_pose = DDIMSampler(pose_model)
29
 
30
+ scribble_model = create_model('./models/cldm_v15.yaml').cpu()
31
  scribble_model.load_state_dict(load_state_dict(cached_download(
32
  hf_hub_url(REPO_ID, scribble_checkpoint)
33
+ ), location='cuda'))
34
  scribble_model = canny_model.cuda()
35
  ddim_sampler_scribble = DDIMSampler(scribble_model)
36
 
 
38
  # TODO: Add other control tasks
39
  if input_control == "Scribble":
40
  return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
41
+ else:
42
  return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, image_resolution, ddim_steps, scale, seed, eta)
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
45
  with torch.no_grad():
46
  img = resize_image(HWC3(input_image), image_resolution)
 
114
 
115
  block = gr.Blocks().queue()
116
  control_task_list = [
 
117
  "Scribble",
118
  "Pose"
119
  ]
 
135
  with gr.Accordion("Advanced options", open=False):
136
  num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
137
  image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
 
 
138
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
139
  scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
140
  seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
 
144
  value='longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality')
145
  with gr.Column():
146
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
147
+ ips = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta]
148
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
149
  examples_list = [
150
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  [
152
  "turtle.png",
153
  "turtle",
 
159
  20,
160
  9.0,
161
  123490213,
162
+ 0.0
 
 
163
 
164
  ]
165
  ]