hysts HF staff commited on
Commit
0ae9725
β€’
1 Parent(s): 83f6448

Update to use diffusers

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸŒ–
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.18.0
8
  python_version: 3.10.9
9
  app_file: app.py
10
  pinned: false
 
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.20.0
8
  python_version: 3.10.9
9
  app_file: app.py
10
  pinned: false
gradio_canny2image.py CHANGED
@@ -23,33 +23,33 @@ def create_demo(process, max_images=12):
23
  maximum=768,
24
  value=512,
25
  step=256)
26
- low_threshold = gr.Slider(label='Canny low threshold',
27
- minimum=1,
28
- maximum=255,
29
- value=100,
30
- step=1)
31
- high_threshold = gr.Slider(label='Canny high threshold',
32
- minimum=1,
33
- maximum=255,
34
- value=200,
35
- step=1)
36
- ddim_steps = gr.Slider(label='Steps',
37
- minimum=1,
38
- maximum=100,
39
- value=20,
40
- step=1)
41
- scale = gr.Slider(label='Guidance Scale',
42
- minimum=0.1,
43
- maximum=30.0,
44
- value=9.0,
45
- step=0.1)
 
 
46
  seed = gr.Slider(label='Seed',
47
  minimum=-1,
48
  maximum=2147483647,
49
  step=1,
50
- randomize=True,
51
- queue=False)
52
- eta = gr.Number(label='eta (DDIM)', value=0.0)
53
  a_prompt = gr.Textbox(
54
  label='Added Prompt',
55
  value='best quality, extremely detailed')
@@ -59,17 +59,25 @@ def create_demo(process, max_images=12):
59
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
60
  )
61
  with gr.Column():
62
- result_gallery = gr.Gallery(label='Output',
63
- show_label=False,
64
- elem_id='gallery').style(
65
- grid=2, height='auto')
66
- ips = [
67
- input_image, prompt, a_prompt, n_prompt, num_samples,
68
- image_resolution, ddim_steps, scale, seed, eta, low_threshold,
69
- high_threshold
 
 
 
 
 
 
 
 
70
  ]
71
  run_button.click(fn=process,
72
- inputs=ips,
73
- outputs=[result_gallery],
74
  api_name='canny')
75
  return demo
 
23
  maximum=768,
24
  value=512,
25
  step=256)
26
+ canny_low_threshold = gr.Slider(
27
+ label='Canny low threshold',
28
+ minimum=1,
29
+ maximum=255,
30
+ value=100,
31
+ step=1)
32
+ canny_high_threshold = gr.Slider(
33
+ label='Canny high threshold',
34
+ minimum=1,
35
+ maximum=255,
36
+ value=200,
37
+ step=1)
38
+ num_steps = gr.Slider(label='Steps',
39
+ minimum=1,
40
+ maximum=100,
41
+ value=20,
42
+ step=1)
43
+ guidance_scale = gr.Slider(label='Guidance Scale',
44
+ minimum=0.1,
45
+ maximum=30.0,
46
+ value=9.0,
47
+ step=0.1)
48
  seed = gr.Slider(label='Seed',
49
  minimum=-1,
50
  maximum=2147483647,
51
  step=1,
52
+ randomize=True)
 
 
53
  a_prompt = gr.Textbox(
54
  label='Added Prompt',
55
  value='best quality, extremely detailed')
 
59
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
60
  )
61
  with gr.Column():
62
+ result = gr.Gallery(label='Output',
63
+ show_label=False,
64
+ elem_id='gallery').style(grid=2,
65
+ height='auto')
66
+ inputs = [
67
+ input_image,
68
+ prompt,
69
+ a_prompt,
70
+ n_prompt,
71
+ num_samples,
72
+ image_resolution,
73
+ num_steps,
74
+ guidance_scale,
75
+ seed,
76
+ canny_low_threshold,
77
+ canny_high_threshold,
78
  ]
79
  run_button.click(fn=process,
80
+ inputs=inputs,
81
+ outputs=result,
82
  api_name='canny')
83
  return demo
gradio_depth2image.py CHANGED
@@ -28,23 +28,21 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=384,
30
  step=1)
31
- ddim_steps = gr.Slider(label='Steps',
32
- minimum=1,
33
- maximum=100,
34
- value=20,
35
- step=1)
36
- scale = gr.Slider(label='Guidance Scale',
37
- minimum=0.1,
38
- maximum=30.0,
39
- value=9.0,
40
- step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
- randomize=True,
46
- queue=False)
47
- eta = gr.Number(label='eta (DDIM)', value=0.0)
48
  a_prompt = gr.Textbox(
49
  label='Added Prompt',
50
  value='best quality, extremely detailed')
@@ -54,16 +52,24 @@ def create_demo(process, max_images=12):
54
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
55
  )
56
  with gr.Column():
57
- result_gallery = gr.Gallery(label='Output',
58
- show_label=False,
59
- elem_id='gallery').style(
60
- grid=2, height='auto')
61
- ips = [
62
- input_image, prompt, a_prompt, n_prompt, num_samples,
63
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
64
  ]
65
  run_button.click(fn=process,
66
- inputs=ips,
67
- outputs=[result_gallery],
68
  api_name='depth')
69
  return demo
 
28
  maximum=1024,
29
  value=384,
30
  step=1)
31
+ num_steps = gr.Slider(label='Steps',
32
+ minimum=1,
33
+ maximum=100,
34
+ value=20,
35
+ step=1)
36
+ guidance_scale = gr.Slider(label='Guidance Scale',
37
+ minimum=0.1,
38
+ maximum=30.0,
39
+ value=9.0,
40
+ step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
+ randomize=True)
 
 
46
  a_prompt = gr.Textbox(
47
  label='Added Prompt',
48
  value='best quality, extremely detailed')
 
52
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
53
  )
54
  with gr.Column():
55
+ result = gr.Gallery(label='Output',
56
+ show_label=False,
57
+ elem_id='gallery').style(grid=2,
58
+ height='auto')
59
+ inputs = [
60
+ input_image,
61
+ prompt,
62
+ a_prompt,
63
+ n_prompt,
64
+ num_samples,
65
+ image_resolution,
66
+ detect_resolution,
67
+ num_steps,
68
+ guidance_scale,
69
+ seed,
70
  ]
71
  run_button.click(fn=process,
72
+ inputs=inputs,
73
+ outputs=result,
74
  api_name='depth')
75
  return demo
gradio_fake_scribble2image.py CHANGED
@@ -28,23 +28,21 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
- ddim_steps = gr.Slider(label='Steps',
32
- minimum=1,
33
- maximum=100,
34
- value=20,
35
- step=1)
36
- scale = gr.Slider(label='Guidance Scale',
37
- minimum=0.1,
38
- maximum=30.0,
39
- value=9.0,
40
- step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
- randomize=True,
46
- queue=False)
47
- eta = gr.Number(label='eta (DDIM)', value=0.0)
48
  a_prompt = gr.Textbox(
49
  label='Added Prompt',
50
  value='best quality, extremely detailed')
@@ -54,16 +52,24 @@ def create_demo(process, max_images=12):
54
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
55
  )
56
  with gr.Column():
57
- result_gallery = gr.Gallery(label='Output',
58
- show_label=False,
59
- elem_id='gallery').style(
60
- grid=2, height='auto')
61
- ips = [
62
- input_image, prompt, a_prompt, n_prompt, num_samples,
63
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
64
  ]
65
  run_button.click(fn=process,
66
- inputs=ips,
67
- outputs=[result_gallery],
68
  api_name='fake_scribble')
69
  return demo
 
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
+ num_steps = gr.Slider(label='Steps',
32
+ minimum=1,
33
+ maximum=100,
34
+ value=20,
35
+ step=1)
36
+ guidance_scale = gr.Slider(label='Guidance Scale',
37
+ minimum=0.1,
38
+ maximum=30.0,
39
+ value=9.0,
40
+ step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
+ randomize=True)
 
 
46
  a_prompt = gr.Textbox(
47
  label='Added Prompt',
48
  value='best quality, extremely detailed')
 
52
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
53
  )
54
  with gr.Column():
55
+ result = gr.Gallery(label='Output',
56
+ show_label=False,
57
+ elem_id='gallery').style(grid=2,
58
+ height='auto')
59
+ inputs = [
60
+ input_image,
61
+ prompt,
62
+ a_prompt,
63
+ n_prompt,
64
+ num_samples,
65
+ image_resolution,
66
+ detect_resolution,
67
+ num_steps,
68
+ guidance_scale,
69
+ seed,
70
  ]
71
  run_button.click(fn=process,
72
+ inputs=inputs,
73
+ outputs=result,
74
  api_name='fake_scribble')
75
  return demo
gradio_hed2image.py CHANGED
@@ -28,23 +28,21 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
- ddim_steps = gr.Slider(label='Steps',
32
- minimum=1,
33
- maximum=100,
34
- value=20,
35
- step=1)
36
- scale = gr.Slider(label='Guidance Scale',
37
- minimum=0.1,
38
- maximum=30.0,
39
- value=9.0,
40
- step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
- randomize=True,
46
- queue=False)
47
- eta = gr.Number(label='eta (DDIM)', value=0.0)
48
  a_prompt = gr.Textbox(
49
  label='Added Prompt',
50
  value='best quality, extremely detailed')
@@ -54,16 +52,24 @@ def create_demo(process, max_images=12):
54
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
55
  )
56
  with gr.Column():
57
- result_gallery = gr.Gallery(label='Output',
58
- show_label=False,
59
- elem_id='gallery').style(
60
- grid=2, height='auto')
61
- ips = [
62
- input_image, prompt, a_prompt, n_prompt, num_samples,
63
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
64
  ]
