RamAnanth1 commited on
Commit
6f9700a
1 Parent(s): 1f955e8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -464
model.py CHANGED
@@ -27,14 +27,9 @@ from annotator.uniformer import apply_uniformer
27
  from annotator.util import HWC3, resize_image
28
 
29
  CONTROLNET_MODEL_IDS = {
30
- 'canny': 'lllyasviel/sd-controlnet-canny',
31
- 'hough': 'lllyasviel/sd-controlnet-mlsd',
32
- 'hed': 'lllyasviel/sd-controlnet-hed',
33
- 'scribble': 'lllyasviel/sd-controlnet-scribble',
34
- 'pose': 'lllyasviel/sd-controlnet-openpose',
35
- 'seg': 'lllyasviel/sd-controlnet-seg',
36
  'depth': 'lllyasviel/sd-controlnet-depth',
37
- 'normal': 'lllyasviel/sd-controlnet-normal',
38
  }
39
 
40
 
@@ -131,405 +126,6 @@ class Model:
131
  generator=generator,
132
  image=control_image).images
133
 
134
- @staticmethod
135
- def preprocess_canny(
136
- input_image: np.ndarray,
137
- image_resolution: int,
138
- low_threshold: int,
139
- high_threshold: int,
140
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
141
- image = resize_image(HWC3(input_image), image_resolution)
142
- control_image = apply_canny(image, low_threshold, high_threshold)
143
- control_image = HWC3(control_image)
144
- vis_control_image = 255 - control_image
145
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
146
- vis_control_image)
147
-
148
- @torch.inference_mode()
149
- def process_canny(
150
- self,
151
- input_image: np.ndarray,
152
- prompt: str,
153
- additional_prompt: str,
154
- negative_prompt: str,
155
- num_images: int,
156
- image_resolution: int,
157
- num_steps: int,
158
- guidance_scale: float,
159
- seed: int,
160
- low_threshold: int,
161
- high_threshold: int,
162
- ) -> list[PIL.Image.Image]:
163
- control_image, vis_control_image = self.preprocess_canny(
164
- input_image=input_image,
165
- image_resolution=image_resolution,
166
- low_threshold=low_threshold,
167
- high_threshold=high_threshold,
168
- )
169
- self.load_controlnet_weight('canny')
170
- results = self.run_pipe(
171
- prompt=self.get_prompt(prompt, additional_prompt),
172
- negative_prompt=negative_prompt,
173
- control_image=control_image,
174
- num_images=num_images,
175
- num_steps=num_steps,
176
- guidance_scale=guidance_scale,
177
- seed=seed,
178
- )
179
- return [vis_control_image] + results
180
-
181
- @staticmethod
182
- def preprocess_hough(
183
- input_image: np.ndarray,
184
- image_resolution: int,
185
- detect_resolution: int,
186
- value_threshold: float,
187
- distance_threshold: float,
188
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
189
- input_image = HWC3(input_image)
190
- control_image = apply_mlsd(
191
- resize_image(input_image, detect_resolution), value_threshold,
192
- distance_threshold)
193
- control_image = HWC3(control_image)
194
- image = resize_image(input_image, image_resolution)
195
- H, W = image.shape[:2]
196
- control_image = cv2.resize(control_image, (W, H),
197
- interpolation=cv2.INTER_NEAREST)
198
-
199
- vis_control_image = 255 - cv2.dilate(
200
- control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
201
-
202
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
203
- vis_control_image)
204
-
205
- @torch.inference_mode()
206
- def process_hough(
207
- self,
208
- input_image: np.ndarray,
209
- prompt: str,
210
- additional_prompt: str,
211
- negative_prompt: str,
212
- num_images: int,
213
- image_resolution: int,
214
- detect_resolution: int,
215
- num_steps: int,
216
- guidance_scale: float,
217
- seed: int,
218
- value_threshold: float,
219
- distance_threshold: float,
220
- ) -> list[PIL.Image.Image]:
221
- control_image, vis_control_image = self.preprocess_hough(
222
- input_image=input_image,
223
- image_resolution=image_resolution,
224
- detect_resolution=detect_resolution,
225
- value_threshold=value_threshold,
226
- distance_threshold=distance_threshold,
227
- )
228
- self.load_controlnet_weight('hough')
229
- results = self.run_pipe(
230
- prompt=self.get_prompt(prompt, additional_prompt),
231
- negative_prompt=negative_prompt,
232
- control_image=control_image,
233
- num_images=num_images,
234
- num_steps=num_steps,
235
- guidance_scale=guidance_scale,
236
- seed=seed,
237
- )
238
- return [vis_control_image] + results
239
-
240
- @staticmethod
241
- def preprocess_hed(
242
- input_image: np.ndarray,
243
- image_resolution: int,
244
- detect_resolution: int,
245
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
246
- input_image = HWC3(input_image)
247
- control_image = apply_hed(resize_image(input_image, detect_resolution))
248
- control_image = HWC3(control_image)
249
- image = resize_image(input_image, image_resolution)
250
- H, W = image.shape[:2]
251
- control_image = cv2.resize(control_image, (W, H),
252
- interpolation=cv2.INTER_LINEAR)
253
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
254
- control_image)
255
-
256
- @torch.inference_mode()
257
- def process_hed(
258
- self,
259
- input_image: np.ndarray,
260
- prompt: str,
261
- additional_prompt: str,
262
- negative_prompt: str,
263
- num_images: int,
264
- image_resolution: int,
265
- detect_resolution: int,
266
- num_steps: int,
267
- guidance_scale: float,
268
- seed: int,
269
- ) -> list[PIL.Image.Image]:
270
- control_image, vis_control_image = self.preprocess_hed(
271
- input_image=input_image,
272
- image_resolution=image_resolution,
273
- detect_resolution=detect_resolution,
274
- )
275
- self.load_controlnet_weight('hed')
276
- results = self.run_pipe(
277
- prompt=self.get_prompt(prompt, additional_prompt),
278
- negative_prompt=negative_prompt,
279
- control_image=control_image,
280
- num_images=num_images,
281
- num_steps=num_steps,
282
- guidance_scale=guidance_scale,
283
- seed=seed,
284
- )
285
- return [vis_control_image] + results
286
-
287
- @staticmethod
288
- def preprocess_scribble(
289
- input_image: np.ndarray,
290
- image_resolution: int,
291
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
292
- image = resize_image(HWC3(input_image), image_resolution)
293
- control_image = np.zeros_like(image, dtype=np.uint8)
294
- control_image[np.min(image, axis=2) < 127] = 255
295
- vis_control_image = 255 - control_image
296
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
297
- vis_control_image)
298
-
299
- @torch.inference_mode()
300
- def process_scribble(
301
- self,
302
- input_image: np.ndarray,
303
- prompt: str,
304
- additional_prompt: str,
305
- negative_prompt: str,
306
- num_images: int,
307
- image_resolution: int,
308
- num_steps: int,
309
- guidance_scale: float,
310
- seed: int,
311
- ) -> list[PIL.Image.Image]:
312
- control_image, vis_control_image = self.preprocess_scribble(
313
- input_image=input_image,
314
- image_resolution=image_resolution,
315
- )
316
- self.load_controlnet_weight('scribble')
317
- results = self.run_pipe(
318
- prompt=self.get_prompt(prompt, additional_prompt),
319
- negative_prompt=negative_prompt,
320
- control_image=control_image,
321
- num_images=num_images,
322
- num_steps=num_steps,
323
- guidance_scale=guidance_scale,
324
- seed=seed,
325
- )
326
- return [vis_control_image] + results
327
-
328
- @staticmethod
329
- def preprocess_scribble_interactive(
330
- input_image: np.ndarray,
331
- image_resolution: int,
332
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
333
- image = resize_image(HWC3(input_image['mask'][:, :, 0]),
334
- image_resolution)
335
- control_image = np.zeros_like(image, dtype=np.uint8)
336
- control_image[np.min(image, axis=2) > 127] = 255
337
- vis_control_image = 255 - control_image
338
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
339
- vis_control_image)
340
-
341
- @torch.inference_mode()
342
- def process_scribble_interactive(
343
- self,
344
- input_image: np.ndarray,
345
- prompt: str,
346
- additional_prompt: str,
347
- negative_prompt: str,
348
- num_images: int,
349
- image_resolution: int,
350
- num_steps: int,
351
- guidance_scale: float,
352
- seed: int,
353
- ) -> list[PIL.Image.Image]:
354
- control_image, vis_control_image = self.preprocess_scribble_interactive(
355
- input_image=input_image,
356
- image_resolution=image_resolution,
357
- )
358
- self.load_controlnet_weight('scribble')
359
- results = self.run_pipe(
360
- prompt=self.get_prompt(prompt, additional_prompt),
361
- negative_prompt=negative_prompt,
362
- control_image=control_image,
363
- num_images=num_images,
364
- num_steps=num_steps,
365
- guidance_scale=guidance_scale,
366
- seed=seed,
367
- )
368
- return [vis_control_image] + results
369
-
370
- @staticmethod
371
- def preprocess_fake_scribble(
372
- input_image: np.ndarray,
373
- image_resolution: int,
374
- detect_resolution: int,
375
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
376
- input_image = HWC3(input_image)
377
- control_image = apply_hed(resize_image(input_image, detect_resolution))
378
- control_image = HWC3(control_image)
379
- image = resize_image(input_image, image_resolution)
380
- H, W = image.shape[:2]
381
-
382
- control_image = cv2.resize(control_image, (W, H),
383
- interpolation=cv2.INTER_LINEAR)
384
- control_image = nms(control_image, 127, 3.0)
385
- control_image = cv2.GaussianBlur(control_image, (0, 0), 3.0)
386
- control_image[control_image > 4] = 255
387
- control_image[control_image < 255] = 0
388
-
389
- vis_control_image = 255 - control_image
390
-
391
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
392
- vis_control_image)
393
-
394
- @torch.inference_mode()
395
- def process_fake_scribble(
396
- self,
397
- input_image: np.ndarray,
398
- prompt: str,
399
- additional_prompt: str,
400
- negative_prompt: str,
401
- num_images: int,
402
- image_resolution: int,
403
- detect_resolution: int,
404
- num_steps: int,
405
- guidance_scale: float,
406
- seed: int,
407
- ) -> list[PIL.Image.Image]:
408
- control_image, vis_control_image = self.preprocess_fake_scribble(
409
- input_image=input_image,
410
- image_resolution=image_resolution,
411
- detect_resolution=detect_resolution,
412
- )
413
- self.load_controlnet_weight('scribble')
414
- results = self.run_pipe(
415
- prompt=self.get_prompt(prompt, additional_prompt),
416
- negative_prompt=negative_prompt,
417
- control_image=control_image,
418
- num_images=num_images,
419
- num_steps=num_steps,
420
- guidance_scale=guidance_scale,
421
- seed=seed,
422
- )
423
- return [vis_control_image] + results
424
-
425
- @staticmethod
426
- def preprocess_pose(
427
- input_image: np.ndarray,
428
- image_resolution: int,
429
- detect_resolution: int,
430
- is_pose_image: bool,
431
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
432
- input_image = HWC3(input_image)
433
- if not is_pose_image:
434
- control_image, _ = apply_openpose(
435
- resize_image(input_image, detect_resolution))
436
- control_image = HWC3(control_image)
437
- image = resize_image(input_image, image_resolution)
438
- H, W = image.shape[:2]
439
- control_image = cv2.resize(control_image, (W, H),
440
- interpolation=cv2.INTER_NEAREST)
441
- else:
442
- control_image = resize_image(input_image, image_resolution)
443
-
444
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
445
- control_image)
446
-
447
- @torch.inference_mode()
448
- def process_pose(
449
- self,
450
- input_image: np.ndarray,
451
- prompt: str,
452
- additional_prompt: str,
453
- negative_prompt: str,
454
- num_images: int,
455
- image_resolution: int,
456
- detect_resolution: int,
457
- num_steps: int,
458
- guidance_scale: float,
459
- seed: int,
460
- is_pose_image: bool,
461
- ) -> list[PIL.Image.Image]:
462
- control_image, vis_control_image = self.preprocess_pose(
463
- input_image=input_image,
464
- image_resolution=image_resolution,
465
- detect_resolution=detect_resolution,
466
- is_pose_image=is_pose_image,
467
- )
468
- self.load_controlnet_weight('pose')
469
- results = self.run_pipe(
470
- prompt=self.get_prompt(prompt, additional_prompt),
471
- negative_prompt=negative_prompt,
472
- control_image=control_image,
473
- num_images=num_images,
474
- num_steps=num_steps,
475
- guidance_scale=guidance_scale,
476
- seed=seed,
477
- )
478
- return [vis_control_image] + results
479
-
480
- @staticmethod
481
- def preprocess_seg(
482
- input_image: np.ndarray,
483
- image_resolution: int,
484
- detect_resolution: int,
485
- is_segmentation_map: bool,
486
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
487
- input_image = HWC3(input_image)
488
- if not is_segmentation_map:
489
- control_image = apply_uniformer(
490
- resize_image(input_image, detect_resolution))
491
- image = resize_image(input_image, image_resolution)
492
- H, W = image.shape[:2]
493
- control_image = cv2.resize(control_image, (W, H),
494
- interpolation=cv2.INTER_NEAREST)
495
- else:
496
- control_image = resize_image(input_image, image_resolution)
497
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
498
- control_image)
499
-
500
- @torch.inference_mode()
501
- def process_seg(
502
- self,
503
- input_image: np.ndarray,
504
- prompt: str,
505
- additional_prompt: str,
506
- negative_prompt: str,
507
- num_images: int,
508
- image_resolution: int,
509
- detect_resolution: int,
510
- num_steps: int,
511
- guidance_scale: float,
512
- seed: int,
513
- is_segmentation_map: bool,
514
- ) -> list[PIL.Image.Image]:
515
- control_image, vis_control_image = self.preprocess_seg(
516
- input_image=input_image,
517
- image_resolution=image_resolution,
518
- detect_resolution=detect_resolution,
519
- is_segmentation_map=is_segmentation_map,
520
- )
521
- self.load_controlnet_weight('seg')
522
- results = self.run_pipe(
523
- prompt=self.get_prompt(prompt, additional_prompt),
524
- negative_prompt=negative_prompt,
525
- control_image=control_image,
526
- num_images=num_images,
527
- num_steps=num_steps,
528
- guidance_scale=guidance_scale,
529
- seed=seed,
530
- )
531
- return [vis_control_image] + results
532
-
533
  @staticmethod
534
  def preprocess_depth(
535
  input_image: np.ndarray,
@@ -583,61 +179,3 @@ class Model:
583
  seed=seed,
584
  )
585
  return [vis_control_image] + results
586
-
587
- @staticmethod
588
- def preprocess_normal(
589
- input_image: np.ndarray,
590
- image_resolution: int,
591
- detect_resolution: int,
592
- bg_threshold: float,
593
- is_normal_image: bool,
594
- ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
595
- input_image = HWC3(input_image)
596
- if not is_normal_image:
597
- _, control_image = apply_midas(resize_image(
598
- input_image, detect_resolution),
599
- bg_th=bg_threshold)
600
- control_image = HWC3(control_image)
601
- image = resize_image(input_image, image_resolution)
602
- H, W = image.shape[:2]
603
- control_image = cv2.resize(control_image, (W, H),
604
- interpolation=cv2.INTER_LINEAR)
605
- else:
606
- control_image = resize_image(input_image, image_resolution)
607
- return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
608
- control_image)
609
-
610
- @torch.inference_mode()
611
- def process_normal(
612
- self,
613
- input_image: np.ndarray,
614
- prompt: str,
615
- additional_prompt: str,
616
- negative_prompt: str,
617
- num_images: int,
618
- image_resolution: int,
619
- detect_resolution: int,
620
- num_steps: int,
621
- guidance_scale: float,
622
- seed: int,
623
- bg_threshold: float,
624
- is_normal_image: bool,
625
- ) -> list[PIL.Image.Image]:
626
- control_image, vis_control_image = self.preprocess_normal(
627
- input_image=input_image,
628
- image_resolution=image_resolution,
629
- detect_resolution=detect_resolution,
630
- bg_threshold=bg_threshold,
631
- is_normal_image=is_normal_image,
632
- )
633
- self.load_controlnet_weight('normal')
634
- results = self.run_pipe(
635
- prompt=self.get_prompt(prompt, additional_prompt),
636
- negative_prompt=negative_prompt,
637
- control_image=control_image,
638
- num_images=num_images,
639
- num_steps=num_steps,
640
- guidance_scale=guidance_scale,
641
- seed=seed,
642
- )
643
- return [vis_control_image] + results
 
27
  from annotator.util import HWC3, resize_image
28
 
29
  CONTROLNET_MODEL_IDS = {
30
+
 
 
 
 
 
31
  'depth': 'lllyasviel/sd-controlnet-depth',
32
+
33
  }
34
 
35
 
 
126
  generator=generator,
127
  image=control_image).images
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @staticmethod
130
  def preprocess_depth(
131
  input_image: np.ndarray,
 
179
  seed=seed,
180
  )
181
  return [vis_control_image] + results