hysts HF staff commited on
Commit
b44d4b1
β€’
1 Parent(s): b7075f8

Update to the original Space

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 😻
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.18.0
8
  python_version: 3.10.9
9
  app_file: app.py
10
  pinned: false
 
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.20.0
8
  python_version: 3.10.9
9
  app_file: app.py
10
  pinned: false
app.py CHANGED
@@ -30,92 +30,117 @@ for name in names:
30
  continue
31
  subprocess.run(shlex.split(command), cwd='ControlNet/annotator/ckpts/')
32
 
33
- from gradio_canny2image import create_demo as create_demo_canny
34
- from gradio_depth2image import create_demo as create_demo_depth
35
- from gradio_fake_scribble2image import create_demo as create_demo_fake_scribble
36
- from gradio_hed2image import create_demo as create_demo_hed
37
- from gradio_hough2image import create_demo as create_demo_hough
38
- from gradio_normal2image import create_demo as create_demo_normal
39
- from gradio_pose2image import create_demo as create_demo_pose
40
- from gradio_scribble2image import create_demo as create_demo_scribble
41
- from gradio_scribble2image_interactive import \
42
  create_demo as create_demo_scribble_interactive
43
- from gradio_seg2image import create_demo as create_demo_seg
44
- from model import (DEFAULT_BASE_MODEL_FILENAME, DEFAULT_BASE_MODEL_REPO,
45
- DEFAULT_BASE_MODEL_URL, Model)
46
 
47
- MAX_IMAGES = 1
48
- DESCRIPTION = '''# [ControlNet](https://github.com/lllyasviel/ControlNet)
49
-
50
- This Space is a modified version of [this Space](https://huggingface.co/spaces/hysts/ControlNet).
51
- The original Space uses [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) as the base model, but [Anything v4.0](https://huggingface.co/andite/anything-v4.0) is used in this Space.
52
- '''
53
 
54
  SPACE_ID = os.getenv('SPACE_ID')
55
- ALLOW_CHANGING_BASE_MODEL = SPACE_ID != 'hysts/ControlNet-with-other-models'
56
-
57
- if not ALLOW_CHANGING_BASE_MODEL:
58
- DESCRIPTION += 'In this Space, the base model is not allowed to be changed so as not to slow down the demo, but it can be changed if you duplicate the Space.'
59
 
60
  if SPACE_ID is not None:
61
- DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.<br/>
62
- <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">
63
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
64
- <p/>
65
  '''
66
 
67
- model = Model()
 
 
 
 
 
 
 
 
68
 
69
  with gr.Blocks(css='style.css') as demo:
70
  gr.Markdown(DESCRIPTION)
71
-
72
  with gr.Tabs():
73
  with gr.TabItem('Canny'):
74
- create_demo_canny(model.process_canny, max_images=MAX_IMAGES)
 
 
75
  with gr.TabItem('Hough'):
76
- create_demo_hough(model.process_hough, max_images=MAX_IMAGES)
 
 
77
  with gr.TabItem('HED'):
78
- create_demo_hed(model.process_hed, max_images=MAX_IMAGES)
 
 
79
  with gr.TabItem('Scribble'):
80
- create_demo_scribble(model.process_scribble, max_images=MAX_IMAGES)
 
 
81
  with gr.TabItem('Scribble Interactive'):
82
  create_demo_scribble_interactive(
83
- model.process_scribble_interactive, max_images=MAX_IMAGES)
 
 
84
  with gr.TabItem('Fake Scribble'):
85
  create_demo_fake_scribble(model.process_fake_scribble,
86
- max_images=MAX_IMAGES)
 
87
  with gr.TabItem('Pose'):
88
- create_demo_pose(model.process_pose, max_images=MAX_IMAGES)
 
 
89
  with gr.TabItem('Segmentation'):
90
- create_demo_seg(model.process_seg, max_images=MAX_IMAGES)
 
 
91
  with gr.TabItem('Depth'):
92
- create_demo_depth(model.process_depth, max_images=MAX_IMAGES)
 
 
93
  with gr.TabItem('Normal map'):
94
- create_demo_normal(model.process_normal, max_images=MAX_IMAGES)
 
 
95
 
96
  with gr.Accordion(label='Base model', open=False):
97
- current_base_model = gr.Text(label='Current base model',
98
- value=DEFAULT_BASE_MODEL_URL)
99
  with gr.Row():
100
- base_model_repo = gr.Text(label='Base model repo',
101
- max_lines=1,
102
- placeholder=DEFAULT_BASE_MODEL_REPO,
103
- interactive=ALLOW_CHANGING_BASE_MODEL)
104
- base_model_filename = gr.Text(
105
- label='Base model file',
106
- max_lines=1,
107
- placeholder=DEFAULT_BASE_MODEL_FILENAME,
108
- interactive=ALLOW_CHANGING_BASE_MODEL)
109
- change_base_model_button = gr.Button('Change base model')
110
- gr.Markdown(
111
- '''- You can use other base models by specifying the repository name and filename.
112
- The base model must be compatible with Stable Diffusion v1.5.''')
113
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  change_base_model_button.click(fn=model.set_base_model,
115
- inputs=[
116
- base_model_repo,
117
- base_model_filename,
118
- ],
119
  outputs=current_base_model)
120
 
121
  demo.queue(api_open=False).launch()
 
30
  continue
31
  subprocess.run(shlex.split(command), cwd='ControlNet/annotator/ckpts/')
32
 
33
+ from app_canny import create_demo as create_demo_canny
34
+ from app_depth import create_demo as create_demo_depth
35
+ from app_fake_scribble import create_demo as create_demo_fake_scribble
36
+ from app_hed import create_demo as create_demo_hed
37
+ from app_hough import create_demo as create_demo_hough
38
+ from app_normal import create_demo as create_demo_normal
39
+ from app_pose import create_demo as create_demo_pose
40
+ from app_scribble import create_demo as create_demo_scribble
41
+ from app_scribble_interactive import \
42
  create_demo as create_demo_scribble_interactive
43
+ from app_seg import create_demo as create_demo_seg
44
+ from model import Model, download_all_controlnet_weights
 
45
 
46
+ DESCRIPTION = '# [ControlNet](https://github.com/lllyasviel/ControlNet)'
 
 
 
 
 
47
 
48
  SPACE_ID = os.getenv('SPACE_ID')
49
+ ALLOW_CHANGING_BASE_MODEL = SPACE_ID != 'hysts/ControlNet'
 
 
 
50
 
51
  if SPACE_ID is not None:
52
+ DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>
 
 
 
53
  '''
54
 
55
+ MAX_IMAGES = int(os.getenv('MAX_IMAGES', '3'))
56
+ DEFAULT_NUM_IMAGES = min(MAX_IMAGES, int(os.getenv('DEFAULT_NUM_IMAGES', '1')))
57
+
58
+ if os.getenv('SYSTEM') == 'spaces':
59
+ download_all_controlnet_weights()
60
+
61
+ DEFAULT_MODEL_ID = os.getenv('DEFAULT_MODEL_ID',
62
+ 'runwayml/stable-diffusion-v1-5')
63
+ model = Model(base_model_id=DEFAULT_MODEL_ID, task_name='canny')
64
 
65
  with gr.Blocks(css='style.css') as demo:
66
  gr.Markdown(DESCRIPTION)
 
67
  with gr.Tabs():
68
  with gr.TabItem('Canny'):
69
+ create_demo_canny(model.process_canny,
70
+ max_images=MAX_IMAGES,
71
+ default_num_images=DEFAULT_NUM_IMAGES)
72
  with gr.TabItem('Hough'):
73
+ create_demo_hough(model.process_hough,
74
+ max_images=MAX_IMAGES,
75
+ default_num_images=DEFAULT_NUM_IMAGES)
76
  with gr.TabItem('HED'):
77
+ create_demo_hed(model.process_hed,
78
+ max_images=MAX_IMAGES,
79
+ default_num_images=DEFAULT_NUM_IMAGES)
80
  with gr.TabItem('Scribble'):
81
+ create_demo_scribble(model.process_scribble,
82
+ max_images=MAX_IMAGES,
83
+ default_num_images=DEFAULT_NUM_IMAGES)
84
  with gr.TabItem('Scribble Interactive'):
85
  create_demo_scribble_interactive(
86
+ model.process_scribble_interactive,
87
+ max_images=MAX_IMAGES,
88
+ default_num_images=DEFAULT_NUM_IMAGES)
89
  with gr.TabItem('Fake Scribble'):
90
  create_demo_fake_scribble(model.process_fake_scribble,
91
+ max_images=MAX_IMAGES,
92
+ default_num_images=DEFAULT_NUM_IMAGES)
93
  with gr.TabItem('Pose'):
94
+ create_demo_pose(model.process_pose,
95
+ max_images=MAX_IMAGES,
96
+ default_num_images=DEFAULT_NUM_IMAGES)
97
  with gr.TabItem('Segmentation'):
98
+ create_demo_seg(model.process_seg,
99
+ max_images=MAX_IMAGES,
100
+ default_num_images=DEFAULT_NUM_IMAGES)
101
  with gr.TabItem('Depth'):
102
+ create_demo_depth(model.process_depth,
103
+ max_images=MAX_IMAGES,
104
+ default_num_images=DEFAULT_NUM_IMAGES)
105
  with gr.TabItem('Normal map'):
106
+ create_demo_normal(model.process_normal,
107
+ max_images=MAX_IMAGES,
108
+ default_num_images=DEFAULT_NUM_IMAGES)
109
 
110
  with gr.Accordion(label='Base model', open=False):
 
 
111
  with gr.Row():
112
+ with gr.Column():
113
+ current_base_model = gr.Text(label='Current base model')
114
+ with gr.Column(scale=0.3):
115
+ check_base_model_button = gr.Button('Check current base model')
116
+ with gr.Row():
117
+ with gr.Column():
118
+ new_base_model_id = gr.Text(
119
+ label='New base model',
120
+ max_lines=1,
121
+ placeholder='runwayml/stable-diffusion-v1-5',
122
+ info=
123
+ 'The base model must be compatible with Stable Diffusion v1.5.',
124
+ interactive=ALLOW_CHANGING_BASE_MODEL)
125
+ with gr.Column(scale=0.3):
126
+ change_base_model_button = gr.Button('Change base model')
127
+ if not ALLOW_CHANGING_BASE_MODEL:
128
+ gr.Markdown(
129
+ '''The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'''
130
+ )
131
+
132
+ gr.Markdown(
133
+ '[Space using Anything-v4.0 as base model](https://huggingface.co/spaces/hysts/ControlNet-with-other-models)'
134
+ )
135
+
136
+ check_base_model_button.click(fn=lambda: model.base_model_id,
137
+ outputs=current_base_model,
138
+ queue=False)
139
+ new_base_model_id.submit(fn=model.set_base_model,
140
+ inputs=new_base_model_id,
141
+ outputs=current_base_model)
142
  change_base_model_button.click(fn=model.set_base_model,
143
+ inputs=new_base_model_id,
 
 
 
144
  outputs=current_base_model)
145
 
146
  demo.queue(api_open=False).launch()
gradio_canny2image.py β†’ app_canny.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
@@ -16,39 +16,40 @@ def create_demo(process, max_images=12):
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
23
  maximum=768,
24
  value=512,
25
  step=256)
26
- low_threshold = gr.Slider(label='Canny low threshold',
27
- minimum=1,
28
- maximum=255,
29
- value=100,
30
- step=1)
31
- high_threshold = gr.Slider(label='Canny high threshold',
32
- minimum=1,
33
- maximum=255,
34
- value=200,
35
- step=1)
36
- ddim_steps = gr.Slider(label='Steps',
37
- minimum=1,
38
- maximum=100,
39
- value=20,
40
- step=1)
41
- scale = gr.Slider(label='Guidance Scale',
42
- minimum=0.1,
43
- maximum=30.0,
44
- value=9.0,
45
- step=0.1)
 
 
46
  seed = gr.Slider(label='Seed',
47
  minimum=-1,
48
  maximum=2147483647,
49
  step=1,
50
  randomize=True)