65
  run_button.click(fn=process,
66
- inputs=ips,
67
- outputs=[result_gallery],
68
  api_name='hed')
69
  return demo
 
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
+ num_steps = gr.Slider(label='Steps',
32
+ minimum=1,
33
+ maximum=100,
34
+ value=20,
35
+ step=1)
36
+ guidance_scale = gr.Slider(label='Guidance Scale',
37
+ minimum=0.1,
38
+ maximum=30.0,
39
+ value=9.0,
40
+ step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
+ randomize=True)
 
 
46
  a_prompt = gr.Textbox(
47
  label='Added Prompt',
48
  value='best quality, extremely detailed')
 
52
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
53
  )
54
  with gr.Column():
55
+ result = gr.Gallery(label='Output',
56
+ show_label=False,
57
+ elem_id='gallery').style(grid=2,
58
+ height='auto')
59
+ inputs = [
60
+ input_image,
61
+ prompt,
62
+ a_prompt,
63
+ n_prompt,
64
+ num_samples,
65
+ image_resolution,
66
+ detect_resolution,
67
+ num_steps,
68
+ guidance_scale,
69
+ seed,
70
  ]
71
  run_button.click(fn=process,
72
+ inputs=inputs,
73
+ outputs=result,
74
  api_name='hed')
75
  return demo
gradio_hough2image.py CHANGED
@@ -28,35 +28,33 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
- value_threshold = gr.Slider(
32
  label='Hough value threshold (MLSD)',
33
  minimum=0.01,
34
  maximum=2.0,
35
  value=0.1,
36
  step=0.01)
37
- distance_threshold = gr.Slider(
38
  label='Hough distance threshold (MLSD)',
39
  minimum=0.01,
40
  maximum=20.0,
41
  value=0.1,
42
  step=0.01)
43
- ddim_steps = gr.Slider(label='Steps',
44
- minimum=1,
45
- maximum=100,
46
- value=20,
47
- step=1)
48
- scale = gr.Slider(label='Guidance Scale',
49
- minimum=0.1,
50
- maximum=30.0,
51
- value=9.0,
52
- step=0.1)
53
  seed = gr.Slider(label='Seed',
54
  minimum=-1,
55
  maximum=2147483647,
56
  step=1,
57
- randomize=True,
58
- queue=False)
59
- eta = gr.Number(label='eta (DDIM)', value=0.0)
60
  a_prompt = gr.Textbox(
61
  label='Added Prompt',
62
  value='best quality, extremely detailed')
@@ -66,17 +64,26 @@ def create_demo(process, max_images=12):
66
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
67
  )
68
  with gr.Column():
69
- result_gallery = gr.Gallery(label='Output',
70
- show_label=False,
71
- elem_id='gallery').style(
72
- grid=2, height='auto')
73
- ips = [
74
- input_image, prompt, a_prompt, n_prompt, num_samples,
75
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta,
76
- value_threshold, distance_threshold
 
 
 
 
 
 
 
 
 
77
  ]
78
  run_button.click(fn=process,
79
- inputs=ips,
80
- outputs=[result_gallery],
81
  api_name='hough')
82
  return demo
 
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
+ mlsd_value_threshold = gr.Slider(
32
  label='Hough value threshold (MLSD)',
33
  minimum=0.01,
34
  maximum=2.0,
35
  value=0.1,
36
  step=0.01)
37
+ mlsd_distance_threshold = gr.Slider(
38
  label='Hough distance threshold (MLSD)',
39
  minimum=0.01,
40
  maximum=20.0,
41
  value=0.1,
42
  step=0.01)
43
+ num_steps = gr.Slider(label='Steps',
44
+ minimum=1,
45
+ maximum=100,
46
+ value=20,
47
+ step=1)
48
+ guidance_scale = gr.Slider(label='Guidance Scale',
49
+ minimum=0.1,
50
+ maximum=30.0,
51
+ value=9.0,
52
+ step=0.1)
53
  seed = gr.Slider(label='Seed',
54
  minimum=-1,
55
  maximum=2147483647,
56
  step=1,
57
+ randomize=True)
 
 
58
  a_prompt = gr.Textbox(
59
  label='Added Prompt',
60
  value='best quality, extremely detailed')
 
64
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
65
  )
66
  with gr.Column():
67
+ result = gr.Gallery(label='Output',
68
+ show_label=False,
69
+ elem_id='gallery').style(grid=2,
70
+ height='auto')
71
+ inputs = [
72
+ input_image,
73
+ prompt,
74
+ a_prompt,
75
+ n_prompt,
76
+ num_samples,
77
+ image_resolution,
78
+ detect_resolution,
79
+ num_steps,
80
+ guidance_scale,
81
+ seed,
82
+ mlsd_value_threshold,
83
+ mlsd_distance_threshold,
84
  ]
85
  run_button.click(fn=process,
86
+ inputs=inputs,
87
+ outputs=result,
88
  api_name='hough')
89
  return demo
gradio_normal2image.py CHANGED
@@ -34,23 +34,21 @@ def create_demo(process, max_images=12):
34
  maximum=1.0,
35
  value=0.4,
36
  step=0.01)
37
- ddim_steps = gr.Slider(label='Steps',
38
- minimum=1,
39
- maximum=100,
40
- value=20,
41
- step=1)
42
- scale = gr.Slider(label='Guidance Scale',
43
- minimum=0.1,
44
- maximum=30.0,
45
- value=9.0,
46
- step=0.1)
47
  seed = gr.Slider(label='Seed',
48
  minimum=-1,
49
  maximum=2147483647,
50
  step=1,
51
- randomize=True,
52
- queue=False)
53
- eta = gr.Number(label='eta (DDIM)', value=0.0)
54
  a_prompt = gr.Textbox(
55
  label='Added Prompt',
56
  value='best quality, extremely detailed')
@@ -60,17 +58,25 @@ def create_demo(process, max_images=12):
60
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
61
  )
62
  with gr.Column():
63
- result_gallery = gr.Gallery(label='Output',
64
- show_label=False,
65
- elem_id='gallery').style(
66
- grid=2, height='auto')
67
- ips = [
68
- input_image, prompt, a_prompt, n_prompt, num_samples,
69
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta,
70
- bg_threshold
 
 
 
 
 
 
 
 
71
  ]
72
  run_button.click(fn=process,
73
- inputs=ips,
74
- outputs=[result_gallery],
75
  api_name='normal')
76
  return demo
 
34
  maximum=1.0,
35
  value=0.4,
36
  step=0.01)
37
+ num_steps = gr.Slider(label='Steps',
38
+ minimum=1,
39
+ maximum=100,
40
+ value=20,
41
+ step=1)
42
+ guidance_scale = gr.Slider(label='Guidance Scale',
43
+ minimum=0.1,
44
+ maximum=30.0,
45
+ value=9.0,
46
+ step=0.1)
47
  seed = gr.Slider(label='Seed',
48
  minimum=-1,
49
  maximum=2147483647,
50
  step=1,
51
+ randomize=True)
 
 
52
  a_prompt = gr.Textbox(
53
  label='Added Prompt',
54
  value='best quality, extremely detailed')
 
58
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
59
  )
60
  with gr.Column():
61
+ result = gr.Gallery(label='Output',
62
+ show_label=False,
63
+ elem_id='gallery').style(grid=2,
64
+ height='auto')
65
+ inputs = [
66
+ input_image,
67
+ prompt,
68
+ a_prompt,
69
+ n_prompt,
70
+ num_samples,
71
+ image_resolution,
72
+ detect_resolution,
73
+ num_steps,
74
+ guidance_scale,
75
+ seed,
76
+ bg_threshold,
77
  ]
78
  run_button.click(fn=process,
79
+ inputs=inputs,
80
+ outputs=result,
81
  api_name='normal')
82
  return demo
gradio_pose2image.py CHANGED
@@ -28,23 +28,21 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
- ddim_steps = gr.Slider(label='Steps',
32
- minimum=1,
33
- maximum=100,
34
- value=20,
35
- step=1)
36
- scale = gr.Slider(label='Guidance Scale',
37
- minimum=0.1,
38
- maximum=30.0,
39
- value=9.0,
40
- step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
- randomize=True,
46
- queue=False)
47
- eta = gr.Number(label='eta (DDIM)', value=0.0)
48
  a_prompt = gr.Textbox(
49
  label='Added Prompt',
50
  value='best quality, extremely detailed')
@@ -54,16 +52,24 @@ def create_demo(process, max_images=12):
54
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
55
  )
56
  with gr.Column():
57
- result_gallery = gr.Gallery(label='Output',
58
- show_label=False,
59
- elem_id='gallery').style(
60
- grid=2, height='auto')
61
- ips = [
62
- input_image, prompt, a_prompt, n_prompt, num_samples,
63
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
64
  ]
65
  run_button.click(fn=process,
66
- inputs=ips,
67
- outputs=[result_gallery],
68
  api_name='pose')
69
  return demo
 
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
+ num_steps = gr.Slider(label='Steps',
32
+ minimum=1,
33
+ maximum=100,
34
+ value=20,
35
+ step=1)
36
+ guidance_scale = gr.Slider(label='Guidance Scale',
37
+ minimum=0.1,
38
+ maximum=30.0,
39
+ value=9.0,
40
+ step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
+ randomize=True)
 
 
46
  a_prompt = gr.Textbox(
47
  label='Added Prompt',
48
  value='best quality, extremely detailed')
 
52
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
53
  )