51
- eta = gr.Number(label='eta (DDIM)', value=0.0)
52
  a_prompt = gr.Textbox(
53
  label='Added Prompt',
54
  value='best quality, extremely detailed')
@@ -58,17 +59,33 @@ def create_demo(process, max_images=12):
58
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
59
  )
60
  with gr.Column():
61
- result_gallery = gr.Gallery(label='Output',
62
- show_label=False,
63
- elem_id='gallery').style(
64
- grid=2, height='auto')
65
- ips = [
66
- input_image, prompt, a_prompt, n_prompt, num_samples,
67
- image_resolution, ddim_steps, scale, seed, eta, low_threshold,
68
- high_threshold
 
 
 
 
 
 
 
 
69
  ]
 
70
  run_button.click(fn=process,
71
- inputs=ips,
72
- outputs=[result_gallery],
73
  api_name='canny')
74
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
+ value=default_num_images,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
23
  maximum=768,
24
  value=512,
25
  step=256)
26
+ canny_low_threshold = gr.Slider(
27
+ label='Canny low threshold',
28
+ minimum=1,
29
+ maximum=255,
30
+ value=100,
31
+ step=1)
32
+ canny_high_threshold = gr.Slider(
33
+ label='Canny high threshold',
34
+ minimum=1,
35
+ maximum=255,
36
+ value=200,
37
+ step=1)
38
+ num_steps = gr.Slider(label='Steps',
39
+ minimum=1,
40
+ maximum=100,
41
+ value=20,
42
+ step=1)
43
+ guidance_scale = gr.Slider(label='Guidance Scale',
44
+ minimum=0.1,
45
+ maximum=30.0,
46
+ value=9.0,
47
+ step=0.1)
48
  seed = gr.Slider(label='Seed',
49
  minimum=-1,
50
  maximum=2147483647,
51
  step=1,
52
  randomize=True)
 
53
  a_prompt = gr.Textbox(
54
  label='Added Prompt',
55
  value='best quality, extremely detailed')
 
59
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
60
  )
61
  with gr.Column():
62
+ result = gr.Gallery(label='Output',
63
+ show_label=False,
64
+ elem_id='gallery').style(grid=2,
65
+ height='auto')
66
+ inputs = [
67
+ input_image,
68
+ prompt,
69
+ a_prompt,
70
+ n_prompt,
71
+ num_samples,
72
+ image_resolution,
73
+ num_steps,
74
+ guidance_scale,
75
+ seed,
76
+ canny_low_threshold,
77
+ canny_high_threshold,
78
  ]
79
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
80
  run_button.click(fn=process,
81
+ inputs=inputs,
82
+ outputs=result,
83
  api_name='canny')
84
  return demo
85
+
86
+
87
+ if __name__ == '__main__':
88
+ from model import Model
89
+ model = Model()
90
+ demo = create_demo(model.process_canny)
91
+ demo.queue().launch()
gradio_depth2image.py β†’ app_depth.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Depth Maps')
@@ -13,10 +13,12 @@ def create_demo(process, max_images=12):
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
 
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
@@ -28,22 +30,21 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=384,
30
  step=1)
31
- ddim_steps = gr.Slider(label='Steps',
32
- minimum=1,
33
- maximum=100,
34
- value=20,
35
- step=1)
36
- scale = gr.Slider(label='Guidance Scale',
37
- minimum=0.1,
38
- maximum=30.0,
39
- value=9.0,
40
- step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
  randomize=True)
46
- eta = gr.Number(label='eta (DDIM)', value=0.0)
47
  a_prompt = gr.Textbox(
48
  label='Added Prompt',
49
  value='best quality, extremely detailed')
@@ -53,16 +54,33 @@ def create_demo(process, max_images=12):
53
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
54
  )
55
  with gr.Column():
56
- result_gallery = gr.Gallery(label='Output',
57
- show_label=False,
58
- elem_id='gallery').style(
59
- grid=2, height='auto')
60
- ips = [
61
- input_image, prompt, a_prompt, n_prompt, num_samples,
62
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
 
63
  ]
 
64
  run_button.click(fn=process,
65
- inputs=ips,
66
- outputs=[result_gallery],
67
  api_name='depth')
68
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Depth Maps')
 
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
16
+ is_depth_image = gr.Checkbox(label='Is depth image',
17
+ value=False)
18
  num_samples = gr.Slider(label='Images',
19
  minimum=1,
20
  maximum=max_images,
21
+ value=default_num_images,
22
  step=1)
23
  image_resolution = gr.Slider(label='Image Resolution',
24
  minimum=256,
 
30
  maximum=1024,
31
  value=384,
32
  step=1)
33
+ num_steps = gr.Slider(label='Steps',
34
+ minimum=1,
35
+ maximum=100,
36
+ value=20,
37
+ step=1)
38
+ guidance_scale = gr.Slider(label='Guidance Scale',
39
+ minimum=0.1,
40
+ maximum=30.0,
41
+ value=9.0,
42
+ step=0.1)
43
  seed = gr.Slider(label='Seed',
44
  minimum=-1,
45
  maximum=2147483647,
46
  step=1,
47
  randomize=True)
 
48
  a_prompt = gr.Textbox(
49
  label='Added Prompt',
50
  value='best quality, extremely detailed')
 
54
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
55
  )
56
  with gr.Column():
57
+ result = gr.Gallery(label='Output',
58
+ show_label=False,
59
+ elem_id='gallery').style(grid=2,
60
+ height='auto')
61
+ inputs = [
62
+ input_image,
63
+ prompt,
64
+ a_prompt,
65
+ n_prompt,
66
+ num_samples,
67
+ image_resolution,
68
+ detect_resolution,
69
+ num_steps,
70
+ guidance_scale,
71
+ seed,
72
+ is_depth_image,
73
  ]
74
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
75
  run_button.click(fn=process,
76
+ inputs=inputs,
77
+ outputs=result,
78
  api_name='depth')
79
  return demo
80
+
81
+
82
+ if __name__ == '__main__':
83
+ from model import Model
84
+ model = Model()
85
+ demo = create_demo(model.process_depth)
86
+ demo.queue().launch()
gradio_fake_scribble2image.py β†’ app_fake_scribble.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Fake Scribble Maps')
@@ -16,7 +16,7 @@ def create_demo(process, max_images=12):
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
@@ -28,22 +28,21 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
- ddim_steps = gr.Slider(label='Steps',
32
- minimum=1,
33
- maximum=100,
34
- value=20,
35
- step=1)
36
- scale = gr.Slider(label='Guidance Scale',
37
- minimum=0.1,
38
- maximum=30.0,
39
- value=9.0,
40
- step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
  randomize=True)
46
- eta = gr.Number(label='eta (DDIM)', value=0.0)
47
  a_prompt = gr.Textbox(
48
  label='Added Prompt',
49
  value='best quality, extremely detailed')
@@ -53,16 +52,32 @@ def create_demo(process, max_images=12):
53
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
54
  )
55
  with gr.Column():
56
- result_gallery = gr.Gallery(label='Output',
57
- show_label=False,
58
- elem_id='gallery').style(
59
- grid=2, height='auto')
60
- ips = [
61
- input_image, prompt, a_prompt, n_prompt, num_samples,
62
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
63
  ]
 
64
  run_button.click(fn=process,
65
- inputs=ips,
66
- outputs=[result_gallery],
67
  api_name='fake_scribble')
68
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Fake Scribble Maps')
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
+ value=default_num_images,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
 
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
+ num_steps = gr.Slider(label='Steps',
32
+ minimum=1,
33
+ maximum=100,
34
+ value=20,
35
+ step=1)
36
+ guidance_scale = gr.Slider(label='Guidance Scale',
37
+ minimum=0.1,
38
+ maximum=30.0,
39
+ value=9.0,
40
+ step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
  randomize=True)
 
46
  a_prompt = gr.Textbox(
47
  label='Added Prompt',
48
  value='best quality, extremely detailed')
 
52
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
53
  )
54
  with gr.Column():
55
+ result = gr.Gallery(label='Output',
56
+ show_label=False,
57
+ elem_id='gallery').style(grid=2,
58
+ height='auto')
59
+ inputs = [
60
+ input_image,
61
+ prompt,
62
+ a_prompt,
63
+ n_prompt,
64
+ num_samples,
65
+ image_resolution,
66
+ detect_resolution,
67
+ num_steps,
68
+ guidance_scale,
69
+ seed,
70
  ]
71
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
72
  run_button.click(fn=process,
73
+ inputs=inputs,
74
+ outputs=result,
75
  api_name='fake_scribble')
76
  return demo
77
+
78
+
79
+ if __name__ == '__main__':
80
+ from model import Model
81
+ model = Model()
82
+ demo = create_demo(model.process_fake_scribble)
83
+ demo.queue().launch()
gradio_hed2image.py β†’ app_hed.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with HED Maps')
@@ -16,7 +16,7 @@ def create_demo(process, max_images=12):
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
@@ -28,22 +28,21 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
- ddim_steps = gr.Slider(label='Steps',
32
- minimum=1,
33
- maximum=100,
34
- value=20,
35
- step=1)
36
- scale = gr.Slider(label='Guidance Scale',
37
- minimum=0.1,
38
- maximum=30.0,
39
- value=9.0,
40
- step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
  randomize=True)
46
- eta = gr.Number(label='eta (DDIM)', value=0.0)
47
  a_prompt = gr.Textbox(
48
  label='Added Prompt',
49
  value='best quality, extremely detailed')
@@ -53,16 +52,32 @@ def create_demo(process, max_images=12):
53
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
54
  )
55
  with gr.Column():
56
- result_gallery = gr.Gallery(label='Output',
57
- show_label=False,
58
- elem_id='gallery').style(
59
- grid=2, height='auto')
60
- ips = [
61
- input_image, prompt, a_prompt, n_prompt, num_samples,
62
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
63
  ]
 
64
  run_button.click(fn=process,
65
- inputs=ips,
66
- outputs=[result_gallery],
67
  api_name='hed')
68
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with HED Maps')
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
+ value=default_num_images,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
 
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
+ num_steps = gr.Slider(label='Steps',
32
+ minimum=1,
33
+ maximum=100,
34
+ value=20,
35
+ step=1)
36
+ guidance_scale = gr.Slider(label='Guidance Scale',
37
+ minimum=0.1,
38
+ maximum=30.0,
39
+ value=9.0,
40
+ step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
  randomize=True)
 
46
  a_prompt = gr.Textbox(
47
  label='Added Prompt',
48
  value='best quality, extremely detailed')
 
52
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
53
  )
54
  with gr.Column():
55
+ result = gr.Gallery(label='Output',
56
+ show_label=False,
57
+ elem_id='gallery').style(grid=2,
58
+ height='auto')
59
+ inputs = [
60
+ input_image,
61
+ prompt,
62
+ a_prompt,
63
+ n_prompt,
64
+ num_samples,
65
+ image_resolution,
66
+ detect_resolution,
67
+ num_steps,
68
+ guidance_scale,
69
+ seed,
70
  ]
71
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
72
  run_button.click(fn=process,
73
+ inputs=inputs,
74
+ outputs=result,
75
  api_name='hed')
76
  return demo
77
+
78
+
79
+ if __name__ == '__main__':
80
+ from model import Model
81
+ model = Model()
82
+ demo = create_demo(model.process_hed)
83
+ demo.queue().launch()
gradio_hough2image.py β†’ app_hough.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Hough Line Maps')
@@ -16,7 +16,7 @@ def create_demo(process, max_images=12):
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
@@ -28,34 +28,33 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
- value_threshold = gr.Slider(
32
  label='Hough value threshold (MLSD)',
33
  minimum=0.01,
34
  maximum=2.0,
35
  value=0.1,
36
  step=0.01)