54
  with gr.Column():
55
+ result = gr.Gallery(label='Output',
56
+ show_label=False,
57
+ elem_id='gallery').style(grid=2,
58
+ height='auto')
59
+ inputs = [
60
+ input_image,
61
+ prompt,
62
+ a_prompt,
63
+ n_prompt,
64
+ num_samples,
65
+ image_resolution,
66
+ detect_resolution,
67
+ num_steps,
68
+ guidance_scale,
69
+ seed,
70
  ]
71
  run_button.click(fn=process,
72
+ inputs=inputs,
73
+ outputs=result,
74
  api_name='pose')
75
  return demo
gradio_scribble2image.py CHANGED
@@ -23,23 +23,21 @@ def create_demo(process, max_images=12):
23
  maximum=768,
24
  value=512,
25
  step=256)
26
- ddim_steps = gr.Slider(label='Steps',
27
- minimum=1,
28
- maximum=100,
29
- value=20,
30
- step=1)
31
- scale = gr.Slider(label='Guidance Scale',
32
- minimum=0.1,
33
- maximum=30.0,
34
- value=9.0,
35
- step=0.1)
36
  seed = gr.Slider(label='Seed',
37
  minimum=-1,
38
  maximum=2147483647,
39
  step=1,
40
- randomize=True,
41
- queue=False)
42
- eta = gr.Number(label='eta (DDIM)', value=0.0)
43
  a_prompt = gr.Textbox(
44
  label='Added Prompt',
45
  value='best quality, extremely detailed')
@@ -49,16 +47,23 @@ def create_demo(process, max_images=12):
49
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
50
  )
51
  with gr.Column():
52
- result_gallery = gr.Gallery(label='Output',
53
- show_label=False,
54
- elem_id='gallery').style(
55
- grid=2, height='auto')
56
- ips = [
57
- input_image, prompt, a_prompt, n_prompt, num_samples,
58
- image_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
59
  ]
60
  run_button.click(fn=process,
61
- inputs=ips,
62
- outputs=[result_gallery],
63
  api_name='scribble')
64
  return demo
 
23
  maximum=768,
24
  value=512,
25
  step=256)
26
+ num_steps = gr.Slider(label='Steps',
27
+ minimum=1,
28
+ maximum=100,
29
+ value=20,
30
+ step=1)
31
+ guidance_scale = gr.Slider(label='Guidance Scale',
32
+ minimum=0.1,
33
+ maximum=30.0,
34
+ value=9.0,
35
+ step=0.1)
36
  seed = gr.Slider(label='Seed',
37
  minimum=-1,
38
  maximum=2147483647,
39
  step=1,
40
+ randomize=True)
 
 
41
  a_prompt = gr.Textbox(
42
  label='Added Prompt',
43
  value='best quality, extremely detailed')
 
47
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
48
  )
49
  with gr.Column():
50
+ result = gr.Gallery(label='Output',
51
+ show_label=False,
52
+ elem_id='gallery').style(grid=2,
53
+ height='auto')
54
+ inputs = [
55
+ input_image,
56
+ prompt,
57
+ a_prompt,
58
+ n_prompt,
59
+ num_samples,
60
+ image_resolution,
61
+ num_steps,
62
+ guidance_scale,
63
+ seed,
64
  ]
65
  run_button.click(fn=process,
66
+ inputs=inputs,
67
+ outputs=result,
68
  api_name='scribble')
69
  return demo
gradio_scribble2image_interactive.py CHANGED
@@ -37,7 +37,7 @@ def create_demo(process, max_images=12):
37
  )
38
  create_button.click(fn=create_canvas,
39
  inputs=[canvas_width, canvas_height],
40
- outputs=[input_image],
41
  queue=False)
42
  prompt = gr.Textbox(label='Prompt')
43
  run_button = gr.Button(label='Run')
@@ -52,23 +52,21 @@ def create_demo(process, max_images=12):
52
  maximum=768,
53
  value=512,
54
  step=256)
55
- ddim_steps = gr.Slider(label='Steps',
56
- minimum=1,
57
- maximum=100,
58
- value=20,
59
- step=1)
60
- scale = gr.Slider(label='Guidance Scale',
61
- minimum=0.1,
62
- maximum=30.0,
63
- value=9.0,
64
- step=0.1)
65
  seed = gr.Slider(label='Seed',
66
  minimum=-1,
67
  maximum=2147483647,
68
  step=1,
69
- randomize=True,
70
- queue=False)
71
- eta = gr.Number(label='eta (DDIM)', value=0.0)
72
  a_prompt = gr.Textbox(
73
  label='Added Prompt',
74
  value='best quality, extremely detailed')
@@ -78,13 +76,20 @@ def create_demo(process, max_images=12):
78
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
79
  )
80
  with gr.Column():
81
- result_gallery = gr.Gallery(label='Output',
82
- show_label=False,
83
- elem_id='gallery').style(
84
- grid=2, height='auto')
85
- ips = [
86
- input_image, prompt, a_prompt, n_prompt, num_samples,
87
- image_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
88
  ]
89
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
90
  return demo
 
37
  )
38
  create_button.click(fn=create_canvas,
39
  inputs=[canvas_width, canvas_height],
40
+ outputs=input_image,
41
  queue=False)
42
  prompt = gr.Textbox(label='Prompt')
43
  run_button = gr.Button(label='Run')
 
52
  maximum=768,
53
  value=512,
54
  step=256)
55
+ num_steps = gr.Slider(label='Steps',
56
+ minimum=1,
57
+ maximum=100,
58
+ value=20,
59
+ step=1)
60
+ guidance_scale = gr.Slider(label='Guidance Scale',
61
+ minimum=0.1,
62
+ maximum=30.0,
63
+ value=9.0,
64
+ step=0.1)
65
  seed = gr.Slider(label='Seed',
66
  minimum=-1,
67
  maximum=2147483647,
68
  step=1,
69
+ randomize=True)
 
 
70
  a_prompt = gr.Textbox(
71
  label='Added Prompt',
72
  value='best quality, extremely detailed')
 
76
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
77
  )
78
  with gr.Column():
79
+ result = gr.Gallery(label='Output',
80
+ show_label=False,
81
+ elem_id='gallery').style(grid=2,
82
+ height='auto')
83
+ inputs = [
84
+ input_image,
85
+ prompt,
86
+ a_prompt,
87
+ n_prompt,
88
+ num_samples,
89
+ image_resolution,
90
+ num_steps,
91
+ guidance_scale,
92
+ seed,
93
  ]
94
+ run_button.click(fn=process, inputs=inputs, outputs=result)
95
  return demo
gradio_seg2image.py CHANGED
@@ -29,23 +29,21 @@ def create_demo(process, max_images=12):
29
  maximum=1024,
30
  value=512,
31
  step=1)
32
- ddim_steps = gr.Slider(label='Steps',
33
- minimum=1,
34
- maximum=100,
35
- value=20,
36
- step=1)
37
- scale = gr.Slider(label='Guidance Scale',
38
- minimum=0.1,
39
- maximum=30.0,
40
- value=9.0,
41
- step=0.1)
42
  seed = gr.Slider(label='Seed',
43
  minimum=-1,
44
  maximum=2147483647,
45
  step=1,
46
- randomize=True,
47
- queue=False)
48
- eta = gr.Number(label='eta (DDIM)', value=0.0)
49
  a_prompt = gr.Textbox(
50
  label='Added Prompt',
51
  value='best quality, extremely detailed')
@@ -55,16 +53,24 @@ def create_demo(process, max_images=12):
55
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
56
  )
57
  with gr.Column():
58
- result_gallery = gr.Gallery(label='Output',
59
- show_label=False,
60
- elem_id='gallery').style(
61
- grid=2, height='auto')
62
- ips = [
63
- input_image, prompt, a_prompt, n_prompt, num_samples,
64
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
65
  ]
66
  run_button.click(fn=process,
67
- inputs=ips,
68
- outputs=[result_gallery],
69
  api_name='seg')
70
  return demo
 
29
  maximum=1024,
30
  value=512,
31
  step=1)
32
+ num_steps = gr.Slider(label='Steps',
33
+ minimum=1,
34
+ maximum=100,
35
+ value=20,
36
+ step=1)
37
+ guidance_scale = gr.Slider(label='Guidance Scale',
38
+ minimum=0.1,
39
+ maximum=30.0,
40
+ value=9.0,
41
+ step=0.1)
42
  seed = gr.Slider(label='Seed',
43
  minimum=-1,
44
  maximum=2147483647,
45
  step=1,
46
+ randomize=True)
 
 
47
  a_prompt = gr.Textbox(
48
  label='Added Prompt',
49
  value='best quality, extremely detailed')
 
53
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
54
  )
55
  with gr.Column():
56
+ result = gr.Gallery(label='Output',
57
+ show_label=False,
58
+ elem_id='gallery').style(grid=2,
59
+ height='auto')
60
+ inputs = [
61
+ input_image,
62
+ prompt,
63
+ a_prompt,
64
+ n_prompt,
65
+ num_samples,
66
+ image_resolution,
67
+ detect_resolution,
68
+ num_steps,
69
+ guidance_scale,
70
+ seed,
71
  ]
72
  run_button.click(fn=process,
73
+ inputs=inputs,
74
+ outputs=result,
75
  api_name='seg')
76
  return demo
model.py CHANGED
@@ -3,20 +3,20 @@
3
  from __future__ import annotations
4
 
5
  import pathlib
6
- import random
7
- import shlex
8
- import subprocess
9
  import sys
10
 
11
  import cv2