37
- distance_threshold = gr.Slider(
38
  label='Hough distance threshold (MLSD)',
39
  minimum=0.01,
40
  maximum=20.0,
41
  value=0.1,
42
  step=0.01)
43
- ddim_steps = gr.Slider(label='Steps',
44
- minimum=1,
45
- maximum=100,
46
- value=20,
47
- step=1)
48
- scale = gr.Slider(label='Guidance Scale',
49
- minimum=0.1,
50
- maximum=30.0,
51
- value=9.0,
52
- step=0.1)
53
  seed = gr.Slider(label='Seed',
54
  minimum=-1,
55
  maximum=2147483647,
56
  step=1,
57
  randomize=True)
58
- eta = gr.Number(label='eta (DDIM)', value=0.0)
59
  a_prompt = gr.Textbox(
60
  label='Added Prompt',
61
  value='best quality, extremely detailed')
@@ -65,17 +64,34 @@ def create_demo(process, max_images=12):
65
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
66
  )
67
  with gr.Column():
68
- result_gallery = gr.Gallery(label='Output',
69
- show_label=False,
70
- elem_id='gallery').style(
71
- grid=2, height='auto')
72
- ips = [
73
- input_image, prompt, a_prompt, n_prompt, num_samples,
74
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta,
75
- value_threshold, distance_threshold
 
 
 
 
 
 
 
 
 
76
  ]
 
77
  run_button.click(fn=process,
78
- inputs=ips,
79
- outputs=[result_gallery],
80
  api_name='hough')
81
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Hough Line Maps')
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
+ value=default_num_images,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
 
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
+ mlsd_value_threshold = gr.Slider(
32
  label='Hough value threshold (MLSD)',
33
  minimum=0.01,
34
  maximum=2.0,
35
  value=0.1,
36
  step=0.01)
37
+ mlsd_distance_threshold = gr.Slider(
38
  label='Hough distance threshold (MLSD)',
39
  minimum=0.01,
40
  maximum=20.0,
41
  value=0.1,
42
  step=0.01)
43
+ num_steps = gr.Slider(label='Steps',
44
+ minimum=1,
45
+ maximum=100,
46
+ value=20,
47
+ step=1)
48
+ guidance_scale = gr.Slider(label='Guidance Scale',
49
+ minimum=0.1,
50
+ maximum=30.0,
51
+ value=9.0,
52
+ step=0.1)
53
  seed = gr.Slider(label='Seed',
54
  minimum=-1,
55
  maximum=2147483647,
56
  step=1,
57
  randomize=True)
 
58
  a_prompt = gr.Textbox(
59
  label='Added Prompt',
60
  value='best quality, extremely detailed')
 
64
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
65
  )
66
  with gr.Column():
67
+ result = gr.Gallery(label='Output',
68
+ show_label=False,
69
+ elem_id='gallery').style(grid=2,
70
+ height='auto')
71
+ inputs = [
72
+ input_image,
73
+ prompt,
74
+ a_prompt,
75
+ n_prompt,
76
+ num_samples,
77
+ image_resolution,
78
+ detect_resolution,
79
+ num_steps,
80
+ guidance_scale,
81
+ seed,
82
+ mlsd_value_threshold,
83
+ mlsd_distance_threshold,
84
  ]
85
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
86
  run_button.click(fn=process,
87
+ inputs=inputs,
88
+ outputs=result,
89
  api_name='hough')
90
  return demo
91
+
92
+
93
+ if __name__ == '__main__':
94
+ from model import Model
95
+ model = Model()
96
+ demo = create_demo(model.process_hough)
97
+ demo.queue().launch()
gradio_normal2image.py β†’ app_normal.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Normal Maps')
@@ -13,10 +13,12 @@ def create_demo(process, max_images=12):
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
 
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
@@ -34,22 +36,21 @@ def create_demo(process, max_images=12):
34
  maximum=1.0,
35
  value=0.4,
36
  step=0.01)
37
- ddim_steps = gr.Slider(label='Steps',
38
- minimum=1,
39
- maximum=100,
40
- value=20,
41
- step=1)
42
- scale = gr.Slider(label='Guidance Scale',
43
- minimum=0.1,
44
- maximum=30.0,
45
- value=9.0,
46
- step=0.1)
47
  seed = gr.Slider(label='Seed',
48
  minimum=-1,
49
  maximum=2147483647,
50
  step=1,
51
  randomize=True)
52
- eta = gr.Number(label='eta (DDIM)', value=0.0)
53
  a_prompt = gr.Textbox(
54
  label='Added Prompt',
55
  value='best quality, extremely detailed')
@@ -59,17 +60,34 @@ def create_demo(process, max_images=12):
59
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
60
  )
61
  with gr.Column():
62
- result_gallery = gr.Gallery(label='Output',
63
- show_label=False,
64
- elem_id='gallery').style(
65
- grid=2, height='auto')
66
- ips = [
67
- input_image, prompt, a_prompt, n_prompt, num_samples,
68
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta,
69
- bg_threshold
 
 
 
 
 
 
 
 
 
70
  ]
 
71
  run_button.click(fn=process,
72
- inputs=ips,
73
- outputs=[result_gallery],
74
  api_name='normal')
75
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Normal Maps')
 
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
16
+ is_normal_image = gr.Checkbox(label='Is normal image',
17
+ value=False)
18
  num_samples = gr.Slider(label='Images',
19
  minimum=1,
20
  maximum=max_images,
21
+ value=default_num_images,
22
  step=1)
23
  image_resolution = gr.Slider(label='Image Resolution',
24
  minimum=256,
 
36
  maximum=1.0,
37
  value=0.4,
38
  step=0.01)
39
+ num_steps = gr.Slider(label='Steps',
40
+ minimum=1,
41
+ maximum=100,
42
+ value=20,
43
+ step=1)
44
+ guidance_scale = gr.Slider(label='Guidance Scale',
45
+ minimum=0.1,
46
+ maximum=30.0,
47
+ value=9.0,
48
+ step=0.1)
49
  seed = gr.Slider(label='Seed',
50
  minimum=-1,
51
  maximum=2147483647,
52
  step=1,
53
  randomize=True)
 
54
  a_prompt = gr.Textbox(
55
  label='Added Prompt',
56
  value='best quality, extremely detailed')
 
60
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
61
  )
62
  with gr.Column():
63
+ result = gr.Gallery(label='Output',
64
+ show_label=False,
65
+ elem_id='gallery').style(grid=2,
66
+ height='auto')
67
+ inputs = [
68
+ input_image,
69
+ prompt,
70
+ a_prompt,
71
+ n_prompt,
72
+ num_samples,
73
+ image_resolution,
74
+ detect_resolution,
75
+ num_steps,
76
+ guidance_scale,
77
+ seed,
78
+ bg_threshold,
79
+ is_normal_image,
80
  ]
81
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
82
  run_button.click(fn=process,
83
+ inputs=inputs,
84
+ outputs=result,
85
  api_name='normal')
86
  return demo
87
+
88
+
89
+ if __name__ == '__main__':
90
+ from model import Model
91
+ model = Model()
92
+ demo = create_demo(model.process_normal)
93
+ demo.queue().launch()
gradio_pose2image.py β†’ app_pose.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Human Pose')
@@ -13,10 +13,15 @@ def create_demo(process, max_images=12):
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
 
 
 
 
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
@@ -28,22 +33,21 @@ def create_demo(process, max_images=12):
28
  maximum=1024,
29
  value=512,
30
  step=1)
31
- ddim_steps = gr.Slider(label='Steps',
32
- minimum=1,
33
- maximum=100,
34
- value=20,
35
- step=1)
36
- scale = gr.Slider(label='Guidance Scale',
37
- minimum=0.1,
38
- maximum=30.0,
39
- value=9.0,
40
- step=0.1)
41
  seed = gr.Slider(label='Seed',
42
  minimum=-1,
43
  maximum=2147483647,
44
  step=1,
45
  randomize=True)
46
- eta = gr.Number(label='eta (DDIM)', value=0.0)
47
  a_prompt = gr.Textbox(
48
  label='Added Prompt',
49
  value='best quality, extremely detailed')
@@ -53,16 +57,33 @@ def create_demo(process, max_images=12):
53
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
54
  )
55
  with gr.Column():
56
- result_gallery = gr.Gallery(label='Output',
57
- show_label=False,
58
- elem_id='gallery').style(
59
- grid=2, height='auto')
60
- ips = [
61
- input_image, prompt, a_prompt, n_prompt, num_samples,
62
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
 
63
  ]
 
64
  run_button.click(fn=process,
65
- inputs=ips,
66
- outputs=[result_gallery],
67
  api_name='pose')
68
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Human Pose')
 
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
16
+ is_pose_image = gr.Checkbox(label='Is pose image',
17
+ value=False)
18
+ gr.Markdown(
19
+ 'You can use [PoseMaker2](https://huggingface.co/spaces/jonigata/PoseMaker2) to create pose images.'
20
+ )
21
  num_samples = gr.Slider(label='Images',
22
  minimum=1,
23
  maximum=max_images,
24
+ value=default_num_images,
25
  step=1)
26
  image_resolution = gr.Slider(label='Image Resolution',
27
  minimum=256,
 
33
  maximum=1024,
34
  value=512,
35
  step=1)
36
+ num_steps = gr.Slider(label='Steps',
37
+ minimum=1,
38
+ maximum=100,
39
+ value=20,
40
+ step=1)
41
+ guidance_scale = gr.Slider(label='Guidance Scale',
42
+ minimum=0.1,
43
+ maximum=30.0,
44
+ value=9.0,
45
+ step=0.1)
46
  seed = gr.Slider(label='Seed',
47
  minimum=-1,
48
  maximum=2147483647,
49
  step=1,
50
  randomize=True)
 
51
  a_prompt = gr.Textbox(
52
  label='Added Prompt',
53
  value='best quality, extremely detailed')
 
57
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
58
  )
59
  with gr.Column():
60
+ result = gr.Gallery(label='Output',
61
+ show_label=False,
62
+ elem_id='gallery').style(grid=2,
63
+ height='auto')
64
+ inputs = [
65
+ input_image,
66
+ prompt,
67
+ a_prompt,
68
+ n_prompt,
69
+ num_samples,
70
+ image_resolution,
71
+ detect_resolution,
72
+ num_steps,
73
+ guidance_scale,
74
+ seed,
75
+ is_pose_image,
76
  ]
77
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
78
  run_button.click(fn=process,
79
+ inputs=inputs,
80
+ outputs=result,
81
  api_name='pose')
82
  return demo
83
+
84
+
85
+ if __name__ == '__main__':
86
+ from model import Model
87
+ model = Model()
88
+ demo = create_demo(model.process_pose)
89
+ demo.queue().launch()
gradio_scribble2image.py β†’ app_scribble.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Scribble Maps')
@@ -16,29 +16,28 @@ def create_demo(process, max_images=12):
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
23
  maximum=768,
24
  value=512,
25
  step=256)
26
- ddim_steps = gr.Slider(label='Steps',
27
- minimum=1,
28
- maximum=100,
29
- value=20,
30
- step=1)
31
- scale = gr.Slider(label='Guidance Scale',
32
- minimum=0.1,
33
- maximum=30.0,
34
- value=9.0,
35
- step=0.1)
36
  seed = gr.Slider(label='Seed',
37
  minimum=-1,
38
  maximum=2147483647,
39
  step=1,
40
  randomize=True)
41
- eta = gr.Number(label='eta (DDIM)', value=0.0)
42
  a_prompt = gr.Textbox(
43
  label='Added Prompt',
44
  value='best quality, extremely detailed')
@@ -48,16 +47,31 @@ def create_demo(process, max_images=12):
48
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
49
  )
50
  with gr.Column():
51
- result_gallery = gr.Gallery(label='Output',
52
- show_label=False,
53
- elem_id='gallery').style(
54
- grid=2, height='auto')
55
- ips = [
56
- input_image, prompt, a_prompt, n_prompt, num_samples,
57
- image_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
58
  ]
 
59
  run_button.click(fn=process,
60
- inputs=ips,
61
- outputs=[result_gallery],
62
  api_name='scribble')
63
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Scribble Maps')
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
+ value=default_num_images,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
23
  maximum=768,
24
  value=512,
25
  step=256)
26
+ num_steps = gr.Slider(label='Steps',
27
+ minimum=1,
28
+ maximum=100,
29
+ value=20,
30
+ step=1)
31
+ guidance_scale = gr.Slider(label='Guidance Scale',
32
+ minimum=0.1,
33
+ maximum=30.0,
34
+ value=9.0,
35
+ step=0.1)
36
  seed = gr.Slider(label='Seed',
37
  minimum=-1,
38
  maximum=2147483647,
39
  step=1,
40
  randomize=True)
 
41
  a_prompt = gr.Textbox(
42
  label='Added Prompt',
43
  value='best quality, extremely detailed')
 
47
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
48
  )
49
  with gr.Column():
50
+ result = gr.Gallery(label='Output',
51
+ show_label=False,
52
+ elem_id='gallery').style(grid=2,
53
+ height='auto')
54
+ inputs = [
55
+ input_image,
56
+ prompt,
57
+ a_prompt,
58
+ n_prompt,
59
+ num_samples,
60
+ image_resolution,
61
+ num_steps,
62
+ guidance_scale,
63
+ seed,
64
  ]
65
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
66
  run_button.click(fn=process,
67
+ inputs=inputs,
68
+ outputs=result,
69
  api_name='scribble')
70
  return demo
71
+
72
+
73
+ if __name__ == '__main__':
74
+ from model import Model
75
+ model = Model()
76
+ demo = create_demo(model.process_scribble)
77
+ demo.queue().launch()
gradio_scribble2image_interactive.py β†’ app_scribble_interactive.py RENAMED
@@ -8,7 +8,7 @@ def create_canvas(w, h):
8
  return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
9
 
10
 
11
- def create_demo(process, max_images=12):
12
  with gr.Blocks() as demo:
13
  with gr.Row():
14
  gr.Markdown(
@@ -37,7 +37,7 @@ def create_demo(process, max_images=12):
37
  )
38
  create_button.click(fn=create_canvas,
39
  inputs=[canvas_width, canvas_height],
40
- outputs=[input_image],
41
  queue=False)
42
  prompt = gr.Textbox(label='Prompt')
43
  run_button = gr.Button(label='Run')
@@ -45,29 +45,28 @@ def create_demo(process, max_images=12):
45
  num_samples = gr.Slider(label='Images',
46
  minimum=1,
47
  maximum=max_images,
48
- value=1,
49
  step=1)
50
  image_resolution = gr.Slider(label='Image Resolution',
51
  minimum=256,
52
  maximum=768,
53
  value=512,
54
  step=256)
55
- ddim_steps = gr.Slider(label='Steps',
56
- minimum=1,
57
- maximum=100,
58
- value=20,
59
- step=1)
60
- scale = gr.Slider(label='Guidance Scale',
61
- minimum=0.1,
62
- maximum=30.0,
63
- value=9.0,
64
- step=0.1)
65
  seed = gr.Slider(label='Seed',
66
  minimum=-1,
67
  maximum=2147483647,
68
  step=1,
69
  randomize=True)
70
- eta = gr.Number(label='eta (DDIM)', value=0.0)
71
  a_prompt = gr.Textbox(
72
  label='Added Prompt',
73
  value='best quality, extremely detailed')
@@ -77,13 +76,28 @@ def create_demo(process, max_images=12):
77
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
78
  )
79
  with gr.Column():
80
- result_gallery = gr.Gallery(label='Output',
81
- show_label=False,
82
- elem_id='gallery').style(
83
- grid=2, height='auto')
84
- ips = [
85
- input_image, prompt, a_prompt, n_prompt, num_samples,
86
- image_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
87
  ]
88
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
 
89
  return demo
 
 
 
 
 
 
 
 
8
  return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
9
 
10
 
11
+ def create_demo(process, max_images=12, default_num_images=3):
12
  with gr.Blocks() as demo:
13
  with gr.Row():
14
  gr.Markdown(
 
37
  )
38
  create_button.click(fn=create_canvas,
39
  inputs=[canvas_width, canvas_height],
40
+ outputs=input_image,
41
  queue=False)
42
  prompt = gr.Textbox(label='Prompt')
43
  run_button = gr.Button(label='Run')
 
45
  num_samples = gr.Slider(label='Images',
46
  minimum=1,
47
  maximum=max_images,
48
+ value=default_num_images,
49
  step=1)
50
  image_resolution = gr.Slider(label='Image Resolution',
51
  minimum=256,
52
  maximum=768,
53
  value=512,
54
  step=256)
55
+ num_steps = gr.Slider(label='Steps',
56
+ minimum=1,
57
+ maximum=100,
58
+ value=20,
59
+ step=1)
60
+ guidance_scale = gr.Slider(label='Guidance Scale',
61
+ minimum=0.1,
62
+ maximum=30.0,
63
+ value=9.0,
64
+ step=0.1)
65
  seed = gr.Slider(label='Seed',
66
  minimum=-1,
67
  maximum=2147483647,
68
  step=1,
69
  randomize=True)
 
70
  a_prompt = gr.Textbox(
71
  label='Added Prompt',
72
  value='best quality, extremely detailed')
 
76
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
77
  )
78
  with gr.Column():
79
+ result = gr.Gallery(label='Output',
80
+ show_label=False,
81
+ elem_id='gallery').style(grid=2,
82
+ height='auto')
83
+ inputs = [
84
+ input_image,
85
+ prompt,
86
+ a_prompt,
87
+ n_prompt,
88
+ num_samples,
89
+ image_resolution,
90
+ num_steps,
91
+ guidance_scale,
92
+ seed,
93
  ]
94
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
95
+ run_button.click(fn=process, inputs=inputs, outputs=result)
96
  return demo
97
+
98
+
99
+ if __name__ == '__main__':
100
+ from model import Model
101
+ model = Model()
102
+ demo = create_demo(model.process_scribble_interactive)
103
+ demo.queue().launch()
gradio_seg2image.py β†’ app_seg.py RENAMED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
 
5
 
6
- def create_demo(process, max_images=12):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Segmentation Maps')
@@ -13,10 +13,12 @@ def create_demo(process, max_images=12):
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
 
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
19
- value=1,
20
  step=1)
21
  image_resolution = gr.Slider(label='Image Resolution',
22
  minimum=256,
@@ -29,22 +31,21 @@ def create_demo(process, max_images=12):
29
  maximum=1024,
30
  value=512,
31
  step=1)
32
- ddim_steps = gr.Slider(label='Steps',
33
- minimum=1,
34
- maximum=100,
35
- value=20,
36
- step=1)
37
- scale = gr.Slider(label='Guidance Scale',
38
- minimum=0.1,
39
- maximum=30.0,
40
- value=9.0,
41
- step=0.1)
42
  seed = gr.Slider(label='Seed',
43
  minimum=-1,
44
  maximum=2147483647,
45
  step=1,
46
  randomize=True)
47
- eta = gr.Number(label='eta (DDIM)', value=0.0)
48
  a_prompt = gr.Textbox(
49
  label='Added Prompt',
50
  value='best quality, extremely detailed')
@@ -54,16 +55,33 @@ def create_demo(process, max_images=12):
54
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
55
  )
56
  with gr.Column():
57
- result_gallery = gr.Gallery(label='Output',
58
- show_label=False,
59
- elem_id='gallery').style(
60
- grid=2, height='auto')
61
- ips = [
62
- input_image, prompt, a_prompt, n_prompt, num_samples,
63
- image_resolution, detect_resolution, ddim_steps, scale, seed, eta
 
 
 
 
 
 
 
 
 
64
  ]
 
65
  run_button.click(fn=process,
66
- inputs=ips,
67
- outputs=[result_gallery],
68
  api_name='seg')
69
  return demo
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
 
6
+ def create_demo(process, max_images=12, default_num_images=3):
7
  with gr.Blocks() as demo:
8
  with gr.Row():
9
  gr.Markdown('## Control Stable Diffusion with Segmentation Maps')
 
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
16
+ is_segmentation_map = gr.Checkbox(
17
+ label='Is segmentation map', value=False)
18
  num_samples = gr.Slider(label='Images',
19
  minimum=1,
20
  maximum=max_images,
21
+ value=default_num_images,
22
  step=1)
23
  image_resolution = gr.Slider(label='Image Resolution',
24
  minimum=256,
 
31
  maximum=1024,
32
  value=512,
33
  step=1)
34
+ num_steps = gr.Slider(label='Steps',
35
+ minimum=1,
36
+ maximum=100,
37
+ value=20,
38
+ step=1)
39
+ guidance_scale = gr.Slider(label='Guidance Scale',
40
+ minimum=0.1,
41
+ maximum=30.0,
42
+ value=9.0,
43
+ step=0.1)
44
  seed = gr.Slider(label='Seed',
45
  minimum=-1,
46
  maximum=2147483647,
47
  step=1,
48
  randomize=True)
 
49
  a_prompt = gr.Textbox(
50
  label='Added Prompt',
51
  value='best quality, extremely detailed')
 
55
  'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
56
  )
57
  with gr.Column():
58
+ result = gr.Gallery(label='Output',
59
+ show_label=False,
60
+ elem_id='gallery').style(grid=2,
61
+ height='auto')
62
+ inputs = [
63
+ input_image,
64
+ prompt,
65
+ a_prompt,
66
+ n_prompt,
67
+ num_samples,
68
+ image_resolution,
69
+ detect_resolution,
70
+ num_steps,
71
+ guidance_scale,
72
+ seed,
73
+ is_segmentation_map,
74
  ]
75
+ prompt.submit(fn=process, inputs=inputs, outputs=result)
76
  run_button.click(fn=process,
77
+ inputs=inputs,
78
+ outputs=result,
79
  api_name='seg')
80
  return demo
81
+
82
+
83
+ if __name__ == '__main__':
84
+ from model import Model
85
+ model = Model()
86
+ demo = create_demo(model.process_seg)
87
+ demo.queue().launch()
model.py CHANGED
@@ -3,21 +3,20 @@
3
  from __future__ import annotations
4
 
5
  import pathlib
6
- import random
7
- import shlex
8
- import subprocess
9
  import sys
10
 
11
  import cv2
12
- import einops
13
  import numpy as np
 
14
  import torch
15
- from huggingface_hub import hf_hub_url
16
- from pytorch_lightning import seed_everything
 
17
 
18
- sys.path.append('ControlNet')
 
 
19
 
20
- import config
21
  from annotator.canny import apply_canny
22
  from annotator.hed import apply_hed, nms
23
  from annotator.midas import apply_midas
@@ -25,733 +24,600 @@ from annotator.mlsd import apply_mlsd
25
  from annotator.openpose import apply_openpose
26
  from annotator.uniformer import apply_uniformer
27
  from annotator.util import HWC3, resize_image
28
- from cldm.model import create_model, load_state_dict
29
- from ldm.models.diffusion.ddim import DDIMSampler
30
  from share import *
31
 
32
- MODEL_NAMES = {
33
- 'canny': 'control_canny-fp16.safetensors',
34
- 'hough': 'control_mlsd-fp16.safetensors',
35
- 'hed': 'control_hed-fp16.safetensors',
36
- 'scribble': 'control_scribble-fp16.safetensors',
37
- 'pose': 'control_openpose-fp16.safetensors',
38
- 'seg': 'control_seg-fp16.safetensors',
39
- 'depth': 'control_depth-fp16.safetensors',
40
- 'normal': 'control_normal-fp16.safetensors',
41
  }