12
- import einops
13
  import numpy as np
 
14
  import torch
15
- from pytorch_lightning import seed_everything
 
 
16
 
17
- sys.path.append('ControlNet')
 
 
18
 
19
- import config
20
  from annotator.canny import apply_canny
21
  from annotator.hed import apply_hed, nms
22
  from annotator.midas import apply_midas
@@ -24,743 +24,594 @@ from annotator.mlsd import apply_mlsd
24
  from annotator.openpose import apply_openpose
25
  from annotator.uniformer import apply_uniformer
26
  from annotator.util import HWC3, resize_image
27
- from cldm.model import create_model, load_state_dict
28
- from ldm.models.diffusion.ddim import DDIMSampler
29
  from share import *
30
 
31
- ORIGINAL_MODEL_NAMES = {
32
- 'canny': 'control_sd15_canny.pth',
33
- 'hough': 'control_sd15_mlsd.pth',
34
- 'hed': 'control_sd15_hed.pth',
35
- 'scribble': 'control_sd15_scribble.pth',
36
- 'pose': 'control_sd15_openpose.pth',
37
- 'seg': 'control_sd15_seg.pth',
38
- 'depth': 'control_sd15_depth.pth',
39
- 'normal': 'control_sd15_normal.pth',
40
  }
41
- ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
42
-
43
- LIGHTWEIGHT_MODEL_NAMES = {
44
- 'canny': 'control_canny-fp16.safetensors',
45
- 'hough': 'control_mlsd-fp16.safetensors',
46
- 'hed': 'control_hed-fp16.safetensors',
47
- 'scribble': 'control_scribble-fp16.safetensors',
48
- 'pose': 'control_openpose-fp16.safetensors',
49
- 'seg': 'control_seg-fp16.safetensors',
50
- 'depth': 'control_depth-fp16.safetensors',
51
- 'normal': 'control_normal-fp16.safetensors',
52
- }
53
- LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
54
 
55
 
56
  class Model:
57
- def __init__(self,
58
- model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
59
- model_dir: str = 'models',
60
- use_lightweight: bool = True):
61
- self.device = torch.device(
62
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
63
- self.model = create_model(model_config_path).to(self.device)
64
- self.ddim_sampler = DDIMSampler(self.model)
65
- self.task_name = ''
66
-
67
- self.model_dir = pathlib.Path(model_dir)
68
- self.model_dir.mkdir(exist_ok=True, parents=True)
69
-
70
- self.use_lightweight = use_lightweight
71
- if use_lightweight:
72
- self.model_names = LIGHTWEIGHT_MODEL_NAMES
73
- self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
74
- base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
75
- self.load_base_model(base_model_url)
76
- else:
77
- self.model_names = ORIGINAL_MODEL_NAMES
78
- self.weight_root = ORIGINAL_WEIGHT_ROOT
79
-
80
- self.download_models()
81
-
82
- def download_base_model(self, model_url: str) -> pathlib.Path:
83
- model_name = model_url.split('/')[-1]
84
- out_path = self.model_dir / model_name
85
- if not out_path.exists():
86
- subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
87
- return out_path
88
-
89
- def load_base_model(self, model_url: str) -> None:
90
- model_path = self.download_base_model(model_url)
91
- self.model.load_state_dict(load_state_dict(model_path,
92
- location=self.device.type),
93
- strict=False)
94
-
95
- def load_weight(self, task_name: str) -> None:
96
  if task_name == self.task_name:
97
  return
98
- weight_path = self.get_weight_path(task_name)
99
- if not self.use_lightweight:
100
- self.model.load_state_dict(
101
- load_state_dict(weight_path, location=self.device))
102
- else:
103
- self.model.control_model.load_state_dict(
104
- load_state_dict(weight_path, location=self.device.type))
105
  self.task_name = task_name
106
 
107
- def get_weight_path(self, task_name: str) -> str:
108
- if 'scribble' in task_name:
109
- task_name = 'scribble'
110
- return f'{self.model_dir}/{self.model_names[task_name]}'
111
-
112
- def download_models(self) -> None:
113
- self.model_dir.mkdir(exist_ok=True, parents=True)
114
- for name in self.model_names.values():
115
- out_path = self.model_dir / name
116
- if out_path.exists():
117
- continue
118
- subprocess.run(
119
- shlex.split(f'wget {self.weight_root}{name} -O {out_path}'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  @torch.inference_mode()
122
- def process_canny(self, input_image, prompt, a_prompt, n_prompt,
123
- num_samples, image_resolution, ddim_steps, scale, seed,
124
- eta, low_threshold, high_threshold):
125
- self.load_weight('canny')
126
-
127
- img = resize_image(HWC3(input_image), image_resolution)
128
- H, W, C = img.shape
129
-
130
- detected_map = apply_canny(img, low_threshold, high_threshold)
131
- detected_map = HWC3(detected_map)
132
-
133
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
134
- control = torch.stack([control for _ in range(num_samples)], dim=0)
135
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
136
-
137
- if seed == -1:
138
- seed = random.randint(0, 65535)
139
- seed_everything(seed)
140
-
141
- if config.save_memory:
142
- self.model.low_vram_shift(is_diffusing=False)
143
-
144
- cond = {
145
- 'c_concat': [control],
146
- 'c_crossattn': [
147
- self.model.get_learned_conditioning(
148
- [prompt + ', ' + a_prompt] * num_samples)
149
- ]
150
- }
151
- un_cond = {
152
- 'c_concat': [control],
153
- 'c_crossattn':
154
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
155
- }
156
- shape = (4, H // 8, W // 8)
157
-
158
- if config.save_memory:
159
- self.model.low_vram_shift(is_diffusing=True)
160
-
161
- samples, intermediates = self.ddim_sampler.sample(
162
- ddim_steps,
163
- num_samples,
164
- shape,
165
- cond,
166
- verbose=False,
167
- eta=eta,
168
- unconditional_guidance_scale=scale,
169
- unconditional_conditioning=un_cond)
170
-
171
- if config.save_memory:
172
- self.model.low_vram_shift(is_diffusing=False)
173
-
174
- x_samples = self.model.decode_first_stage(samples)
175
- x_samples = (
176
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
177
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
178
-
179
- results = [x_samples[i] for i in range(num_samples)]
180
- return [255 - detected_map] + results
181
 
182
- @torch.inference_mode()
183
- def process_hough(self, input_image, prompt, a_prompt, n_prompt,
184
- num_samples, image_resolution, detect_resolution,
185
- ddim_steps, scale, seed, eta, value_threshold,
186
- distance_threshold):
187
- self.load_weight('hough')
188
 
189
- input_image = HWC3(input_image)
190
- detected_map = apply_mlsd(resize_image(input_image, detect_resolution),
191
- value_threshold, distance_threshold)
192
- detected_map = HWC3(detected_map)
193
- img = resize_image(input_image, image_resolution)
194
- H, W, C = img.shape
195
-
196
- detected_map = cv2.resize(detected_map, (W, H),
197
- interpolation=cv2.INTER_NEAREST)
198
-
199
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
200
- control = torch.stack([control for _ in range(num_samples)], dim=0)
201
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
202
-
203
- if seed == -1:
204
- seed = random.randint(0, 65535)
205
- seed_everything(seed)
206
-
207
- if config.save_memory:
208
- self.model.low_vram_shift(is_diffusing=False)
209
-
210
- cond = {
211
- 'c_concat': [control],
212
- 'c_crossattn': [
213
- self.model.get_learned_conditioning(
214
- [prompt + ', ' + a_prompt] * num_samples)
215
- ]
216
- }
217
- un_cond = {
218
- 'c_concat': [control],
219
- 'c_crossattn':
220
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
221
- }
222
- shape = (4, H // 8, W // 8)
223
-
224
- if config.save_memory:
225
- self.model.low_vram_shift(is_diffusing=True)
226
-
227
- samples, intermediates = self.ddim_sampler.sample(
228
- ddim_steps,
229
- num_samples,
230
- shape,
231
- cond,
232
- verbose=False,
233
- eta=eta,
234
- unconditional_guidance_scale=scale,
235
- unconditional_conditioning=un_cond)
236
-
237
- if config.save_memory:
238
- self.model.low_vram_shift(is_diffusing=False)
239
-
240
- x_samples = self.model.decode_first_stage(samples)
241
- x_samples = (
242
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
243
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
244
-
245
- results = [x_samples[i] for i in range(num_samples)]
246
- return [
247
- 255 - cv2.dilate(detected_map,
248
- np.ones(shape=(3, 3), dtype=np.uint8),
249
- iterations=1)
250
- ] + results
251
 
252
  @torch.inference_mode()
253
- def process_hed(self, input_image, prompt, a_prompt, n_prompt, num_samples,
254
- image_resolution, detect_resolution, ddim_steps, scale,
255
- seed, eta):
256
- self.load_weight('hed')
257
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  input_image = HWC3(input_image)
259
- detected_map = apply_hed(resize_image(input_image, detect_resolution))
260
- detected_map = HWC3(detected_map)
261
- img = resize_image(input_image, image_resolution)
262
- H, W, C = img.shape
263
-
264
- detected_map = cv2.resize(detected_map, (W, H),
265
- interpolation=cv2.INTER_LINEAR)
266
-
267
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
268
- control = torch.stack([control for _ in range(num_samples)], dim=0)
269
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
270
-
271
- if seed == -1:
272
- seed = random.randint(0, 65535)
273
- seed_everything(seed)
274
-
275
- if config.save_memory:
276
- self.model.low_vram_shift(is_diffusing=False)
277
-
278
- cond = {
279
- 'c_concat': [control],
280
- 'c_crossattn': [
281
- self.model.get_learned_conditioning(
282
- [prompt + ', ' + a_prompt] * num_samples)
283
- ]
284
- }
285
- un_cond = {
286
- 'c_concat': [control],
287
- 'c_crossattn':
288
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
289
- }
290
- shape = (4, H // 8, W // 8)
291
-
292
- if config.save_memory:
293
- self.model.low_vram_shift(is_diffusing=True)
294
-
295
- samples, intermediates = self.ddim_sampler.sample(
296
- ddim_steps,
297
- num_samples,
298
- shape,
299
- cond,
300
- verbose=False,
301
- eta=eta,
302
- unconditional_guidance_scale=scale,
303
- unconditional_conditioning=un_cond)
304
-
305
- if config.save_memory:
306
- self.model.low_vram_shift(is_diffusing=False)
307
-
308
- x_samples = self.model.decode_first_stage(samples)
309
- x_samples = (
310
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
311
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
312
-
313
- results = [x_samples[i] for i in range(num_samples)]
314
- return [detected_map] + results
315
 
316
  @torch.inference_mode()
317
- def process_scribble(self, input_image, prompt, a_prompt, n_prompt,
318
- num_samples, image_resolution, ddim_steps, scale,
319
- seed, eta):
320
- self.load_weight('scribble')
321
-
322
- img = resize_image(HWC3(input_image), image_resolution)
323
- H, W, C = img.shape
324
-
325
- detected_map = np.zeros_like(img, dtype=np.uint8)
326
- detected_map[np.min(img, axis=2) < 127] = 255
327
-
328
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
329
- control = torch.stack([control for _ in range(num_samples)], dim=0)
330
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
331
-
332
- if seed == -1:
333
- seed = random.randint(0, 65535)
334
- seed_everything(seed)
335
-
336
- if config.save_memory:
337
- self.model.low_vram_shift(is_diffusing=False)
338
-
339
- cond = {
340
- 'c_concat': [control],
341
- 'c_crossattn': [
342
- self.model.get_learned_conditioning(
343
- [prompt + ', ' + a_prompt] * num_samples)
344
- ]
345
- }
346
- un_cond = {
347
- 'c_concat': [control],
348
- 'c_crossattn':
349
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
350
- }
351
- shape = (4, H // 8, W // 8)
352
-
353
- if config.save_memory:
354
- self.model.low_vram_shift(is_diffusing=True)
355
-
356
- samples, intermediates = self.ddim_sampler.sample(
357
- ddim_steps,
358
- num_samples,
359
- shape,
360
- cond,
361
- verbose=False,
362
- eta=eta,
363
- unconditional_guidance_scale=scale,
364
- unconditional_conditioning=un_cond)
365
-
366
- if config.save_memory:
367
- self.model.low_vram_shift(is_diffusing=False)
368
-
369
- x_samples = self.model.decode_first_stage(samples)
370
- x_samples = (
371
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
372
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
373
-
374
- results = [x_samples[i] for i in range(num_samples)]
375
- return [255 - detected_map] + results
376
 
377
  @torch.inference_mode()
378
- def process_scribble_interactive(self, input_image, prompt, a_prompt,
379
- n_prompt, num_samples, image_resolution,
380
- ddim_steps, scale, seed, eta):
381
- self.load_weight('scribble')
382
-
383
- img = resize_image(HWC3(input_image['mask'][:, :, 0]),
384
- image_resolution)
385
- H, W, C = img.shape
386
-
387
- detected_map = np.zeros_like(img, dtype=np.uint8)
388
- detected_map[np.min(img, axis=2) > 127] = 255
389
-
390
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
391
- control = torch.stack([control for _ in range(num_samples)], dim=0)
392
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
393
-
394
- if seed == -1:
395
- seed = random.randint(0, 65535)
396
- seed_everything(seed)
397
-
398
- if config.save_memory:
399
- self.model.low_vram_shift(is_diffusing=False)
400
-
401
- cond = {
402
- 'c_concat': [control],
403
- 'c_crossattn': [
404
- self.model.get_learned_conditioning(
405
- [prompt + ', ' + a_prompt] * num_samples)
406
- ]
407
- }
408
- un_cond = {
409
- 'c_concat': [control],
410
- 'c_crossattn':
411
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
412
- }
413
- shape = (4, H // 8, W // 8)
414
-
415
- if config.save_memory:
416
- self.model.low_vram_shift(is_diffusing=True)
417
-
418
- samples, intermediates = self.ddim_sampler.sample(
419
- ddim_steps,
420
- num_samples,
421
- shape,
422
- cond,
423
- verbose=False,
424
- eta=eta,
425
- unconditional_guidance_scale=scale,
426
- unconditional_conditioning=un_cond)
427
-
428
- if config.save_memory:
429
- self.model.low_vram_shift(is_diffusing=False)
430
-
431
- x_samples = self.model.decode_first_stage(samples)
432
- x_samples = (
433
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
434
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
435
-
436
- results = [x_samples[i] for i in range(num_samples)]
437
- return [255 - detected_map] + results
438
 
439
  @torch.inference_mode()
440
- def process_fake_scribble(self, input_image, prompt, a_prompt, n_prompt,
441
- num_samples, image_resolution, detect_resolution,
442
- ddim_steps, scale, seed, eta):
443
- self.load_weight('scribble')
444
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  input_image = HWC3(input_image)
446
- detected_map = apply_hed(resize_image(input_image, detect_resolution))
447
- detected_map = HWC3(detected_map)
448
- img = resize_image(input_image, image_resolution)
449
- H, W, C = img.shape
450
-
451
- detected_map = cv2.resize(detected_map, (W, H),
452
- interpolation=cv2.INTER_LINEAR)
453
- detected_map = nms(detected_map, 127, 3.0)
454
- detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
455
- detected_map[detected_map > 4] = 255
456
- detected_map[detected_map < 255] = 0
457
-
458
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
459
- control = torch.stack([control for _ in range(num_samples)], dim=0)
460
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
461
-
462
- if seed == -1:
463
- seed = random.randint(0, 65535)
464
- seed_everything(seed)
465
-
466
- if config.save_memory:
467
- self.model.low_vram_shift(is_diffusing=False)
468
-
469
- cond = {
470
- 'c_concat': [control],
471
- 'c_crossattn': [
472
- self.model.get_learned_conditioning(
473
- [prompt + ', ' + a_prompt] * num_samples)
474
- ]
475
- }
476
- un_cond = {
477
- 'c_concat': [control],
478
- 'c_crossattn':
479
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
480
- }
481
- shape = (4, H // 8, W // 8)
482
-
483
- if config.save_memory:
484
- self.model.low_vram_shift(is_diffusing=True)
485
-
486
- samples, intermediates = self.ddim_sampler.sample(
487
- ddim_steps,
488
- num_samples,
489
- shape,
490
- cond,
491
- verbose=False,
492
- eta=eta,
493
- unconditional_guidance_scale=scale,
494
- unconditional_conditioning=un_cond)
495
-
496
- if config.save_memory:
497
- self.model.low_vram_shift(is_diffusing=False)
498
-
499
- x_samples = self.model.decode_first_stage(samples)
500
- x_samples = (
501
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
502
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
503
-
504
- results = [x_samples[i] for i in range(num_samples)]
505
- return [255 - detected_map] + results
506
 
507
- @torch.inference_mode()
508
- def process_pose(self, input_image, prompt, a_prompt, n_prompt,
509
- num_samples, image_resolution, detect_resolution,
510
- ddim_steps, scale, seed, eta):
511
- self.load_weight('pose')
 
 
 
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  input_image = HWC3(input_image)
514
- detected_map, _ = apply_openpose(
515
  resize_image(input_image, detect_resolution))
516
- detected_map = HWC3(detected_map)
517
- img = resize_image(input_image, image_resolution)
518
- H, W, C = img.shape
519
-
520
- detected_map = cv2.resize(detected_map, (W, H),
521
- interpolation=cv2.INTER_NEAREST)
522
-
523
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
524
- control = torch.stack([control for _ in range(num_samples)], dim=0)
525
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
526
-
527
- if seed == -1:
528
- seed = random.randint(0, 65535)
529
- seed_everything(seed)
530
-
531
- if config.save_memory:
532
- self.model.low_vram_shift(is_diffusing=False)
533
-
534
- cond = {
535
- 'c_concat': [control],
536
- 'c_crossattn': [
537
- self.model.get_learned_conditioning(
538
- [prompt + ', ' + a_prompt] * num_samples)
539
- ]
540
- }
541
- un_cond = {
542
- 'c_concat': [control],
543
- 'c_crossattn':
544
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
545
- }
546
- shape = (4, H // 8, W // 8)
547
-
548
- if config.save_memory:
549
- self.model.low_vram_shift(is_diffusing=True)
550
-
551
- samples, intermediates = self.ddim_sampler.sample(
552
- ddim_steps,
553
- num_samples,
554
- shape,
555
- cond,
556
- verbose=False,
557
- eta=eta,
558
- unconditional_guidance_scale=scale,
559
- unconditional_conditioning=un_cond)
560
-
561
- if config.save_memory:
562
- self.model.low_vram_shift(is_diffusing=False)
563
-
564
- x_samples = self.model.decode_first_stage(samples)
565
- x_samples = (
566
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
567
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
568
-
569
- results = [x_samples[i] for i in range(num_samples)]
570
- return [detected_map] + results
571
 
572
- @torch.inference_mode()
573
- def process_seg(self, input_image, prompt, a_prompt, n_prompt, num_samples,
574
- image_resolution, detect_resolution, ddim_steps, scale,
575
- seed, eta):
576
- self.load_weight('seg')
577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
  input_image = HWC3(input_image)
579
- detected_map = apply_uniformer(
580
  resize_image(input_image, detect_resolution))
581
- img = resize_image(input_image, image_resolution)
582
- H, W, C = img.shape
583
-
584
- detected_map = cv2.resize(detected_map, (W, H),
585
- interpolation=cv2.INTER_NEAREST)
586
-
587
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
588
- control = torch.stack([control for _ in range(num_samples)], dim=0)
589
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
590
-
591
- if seed == -1:
592
- seed = random.randint(0, 65535)
593
- seed_everything(seed)
594
-
595
- if config.save_memory:
596
- self.model.low_vram_shift(is_diffusing=False)
597
-
598
- cond = {
599
- 'c_concat': [control],
600
- 'c_crossattn': [
601
- self.model.get_learned_conditioning(
602
- [prompt + ', ' + a_prompt] * num_samples)
603
- ]
604
- }
605
- un_cond = {
606
- 'c_concat': [control],
607
- 'c_crossattn':
608
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
609
- }
610
- shape = (4, H // 8, W // 8)
611
-
612
- if config.save_memory:
613
- self.model.low_vram_shift(is_diffusing=True)
614
-
615
- samples, intermediates = self.ddim_sampler.sample(
616
- ddim_steps,
617
- num_samples,
618
- shape,
619
- cond,
620
- verbose=False,
621
- eta=eta,
622
- unconditional_guidance_scale=scale,
623
- unconditional_conditioning=un_cond)
624
-
625
- if config.save_memory:
626
- self.model.low_vram_shift(is_diffusing=False)
627
-
628
- x_samples = self.model.decode_first_stage(samples)
629
- x_samples = (
630
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
631
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
632
-
633
- results = [x_samples[i] for i in range(num_samples)]
634
- return [detected_map] + results
635
 
636
  @torch.inference_mode()
637
- def process_depth(self, input_image, prompt, a_prompt, n_prompt,
638
- num_samples, image_resolution, detect_resolution,
639
- ddim_steps, scale, seed, eta):
640
- self.load_weight('depth')
641
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  input_image = HWC3(input_image)
643
- detected_map, _ = apply_midas(
644
  resize_image(input_image, detect_resolution))
645
- detected_map = HWC3(detected_map)
646
- img = resize_image(input_image, image_resolution)
647
- H, W, C = img.shape
648
-
649
- detected_map = cv2.resize(detected_map, (W, H),
650
- interpolation=cv2.INTER_LINEAR)
651
-
652
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
653
- control = torch.stack([control for _ in range(num_samples)], dim=0)
654
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
655
-
656
- if seed == -1:
657
- seed = random.randint(0, 65535)
658
- seed_everything(seed)
659
-
660
- if config.save_memory:
661
- self.model.low_vram_shift(is_diffusing=False)
662
-
663
- cond = {
664
- 'c_concat': [control],
665
- 'c_crossattn': [
666
- self.model.get_learned_conditioning(
667
- [prompt + ', ' + a_prompt] * num_samples)
668
- ]
669
- }
670
- un_cond = {
671
- 'c_concat': [control],
672
- 'c_crossattn':
673
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
674
- }
675
- shape = (4, H // 8, W // 8)
676
-
677
- if config.save_memory:
678
- self.model.low_vram_shift(is_diffusing=True)
679
-
680
- samples, intermediates = self.ddim_sampler.sample(
681
- ddim_steps,
682
- num_samples,
683
- shape,
684
- cond,
685
- verbose=False,
686
- eta=eta,
687
- unconditional_guidance_scale=scale,
688
- unconditional_conditioning=un_cond)
689
-
690
- if config.save_memory:
691
- self.model.low_vram_shift(is_diffusing=False)
692
-
693
- x_samples = self.model.decode_first_stage(samples)
694
- x_samples = (
695
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
696
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
697
-
698
- results = [x_samples[i] for i in range(num_samples)]
699
- return [detected_map] + results
700
 
701
  @torch.inference_mode()
702
- def process_normal(self, input_image, prompt, a_prompt, n_prompt,
703
- num_samples, image_resolution, detect_resolution,
704
- ddim_steps, scale, seed, eta, bg_threshold):
705
- self.load_weight('normal')
706
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  input_image = HWC3(input_image)
708
- _, detected_map = apply_midas(resize_image(input_image,
709
- detect_resolution),
710
- bg_th=bg_threshold)
711
- detected_map = HWC3(detected_map)
712
- img = resize_image(input_image, image_resolution)
713
- H, W, C = img.shape
714
-
715
- detected_map = cv2.resize(detected_map, (W, H),
716
- interpolation=cv2.INTER_LINEAR)
717
-
718
- control = torch.from_numpy(
719
- detected_map[:, :, ::-1].copy()).float().cuda() / 255.0
720
- control = torch.stack([control for _ in range(num_samples)], dim=0)
721
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
722
-
723
- if seed == -1:
724
- seed = random.randint(0, 65535)
725
- seed_everything(seed)
726
-
727
- if config.save_memory:
728
- self.model.low_vram_shift(is_diffusing=False)
729
-
730
- cond = {
731
- 'c_concat': [control],
732
- 'c_crossattn': [
733
- self.model.get_learned_conditioning(
734
- [prompt + ', ' + a_prompt] * num_samples)
735
- ]
736
- }
737
- un_cond = {
738
- 'c_concat': [control],
739
- 'c_crossattn':
740
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
741
- }
742
- shape = (4, H // 8, W // 8)
743
-
744
- if config.save_memory:
745
- self.model.low_vram_shift(is_diffusing=True)
746
-
747
- samples, intermediates = self.ddim_sampler.sample(
748
- ddim_steps,
749
- num_samples,
750
- shape,
751
- cond,
752
- verbose=False,
753
- eta=eta,
754
- unconditional_guidance_scale=scale,
755
- unconditional_conditioning=un_cond)
756
-
757
- if config.save_memory:
758
- self.model.low_vram_shift(is_diffusing=False)
759
-
760
- x_samples = self.model.decode_first_stage(samples)
761
- x_samples = (
762
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
763
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
764
-
765
- results = [x_samples[i] for i in range(num_samples)]
766
- return [detected_map] + results
 
3
  from __future__ import annotations
4
 
5
  import pathlib
 
 
 
6
  import sys
7
 
8
  import cv2
 
9
  import numpy as np
10
+ import PIL.Image
11
  import torch
12
+ from diffusers import (ControlNetModel, DiffusionPipeline,
13
+ StableDiffusionControlNetPipeline,
14
+ UniPCMultistepScheduler)
15
 
16
+ repo_dir = pathlib.Path(__file__).parent
17
+ submodule_dir = repo_dir / 'ControlNet'
18
+ sys.path.append(submodule_dir.as_posix())
19
 
 
20
  from annotator.canny import apply_canny
21
  from annotator.hed import apply_hed, nms
22
  from annotator.midas import apply_midas
 
24
  from annotator.openpose import apply_openpose
25
  from annotator.uniformer import apply_uniformer
26
  from annotator.util import HWC3, resize_image
 
 
27
  from share import *
28
 
29
+ CONTROLNET_MODEL_IDS = {
30
+ 'canny': 'lllyasviel/sd-controlnet-canny',
31
+ 'hough': 'lllyasviel/sd-controlnet-mlsd',
32
+ 'hed': 'lllyasviel/sd-controlnet-hed',
33
+ 'scribble': 'lllyasviel/sd-controlnet-scribble',
34
+ 'pose': 'lllyasviel/sd-controlnet-openpose',
35
+ 'seg': 'lllyasviel/sd-controlnet-seg',
36
+ 'depth': 'lllyasviel/sd-controlnet-depth',
37
+ 'normal': 'lllyasviel/sd-controlnet-normal',
38
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  class Model:
42
+ def __init__(self):
43
+ # FIXME
44
+ self.base_model_id = 'andite/anything-v4.0'
45
+ self.task_name = 'pose'
46
+ self.pipe = self.load_pipe()
47
+
48
+ def load_pipe(self) -> DiffusionPipeline:
49
+ model_id = CONTROLNET_MODEL_IDS[self.task_name]
50
+ controlnet = ControlNetModel.from_pretrained(model_id,
51
+ torch_dtype=torch.float16)
52
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
53
+ self.base_model_id,
54
+ safety_checker=None,
55
+ controlnet=controlnet,
56
+ torch_dtype=torch.float16)
57
+ pipe.scheduler = UniPCMultistepScheduler.from_config(
58
+ pipe.scheduler.config)
59
+ pipe.enable_xformers_memory_efficient_attention()
60
+ pipe.enable_model_cpu_offload()
61
+ return pipe
62
+
63
+ def load_controlnet_weight(self, task_name: str) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  if task_name == self.task_name:
65
  return
66
+ model_id = CONTROLNET_MODEL_IDS[task_name]
67
+ controlnet = ControlNetModel.from_pretrained(model_id,
68
+ torch_dtype=torch.float16)
69
+ from accelerate import cpu_offload_with_hook
70
+ cpu_offload_with_hook(controlnet, torch.device('cuda:0'))
71
+ self.pipe.controlnet = controlnet
 
72
  self.task_name = task_name
73
 
74
+ def get_prompt(self, prompt: str, additional_prompt: str) -> str:
75
+ if not prompt:
76
+ prompt = additional_prompt
77
+ else:
78
+ prompt = f'{prompt}, {additional_prompt}'
79
+ return prompt
80
+
81
+ def run_pipe(
82
+ self,
83
+ prompt: str,
84
+ negative_prompt: str,
85
+ control_image: PIL.Image.Image,
86
+ num_images: int,
87
+ num_steps: int,
88
+ guidance_scale: float,
89
+ seed: int,
90
+ ):
91
+ generator = torch.Generator().manual_seed(seed)
92
+ return self.pipe(prompt=prompt,
93
+ negative_prompt=negative_prompt,
94
+ guidance_scale=guidance_scale,
95
+ num_images_per_prompt=num_images,
96
+ num_inference_steps=num_steps,
97
+ generator=generator,
98
+ image=control_image)
99
+
100
+ def process(
101
+ self,
102
+ task_name: str,
103
+ prompt: str,
104
+ additional_prompt: str,
105
+ negative_prompt: str,
106
+ control_image: PIL.Image.Image,
107
+ vis_control_image: PIL.Image.Image,
108
+ num_samples: int,
109
+ num_steps: int,
110
+ guidance_scale: float,
111
+ seed: int,
112
+ ):
113
+ self.load_controlnet_weight(task_name)
114
+ results = self.run_pipe(
115
+ prompt=self.get_prompt(prompt, additional_prompt),
116
+ negative_prompt=negative_prompt,
117
+ control_image=control_image,
118
+ num_images=num_samples,
119
+ num_steps=num_steps,
120
+ guidance_scale=guidance_scale,
121
+ seed=seed,
122
+ )
123
+ return [vis_control_image] + results.images
124
+
125
+ def preprocess_canny(
126
+ self,
127
+ input_image: np.ndarray,
128
+ image_resolution: int,
129
+ low_threshold: int,
130
+ high_threshold: int,
131
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
132
+ image = resize_image(HWC3(input_image), image_resolution)
133
+ control_image = apply_canny(image, low_threshold, high_threshold)
134
+ control_image = HWC3(control_image)
135
+ vis_control_image = 255 - control_image
136
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
137
+ vis_control_image)
138
 
139
  @torch.inference_mode()
140
+ def process_canny(
141
+ self,
142
+ input_image: np.ndarray,
143
+ prompt: str,
144
+ additional_prompt: str,
145
+ negative_prompt: str,
146
+ num_samples: int,
147
+ image_resolution: int,
148
+ num_steps: int,
149
+ guidance_scale: float,
150
+ seed: int,
151
+ low_threshold: int,
152
+ high_threshold: int,
153
+ ) -> list[PIL.Image.Image]:
154
+ control_image, vis_control_image = self.preprocess_canny(
155
+ input_image=input_image,
156
+ image_resolution=image_resolution,
157
+ low_threshold=low_threshold,
158
+ high_threshold=high_threshold,
159
+ )
160
+ return self.process(
161
+ task_name='canny',
162
+ prompt=prompt,
163
+ additional_prompt=additional_prompt,
164
+ negative_prompt=negative_prompt,
165
+ control_image=control_image,
166
+ vis_control_image=vis_control_image,
167
+ num_samples=num_samples,
168
+ num_steps=num_steps,
169
+ guidance_scale=guidance_scale,
170
+ seed=seed,
171
+ )
172
+
173
+ def preprocess_hough(
174
+ self,
175
+ input_image: np.ndarray,
176
+ image_resolution: int,
177
+ detect_resolution: int,
178
+ value_threshold: float,
179
+ distance_threshold: float,
180
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
181
+ input_image = HWC3(input_image)
182
+ control_image = apply_mlsd(
183
+ resize_image(input_image, detect_resolution), value_threshold,
184
+ distance_threshold)
185
+ control_image = HWC3(control_image)
186
+ image = resize_image(input_image, image_resolution)
187
+ H, W = image.shape[:2]
188
+ control_image = cv2.resize(control_image, (W, H),
189
+ interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
 
 
 
190
 
191
+ vis_control_image = 255 - cv2.dilate(
192
+ control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
 
 
 
 
193
 
194
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
195
+ vis_control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  @torch.inference_mode()
198
+ def process_hough(
199
+ self,
200
+ input_image: np.ndarray,
201
+ prompt: str,
202
+ additional_prompt: str,
203
+ negative_prompt: str,
204
+ num_samples: int,
205
+ image_resolution: int,
206
+ detect_resolution: int,
207
+ num_steps: int,
208
+ guidance_scale: float,
209
+ seed: int,
210
+ value_threshold: float,
211
+ distance_threshold: float,
212
+ ) -> list[PIL.Image.Image]:
213
+ control_image, vis_control_image = self.preprocess_hough(
214
+ input_image=input_image,
215
+ image_resolution=image_resolution,
216
+ detect_resolution=detect_resolution,
217
+ value_threshold=value_threshold,
218
+ distance_threshold=distance_threshold,
219
+ )
220
+ return self.process(
221
+ task_name='hough',
222
+ prompt=prompt,
223
+ additional_prompt=additional_prompt,
224
+ negative_prompt=negative_prompt,
225
+ control_image=control_image,
226
+ vis_control_image=vis_control_image,
227
+ num_samples=num_samples,
228
+ num_steps=num_steps,
229
+ guidance_scale=guidance_scale,
230
+ seed=seed,
231
+ )
232
+
233
+ def preprocess_hed(
234
+ self,
235
+ input_image: np.ndarray,
236
+ image_resolution: int,
237
+ detect_resolution: int,
238
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
239
  input_image = HWC3(input_image)
240
+ control_image = apply_hed(resize_image(input_image, detect_resolution))
241
+ control_image = HWC3(control_image)
242
+ image = resize_image(input_image, image_resolution)
243
+ H, W = image.shape[:2]
244
+ control_image = cv2.resize(control_image, (W, H),
245
+ interpolation=cv2.INTER_LINEAR)
246
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
247
+ control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  @torch.inference_mode()
250
+ def process_hed(
251
+ self,
252
+ input_image: np.ndarray,
253
+ prompt: str,
254
+ additional_prompt: str,
255
+ negative_prompt: str,
256
+ num_samples: int,
257
+ image_resolution: int,
258
+ detect_resolution: int,
259
+ num_steps: int,
260
+ guidance_scale: float,
261
+ seed: int,
262
+ ) -> list[PIL.Image.Image]:
263
+ control_image, vis_control_image = self.preprocess_hed(
264
+ input_image=input_image,
265
+ image_resolution=image_resolution,
266
+ detect_resolution=detect_resolution,
267
+ )
268
+ return self.process(
269
+ task_name='hed',
270
+ prompt=prompt,
271
+ additional_prompt=additional_prompt,
272
+ negative_prompt=negative_prompt,
273
+ control_image=control_image,
274
+ vis_control_image=vis_control_image,
275
+ num_samples=num_samples,
276
+ num_steps=num_steps,
277
+ guidance_scale=guidance_scale,
278
+ seed=seed,
279
+ )
280
+
281
+ def preprocess_scribble(
282
+ self,
283
+ input_image: np.ndarray,
284
+ image_resolution: int,
285
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
286
+ image = resize_image(HWC3(input_image), image_resolution)
287
+ control_image = np.zeros_like(image, dtype=np.uint8)
288
+ control_image[np.min(image, axis=2) < 127] = 255
289
+ vis_control_image = 255 - control_image
290
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
291
+ vis_control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  @torch.inference_mode()
294
+ def process_scribble(
295
+ self,
296
+ input_image: np.ndarray,
297
+ prompt: str,
298
+ additional_prompt: str,
299
+ negative_prompt: str,
300
+ num_samples: int,
301
+ image_resolution: int,
302
+ num_steps: int,
303
+ guidance_scale: float,
304
+ seed: int,
305
+ ) -> list[PIL.Image.Image]:
306
+ control_image, vis_control_image = self.preprocess_scribble(
307
+ input_image=input_image,
308
+ image_resolution=image_resolution,
309
+ )
310
+ return self.process(
311
+ task_name='scribble',
312
+ prompt=prompt,
313
+ additional_prompt=additional_prompt,
314
+ negative_prompt=negative_prompt,
315
+ control_image=control_image,
316
+ vis_control_image=vis_control_image,
317
+ num_samples=num_samples,
318
+ num_steps=num_steps,
319
+ guidance_scale=guidance_scale,
320
+ seed=seed,
321
+ )
322
+
323
+ def preprocess_scribble_interactive(
324
+ self,
325
+ input_image: np.ndarray,
326
+ image_resolution: int,
327
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
328
+ image = resize_image(HWC3(input_image['mask'][:, :, 0]),
329
+ image_resolution)
330
+ control_image = np.zeros_like(image, dtype=np.uint8)
331
+ control_image[np.min(image, axis=2) > 127] = 255
332
+ vis_control_image = 255 - control_image
333
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
334
+ vis_control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  @torch.inference_mode()
337
+ def process_scribble_interactive(
338
+ self,
339
+ input_image: np.ndarray,
340
+ prompt: str,
341
+ additional_prompt: str,
342
+ negative_prompt: str,
343
+ num_samples: int,
344
+ image_resolution: int,
345
+ num_steps: int,
346
+ guidance_scale: float,
347
+ seed: int,
348
+ ) -> list[PIL.Image.Image]:
349
+ control_image, vis_control_image = self.preprocess_scribble_interactive(
350
+ input_image=input_image,
351
+ image_resolution=image_resolution,
352
+ )
353
+ return self.process(
354
+ task_name='scribble',
355
+ prompt=prompt,
356
+ additional_prompt=additional_prompt,
357
+ negative_prompt=negative_prompt,
358
+ control_image=control_image,
359
+ vis_control_image=vis_control_image,
360
+ num_samples=num_samples,
361
+ num_steps=num_steps,
362
+ guidance_scale=guidance_scale,
363
+ seed=seed,
364
+ )
365
+
366
+ def preprocess_fake_scribble(
367
+ self,
368
+ input_image: np.ndarray,
369
+ image_resolution: int,
370
+ detect_resolution: int,
371
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
372
  input_image = HWC3(input_image)
373
+ control_image = apply_hed(resize_image(input_image, detect_resolution))
374
+ control_image = HWC3(control_image)
375
+ image = resize_image(input_image, image_resolution)
376
+ H, W = image.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
+ control_image = cv2.resize(control_image, (W, H),
379
+ interpolation=cv2.INTER_LINEAR)
380
+ control_image = nms(control_image, 127, 3.0)
381
+ control_image = cv2.GaussianBlur(control_image, (0, 0), 3.0)
382
+ control_image[control_image > 4] = 255
383
+ control_image[control_image < 255] = 0
384
+
385
+ vis_control_image = 255 - control_image
386
 
387
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
388
+ vis_control_image)
389
+
390
+ @torch.inference_mode()
391
+ def process_fake_scribble(
392
+ self,
393
+ input_image: np.ndarray,
394
+ prompt: str,
395
+ additional_prompt: str,
396
+ negative_prompt: str,
397
+ num_samples: int,
398
+ image_resolution: int,
399
+ detect_resolution: int,
400
+ num_steps: int,
401
+ guidance_scale: float,
402
+ seed: int,
403
+ ) -> list[PIL.Image.Image]:
404
+ control_image, vis_control_image = self.preprocess_fake_scribble(
405
+ input_image=input_image,
406
+ image_resolution=image_resolution,
407
+ detect_resolution=detect_resolution,
408
+ )
409
+ return self.process(
410
+ task_name='scribble',
411
+ prompt=prompt,
412
+ additional_prompt=additional_prompt,
413
+ negative_prompt=negative_prompt,
414
+ control_image=control_image,
415
+ vis_control_image=vis_control_image,
416
+ num_samples=num_samples,
417
+ num_steps=num_steps,
418
+ guidance_scale=guidance_scale,
419
+ seed=seed,
420
+ )
421
+
422
+ def preprocess_pose(
423
+ self,
424
+ input_image: np.ndarray,
425
+ image_resolution: int,
426
+ detect_resolution: int,
427
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
428
  input_image = HWC3(input_image)
429
+ control_image, _ = apply_openpose(
430
  resize_image(input_image, detect_resolution))
431
+ control_image = HWC3(control_image)
432
+ image = resize_image(input_image, image_resolution)
433
+ H, W = image.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
+ control_image = cv2.resize(control_image, (W, H),
436
+ interpolation=cv2.INTER_NEAREST)
 
 
 
437
 
438
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
439
+ control_image)
440
+
441
+ @torch.inference_mode()
442
+ def process_pose(
443
+ self,
444
+ input_image: np.ndarray,
445
+ prompt: str,
446
+ additional_prompt: str,
447
+ negative_prompt: str,
448
+ num_samples: int,
449
+ image_resolution: int,
450
+ detect_resolution: int,
451
+ num_steps: int,
452
+ guidance_scale: float,
453
+ seed: int,
454
+ ) -> list[PIL.Image.Image]:
455
+ control_image, vis_control_image = self.preprocess_pose(
456
+ input_image=input_image,
457
+ image_resolution=image_resolution,
458
+ detect_resolution=detect_resolution,
459
+ )
460
+ return self.process(
461
+ task_name='pose',
462
+ prompt=prompt,
463
+ additional_prompt=additional_prompt,
464
+ negative_prompt=negative_prompt,
465
+ control_image=control_image,
466
+ vis_control_image=vis_control_image,
467
+ num_samples=num_samples,
468
+ num_steps=num_steps,
469
+ guidance_scale=guidance_scale,
470
+ seed=seed,
471
+ )
472
+
473
+ def preprocess_seg(
474
+ self,
475
+ input_image: np.ndarray,
476
+ image_resolution: int,
477
+ detect_resolution: int,
478
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
479
  input_image = HWC3(input_image)
480
+ control_image = apply_uniformer(
481
  resize_image(input_image, detect_resolution))
482
+ image = resize_image(input_image, image_resolution)
483
+ H, W = image.shape[:2]
484
+ control_image = cv2.resize(control_image, (W, H),
485
+ interpolation=cv2.INTER_NEAREST)
486
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
487
+ control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
  @torch.inference_mode()
490
+ def process_seg(
491
+ self,
492
+ input_image: np.ndarray,
493
+ prompt: str,
494
+ additional_prompt: str,
495
+ negative_prompt: str,
496
+ num_samples: int,
497
+ image_resolution: int,
498
+ detect_resolution: int,
499
+ num_steps: int,
500
+ guidance_scale: float,
501
+ seed: int,
502
+ ) -> list[PIL.Image.Image]:
503
+ control_image, vis_control_image = self.preprocess_seg(
504
+ input_image=input_image,
505
+ image_resolution=image_resolution,
506
+ detect_resolution=detect_resolution,
507
+ )
508
+ return self.process(
509
+ task_name='seg',
510
+ prompt=prompt,
511
+ additional_prompt=additional_prompt,
512
+ negative_prompt=negative_prompt,
513
+ control_image=control_image,
514
+ vis_control_image=vis_control_image,
515
+ num_samples=num_samples,
516
+ num_steps=num_steps,
517
+ guidance_scale=guidance_scale,
518
+ seed=seed,
519
+ )
520
+
521
+ def preprocess_depth(
522
+ self,
523
+ input_image: np.ndarray,
524
+ image_resolution: int,
525
+ detect_resolution: int,
526
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
527
  input_image = HWC3(input_image)
528
+ control_image, _ = apply_midas(
529
  resize_image(input_image, detect_resolution))
530
+ control_image = HWC3(control_image)
531
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
532
+ control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
 
534
  @torch.inference_mode()
535
+ def process_depth(
536
+ self,
537
+ input_image: np.ndarray,
538
+ prompt: str,
539
+ additional_prompt: str,
540
+ negative_prompt: str,
541
+ num_samples: int,
542
+ image_resolution: int,
543
+ detect_resolution: int,
544
+ num_steps: int,
545
+ guidance_scale: float,
546
+ seed: int,
547
+ ) -> list[PIL.Image.Image]:
548
+ control_image, vis_control_image = self.preprocess_depth(
549
+ input_image=input_image,
550
+ image_resolution=image_resolution,
551
+ detect_resolution=detect_resolution,
552
+ )
553
+ return self.process(
554
+ task_name='depth',
555
+ prompt=prompt,
556
+ additional_prompt=additional_prompt,
557
+ negative_prompt=negative_prompt,
558
+ control_image=control_image,
559
+ vis_control_image=vis_control_image,
560
+ num_samples=num_samples,
561
+ num_steps=num_steps,
562
+ guidance_scale=guidance_scale,
563
+ seed=seed,
564
+ )
565
+
566
+ def preprocess_normal(
567
+ self,
568
+ input_image: np.ndarray,
569
+ image_resolution: int,
570
+ detect_resolution: int,
571
+ bg_threshold,
572
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
573
  input_image = HWC3(input_image)
574
+ _, control_image = apply_midas(resize_image(input_image,
575
+ detect_resolution),
576
+ bg_th=bg_threshold)
577
+ control_image = HWC3(control_image)
578
+ image = resize_image(input_image, image_resolution)
579
+ H, W = image.shape[:2]
580
+ control_image = cv2.resize(control_image, (W, H),
581
+ interpolation=cv2.INTER_LINEAR)
582
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
583
+ control_image)
584
+
585
+ @torch.inference_mode()
586
+ def process_normal(
587
+ self,
588
+ input_image: np.ndarray,
589
+ prompt: str,
590
+ additional_prompt: str,
591
+ negative_prompt: str,
592
+ num_samples: int,
593
+ image_resolution: int,
594
+ detect_resolution: int,
595
+ num_steps: int,
596
+ guidance_scale: float,
597
+ seed: int,
598
+ bg_threshold,
599
+ ) -> list[PIL.Image.Image]:
600
+ control_image, vis_control_image = self.preprocess_normal(
601
+ input_image=input_image,
602
+ image_resolution=image_resolution,
603
+ detect_resolution=detect_resolution,
604
+ bg_threshold=bg_threshold,
605
+ )
606
+ return self.process(
607
+ task_name='normal',
608
+ prompt=prompt,
609
+ additional_prompt=additional_prompt,
610
+ negative_prompt=negative_prompt,
611
+ control_image=control_image,
612
+ vis_control_image=vis_control_image,
613
+ num_samples=num_samples,
614
+ num_steps=num_steps,
615
+ guidance_scale=guidance_scale,
616
+ seed=seed,
617
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
  addict==2.4.0
2
  albumentations==1.3.0
3
  einops==0.6.0
4
- gradio==3.18.0
 
 
5
  imageio==2.25.0
6
  imageio-ffmpeg==0.4.8
7
  kornia==0.6.9
 
1
  addict==2.4.0
2
  albumentations==1.3.0
3
  einops==0.6.0
4
+ git+https://github.com/huggingface/accelerate@78151f8
5
+ git+https://github.com/huggingface/diffusers@fa6d52d
6
+ gradio==3.20.0
7
  imageio==2.25.0
8
  imageio-ffmpeg==0.4.8
9
  kornia==0.6.9