42
- MODEL_REPO = 'webui/ControlNet-modules-safetensors'
43
 
44
- DEFAULT_BASE_MODEL_REPO = 'andite/anything-v4.0'
45
- DEFAULT_BASE_MODEL_FILENAME = 'anything-v4.0-pruned.safetensors'
46
- DEFAULT_BASE_MODEL_URL = 'https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.safetensors'
 
47
 
48
 
49
  class Model:
50
  def __init__(self,
51
- model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
52
- model_dir: str = 'models'):
53
- self.device = torch.device(
54
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
55
- self.model = create_model(model_config_path).to(self.device)
56
- self.ddim_sampler = DDIMSampler(self.model)
57
  self.task_name = ''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- self.base_model_url = ''
60
- self.model_dir = pathlib.Path(model_dir)
61
- self.model_dir.mkdir(exist_ok=True, parents=True)
62
-
63
- self.download_models()
64
- self.set_base_model(DEFAULT_BASE_MODEL_REPO,
65
- DEFAULT_BASE_MODEL_FILENAME)
66
-
67
- def set_base_model(self, model_id: str, filename: str) -> str:
68
- if not model_id or not filename:
69
- return self.base_model_url
70
- base_model_url = hf_hub_url(model_id, filename)
71
- if base_model_url != self.base_model_url:
72
- self.load_base_model(base_model_url)
73
- self.base_model_url = base_model_url
74
- return self.base_model_url
75
-
76
- def download_base_model(self, model_url: str) -> pathlib.Path:
77
- self.model_dir.mkdir(exist_ok=True, parents=True)
78
- model_name = model_url.split('/')[-1]
79
- out_path = self.model_dir / model_name
80
- if not out_path.exists():
81
- subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
82
- return out_path
83
-
84
- def load_base_model(self, model_url: str) -> None:
85
- model_path = self.download_base_model(model_url)
86
- self.model.load_state_dict(load_state_dict(model_path,
87
- location=self.device.type),
88
- strict=False)
89
-
90
- def load_weight(self, task_name: str) -> None:
91
  if task_name == self.task_name:
92
  return
93
- weight_path = self.get_weight_path(task_name)
94
- self.model.control_model.load_state_dict(
95
- load_state_dict(weight_path, location=self.device.type))
 
 
 
96
  self.task_name = task_name
97
 
98
- def get_weight_path(self, task_name: str) -> str:
99
- if 'scribble' in task_name:
100
- task_name = 'scribble'
101
- return f'{self.model_dir}/{MODEL_NAMES[task_name]}'
102
-
103
- def download_models(self) -> None:
104
- self.model_dir.mkdir(exist_ok=True, parents=True)
105
- for name in MODEL_NAMES.values():
106
- out_path = self.model_dir / name
107
- if out_path.exists():
108
- continue
109
- model_url = hf_hub_url(MODEL_REPO, name)
110
- subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  @torch.inference_mode()
113
- def process_canny(self, input_image, prompt, a_prompt, n_prompt,
114
- num_samples, image_resolution, ddim_steps, scale, seed,
115
- eta, low_threshold, high_threshold):
116
- self.load_weight('canny')
117
-
118
- img = resize_image(HWC3(input_image), image_resolution)
119
- H, W, C = img.shape
120
-
121
- detected_map = apply_canny(img, low_threshold, high_threshold)
122
- detected_map = HWC3(detected_map)
123
-
124
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
125
- control = torch.stack([control for _ in range(num_samples)], dim=0)
126
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
127
-
128
- if seed == -1:
129
- seed = random.randint(0, 65535)
130
- seed_everything(seed)
131
-
132
- if config.save_memory:
133
- self.model.low_vram_shift(is_diffusing=False)
134
-
135
- cond = {
136
- 'c_concat': [control],
137
- 'c_crossattn': [
138
- self.model.get_learned_conditioning(
139
- [prompt + ', ' + a_prompt] * num_samples)
140
- ]
141
- }
142
- un_cond = {
143
- 'c_concat': [control],
144
- 'c_crossattn':
145
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
146
- }
147
- shape = (4, H // 8, W // 8)
148
-
149
- if config.save_memory:
150
- self.model.low_vram_shift(is_diffusing=True)
151
-
152
- samples, intermediates = self.ddim_sampler.sample(
153
- ddim_steps,
154
- num_samples,
155
- shape,
156
- cond,
157
- verbose=False,
158
- eta=eta,
159
- unconditional_guidance_scale=scale,
160
- unconditional_conditioning=un_cond)
161
-
162
- if config.save_memory:
163
- self.model.low_vram_shift(is_diffusing=False)
164
-
165
- x_samples = self.model.decode_first_stage(samples)
166
- x_samples = (
167
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
168
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
169
-
170
- results = [x_samples[i] for i in range(num_samples)]
171
- return [255 - detected_map] + results
172
 
173
- @torch.inference_mode()
174
- def process_hough(self, input_image, prompt, a_prompt, n_prompt,
175
- num_samples, image_resolution, detect_resolution,
176
- ddim_steps, scale, seed, eta, value_threshold,
177
- distance_threshold):
178
- self.load_weight('hough')
179
 
180
- input_image = HWC3(input_image)
181
- detected_map = apply_mlsd(resize_image(input_image, detect_resolution),
182
- value_threshold, distance_threshold)
183
- detected_map = HWC3(detected_map)
184
- img = resize_image(input_image, image_resolution)
185
- H, W, C = img.shape
186
-
187
- detected_map = cv2.resize(detected_map, (W, H),
188
- interpolation=cv2.INTER_NEAREST)
189
-
190
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
191
- control = torch.stack([control for _ in range(num_samples)], dim=0)
192
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
193
-
194
- if seed == -1:
195
- seed = random.randint(0, 65535)
196
- seed_everything(seed)
197
-
198
- if config.save_memory:
199
- self.model.low_vram_shift(is_diffusing=False)
200
-
201
- cond = {
202
- 'c_concat': [control],
203
- 'c_crossattn': [
204
- self.model.get_learned_conditioning(
205
- [prompt + ', ' + a_prompt] * num_samples)
206
- ]
207
- }
208
- un_cond = {
209
- 'c_concat': [control],
210
- 'c_crossattn':
211
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
212
- }
213
- shape = (4, H // 8, W // 8)
214
-
215
- if config.save_memory:
216
- self.model.low_vram_shift(is_diffusing=True)
217
-
218
- samples, intermediates = self.ddim_sampler.sample(
219
- ddim_steps,
220
- num_samples,
221
- shape,
222
- cond,
223
- verbose=False,
224
- eta=eta,
225
- unconditional_guidance_scale=scale,
226
- unconditional_conditioning=un_cond)
227
-
228
- if config.save_memory:
229
- self.model.low_vram_shift(is_diffusing=False)
230
-
231
- x_samples = self.model.decode_first_stage(samples)
232
- x_samples = (
233
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
234
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
235
-
236
- results = [x_samples[i] for i in range(num_samples)]
237
- return [
238
- 255 - cv2.dilate(detected_map,
239
- np.ones(shape=(3, 3), dtype=np.uint8),
240
- iterations=1)
241
- ] + results
242
 
243
  @torch.inference_mode()
244
- def process_hed(self, input_image, prompt, a_prompt, n_prompt, num_samples,
245
- image_resolution, detect_resolution, ddim_steps, scale,
246
- seed, eta):
247
- self.load_weight('hed')
248
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  input_image = HWC3(input_image)
250
- detected_map = apply_hed(resize_image(input_image, detect_resolution))
251
- detected_map = HWC3(detected_map)
252
- img = resize_image(input_image, image_resolution)
253
- H, W, C = img.shape
254
-
255
- detected_map = cv2.resize(detected_map, (W, H),
256
- interpolation=cv2.INTER_LINEAR)
257
-
258
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
259
- control = torch.stack([control for _ in range(num_samples)], dim=0)
260
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
261
-
262
- if seed == -1:
263
- seed = random.randint(0, 65535)
264
- seed_everything(seed)
265
-
266
- if config.save_memory:
267
- self.model.low_vram_shift(is_diffusing=False)
268
-
269
- cond = {
270
- 'c_concat': [control],
271
- 'c_crossattn': [
272
- self.model.get_learned_conditioning(
273
- [prompt + ', ' + a_prompt] * num_samples)
274
- ]
275
- }
276
- un_cond = {
277
- 'c_concat': [control],
278
- 'c_crossattn':
279
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
280
- }
281
- shape = (4, H // 8, W // 8)
282
-
283
- if config.save_memory:
284
- self.model.low_vram_shift(is_diffusing=True)
285
-
286
- samples, intermediates = self.ddim_sampler.sample(
287
- ddim_steps,
288
- num_samples,
289
- shape,
290
- cond,
291
- verbose=False,
292
- eta=eta,
293
- unconditional_guidance_scale=scale,
294
- unconditional_conditioning=un_cond)
295
-
296
- if config.save_memory:
297
- self.model.low_vram_shift(is_diffusing=False)
298
-
299
- x_samples = self.model.decode_first_stage(samples)
300
- x_samples = (
301
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
302
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
303
-
304
- results = [x_samples[i] for i in range(num_samples)]
305
- return [detected_map] + results
306
 
307
  @torch.inference_mode()
308
- def process_scribble(self, input_image, prompt, a_prompt, n_prompt,
309
- num_samples, image_resolution, ddim_steps, scale,
310
- seed, eta):
311
- self.load_weight('scribble')
312
-
313
- img = resize_image(HWC3(input_image), image_resolution)
314
- H, W, C = img.shape
315
-
316
- detected_map = np.zeros_like(img, dtype=np.uint8)
317
- detected_map[np.min(img, axis=2) < 127] = 255
318
-
319
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
320
- control = torch.stack([control for _ in range(num_samples)], dim=0)
321
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
322
-
323
- if seed == -1:
324
- seed = random.randint(0, 65535)
325
- seed_everything(seed)
326
-
327
- if config.save_memory:
328
- self.model.low_vram_shift(is_diffusing=False)
329
-
330
- cond = {
331
- 'c_concat': [control],
332
- 'c_crossattn': [
333
- self.model.get_learned_conditioning(
334
- [prompt + ', ' + a_prompt] * num_samples)
335
- ]
336
- }
337
- un_cond = {
338
- 'c_concat': [control],
339
- 'c_crossattn':
340
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
341
- }
342
- shape = (4, H // 8, W // 8)
343
-
344
- if config.save_memory:
345
- self.model.low_vram_shift(is_diffusing=True)
346
-
347
- samples, intermediates = self.ddim_sampler.sample(
348
- ddim_steps,
349
- num_samples,
350
- shape,
351
- cond,
352
- verbose=False,
353
- eta=eta,
354
- unconditional_guidance_scale=scale,
355
- unconditional_conditioning=un_cond)
356
-
357
- if config.save_memory:
358
- self.model.low_vram_shift(is_diffusing=False)
359
-
360
- x_samples = self.model.decode_first_stage(samples)
361
- x_samples = (
362
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
363
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
364
-
365
- results = [x_samples[i] for i in range(num_samples)]
366
- return [255 - detected_map] + results
367
 
368
  @torch.inference_mode()
369
- def process_scribble_interactive(self, input_image, prompt, a_prompt,
370
- n_prompt, num_samples, image_resolution,
371
- ddim_steps, scale, seed, eta):
372
- self.load_weight('scribble')
373
-
374
- img = resize_image(HWC3(input_image['mask'][:, :, 0]),
375
- image_resolution)
376
- H, W, C = img.shape
377
-
378
- detected_map = np.zeros_like(img, dtype=np.uint8)
379
- detected_map[np.min(img, axis=2) > 127] = 255
380
-
381
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
382
- control = torch.stack([control for _ in range(num_samples)], dim=0)
383
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
384
-
385
- if seed == -1:
386
- seed = random.randint(0, 65535)
387
- seed_everything(seed)
388
-
389
- if config.save_memory:
390
- self.model.low_vram_shift(is_diffusing=False)
391
-
392
- cond = {
393
- 'c_concat': [control],
394
- 'c_crossattn': [
395
- self.model.get_learned_conditioning(
396
- [prompt + ', ' + a_prompt] * num_samples)
397
- ]
398
- }
399
- un_cond = {
400
- 'c_concat': [control],
401
- 'c_crossattn':
402
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
403
- }
404
- shape = (4, H // 8, W // 8)
405
-
406
- if config.save_memory:
407
- self.model.low_vram_shift(is_diffusing=True)
408
-
409
- samples, intermediates = self.ddim_sampler.sample(
410
- ddim_steps,
411
- num_samples,
412
- shape,
413
- cond,
414
- verbose=False,
415
- eta=eta,
416
- unconditional_guidance_scale=scale,
417
- unconditional_conditioning=un_cond)
418
-
419
- if config.save_memory:
420
- self.model.low_vram_shift(is_diffusing=False)
421
-
422
- x_samples = self.model.decode_first_stage(samples)
423
- x_samples = (
424
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
425
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
426
-
427
- results = [x_samples[i] for i in range(num_samples)]
428
- return [255 - detected_map] + results
429
 
430
  @torch.inference_mode()
431
- def process_fake_scribble(self, input_image, prompt, a_prompt, n_prompt,
432
- num_samples, image_resolution, detect_resolution,
433
- ddim_steps, scale, seed, eta):
434
- self.load_weight('scribble')
435
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  input_image = HWC3(input_image)
437
- detected_map = apply_hed(resize_image(input_image, detect_resolution))
438
- detected_map = HWC3(detected_map)
439
- img = resize_image(input_image, image_resolution)
440
- H, W, C = img.shape
441
-
442
- detected_map = cv2.resize(detected_map, (W, H),
443
- interpolation=cv2.INTER_LINEAR)
444
- detected_map = nms(detected_map, 127, 3.0)
445
- detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
446
- detected_map[detected_map > 4] = 255
447
- detected_map[detected_map < 255] = 0
448
-
449
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
450
- control = torch.stack([control for _ in range(num_samples)], dim=0)
451
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
452
-
453
- if seed == -1:
454
- seed = random.randint(0, 65535)
455
- seed_everything(seed)
456
-
457
- if config.save_memory:
458
- self.model.low_vram_shift(is_diffusing=False)
459
-
460
- cond = {
461
- 'c_concat': [control],
462
- 'c_crossattn': [
463
- self.model.get_learned_conditioning(
464
- [prompt + ', ' + a_prompt] * num_samples)
465
- ]
466
- }
467
- un_cond = {
468
- 'c_concat': [control],
469
- 'c_crossattn':
470
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
471
- }
472
- shape = (4, H // 8, W // 8)
473
-
474
- if config.save_memory:
475
- self.model.low_vram_shift(is_diffusing=True)
476
-
477
- samples, intermediates = self.ddim_sampler.sample(
478
- ddim_steps,
479
- num_samples,
480
- shape,
481
- cond,
482
- verbose=False,
483
- eta=eta,
484
- unconditional_guidance_scale=scale,
485
- unconditional_conditioning=un_cond)
486
-
487
- if config.save_memory:
488
- self.model.low_vram_shift(is_diffusing=False)
489
-
490
- x_samples = self.model.decode_first_stage(samples)
491
- x_samples = (
492
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
493
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
494
-
495
- results = [x_samples[i] for i in range(num_samples)]
496
- return [255 - detected_map] + results
497
 
498
- @torch.inference_mode()
499
- def process_pose(self, input_image, prompt, a_prompt, n_prompt,
500
- num_samples, image_resolution, detect_resolution,
501
- ddim_steps, scale, seed, eta):
502
- self.load_weight('pose')
 
503
 
504
- input_image = HWC3(input_image)
505
- detected_map, _ = apply_openpose(
506
- resize_image(input_image, detect_resolution))
507
- detected_map = HWC3(detected_map)
508
- img = resize_image(input_image, image_resolution)
509
- H, W, C = img.shape
510
-
511
- detected_map = cv2.resize(detected_map, (W, H),
512
- interpolation=cv2.INTER_NEAREST)
513
-
514
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
515
- control = torch.stack([control for _ in range(num_samples)], dim=0)
516
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
517
-
518
- if seed == -1:
519
- seed = random.randint(0, 65535)
520
- seed_everything(seed)
521
-
522
- if config.save_memory:
523
- self.model.low_vram_shift(is_diffusing=False)
524
-
525
- cond = {
526
- 'c_concat': [control],
527
- 'c_crossattn': [
528
- self.model.get_learned_conditioning(
529
- [prompt + ', ' + a_prompt] * num_samples)
530
- ]
531
- }
532
- un_cond = {
533
- 'c_concat': [control],
534
- 'c_crossattn':
535
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
536
- }
537
- shape = (4, H // 8, W // 8)
538
-
539
- if config.save_memory:
540
- self.model.low_vram_shift(is_diffusing=True)
541
-
542
- samples, intermediates = self.ddim_sampler.sample(
543
- ddim_steps,
544
- num_samples,
545
- shape,
546
- cond,
547
- verbose=False,
548
- eta=eta,
549
- unconditional_guidance_scale=scale,
550
- unconditional_conditioning=un_cond)
551
-
552
- if config.save_memory:
553
- self.model.low_vram_shift(is_diffusing=False)
554
-
555
- x_samples = self.model.decode_first_stage(samples)
556
- x_samples = (
557
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
558
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
559
-
560
- results = [x_samples[i] for i in range(num_samples)]
561
- return [detected_map] + results
562
 
563
- @torch.inference_mode()
564
- def process_seg(self, input_image, prompt, a_prompt, n_prompt, num_samples,
565
- image_resolution, detect_resolution, ddim_steps, scale,
566
- seed, eta):
567
- self.load_weight('seg')
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  input_image = HWC3(input_image)
570
- detected_map = apply_uniformer(
571
- resize_image(input_image, detect_resolution))
572
- img = resize_image(input_image, image_resolution)
573
- H, W, C = img.shape
574
-
575
- detected_map = cv2.resize(detected_map, (W, H),
576
- interpolation=cv2.INTER_NEAREST)
577
-
578
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
579
- control = torch.stack([control for _ in range(num_samples)], dim=0)
580
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
581
-
582
- if seed == -1:
583
- seed = random.randint(0, 65535)
584
- seed_everything(seed)
585
-
586
- if config.save_memory:
587
- self.model.low_vram_shift(is_diffusing=False)
588
-
589
- cond = {
590
- 'c_concat': [control],
591
- 'c_crossattn': [
592
- self.model.get_learned_conditioning(
593
- [prompt + ', ' + a_prompt] * num_samples)
594
- ]
595
- }
596
- un_cond = {
597
- 'c_concat': [control],
598
- 'c_crossattn':
599
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
600
- }
601
- shape = (4, H // 8, W // 8)
602
-
603
- if config.save_memory:
604
- self.model.low_vram_shift(is_diffusing=True)
605
-
606
- samples, intermediates = self.ddim_sampler.sample(
607
- ddim_steps,
608
- num_samples,
609
- shape,
610
- cond,
611
- verbose=False,
612
- eta=eta,
613
- unconditional_guidance_scale=scale,
614
- unconditional_conditioning=un_cond)
615
-
616
- if config.save_memory:
617
- self.model.low_vram_shift(is_diffusing=False)
618
-
619
- x_samples = self.model.decode_first_stage(samples)
620
- x_samples = (
621
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
622
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
623
-
624
- results = [x_samples[i] for i in range(num_samples)]
625
- return [detected_map] + results
626
 
627
  @torch.inference_mode()
628
- def process_depth(self, input_image, prompt, a_prompt, n_prompt,
629
- num_samples, image_resolution, detect_resolution,
630
- ddim_steps, scale, seed, eta):
631
- self.load_weight('depth')
632
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  input_image = HWC3(input_image)
634
- detected_map, _ = apply_midas(
635
- resize_image(input_image, detect_resolution))
636
- detected_map = HWC3(detected_map)
637
- img = resize_image(input_image, image_resolution)
638
- H, W, C = img.shape
639
-
640
- detected_map = cv2.resize(detected_map, (W, H),
641
- interpolation=cv2.INTER_LINEAR)
642
-
643
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
644
- control = torch.stack([control for _ in range(num_samples)], dim=0)
645
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
646
-
647
- if seed == -1:
648
- seed = random.randint(0, 65535)
649
- seed_everything(seed)
650
-
651
- if config.save_memory:
652
- self.model.low_vram_shift(is_diffusing=False)
653
-
654
- cond = {
655
- 'c_concat': [control],
656
- 'c_crossattn': [
657
- self.model.get_learned_conditioning(
658
- [prompt + ', ' + a_prompt] * num_samples)
659
- ]
660
- }
661
- un_cond = {
662
- 'c_concat': [control],
663
- 'c_crossattn':
664
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
665
- }
666
- shape = (4, H // 8, W // 8)
667
-
668
- if config.save_memory:
669
- self.model.low_vram_shift(is_diffusing=True)
670
-
671
- samples, intermediates = self.ddim_sampler.sample(
672
- ddim_steps,
673
- num_samples,
674
- shape,
675
- cond,
676
- verbose=False,
677
- eta=eta,
678
- unconditional_guidance_scale=scale,
679
- unconditional_conditioning=un_cond)
680
-
681
- if config.save_memory:
682
- self.model.low_vram_shift(is_diffusing=False)
683
-
684
- x_samples = self.model.decode_first_stage(samples)
685
- x_samples = (
686
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
687
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
688
-
689
- results = [x_samples[i] for i in range(num_samples)]
690
- return [detected_map] + results
691
 
692
  @torch.inference_mode()
693
- def process_normal(self, input_image, prompt, a_prompt, n_prompt,
694
- num_samples, image_resolution, detect_resolution,
695
- ddim_steps, scale, seed, eta, bg_threshold):
696
- self.load_weight('normal')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
  input_image = HWC3(input_image)
699
- _, detected_map = apply_midas(resize_image(input_image,
700
- detect_resolution),
701
- bg_th=bg_threshold)
702
- detected_map = HWC3(detected_map)
703
- img = resize_image(input_image, image_resolution)
704
- H, W, C = img.shape
705
-
706
- detected_map = cv2.resize(detected_map, (W, H),
707
- interpolation=cv2.INTER_LINEAR)
708
-
709
- control = torch.from_numpy(
710
- detected_map[:, :, ::-1].copy()).float().cuda() / 255.0
711
- control = torch.stack([control for _ in range(num_samples)], dim=0)
712
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
713
-
714
- if seed == -1:
715
- seed = random.randint(0, 65535)
716
- seed_everything(seed)
717
-
718
- if config.save_memory:
719
- self.model.low_vram_shift(is_diffusing=False)
720
-
721
- cond = {
722
- 'c_concat': [control],
723
- 'c_crossattn': [
724
- self.model.get_learned_conditioning(
725
- [prompt + ', ' + a_prompt] * num_samples)
726
- ]
727
- }
728
- un_cond = {
729
- 'c_concat': [control],
730
- 'c_crossattn':
731
- [self.model.get_learned_conditioning([n_prompt] * num_samples)]
732
- }
733
- shape = (4, H // 8, W // 8)
734
-
735
- if config.save_memory:
736
- self.model.low_vram_shift(is_diffusing=True)
737
-
738
- samples, intermediates = self.ddim_sampler.sample(
739
- ddim_steps,
740
- num_samples,
741
- shape,
742
- cond,
743
- verbose=False,
744
- eta=eta,
745
- unconditional_guidance_scale=scale,
746
- unconditional_conditioning=un_cond)
747
-
748
- if config.save_memory:
749
- self.model.low_vram_shift(is_diffusing=False)
750
-
751
- x_samples = self.model.decode_first_stage(samples)
752
- x_samples = (
753
- einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
754
- 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
755
-
756
- results = [x_samples[i] for i in range(num_samples)]
757
- return [detected_map] + results
 
3
  from __future__ import annotations
4
 
5
  import pathlib
 
 
 
6
  import sys
7
 
8
  import cv2
 
9
  import numpy as np
10
+ import PIL.Image
11
  import torch
12
+ from diffusers import (ControlNetModel, DiffusionPipeline,
13
+ StableDiffusionControlNetPipeline,
14
+ UniPCMultistepScheduler)
15
 
16
+ repo_dir = pathlib.Path(__file__).parent
17
+ submodule_dir = repo_dir / 'ControlNet'
18
+ sys.path.append(submodule_dir.as_posix())
19
 
 
20
  from annotator.canny import apply_canny
21
  from annotator.hed import apply_hed, nms
22
  from annotator.midas import apply_midas
 
24
  from annotator.openpose import apply_openpose
25
  from annotator.uniformer import apply_uniformer
26
  from annotator.util import HWC3, resize_image
 
 
27
  from share import *
28
 
29
+ CONTROLNET_MODEL_IDS = {
30
+ 'canny': 'lllyasviel/sd-controlnet-canny',
31
+ 'hough': 'lllyasviel/sd-controlnet-mlsd',
32
+ 'hed': 'lllyasviel/sd-controlnet-hed',
33
+ 'scribble': 'lllyasviel/sd-controlnet-scribble',
34
+ 'pose': 'lllyasviel/sd-controlnet-openpose',
35
+ 'seg': 'lllyasviel/sd-controlnet-seg',
36
+ 'depth': 'lllyasviel/sd-controlnet-depth',
37
+ 'normal': 'lllyasviel/sd-controlnet-normal',
38
  }
 
39
 
40
+
41
+ def download_all_controlnet_weights() -> None:
42
+ for model_id in CONTROLNET_MODEL_IDS.values():
43
+ ControlNetModel.from_pretrained(model_id)
44
 
45
 
46
  class Model:
47
  def __init__(self,
48
+ base_model_id: str = 'runwayml/stable-diffusion-v1-5',
49
+ task_name: str = 'canny'):
50
+ self.base_model_id = ''
 
 
 
51
  self.task_name = ''
52
+ self.pipe = self.load_pipe(base_model_id, task_name)
53
+
54
+ def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
55
+ if base_model_id == self.base_model_id and task_name == self.task_name:
56
+ return self.pipe
57
+ model_id = CONTROLNET_MODEL_IDS[task_name]
58
+ controlnet = ControlNetModel.from_pretrained(model_id,
59
+ torch_dtype=torch.float16)
60
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
61
+ base_model_id,
62
+ safety_checker=None,
63
+ controlnet=controlnet,
64
+ torch_dtype=torch.float16)
65
+ pipe.scheduler = UniPCMultistepScheduler.from_config(
66
+ pipe.scheduler.config)
67
+ pipe.enable_xformers_memory_efficient_attention()
68
+ pipe.enable_model_cpu_offload()
69
+ self.base_model_id = base_model_id
70
+ self.task_name = task_name
71
+ return pipe
72
+
73
+ def set_base_model(self, base_model_id: str) -> str:
74
+ self.pipe = self.load_pipe(base_model_id, self.task_name)
75
+ return self.base_model_id
76
 
77
+ def load_controlnet_weight(self, task_name: str) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if task_name == self.task_name:
79
  return
80
+ model_id = CONTROLNET_MODEL_IDS[task_name]
81
+ controlnet = ControlNetModel.from_pretrained(model_id,
82
+ torch_dtype=torch.float16)
83
+ from accelerate import cpu_offload_with_hook
84
+ cpu_offload_with_hook(controlnet, torch.device('cuda:0'))
85
+ self.pipe.controlnet = controlnet
86
  self.task_name = task_name
87
 
88
+ def get_prompt(self, prompt: str, additional_prompt: str) -> str:
89
+ if not prompt:
90
+ prompt = additional_prompt
91
+ else:
92
+ prompt = f'{prompt}, {additional_prompt}'
93
+ return prompt
94
+
95
+ def run_pipe(
96
+ self,
97
+ prompt: str,
98
+ negative_prompt: str,
99
+ control_image: PIL.Image.Image,
100
+ num_images: int,
101
+ num_steps: int,
102
+ guidance_scale: float,
103
+ seed: int,
104
+ ) -> list[PIL.Image.Image]:
105
+ generator = torch.Generator().manual_seed(seed)
106
+ return self.pipe(prompt=prompt,
107
+ negative_prompt=negative_prompt,
108
+ guidance_scale=guidance_scale,
109
+ num_images_per_prompt=num_images,
110
+ num_inference_steps=num_steps,
111
+ generator=generator,
112
+ image=control_image).images
113
+
114
+ @staticmethod
115
+ def preprocess_canny(
116
+ input_image: np.ndarray,
117
+ image_resolution: int,
118
+ low_threshold: int,
119
+ high_threshold: int,
120
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
121
+ image = resize_image(HWC3(input_image), image_resolution)
122
+ control_image = apply_canny(image, low_threshold, high_threshold)
123
+ control_image = HWC3(control_image)
124
+ vis_control_image = 255 - control_image
125
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
126
+ vis_control_image)
127
 
128
  @torch.inference_mode()
129
+ def process_canny(
130
+ self,
131
+ input_image: np.ndarray,
132
+ prompt: str,
133
+ additional_prompt: str,
134
+ negative_prompt: str,
135
+ num_images: int,
136
+ image_resolution: int,
137
+ num_steps: int,
138
+ guidance_scale: float,
139
+ seed: int,
140
+ low_threshold: int,
141
+ high_threshold: int,
142
+ ) -> list[PIL.Image.Image]:
143
+ control_image, vis_control_image = self.preprocess_canny(
144
+ input_image=input_image,
145
+ image_resolution=image_resolution,
146
+ low_threshold=low_threshold,
147
+ high_threshold=high_threshold,
148
+ )
149
+ self.load_controlnet_weight('canny')
150
+ results = self.run_pipe(
151
+ prompt=self.get_prompt(prompt, additional_prompt),
152
+ negative_prompt=negative_prompt,
153
+ control_image=control_image,
154
+ num_images=num_images,
155
+ num_steps=num_steps,
156
+ guidance_scale=guidance_scale,
157
+ seed=seed,
158
+ )
159
+ return [vis_control_image] + results
160
+
161
+ @staticmethod
162
+ def preprocess_hough(
163
+ input_image: np.ndarray,
164
+ image_resolution: int,
165
+ detect_resolution: int,
166
+ value_threshold: float,
167
+ distance_threshold: float,
168
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
169
+ input_image = HWC3(input_image)
170
+ control_image = apply_mlsd(
171
+ resize_image(input_image, detect_resolution), value_threshold,
172
+ distance_threshold)
173
+ control_image = HWC3(control_image)
174
+ image = resize_image(input_image, image_resolution)
175
+ H, W = image.shape[:2]
176
+ control_image = cv2.resize(control_image, (W, H),
177
+ interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
 
 
 
 
178
 
179
+ vis_control_image = 255 - cv2.dilate(
180
+ control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
 
 
 
 
181
 
182
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
183
+ vis_control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  @torch.inference_mode()
186
+ def process_hough(
187
+ self,
188
+ input_image: np.ndarray,
189
+ prompt: str,
190
+ additional_prompt: str,
191
+ negative_prompt: str,
192
+ num_images: int,
193
+ image_resolution: int,
194
+ detect_resolution: int,
195
+ num_steps: int,
196
+ guidance_scale: float,
197
+ seed: int,
198
+ value_threshold: float,
199
+ distance_threshold: float,
200
+ ) -> list[PIL.Image.Image]:
201
+ control_image, vis_control_image = self.preprocess_hough(
202
+ input_image=input_image,
203
+ image_resolution=image_resolution,
204
+ detect_resolution=detect_resolution,
205
+ value_threshold=value_threshold,
206
+ distance_threshold=distance_threshold,
207
+ )
208
+ self.load_controlnet_weight('hough')
209
+ results = self.run_pipe(
210
+ prompt=self.get_prompt(prompt, additional_prompt),
211
+ negative_prompt=negative_prompt,
212
+ control_image=control_image,
213
+ num_images=num_images,
214
+ num_steps=num_steps,
215
+ guidance_scale=guidance_scale,
216
+ seed=seed,
217
+ )
218
+ return [vis_control_image] + results
219
+
220
+ @staticmethod
221
+ def preprocess_hed(
222
+ input_image: np.ndarray,
223
+ image_resolution: int,
224
+ detect_resolution: int,
225
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
226
  input_image = HWC3(input_image)
227
+ control_image = apply_hed(resize_image(input_image, detect_resolution))
228
+ control_image = HWC3(control_image)
229
+ image = resize_image(input_image, image_resolution)
230
+ H, W = image.shape[:2]
231
+ control_image = cv2.resize(control_image, (W, H),
232
+ interpolation=cv2.INTER_LINEAR)
233
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
234
+ control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  @torch.inference_mode()
237
+ def process_hed(
238
+ self,
239
+ input_image: np.ndarray,
240
+ prompt: str,
241
+ additional_prompt: str,
242
+ negative_prompt: str,
243
+ num_images: int,
244
+ image_resolution: int,
245
+ detect_resolution: int,
246
+ num_steps: int,
247
+ guidance_scale: float,
248
+ seed: int,
249
+ ) -> list[PIL.Image.Image]:
250
+ control_image, vis_control_image = self.preprocess_hed(
251
+ input_image=input_image,
252
+ image_resolution=image_resolution,
253
+ detect_resolution=detect_resolution,
254
+ )
255
+ self.load_controlnet_weight('hed')
256
+ results = self.run_pipe(
257
+ prompt=self.get_prompt(prompt, additional_prompt),
258
+ negative_prompt=negative_prompt,
259
+ control_image=control_image,
260
+ num_images=num_images,
261
+ num_steps=num_steps,
262
+ guidance_scale=guidance_scale,
263
+ seed=seed,
264
+ )
265
+ return [vis_control_image] + results
266
+
267
+ @staticmethod
268
+ def preprocess_scribble(
269
+ input_image: np.ndarray,
270
+ image_resolution: int,
271
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
272
+ image = resize_image(HWC3(input_image), image_resolution)
273
+ control_image = np.zeros_like(image, dtype=np.uint8)
274
+ control_image[np.min(image, axis=2) < 127] = 255
275
+ vis_control_image = 255 - control_image
276
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
277
+ vis_control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  @torch.inference_mode()
280
+ def process_scribble(
281
+ self,
282
+ input_image: np.ndarray,
283
+ prompt: str,
284
+ additional_prompt: str,
285
+ negative_prompt: str,
286
+ num_images: int,
287
+ image_resolution: int,
288
+ num_steps: int,
289
+ guidance_scale: float,
290
+ seed: int,
291
+ ) -> list[PIL.Image.Image]:
292
+ control_image, vis_control_image = self.preprocess_scribble(
293
+ input_image=input_image,
294
+ image_resolution=image_resolution,
295
+ )
296
+ self.load_controlnet_weight('scribble')
297
+ results = self.run_pipe(
298
+ prompt=self.get_prompt(prompt, additional_prompt),
299
+ negative_prompt=negative_prompt,
300
+ control_image=control_image,
301
+ num_images=num_images,
302
+ num_steps=num_steps,
303
+ guidance_scale=guidance_scale,
304
+ seed=seed,
305
+ )
306
+ return [vis_control_image] + results
307
+
308
+ @staticmethod
309
+ def preprocess_scribble_interactive(
310
+ input_image: np.ndarray,
311
+ image_resolution: int,
312
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
313
+ image = resize_image(HWC3(input_image['mask'][:, :, 0]),
314
+ image_resolution)
315
+ control_image = np.zeros_like(image, dtype=np.uint8)
316
+ control_image[np.min(image, axis=2) > 127] = 255
317
+ vis_control_image = 255 - control_image
318
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
319
+ vis_control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  @torch.inference_mode()
322
+ def process_scribble_interactive(
323
+ self,
324
+ input_image: np.ndarray,
325
+ prompt: str,
326
+ additional_prompt: str,
327
+ negative_prompt: str,
328
+ num_images: int,
329
+ image_resolution: int,
330
+ num_steps: int,
331
+ guidance_scale: float,
332
+ seed: int,
333
+ ) -> list[PIL.Image.Image]:
334
+ control_image, vis_control_image = self.preprocess_scribble_interactive(
335
+ input_image=input_image,
336
+ image_resolution=image_resolution,
337
+ )
338
+ self.load_controlnet_weight('scribble')
339
+ results = self.run_pipe(
340
+ prompt=self.get_prompt(prompt, additional_prompt),
341
+ negative_prompt=negative_prompt,
342
+ control_image=control_image,
343
+ num_images=num_images,
344
+ num_steps=num_steps,
345
+ guidance_scale=guidance_scale,
346
+ seed=seed,
347
+ )
348
+ return [vis_control_image] + results
349
+
350
+ @staticmethod
351
+ def preprocess_fake_scribble(
352
+ input_image: np.ndarray,
353
+ image_resolution: int,
354
+ detect_resolution: int,
355
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
356
  input_image = HWC3(input_image)
357
+ control_image = apply_hed(resize_image(input_image, detect_resolution))
358
+ control_image = HWC3(control_image)
359
+ image = resize_image(input_image, image_resolution)
360
+ H, W = image.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
+ control_image = cv2.resize(control_image, (W, H),
363
+ interpolation=cv2.INTER_LINEAR)
364
+ control_image = nms(control_image, 127, 3.0)
365
+ control_image = cv2.GaussianBlur(control_image, (0, 0), 3.0)
366
+ control_image[control_image > 4] = 255
367
+ control_image[control_image < 255] = 0
368
 
369
+ vis_control_image = 255 - control_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
372
+ vis_control_image)
 
 
 
373
 
374
+ @torch.inference_mode()
375
+ def process_fake_scribble(
376
+ self,
377
+ input_image: np.ndarray,
378
+ prompt: str,
379
+ additional_prompt: str,
380
+ negative_prompt: str,
381
+ num_images: int,
382
+ image_resolution: int,
383
+ detect_resolution: int,
384
+ num_steps: int,
385
+ guidance_scale: float,
386
+ seed: int,
387
+ ) -> list[PIL.Image.Image]:
388
+ control_image, vis_control_image = self.preprocess_fake_scribble(
389
+ input_image=input_image,
390
+ image_resolution=image_resolution,
391
+ detect_resolution=detect_resolution,
392
+ )
393
+ self.load_controlnet_weight('scribble')
394
+ results = self.run_pipe(
395
+ prompt=self.get_prompt(prompt, additional_prompt),
396
+ negative_prompt=negative_prompt,
397
+ control_image=control_image,
398
+ num_images=num_images,
399
+ num_steps=num_steps,
400
+ guidance_scale=guidance_scale,
401
+ seed=seed,
402
+ )
403
+ return [vis_control_image] + results
404
+
405
+ @staticmethod
406
+ def preprocess_pose(
407
+ input_image: np.ndarray,
408
+ image_resolution: int,
409
+ detect_resolution: int,
410
+ is_pose_image: bool,
411
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
412
  input_image = HWC3(input_image)
413
+ if not is_pose_image:
414
+ control_image, _ = apply_openpose(
415
+ resize_image(input_image, detect_resolution))
416
+ control_image = HWC3(control_image)
417
+ image = resize_image(input_image, image_resolution)
418
+ H, W = image.shape[:2]
419
+ control_image = cv2.resize(control_image, (W, H),
420
+ interpolation=cv2.INTER_NEAREST)
421
+ else:
422
+ control_image = input_image
423
+
424
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
425
+ control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
  @torch.inference_mode()
428
+ def process_pose(
429
+ self,
430
+ input_image: np.ndarray,
431
+ prompt: str,
432
+ additional_prompt: str,
433
+ negative_prompt: str,
434
+ num_images: int,
435
+ image_resolution: int,
436
+ detect_resolution: int,
437
+ num_steps: int,
438
+ guidance_scale: float,
439
+ seed: int,
440
+ is_pose_image: bool,
441
+ ) -> list[PIL.Image.Image]:
442
+ control_image, vis_control_image = self.preprocess_pose(
443
+ input_image=input_image,
444
+ image_resolution=image_resolution,
445
+ detect_resolution=detect_resolution,
446
+ is_pose_image=is_pose_image,
447
+ )
448
+ self.load_controlnet_weight('pose')
449
+ results = self.run_pipe(
450
+ prompt=self.get_prompt(prompt, additional_prompt),
451
+ negative_prompt=negative_prompt,
452
+ control_image=control_image,
453
+ num_images=num_images,
454
+ num_steps=num_steps,
455
+ guidance_scale=guidance_scale,
456
+ seed=seed,
457
+ )
458
+ return [vis_control_image] + results
459
+
460
+ @staticmethod
461
+ def preprocess_seg(
462
+ input_image: np.ndarray,
463
+ image_resolution: int,
464
+ detect_resolution: int,
465
+ is_segmentation_map: bool,
466
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
467
  input_image = HWC3(input_image)
468
+ if not is_segmentation_map:
469
+ control_image = apply_uniformer(
470
+ resize_image(input_image, detect_resolution))
471
+ image = resize_image(input_image, image_resolution)
472
+ H, W = image.shape[:2]
473
+ control_image = cv2.resize(control_image, (W, H),
474
+ interpolation=cv2.INTER_NEAREST)
475
+ else:
476
+ control_image = input_image
477
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
478
+ control_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
  @torch.inference_mode()
481
+ def process_seg(
482
+ self,
483
+ input_image: np.ndarray,
484
+ prompt: str,
485
+ additional_prompt: str,
486
+ negative_prompt: str,
487
+ num_images: int,
488
+ image_resolution: int,
489
+ detect_resolution: int,
490
+ num_steps: int,
491
+ guidance_scale: float,
492
+ seed: int,
493
+ is_segmentation_map: bool,
494
+ ) -> list[PIL.Image.Image]:
495
+ control_image, vis_control_image = self.preprocess_seg(
496
+ input_image=input_image,
497
+ image_resolution=image_resolution,
498
+ detect_resolution=detect_resolution,
499
+ is_segmentation_map=is_segmentation_map,
500
+ )
501
+ self.load_controlnet_weight('seg')
502
+ results = self.run_pipe(
503
+ prompt=self.get_prompt(prompt, additional_prompt),
504
+ negative_prompt=negative_prompt,
505
+ control_image=control_image,
506
+ num_images=num_images,
507
+ num_steps=num_steps,
508
+ guidance_scale=guidance_scale,
509
+ seed=seed,
510
+ )
511
+ return [vis_control_image] + results
512
+
513
+ @staticmethod
514
+ def preprocess_depth(
515
+ input_image: np.ndarray,
516
+ image_resolution: int,
517
+ detect_resolution: int,
518
+ is_depth_image: bool,
519
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
520
+ input_image = HWC3(input_image)
521
+ if not is_depth_image:
522
+ control_image, _ = apply_midas(
523
+ resize_image(input_image, detect_resolution))
524
+ control_image = HWC3(control_image)
525
+ image = resize_image(input_image, image_resolution)
526
+ H, W = image.shape[:2]
527
+ control_image = cv2.resize(control_image, (W, H),
528
+ interpolation=cv2.INTER_LINEAR)
529
+ else:
530
+ control_image = input_image
531
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
532
+ control_image)
533
 
534
+ @torch.inference_mode()
535
+ def process_depth(
536
+ self,
537
+ input_image: np.ndarray,
538
+ prompt: str,
539
+ additional_prompt: str,
540
+ negative_prompt: str,
541
+ num_images: int,
542
+ image_resolution: int,
543
+ detect_resolution: int,
544
+ num_steps: int,
545
+ guidance_scale: float,
546
+ seed: int,
547
+ is_depth_image: bool,
548
+ ) -> list[PIL.Image.Image]:
549
+ control_image, vis_control_image = self.preprocess_depth(
550
+ input_image=input_image,
551
+ image_resolution=image_resolution,
552
+ detect_resolution=detect_resolution,
553
+ is_depth_image=is_depth_image,
554
+ )
555
+ self.load_controlnet_weight('depth')
556
+ results = self.run_pipe(
557
+ prompt=self.get_prompt(prompt, additional_prompt),
558
+ negative_prompt=negative_prompt,
559
+ control_image=control_image,
560
+ num_images=num_images,
561
+ num_steps=num_steps,
562
+ guidance_scale=guidance_scale,
563
+ seed=seed,
564
+ )
565
+ return [vis_control_image] + results
566
+
567
+ @staticmethod
568
+ def preprocess_normal(
569
+ input_image: np.ndarray,
570
+ image_resolution: int,
571
+ detect_resolution: int,
572
+ bg_threshold: float,
573
+ is_normal_image: bool,
574
+ ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
575
  input_image = HWC3(input_image)
576
+ if not is_normal_image:
577
+ _, control_image = apply_midas(resize_image(
578
+ input_image, detect_resolution),
579
+ bg_th=bg_threshold)
580
+ control_image = HWC3(control_image)
581
+ image = resize_image(input_image, image_resolution)
582
+ H, W = image.shape[:2]
583
+ control_image = cv2.resize(control_image, (W, H),
584
+ interpolation=cv2.INTER_LINEAR)
585
+ else:
586
+ control_image = input_image
587
+ return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
588
+ control_image)
589
+
590
+ @torch.inference_mode()
591
+ def process_normal(
592
+ self,
593
+ input_image: np.ndarray,
594
+ prompt: str,
595
+ additional_prompt: str,
596
+ negative_prompt: str,
597
+ num_images: int,
598
+ image_resolution: int,
599
+ detect_resolution: int,
600
+ num_steps: int,
601
+ guidance_scale: float,
602
+ seed: int,
603
+ bg_threshold: float,
604
+ is_normal_image: bool,
605
+ ) -> list[PIL.Image.Image]:
606
+ control_image, vis_control_image = self.preprocess_normal(
607
+ input_image=input_image,
608
+ image_resolution=image_resolution,
609
+ detect_resolution=detect_resolution,
610
+ bg_threshold=bg_threshold,
611
+ is_normal_image=is_normal_image,
612
+ )
613
+ self.load_controlnet_weight('normal')
614
+ results = self.run_pipe(
615
+ prompt=self.get_prompt(prompt, additional_prompt),
616
+ negative_prompt=negative_prompt,
617
+ control_image=control_image,
618
+ num_images=num_images,
619
+ num_steps=num_steps,
620
+ guidance_scale=guidance_scale,
621
+ seed=seed,
622
+ )
623
+ return [vis_control_image] + results
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
  addict==2.4.0
2
  albumentations==1.3.0
3
  einops==0.6.0
4
- gradio==3.18.0
5
- huggingface-hub==0.12.0
 
6
  imageio==2.25.0
7
  imageio-ffmpeg==0.4.8
8
  kornia==0.6.9
 
1
  addict==2.4.0
2
  albumentations==1.3.0
3
  einops==0.6.0
4
+ git+https://github.com/huggingface/accelerate@78151f8
5
+ git+https://github.com/huggingface/diffusers@fa6d52d
6
+ gradio==3.20.0
7
  imageio==2.25.0
8
  imageio-ffmpeg==0.4.8
9
  kornia==0.6.9