danhtran2mind commited on
Commit
f56ede2
·
verified ·
1 Parent(s): 041f0c4

Upload 68 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +13 -0
  2. LICENSE +21 -0
  3. apps/gradio_app.py +187 -0
  4. apps/gradio_app/__init__.py +0 -0
  5. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png +3 -0
  6. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json +12 -0
  7. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg +0 -0
  8. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png +3 -0
  9. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json +12 -0
  10. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg +3 -0
  11. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png +3 -0
  12. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json +12 -0
  13. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg +3 -0
  14. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png +3 -0
  15. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json +12 -0
  16. apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg +0 -0
  17. apps/gradio_app/examples.py +99 -0
  18. apps/gradio_app/inference.py +45 -0
  19. apps/gradio_app/project_info.py +37 -0
  20. apps/gradio_app/setup_scripts.py +59 -0
  21. apps/gradio_app/static/style.css +574 -0
  22. apps/old-gradio_app.py +177 -0
  23. assets/.gitkeep +0 -0
  24. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png +3 -0
  25. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json +12 -0
  26. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg +0 -0
  27. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png +3 -0
  28. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json +12 -0
  29. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg +3 -0
  30. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png +3 -0
  31. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json +12 -0
  32. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg +3 -0
  33. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png +3 -0
  34. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json +12 -0
  35. assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg +0 -0
  36. ckpts/.gitignore +2 -0
  37. configs/.gitkeep +0 -0
  38. configs/datasets_info.yaml +3 -0
  39. configs/model_ckpts.yaml +16 -0
  40. data/.gitignore +2 -0
  41. docs/inference/inference_doc.md +176 -0
  42. docs/scripts/download_ckpts_doc.md +29 -0
  43. docs/scripts/download_datasets_doc.md +20 -0
  44. docs/training/training_doc.md +106 -0
  45. notebooks/SD-2.1-Openpose-ControlNet.ipynb +0 -0
  46. requirements/requirements.txt +7 -0
  47. requirements/requirements_compatible.txt +7 -0
  48. scripts/download_ckpts.py +58 -0
  49. scripts/download_datasets.py +48 -0
  50. scripts/setup_third_party.py +38 -0
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png filter=lfs diff=lfs merge=lfs -text
37
+ apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png filter=lfs diff=lfs merge=lfs -text
38
+ apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg filter=lfs diff=lfs merge=lfs -text
39
+ apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png filter=lfs diff=lfs merge=lfs -text
40
+ apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg filter=lfs diff=lfs merge=lfs -text
41
+ apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg filter=lfs diff=lfs merge=lfs -text
45
+ assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg filter=lfs diff=lfs merge=lfs -text
47
+ assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png filter=lfs diff=lfs merge=lfs -text
48
+ tests/test_data/a_man_is_doing_yoga_in_a_serene_park_0.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Danh Tran
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
apps/gradio_app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import gradio as gr
4
+ import random
5
+ from gradio_app.inference import run_inference
6
+ from gradio_app.examples import load_examples, select_example
7
+ from gradio_app.project_info import (
8
+ NAME,
9
+ CONTENT_DESCRIPTION,
10
+ CONTENT_IN_1,
11
+ CONTENT_OUT_1
12
+ )
13
+
14
+ def run_setup_script():
15
+ setup_script = os.path.join(os.path.dirname(__file__), "gradio_app", "setup_scripts.py")
16
+ try:
17
+ result = subprocess.run(["python", setup_script], capture_output=True, text=True, check=True)
18
+ return result.stdout
19
+ except subprocess.CalledProcessError as e:
20
+ print(f"Setup script failed with error: {e.stderr}")
21
+ return f"Setup script failed: {e.stderr}"
22
+
23
+ def stop_app():
24
+ """Function to stop the Gradio app."""
25
+ try:
26
+ gr.Interface.close_all() # Attempt to close all running Gradio interfaces
27
+ return "Application stopped successfully."
28
+ except Exception as e:
29
+ return f"Error stopping application: {str(e)}"
30
+
31
+ def create_gui():
32
+ try:
33
+ custom_css = open("apps/gradio_app/static/style.css").read()
34
+ except FileNotFoundError:
35
+ print("Error: style.css not found at gradio_app/static/style.css")
36
+ custom_css = "" # Fallback to empty CSS if file is missing
37
+
38
+ with gr.Blocks(css=custom_css) as demo:
39
+ gr.Markdown(NAME)
40
+ gr.HTML(CONTENT_DESCRIPTION)
41
+ gr.HTML(CONTENT_IN_1)
42
+
43
+ with gr.Row():
44
+ with gr.Column(scale=2):
45
+ input_image = gr.Image(type="filepath", label="Input Image")
46
+ prompt = gr.Textbox(
47
+ label="Prompt",
48
+ value="a man is doing yoga"
49
+ )
50
+ negative_prompt = gr.Textbox(
51
+ label="Negative Prompt",
52
+ value="monochrome, lowres, bad anatomy, worst quality, low quality"
53
+ )
54
+
55
+ with gr.Row():
56
+ width = gr.Slider(
57
+ minimum=256,
58
+ maximum=1024,
59
+ value=512,
60
+ step=64,
61
+ label="Width"
62
+ )
63
+ height = gr.Slider(
64
+ minimum=256,
65
+ maximum=1024,
66
+ value=512,
67
+ step=64,
68
+ label="Height"
69
+ )
70
+
71
+ with gr.Accordion("Advanced Settings", open=False):
72
+ num_steps = gr.Slider(
73
+ minimum=1,
74
+ maximum=100,
75
+ value=30,
76
+ step=1,
77
+ label="Number of Inference Steps"
78
+ )
79
+ use_random_seed = gr.Checkbox(label="Use Random Seed", value=False)
80
+ seed = gr.Slider(
81
+ minimum=0,
82
+ maximum=2**32 - 1,
83
+ value=42,
84
+ step=1,
85
+ label="Random Seed",
86
+ visible=True
87
+ )
88
+
89
+ guidance_scale = gr.Slider(
90
+ minimum=1.0,
91
+ maximum=20.0,
92
+ value=7.5,
93
+ step=0.1,
94
+ label="Guidance Scale"
95
+ )
96
+ controlnet_conditioning_scale = gr.Slider(
97
+ minimum=0.0,
98
+ maximum=1.0,
99
+ value=1.0,
100
+ step=0.1,
101
+ label="ControlNet Conditioning Scale"
102
+ )
103
+
104
+ with gr.Column(scale=3):
105
+ output_images = gr.Image(label="Generated Images")
106
+ output_message = gr.Textbox(label="Status")
107
+
108
+ submit_button = gr.Button("Generate Images", elem_classes="submit-btn")
109
+ stop_button = gr.Button("Stop Application", elem_classes="stop-btn")
110
+
111
+ def update_seed_visibility(use_random):
112
+ return gr.update(visible=not use_random)
113
+
114
+ use_random_seed.change(
115
+ fn=update_seed_visibility,
116
+ inputs=use_random_seed,
117
+ outputs=seed
118
+ )
119
+
120
+ # Load examples
121
+ examples_data = load_examples(os.path.join("apps", "gradio_app",
122
+ "assets", "examples", "Stable-Diffusion-2.1-Openpose-ControlNet"))
123
+ examples_component = gr.Examples(
124
+ examples=examples_data,
125
+ inputs=[
126
+ input_image,
127
+ prompt,
128
+ negative_prompt,
129
+ output_images,
130
+ num_steps,
131
+ seed,
132
+ width,
133
+ height,
134
+ guidance_scale,
135
+ controlnet_conditioning_scale,
136
+ use_random_seed
137
+ ],
138
+ outputs=[
139
+ input_image,
140
+ prompt,
141
+ negative_prompt,
142
+ output_images,
143
+ num_steps,
144
+ seed,
145
+ width,
146
+ height,
147
+ guidance_scale,
148
+ controlnet_conditioning_scale,
149
+ use_random_seed,
150
+ output_message
151
+ ],
152
+ fn=select_example,
153
+ cache_examples=False,
154
+ label="Examples: Yoga Poses"
155
+ )
156
+
157
+ submit_button.click(
158
+ fn=run_inference,
159
+ inputs=[
160
+ input_image,
161
+ prompt,
162
+ negative_prompt,
163
+ num_steps,
164
+ seed,
165
+ width,
166
+ height,
167
+ guidance_scale,
168
+ controlnet_conditioning_scale,
169
+ use_random_seed,
170
+ ],
171
+ outputs=[output_images, output_message]
172
+ )
173
+
174
+ stop_button.click(
175
+ fn=stop_app,
176
+ inputs=[],
177
+ outputs=[output_message]
178
+ )
179
+
180
+ gr.HTML(CONTENT_OUT_1)
181
+
182
+ return demo
183
+
184
+ if __name__ == "__main__":
185
+ run_setup_script()
186
+ demo = create_gui()
187
+ demo.launch(share=True)
apps/gradio_app/__init__.py ADDED
File without changes
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png ADDED

Git LFS Details

  • SHA256: 3dc2b7efb61afd2d6ceda1b32ec9792a5b07f3ac3d7a96d7acdd2102ddb957b7
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_image": "yoga.jpg",
3
+ "output_image": "a_man_is_doing_yoga_in_a_serene_park_0.png",
4
+ "prompt": "A man is doing yoga in a serene park.",
5
+ "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face",
6
+ "num_steps": 50,
7
+ "seed": 100,
8
+ "width": 512,
9
+ "height": 512,
10
+ "guidance_scale": 5.5,
11
+ "controlnet_conditioning_scale": 0.6
12
+ }
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg ADDED
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png ADDED

Git LFS Details

  • SHA256: 2e83cc3b007c2303e276b3ac60a8fa930877e584e3534f12e1441ec83ed9e9fd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_image": "ride_bike.jpg",
3
+ "output_image": "a_man_is_galloping_on_a_horse_0.png",
4
+ "prompt": "A man is galloping on a horse.",
5
+ "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face",
6
+ "num_steps": 100,
7
+ "seed": 56,
8
+ "width": 1080,
9
+ "height": 720,
10
+ "guidance_scale": 9.5,
11
+ "controlnet_conditioning_scale": 0.5
12
+ }
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg ADDED

Git LFS Details

  • SHA256: 76310cad16fcf71097c9660d46a95ced0992d48bd92469e83fd25ee59f015998
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png ADDED

Git LFS Details

  • SHA256: a048958e0ed28806ecb7c9834f91b07a464b73cd641fa19b03f39ff542986530
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_image": "tennis.jpg",
3
+ "output_image": "a_woman_is_holding_a_baseball_bat_in_her_hand_0.png",
4
+ "prompt": "A woman is holding a baseball bat in her hand.",
5
+ "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face",
6
+ "num_steps": 100,
7
+ "seed": 765,
8
+ "width": 990,
9
+ "height": 720,
10
+ "guidance_scale": 6.5,
11
+ "controlnet_conditioning_scale": 0.7
12
+ }
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg ADDED

Git LFS Details

  • SHA256: 259845edb5c365bccb33f9207630d829bb5a839e72bf7d0326f11ae4862694fa
  • Pointer size: 132 Bytes
  • Size of remote file: 5.61 MB
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png ADDED

Git LFS Details

  • SHA256: deaa70aba05ab58ea0f9bd16512c6dcc7e0951559037779063045b7c035342f8
  • Pointer size: 131 Bytes
  • Size of remote file: 441 kB
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_image": "man_and_sword.jpg",
3
+ "output_image": "a_woman_raises_a_katana_0.png",
4
+ "prompt": "A woman raises a katana.",
5
+ "negative_prompt": "body elongated, fragmentation, many hands, ugly, deformed face",
6
+ "num_steps": 50,
7
+ "seed": 78,
8
+ "width": 540,
9
+ "height": 512,
10
+ "guidance_scale": 6.5,
11
+ "controlnet_conditioning_scale": 0.8
12
+ }
apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg ADDED
apps/gradio_app/examples.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ def load_examples(examples_base_path=os.path.join("apps", "gradio_app",
7
+ "assets", "examples", "Stable-Diffusion-2.1-Openpose-ControlNet")):
8
+
9
+ """Load example configurations and input images from the Stable-Diffusion-2.1-Openpose-ControlNet directory."""
10
+ examples = []
11
+
12
+ # Iterate through example folders (e.g., '1', '2', '3', '4')
13
+ for folder in os.listdir(examples_base_path):
14
+ folder_path = os.path.join(examples_base_path, folder)
15
+ config_path = os.path.join(folder_path, "config.json")
16
+
17
+ if os.path.exists(config_path):
18
+ try:
19
+ with open(config_path, 'r') as f:
20
+ config = json.load(f)
21
+
22
+ # Extract configuration fields
23
+ input_filename = config["input_image"]
24
+ output_filename = config["output_image"]
25
+ prompt = config.get("prompt", "a man is doing yoga")
26
+ negative_prompt = config.get("negative_prompt", "monochrome, lowres, bad anatomy, worst quality, low quality")
27
+ num_steps = config.get("num_steps", 30)
28
+ seed = config.get("seed", 42)
29
+ width = config.get("width", 512)
30
+ height = config.get("height", 512)
31
+ guidance_scale = config.get("guidance_scale", 7.5)
32
+ controlnet_conditioning_scale = config.get("controlnet_conditioning_scale", 1.0)
33
+
34
+ # Construct absolute path for input image
35
+ input_image_path = os.path.join(folder_path, input_filename)
36
+ output_image_path = os.path.join(folder_path, output_filename)
37
+ # Check if input image exists
38
+ if os.path.exists(input_image_path):
39
+ input_image_data = Image.open(input_image_path)
40
+ output_image_data = Image.open(output_image_path)
41
+ # Append example data in the order expected by Gradio inputs
42
+ examples.append([
43
+ input_image_data, # Input image
44
+ prompt,
45
+ negative_prompt,
46
+ output_image_data,
47
+ num_steps,
48
+ seed,
49
+ width,
50
+ height,
51
+ guidance_scale,
52
+ controlnet_conditioning_scale,
53
+ False # use_random_seed, hardcoded as per original gr.Examples
54
+ ])
55
+ else:
56
+ print(f"Input image not found at {input_image_path}")
57
+
58
+ except json.JSONDecodeError as e:
59
+ print(f"Error decoding JSON from {config_path}: {str(e)}")
60
+ except Exception as e:
61
+ print(f"Error processing example in {folder_path}: {str(e)}")
62
+
63
+ return examples
64
+
65
+ def select_example(evt: gr.SelectData, examples_data):
66
+ """Handle selection of an example to populate Gradio inputs."""
67
+ example_index = evt.index
68
+ # Extract example data
69
+ # input_image_data, prompt, negative_prompt, output_image_data, num_steps, seed, width, height, guidance_scale, controlnet_conditioning_scale, use_random_seed = examples_data[example_index]
70
+ (
71
+ input_image_data,
72
+ prompt,
73
+ negative_prompt,
74
+ output_image_data,
75
+ num_steps,
76
+ seed,
77
+ width,
78
+ height,
79
+ guidance_scale,
80
+ controlnet_conditioning_scale,
81
+ use_random_seed,
82
+ ) = examples_data[example_index]
83
+
84
+
85
+ # Return values to update Gradio interface inputs and output message
86
+ return (
87
+ input_image_data, # Input image
88
+ prompt, # Prompt
89
+ negative_prompt, # Negative prompt
90
+ output_image_data, # Output image
91
+ num_steps, # Number of inference steps
92
+ seed, # Random seed
93
+ width, # Width
94
+ height, # Height
95
+ guidance_scale, # Guidance scale
96
+ controlnet_conditioning_scale, # ControlNet conditioning scale
97
+ use_random_seed, # Use random seed
98
+ f"Loaded example {example_index + 1} with prompt: {prompt}" # Output message
99
+ )
apps/gradio_app/inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ import sys
4
+
5
+ # Add the project root directory to the Python path
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
7
+
8
+ from src.controlnet_image_generator.infer import infer
9
+
10
+
11
+ def run_inference(
12
+ input_image,
13
+ prompt,
14
+ negative_prompt,
15
+ num_steps,
16
+ seed,
17
+ width,
18
+ height,
19
+ guidance_scale,
20
+ controlnet_conditioning_scale,
21
+ use_random_seed=False,
22
+ ):
23
+ config_path = "configs/model_ckpts.yaml"
24
+
25
+ if use_random_seed:
26
+ seed = random.randint(0, 2 ** 32)
27
+
28
+ try:
29
+ result = infer(
30
+ config_path=config_path,
31
+ input_image=input_image,
32
+ image_url=None,
33
+ prompt=prompt,
34
+ negative_prompt=negative_prompt,
35
+ num_steps=num_steps,
36
+ seed=seed,
37
+ width=width,
38
+ height=height,
39
+ guidance_scale=guidance_scale,
40
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
41
+ )
42
+ result = list(result)[0]
43
+ return result, "Inference completed successfully"
44
+ except Exception as e:
45
+ return [], f"Error during inference: {str(e)}"
apps/gradio_app/project_info.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NAME = """
2
+ # ControlNet Image Generator 🖌️
3
+ """.strip()
4
+
5
+ CONTENT_DESCRIPTION = """
6
+ <h3>ControlNet ⚡️ boosts Stable Diffusion with sharp, innovative image generation control 🖌️</h3>
7
+ """.strip()
8
+
9
+ # CONTENT_IN_1 = """
10
+ # Transforms low-res anime images into sharp, vibrant HD visuals, enhancing textures and details for artwork and games.
11
+ # """.strip()
12
+
13
+ CONTENT_IN_1 = """
14
+ <p class="source">
15
+ For more information, you can check out my GitHub repository and HuggingFace Model Hub:<br>
16
+ Source code:
17
+ <a class="badge" href="https://github.com/danhtran2mind/CoantrolNet-Image-Generator">
18
+ <img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FControlNet--Image--Generator-blue?style=flat?logo=github" alt="GitHub Repo">
19
+ </a>,
20
+ Model Hub:
21
+ <a class="badge" href="https://huggingface.co/danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet">
22
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FStable--Diffusion--2.1--Openpose--ControlNet-yellow?style=flat?logo=huggingface" alt="HuggingFace Model">
23
+ </a>.
24
+ </p>
25
+ """.strip()
26
+
27
+ CONTENT_OUT_1 = """
28
+ <div class="quote-container">
29
+ <p>
30
+ This project is built using code from
31
+ <a class="badge" href="https://github.com/huggingface/diffusers">
32
+ <img src="https://img.shields.io/badge/Built%20on-huggingface%2Fdiffusers-blue?style=flat&logo=github" alt="Built on Real-ESRGAN">
33
+ </a>.
34
+ </p>
35
+ </div>
36
+ """.strip()
37
+
apps/gradio_app/setup_scripts.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+
5
+ def run_script(script_path, args=None):
6
+ """
7
+ Run a Python script using subprocess with optional arguments and handle errors.
8
+ Returns True if successful, False otherwise.
9
+ """
10
+ if not os.path.isfile(script_path):
11
+ print(f"Script not found: {script_path}")
12
+ return False
13
+
14
+ try:
15
+ command = [sys.executable, script_path]
16
+ if args:
17
+ command.extend(args)
18
+ result = subprocess.run(
19
+ command,
20
+ check=True,
21
+ text=True,
22
+ capture_output=True
23
+ )
24
+ print(f"Successfully executed {script_path}")
25
+ print(result.stdout)
26
+ return True
27
+ except subprocess.CalledProcessError as e:
28
+ print(f"Error executing {script_path}:")
29
+ print(e.stderr)
30
+ return False
31
+ except Exception as e:
32
+ print(f"Unexpected error executing {script_path}: {str(e)}")
33
+ return False
34
+
35
+ def main():
36
+ """
37
+ Main function to execute download_ckpts.py with proper error handling.
38
+ """
39
+ scripts_dir = "scripts"
40
+ scripts = [
41
+ {
42
+ "path": os.path.join(scripts_dir, "download_ckpts.py"),
43
+ "args": [] # Empty list for args to avoid NoneType issues
44
+ }
45
+ ]
46
+
47
+ for script in scripts:
48
+ script_path = script["path"]
49
+ args = script.get("args", []) # Safely get args with default empty list
50
+ print(f"Starting execution of {script_path}{' with args: ' + ' '.join(args) if args else ''}\n")
51
+
52
+ if not run_script(script_path, args):
53
+ print(f"Stopping execution due to error in {script_path}")
54
+ sys.exit(1)
55
+
56
+ print(f"Completed execution of {script_path}\n")
57
+
58
+ if __name__ == "__main__":
59
+ main()
apps/gradio_app/static/style.css ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap'); */
2
+ /* ─── palette ───────────────────────────────────────────── */
3
+ body, .gradio-container {
4
+ font-family: 'Inter', sans-serif;
5
+ background: #FFFBF7;
6
+ color: #0F172A;
7
+ }
8
+ a {
9
+ color: #F97316;
10
+ text-decoration: none;
11
+ font-weight: 600;
12
+ }
13
+ a:hover { color: #C2410C; }
14
+ /* ─── headline ──────────────────────────────────────────── */
15
+ #titlebar {
16
+ text-align: center;
17
+ margin-top: 2.4rem;
18
+ margin-bottom: .9rem;
19
+ }
20
+ /* ─── card look ─────────────────────────────────────────── */
21
+ .gr-block,
22
+ .gr-box,
23
+ .gr-row,
24
+ #cite-wrapper {
25
+ border: 1px solid #F8C89B;
26
+ border-radius: 10px;
27
+ background: #fff;
28
+ box-shadow: 0 3px 6px rgba(0, 0, 0, .05);
29
+ }
30
+ .gr-gallery-item { background: #fff; }
31
+ /* ─── controls / inputs ─────────────────────────────────── */
32
+ .gr-button-primary,
33
+ #copy-btn {
34
+ background: linear-gradient(90deg, #F97316 0%, #C2410C 100%);
35
+ border: none;
36
+ color: #fff;
37
+ border-radius: 6px;
38
+ font-weight: 600;
39
+ transition: transform .12s ease, box-shadow .12s ease;
40
+ }
41
+ .gr-button-primary:hover,
42
+ #copy-btn:hover {
43
+ transform: translateY(-2px);
44
+ box-shadow: 0 4px 12px rgba(249, 115, 22, .35);
45
+ }
46
+ .gr-dropdown input {
47
+ border: 1px solid #F9731699;
48
+ }
49
+ .preview img,
50
+ .preview canvas { object-fit: contain !important; }
51
+ /* ─── hero section ─────────────────────────────────────── */
52
+ #hero-wrapper { text-align: center; }
53
+ #hero-badge {
54
+ display: inline-block;
55
+ padding: .85rem 1.2rem;
56
+ border-radius: 8px;
57
+ background: #FFEAD2;
58
+ border: 1px solid #F9731655;
59
+ font-size: .95rem;
60
+ font-weight: 600;
61
+ margin-bottom: .5rem;
62
+ }
63
+ #hero-links {
64
+ font-size: .95rem;
65
+ font-weight: 600;
66
+ margin-bottom: 1.6rem;
67
+ }
68
+ #hero-links img {
69
+ height: 22px;
70
+ vertical-align: middle;
71
+ margin-left: .55rem;
72
+ }
73
+ /* ─── score area ───────────────────────────────────────── */
74
+ #score-area {
75
+ text-align: center;
76
+ }
77
+ .title-container {
78
+ display: flex;
79
+ align-items: center;
80
+ gap: 12px;
81
+ justify-content: center;
82
+ margin-bottom: 10px;
83
+ text-align: center;
84
+ }
85
+ .match-badge {
86
+ display: inline-block;
87
+ padding: .35rem .9rem;
88
+ border-radius: 9999px;
89
+ font-weight: 600;
90
+ font-size: 1.25rem;
91
+ }
92
+ /* ─── citation card ────────────────────────────────────── */
93
+ #cite-wrapper {
94
+ position: relative;
95
+ padding: .9rem 1rem;
96
+ margin-top: 2rem;
97
+ }
98
+ #cite-wrapper code {
99
+ font-family: SFMono-Regular, Consolas, monospace;
100
+ font-size: .84rem;
101
+ white-space: pre-wrap;
102
+ color: #0F172A;
103
+ }
104
+ #copy-btn {
105
+ position: absolute;
106
+ top: .55rem;
107
+ right: .6rem;
108
+ padding: .18rem .7rem;
109
+ font-size: .72rem;
110
+ line-height: 1;
111
+ }
112
+ /* ─── dark mode ────────────────────────────────────── */
113
+ .dark body,
114
+ .dark .gradio-container {
115
+ background-color: #332a22;
116
+ color: #e5e7eb;
117
+ }
118
+ .dark .gr-block,
119
+ .dark .gr-box,
120
+ .dark .gr-row {
121
+ background-color: #332a22;
122
+ border: 1px solid #4b5563;
123
+ }
124
+ .dark .gr-dropdown input {
125
+ background-color: #332a22;
126
+ color: #f1f5f9;
127
+ border: 1px solid #F97316aa;
128
+ }
129
+ .dark #hero-badge {
130
+ background: #334155;
131
+ border: 1px solid #F9731655;
132
+ color: #fefefe;
133
+ }
134
+ .dark #cite-wrapper {
135
+ background-color: #473f38;
136
+ }
137
+ .dark #bibtex {
138
+ color: #f8fafc !important;
139
+ }
140
+ .dark .card {
141
+ background-color: #473f38;
142
+ }
143
+ /* ─── switch logo for light/dark theme ─────────────── */
144
+ .logo-dark { display: none; }
145
+ .dark .logo-light { display: none; }
146
+ .dark .logo-dark { display: inline; }
147
+
148
+ /* https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap */
149
+
150
+ /* cyrillic-ext */
151
+ @font-face {
152
+ font-family: 'Inter';
153
+ font-style: normal;
154
+ font-weight: 400;
155
+ font-display: swap;
156
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
157
+ unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
158
+ }
159
+ /* cyrillic */
160
+ @font-face {
161
+ font-family: 'Inter';
162
+ font-style: normal;
163
+ font-weight: 400;
164
+ font-display: swap;
165
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
166
+ unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
167
+ }
168
+ /* greek-ext */
169
+ @font-face {
170
+ font-family: 'Inter';
171
+ font-style: normal;
172
+ font-weight: 400;
173
+ font-display: swap;
174
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
175
+ unicode-range: U+1F00-1FFF;
176
+ }
177
+ /* greek */
178
+ @font-face {
179
+ font-family: 'Inter';
180
+ font-style: normal;
181
+ font-weight: 400;
182
+ font-display: swap;
183
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
184
+ unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
185
+ }
186
+ /* vietnamese */
187
+ @font-face {
188
+ font-family: 'Inter';
189
+ font-style: normal;
190
+ font-weight: 400;
191
+ font-display: swap;
192
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
193
+ unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
194
+ }
195
+ /* latin-ext */
196
+ @font-face {
197
+ font-family: 'Inter';
198
+ font-style: normal;
199
+ font-weight: 400;
200
+ font-display: swap;
201
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
202
+ unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
203
+ }
204
+ /* latin */
205
+ @font-face {
206
+ font-family: 'Inter';
207
+ font-style: normal;
208
+ font-weight: 400;
209
+ font-display: swap;
210
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
211
+ unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
212
+ }
213
+ /* cyrillic-ext */
214
+ @font-face {
215
+ font-family: 'Inter';
216
+ font-style: normal;
217
+ font-weight: 500;
218
+ font-display: swap;
219
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
220
+ unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
221
+ }
222
+ /* cyrillic */
223
+ @font-face {
224
+ font-family: 'Inter';
225
+ font-style: normal;
226
+ font-weight: 500;
227
+ font-display: swap;
228
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
229
+ unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
230
+ }
231
+ /* greek-ext */
232
+ @font-face {
233
+ font-family: 'Inter';
234
+ font-style: normal;
235
+ font-weight: 500;
236
+ font-display: swap;
237
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
238
+ unicode-range: U+1F00-1FFF;
239
+ }
240
+ /* greek */
241
+ @font-face {
242
+ font-family: 'Inter';
243
+ font-style: normal;
244
+ font-weight: 500;
245
+ font-display: swap;
246
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
247
+ unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
248
+ }
249
+ /* vietnamese */
250
+ @font-face {
251
+ font-family: 'Inter';
252
+ font-style: normal;
253
+ font-weight: 500;
254
+ font-display: swap;
255
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
256
+ unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
257
+ }
258
+ /* latin-ext */
259
+ @font-face {
260
+ font-family: 'Inter';
261
+ font-style: normal;
262
+ font-weight: 500;
263
+ font-display: swap;
264
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
265
+ unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
266
+ }
267
+ /* latin */
268
+ @font-face {
269
+ font-family: 'Inter';
270
+ font-style: normal;
271
+ font-weight: 500;
272
+ font-display: swap;
273
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
274
+ unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
275
+ }
276
+ /* cyrillic-ext */
277
+ @font-face {
278
+ font-family: 'Inter';
279
+ font-style: normal;
280
+ font-weight: 600;
281
+ font-display: swap;
282
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
283
+ unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
284
+ }
285
+ /* cyrillic */
286
+ @font-face {
287
+ font-family: 'Inter';
288
+ font-style: normal;
289
+ font-weight: 600;
290
+ font-display: swap;
291
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
292
+ unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
293
+ }
294
+ /* greek-ext */
295
+ @font-face {
296
+ font-family: 'Inter';
297
+ font-style: normal;
298
+ font-weight: 600;
299
+ font-display: swap;
300
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
301
+ unicode-range: U+1F00-1FFF;
302
+ }
303
+ /* greek */
304
+ @font-face {
305
+ font-family: 'Inter';
306
+ font-style: normal;
307
+ font-weight: 600;
308
+ font-display: swap;
309
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
310
+ unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
311
+ }
312
+ /* vietnamese */
313
+ @font-face {
314
+ font-family: 'Inter';
315
+ font-style: normal;
316
+ font-weight: 600;
317
+ font-display: swap;
318
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
319
+ unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
320
+ }
321
+ /* latin-ext */
322
+ @font-face {
323
+ font-family: 'Inter';
324
+ font-style: normal;
325
+ font-weight: 600;
326
+ font-display: swap;
327
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
328
+ unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
329
+ }
330
+ /* latin */
331
+ @font-face {
332
+ font-family: 'Inter';
333
+ font-style: normal;
334
+ font-weight: 600;
335
+ font-display: swap;
336
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
337
+ unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
338
+ }
339
+ /* cyrillic-ext */
340
+ @font-face {
341
+ font-family: 'Inter';
342
+ font-style: normal;
343
+ font-weight: 700;
344
+ font-display: swap;
345
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
346
+ unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
347
+ }
348
+ /* cyrillic */
349
+ @font-face {
350
+ font-family: 'Inter';
351
+ font-style: normal;
352
+ font-weight: 700;
353
+ font-display: swap;
354
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
355
+ unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
356
+ }
357
+ /* greek-ext */
358
+ @font-face {
359
+ font-family: 'Inter';
360
+ font-style: normal;
361
+ font-weight: 700;
362
+ font-display: swap;
363
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
364
+ unicode-range: U+1F00-1FFF;
365
+ }
366
+ /* greek */
367
+ @font-face {
368
+ font-family: 'Inter';
369
+ font-style: normal;
370
+ font-weight: 700;
371
+ font-display: swap;
372
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
373
+ unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
374
+ }
375
+ /* vietnamese */
376
+ @font-face {
377
+ font-family: 'Inter';
378
+ font-style: normal;
379
+ font-weight: 700;
380
+ font-display: swap;
381
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
382
+ unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
383
+ }
384
+ /* latin-ext */
385
+ @font-face {
386
+ font-family: 'Inter';
387
+ font-style: normal;
388
+ font-weight: 700;
389
+ font-display: swap;
390
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
391
+ unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
392
+ }
393
+ /* latin */
394
+ @font-face {
395
+ font-family: 'Inter';
396
+ font-style: normal;
397
+ font-weight: 700;
398
+ font-display: swap;
399
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
400
+ unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
401
+ }
402
+ /* cyrillic-ext */
403
+ @font-face {
404
+ font-family: 'Inter';
405
+ font-style: normal;
406
+ font-weight: 800;
407
+ font-display: swap;
408
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2');
409
+ unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F;
410
+ }
411
+ /* cyrillic */
412
+ @font-face {
413
+ font-family: 'Inter';
414
+ font-style: normal;
415
+ font-weight: 800;
416
+ font-display: swap;
417
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2');
418
+ unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116;
419
+ }
420
+ /* greek-ext */
421
+ @font-face {
422
+ font-family: 'Inter';
423
+ font-style: normal;
424
+ font-weight: 800;
425
+ font-display: swap;
426
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2');
427
+ unicode-range: U+1F00-1FFF;
428
+ }
429
+ /* greek */
430
+ @font-face {
431
+ font-family: 'Inter';
432
+ font-style: normal;
433
+ font-weight: 800;
434
+ font-display: swap;
435
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2');
436
+ unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF;
437
+ }
438
+ /* vietnamese */
439
+ @font-face {
440
+ font-family: 'Inter';
441
+ font-style: normal;
442
+ font-weight: 800;
443
+ font-display: swap;
444
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2');
445
+ unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB;
446
+ }
447
+ /* latin-ext */
448
+ @font-face {
449
+ font-family: 'Inter';
450
+ font-style: normal;
451
+ font-weight: 800;
452
+ font-display: swap;
453
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2');
454
+ unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF;
455
+ }
456
+ /* latin */
457
+ @font-face {
458
+ font-family: 'Inter';
459
+ font-style: normal;
460
+ font-weight: 800;
461
+ font-display: swap;
462
+ src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2');
463
+ unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
464
+ }
465
+
466
+ /* title_css */
467
+ #title {
468
+ font-size: 2.6rem;
469
+ font-weight: 800;
470
+ margin: 0;
471
+ line-height: 1.25;
472
+ color: #0F172A;
473
+ }
474
+ /* brand class is passed in title parameter */
475
+ #title .brand {
476
+ background: linear-gradient(90deg, #F97316 0%, #C2410C 90%);
477
+ -webkit-background-clip: text;
478
+ color: transparent;
479
+ }
480
+ .dark #title {
481
+ color: #f8fafc;
482
+ }
483
+ .title-container {
484
+ display: flex;
485
+ align-items: center;
486
+ gap: 12px;
487
+ justify-content: center;
488
+ margin-bottom: 10px;
489
+ text-align: center;
490
+ }
491
+
492
+ /* Dark Mode */
493
+ @media (prefers-color-scheme: dark) {
494
+ body { @extend .dark; }
495
+ }
496
+ /* Smaller size for input image */
497
+ .input-image img {
498
+ max-width: 300px;
499
+ height: auto;
500
+ }
501
+ /* Larger size for output image */
502
+ .output-image img {
503
+ max-width: 500px;
504
+ height: auto;
505
+ }
506
+
507
+ /* Add styling for warning message */
508
+ .warning-message {
509
+ color: red;
510
+ font-size: 14px;
511
+ margin-top: 5px;
512
+ display: block;
513
+ }
514
+ #warning-text {
515
+ min-height: 20px; /* Ensure space for warning */
516
+ }
517
+ /*Components for Gradio App*/
518
+ .quote-container {
519
+ border-left: 5px solid #007bff;
520
+ padding-left: 15px;
521
+ margin-bottom: 15px;
522
+ font-style: italic;
523
+ }
524
+ .attribution p {
525
+ margin: 10px 0;
526
+ }
527
+ .badge {
528
+ display: inline-block;
529
+ border-radius: 4px;
530
+ text-decoration: none;
531
+ font-size: 14px;
532
+ transition: background-color 0.3s;
533
+ }
534
+ .badge:hover {
535
+ background-color: #0056b3;
536
+ }
537
+ .badge img {
538
+ vertical-align: middle;
539
+ margin-right: 5px;
540
+ }
541
+ .source {
542
+ font-size: 14px;
543
+ }
544
+
545
+ /* Start- Stop Buttons */
546
+ .submit-btn {
547
+ background-color: #f97316; /* Green background */
548
+ color: white;
549
+ font-weight: bold;
550
+ padding: 8px 16px;
551
+ border-radius: 6px;
552
+ border: none;
553
+ cursor: pointer;
554
+ transition: background-color 0.3s ease;
555
+ }
556
+
557
+ .submit-btn:hover {
558
+ background-color: #f97416de; /* Darker green on hover */
559
+ }
560
+
561
+ .stop-btn {
562
+ background-color: grey; /* Red background */
563
+ color: white;
564
+ font-weight: 600;
565
+ padding: 8px 16px;
566
+ border-radius: 6px;
567
+ border: none;
568
+ cursor: pointer;
569
+ transition: background-color 0.3s ease;
570
+ }
571
+
572
+ .stop-btn:hover {
573
+ background-color: rgba(128, 128, 128, 0.858); /* Darker red on hover */
574
+ }
apps/old-gradio_app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import gradio as gr
5
+ import torch
6
+ import random
7
+
8
+ # Add the project root directory to the Python path
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+
11
+ from src.controlnet_image_generator.infer import infer
12
+
13
+ def run_setup_script():
14
+ setup_script = os.path.join(os.path.dirname(__file__), "gradio_app", "setup_scripts.py")
15
+ try:
16
+ result = subprocess.run(["python", setup_script], capture_output=True, text=True, check=True)
17
+ return result.stdout
18
+ except subprocess.CalledProcessError as e:
19
+ print(f"Setup script failed with error: {e.stderr}")
20
+ return f"Setup script failed: {e.stderr}"
21
+
22
+ def run_inference(
23
+ input_image,
24
+ prompt,
25
+ negative_prompt,
26
+ num_steps,
27
+ seed,
28
+ width,
29
+ height,
30
+ guidance_scale,
31
+ controlnet_conditioning_scale,
32
+ use_random_seed=False,
33
+ ):
34
+ config_path = "configs/model_ckpts.yaml"
35
+
36
+ if use_random_seed:
37
+ seed = random.randint(0, 2 ** 32)
38
+
39
+ try:
40
+ result = infer(
41
+ config_path=config_path,
42
+ input_image=input_image,
43
+ image_url=None,
44
+ prompt=prompt,
45
+ negative_prompt=negative_prompt,
46
+ num_steps=num_steps,
47
+ seed=seed,
48
+ width=width,
49
+ height=height,
50
+ guidance_scale=guidance_scale,
51
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
52
+ )
53
+ result = list(result)[0]
54
+ return result, "Inference completed successfully"
55
+ except Exception as e:
56
+ return [], f"Error during inference: {str(e)}"
57
+
58
+ def stop_app():
59
+ """Function to stop the Gradio app."""
60
+ try:
61
+ gr.Interface.close_all() # Attempt to close all running Gradio interfaces
62
+ return "Application stopped successfully."
63
+ except Exception as e:
64
+ return f"Error stopping application: {str(e)}"
65
+
66
+ def create_gui():
67
+ cuscustom_css = open("apps/gradio_app/static/style.css").read()
68
+ with gr.Blocks(css=cuscustom_css) as demo:
69
+ gr.Markdown("# ControlNet Image Generation with Pose Detection")
70
+
71
+ with gr.Row():
72
+ with gr.Column():
73
+ input_image = gr.Image(type="filepath", label="Input Image")
74
+ prompt = gr.Textbox(
75
+ label="Prompt",
76
+ value="a man is doing yoga"
77
+ )
78
+ negative_prompt = gr.Textbox(
79
+ label="Negative Prompt",
80
+ value="monochrome, lowres, bad anatomy, worst quality, low quality"
81
+ )
82
+
83
+ with gr.Row():
84
+ width = gr.Slider(
85
+ minimum=256,
86
+ maximum=1024,
87
+ value=512,
88
+ step=64,
89
+ label="Width"
90
+ )
91
+ height = gr.Slider(
92
+ minimum=256,
93
+ maximum=1024,
94
+ value=512,
95
+ step=64,
96
+ label="Height"
97
+ )
98
+
99
+ with gr.Accordion("Advanced Settings", open=False):
100
+ num_steps = gr.Slider(
101
+ minimum=1,
102
+ maximum=100,
103
+ value=30,
104
+ step=1,
105
+ label="Number of Inference Steps"
106
+ )
107
+ use_random_seed = gr.Checkbox(label="Use Random Seed", value=False)
108
+ seed = gr.Slider(
109
+ minimum=0,
110
+ maximum=2**32,
111
+ value=42,
112
+ step=1,
113
+ label="Random Seed",
114
+ visible=True
115
+ )
116
+
117
+ guidance_scale = gr.Slider(
118
+ minimum=1.0,
119
+ maximum=20.0,
120
+ value=7.5,
121
+ step=0.1,
122
+ label="Guidance Scale"
123
+ )
124
+ controlnet_conditioning_scale = gr.Slider(
125
+ minimum=0.0,
126
+ maximum=1.0,
127
+ value=1.0,
128
+ step=0.1,
129
+ label="ControlNet Conditioning Scale"
130
+ )
131
+
132
+ with gr.Column():
133
+ output_images = gr.Image(label="Generated Images")
134
+ output_message = gr.Textbox(label="Status")
135
+
136
+ # with gr.Row():
137
+ submit_button = gr.Button("Generate Images", elem_classes="submit-btn")
138
+ stop_button = gr.Button("Stop Application", elem_classes="stop-btn")
139
+
140
+ def update_seed_visibility(use_random):
141
+ return gr.update(visible=not use_random)
142
+
143
+ use_random_seed.change(
144
+ fn=update_seed_visibility,
145
+ inputs=use_random_seed,
146
+ outputs=seed
147
+ )
148
+
149
+ submit_button.click(
150
+ fn=run_inference,
151
+ inputs=[
152
+ input_image,
153
+ prompt,
154
+ negative_prompt,
155
+ num_steps,
156
+ seed,
157
+ width,
158
+ height,
159
+ guidance_scale,
160
+ controlnet_conditioning_scale,
161
+ use_random_seed,
162
+ ],
163
+ outputs=[output_images, output_message]
164
+ )
165
+
166
+ stop_button.click(
167
+ fn=stop_app,
168
+ inputs=[],
169
+ outputs=[output_message]
170
+ )
171
+
172
+ return demo
173
+
174
+ if __name__ == "__main__":
175
+ run_setup_script()
176
+ demo = create_gui()
177
+ demo.launch(share=True)
assets/.gitkeep ADDED
File without changes
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png ADDED

Git LFS Details

  • SHA256: 3dc2b7efb61afd2d6ceda1b32ec9792a5b07f3ac3d7a96d7acdd2102ddb957b7
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_image": "yoga.jpg",
3
+ "output_image": "a_man_is_doing_yoga_in_a_serene_park_0.png",
4
+ "prompt": "A man is doing yoga in a serene park.",
5
+ "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face",
6
+ "num_steps": 50,
7
+ "seed": 100,
8
+ "width": 512,
9
+ "height": 512,
10
+ "guidance_scale": 5.5,
11
+ "controlnet_conditioning_scale": 0.6
12
+ }
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg ADDED
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png ADDED

Git LFS Details

  • SHA256: 2e83cc3b007c2303e276b3ac60a8fa930877e584e3534f12e1441ec83ed9e9fd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_image": "ride_bike.jpg",
3
+ "output_image": "a_man_is_galloping_on_a_horse_0.png",
4
+ "prompt": "A man is galloping on a horse.",
5
+ "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face",
6
+ "num_steps": 100,
7
+ "seed": 56,
8
+ "width": 1080,
9
+ "height": 720,
10
+ "guidance_scale": 9.5,
11
+ "controlnet_conditioning_scale": 0.5
12
+ }
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg ADDED

Git LFS Details

  • SHA256: 76310cad16fcf71097c9660d46a95ced0992d48bd92469e83fd25ee59f015998
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png ADDED

Git LFS Details

  • SHA256: a048958e0ed28806ecb7c9834f91b07a464b73cd641fa19b03f39ff542986530
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_image": "tennis.jpg",
3
+ "output_image": "a_woman_is_holding_a_baseball_bat_in_her_hand_0.png",
4
+ "prompt": "A woman is holding a baseball bat in her hand.",
5
+ "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face",
6
+ "num_steps": 100,
7
+ "seed": 765,
8
+ "width": 990,
9
+ "height": 720,
10
+ "guidance_scale": 6.5,
11
+ "controlnet_conditioning_scale": 0.7
12
+ }
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg ADDED

Git LFS Details

  • SHA256: 259845edb5c365bccb33f9207630d829bb5a839e72bf7d0326f11ae4862694fa
  • Pointer size: 132 Bytes
  • Size of remote file: 5.61 MB
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png ADDED

Git LFS Details

  • SHA256: deaa70aba05ab58ea0f9bd16512c6dcc7e0951559037779063045b7c035342f8
  • Pointer size: 131 Bytes
  • Size of remote file: 441 kB
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_image": "man_and_sword.jpg",
3
+ "output_image": "a_woman_raises_a_katana_0.png",
4
+ "prompt": "A woman raises a katana.",
5
+ "negative_prompt": "body elongated, fragmentation, many hands, ugly, deformed face",
6
+ "num_steps": 50,
7
+ "seed": 78,
8
+ "width": 540,
9
+ "height": 512,
10
+ "guidance_scale": 6.5,
11
+ "controlnet_conditioning_scale": 0.8
12
+ }
assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg ADDED
ckpts/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
configs/.gitkeep ADDED
File without changes
configs/datasets_info.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ - dataset_name: "HighCWu/open_pose_controlnet_subset"
2
+ local_dir: "HighCWu-open_pose_controlnet_subset"
3
+ platform: "HuggingFace"
configs/model_ckpts.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - model_id: "danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet"
2
+ local_dir: "ckpts/Stable-Diffusion-2.1-Openpose-ControlNet"
3
+ allow:
4
+ - diffusion_pytorch_model.safetensors
5
+ - config.json
6
+
7
+ - model_id: "stabilityai/stable-diffusion-2-1"
8
+ local_dir: "ckpts/stable-diffusion-2-1"
9
+ deny:
10
+ - v2-1_768-ema-pruned.ckpt
11
+ - v2-1_768-ema-pruned.safetensors
12
+ - v2-1_768-nonema-pruned.ckpt
13
+ - v2-1_768-nonema-pruned.safetensors
14
+
15
+ - model_id: "lllyasviel/ControlNet"
16
+ local_dir: null
data/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
docs/inference/inference_doc.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ControlNet Image Generation with Pose Detection
2
+
3
+ This document provides a comprehensive overview of a Python script designed for image generation using ControlNet with pose detection, integrated with the Stable Diffusion model. The script processes an input image to detect human poses and generates new images based on a text prompt, guided by the detected poses.
4
+
5
+ ## Purpose
6
+
7
+ The script enables users to generate images that adhere to specific poses extracted from an input image, combining the power of ControlNet for pose conditioning with Stable Diffusion for high-quality image synthesis. It is particularly useful for applications requiring pose-guided image generation, such as creating stylized images of people in specific poses (e.g., yoga, dancing) based on a reference image.
8
+
9
+ ## Dependencies
10
+
11
+ The script relies on the following Python libraries and custom modules:
12
+
13
+ - **Standard Libraries**:
14
+ - `torch`: For tensor operations and deep learning model handling.
15
+ - `argparse`: For parsing command-line arguments.
16
+ - `os`: For file and directory operations.
17
+ - `sys`: For modifying the Python path to include the project root.
18
+
19
+ - **Custom Modules** (assumed to be part of the project structure):
20
+ - `inference.config_loader`:
21
+ - `load_config`: Loads model configurations from a YAML file.
22
+ - `find_config_by_model_id`: Retrieves specific model configurations by ID.
23
+ - `inference.model_initializer`:
24
+ - `initialize_controlnet`: Initializes the ControlNet model.
25
+ - `initialize_pipeline`: Initializes the Stable Diffusion pipeline.
26
+ - `initialize_controlnet_detector`: Initializes the pose detection model.
27
+ - `inference.device_manager`:
28
+ - `setup_device`: Configures the computation device (e.g., CPU or GPU).
29
+ - `inference.image_processor`:
30
+ - `load_input_image`: Loads the input image from a local path or URL.
31
+ - `detect_poses`: Detects human poses in the input image.
32
+ - `inference.image_generator`:
33
+ - `generate_images`: Generates images using the pipeline and pose conditions.
34
+ - `save_images`: Saves generated images to the specified directory.
35
+
36
+ ## Script Structure
37
+
38
+ The script is organized into the following components:
39
+
40
+ 1. **Imports and Path Setup**:
41
+ - Imports necessary libraries and adds the project root directory to the Python path for accessing custom modules.
42
+ - Ensures the script can locate custom modules regardless of the execution context.
43
+
44
+ 2. **Global Variables**:
45
+ - Defines three global variables to cache initialized models:
46
+ - `controlnet_detector`: For pose detection.
47
+ - `controlnet`: For pose-guided conditioning.
48
+ - `pipe`: The Stable Diffusion pipeline.
49
+ - These variables persist across multiple calls to the `infer` function to avoid redundant model initialization.
50
+
51
+ 3. **Main Function: `infer`**:
52
+ - The core function that orchestrates the image generation process.
53
+ - Takes configurable parameters for input, model settings, and output options.
54
+
55
+ 4. **Command-Line Interface**:
56
+ - Uses `argparse` to provide a user-friendly interface for running the script with customizable parameters.
57
+
58
+ ## Main Function: `infer`
59
+
60
+ The `infer` function handles the end-to-end process of loading models, processing input images, detecting poses, generating images, and optionally saving the results.
61
+
62
+ ### Parameters
63
+
64
+ | Parameter | Type | Description | Default |
65
+ |-----------|------|-------------|---------|
66
+ | `config_path` | `str` | Path to the configuration YAML file. | `"configs/model_ckpts.yaml"` |
67
+ | `input_image` | `str` | Path to the local input image. Mutually exclusive with `image_url`. | `None` |
68
+ | `image_url` | `str` | URL of the input image. Mutually exclusive with `input_image`. | `None` |
69
+ | `prompt` | `str` | Text prompt for image generation. | `"a man is doing yoga"` |
70
+ | `negative_prompt` | `str` | Negative prompt to avoid undesired features. | `"monochrome, lowres, bad anatomy, worst quality, low quality"` |
71
+ | `num_steps` | `int` | Number of inference steps. | `20` |
72
+ | `seed` | `int` | Random seed for reproducibility. | `2` |
73
+ | `width` | `int` | Width of the generated image (pixels). | `512` |
74
+ | `height` | `int` | Height of the generated image (pixels). | `512` |
75
+ | `guidance_scale` | `float` | Guidance scale for prompt adherence. | `7.5` |
76
+ | `controlnet_conditioning_scale` | `float` | ControlNet conditioning scale for pose influence. | `1.0` |
77
+ | `output_dir` | `str` | Directory to save generated images. | `tests/test_data` |
78
+ | `use_prompt_as_output_name` | `bool` | Use prompt in output filenames. | `False` |
79
+ | `save_output` | `bool` | Save generated images to `output_dir`. | `False` |
80
+
81
+ ### Workflow
82
+
83
+ 1. **Configuration Loading**:
84
+ - Loads model configurations from `config_path` using `load_config`.
85
+ - Retrieves specific configurations for:
86
+ - Pose detection model (`lllyasviel/ControlNet`).
87
+ - ControlNet model (`danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet`).
88
+ - Stable Diffusion pipeline (`stabilityai/stable-diffusion-2-1`).
89
+
90
+ 2. **Model Initialization**:
91
+ - Checks if `controlnet_detector`, `controlnet`, or `pipe` are `None`.
92
+ - If `None`, initializes them using the respective configurations to avoid redundant loading.
93
+
94
+ 3. **Device Setup**:
95
+ - Configures the computation device (e.g., CPU or GPU) for the pipeline using `setup_device`.
96
+
97
+ 4. **Image Processing**:
98
+ - Loads the input image from either `input_image` or `image_url` using `load_input_image`.
99
+ - Detects poses in the input image using `detect_poses` with the `controlnet_detector`.
100
+
101
+ 5. **Image Generation**:
102
+ - Creates a list of random number generators seeded with `seed + i` for each detected pose.
103
+ - Generates images using `generate_images`, passing:
104
+ - The pipeline (`pipe`).
105
+ - Repeated prompts and negative prompts for each pose.
106
+ - Detected poses as conditioning inputs.
107
+ - Generators for reproducibility.
108
+ - Parameters like `num_steps`, `guidance_scale`, `controlnet_conditioning_scale`, `width`, and `height`.
109
+
110
+ 6. **Output Handling**:
111
+ - If `save_output` is `True`, saves the generated images to `output_dir` using `save_images`.
112
+ - If `use_prompt_as_output_name` is `True`, incorporates the prompt into the output filenames.
113
+ - Returns the list of generated images.
114
+
115
+ ## Command-Line Interface
116
+
117
+ The script includes a command-line interface using `argparse` for flexible execution.
118
+
119
+ ### Arguments Table
120
+
121
+ | Argument | Type | Default Value | Description |
122
+ |----------|------|---------------|-------------|
123
+ | `--input_image` | `str` | `tests/test_data/yoga1.jpg` | Path to the local input image. Mutually exclusive with `--image_url`. |
124
+ | `--image_url` | `str` | `None` | URL of the input image (e.g., `https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg`). Mutually exclusive with `--input_image`. |
125
+ | `--config_path` | `str` | `configs/model_ckpts.yaml` | Path to the configuration YAML file for model settings. |
126
+ | `--prompt` | `str` | `"a man is doing yoga"` | Text prompt for image generation. |
127
+ | `--negative_prompt` | `str` | `"monochrome, lowres, bad anatomy, worst quality, low quality"` | Negative prompt to avoid undesired features in generated images. |
128
+ | `--num_steps` | `int` | `20` | Number of inference steps for image generation. |
129
+ | `--seed` | `int` | `2` | Random seed for reproducible generation. |
130
+ | `--width` | `int` | `512` | Width of the generated image in pixels. |
131
+ | `--height` | `int` | `512` | Height of the generated image in pixels. |
132
+ | `--guidance_scale` | `float` | `7.5` | Guidance scale for prompt adherence during generation. |
133
+ | `--controlnet_conditioning_scale` | `float` | `1.0` | ControlNet conditioning scale to balance pose influence. |
134
+ | `--output_dir` | `str` | `tests/test_data` | Directory to save generated images. |
135
+ | `--use_prompt_as_output_name` | Flag | `False` | If set, incorporates the prompt into output image filenames. |
136
+ | `--save_output` | Flag | `False` | If set, saves generated images to the specified output directory. |
137
+
138
+ ### Example Usage
139
+
140
+ ```bash
141
+ python script.py --input_image tests/test_data/yoga1.jpg --prompt "a woman doing yoga in a park" --num_steps 30 --guidance_scale 8.0 --save_output --use_prompt_as_output_name
142
+ ```
143
+
144
+ This command:
145
+ - Uses the local image `tests/test_data/yoga1.jpg` as input.
146
+ - Generates images with the prompt `"a woman doing yoga in a park"`.
147
+ - Runs for 30 inference steps with a guidance scale of 8.0.
148
+ - Saves the output images to `tests/test_data`, with filenames including the prompt.
149
+
150
+ Alternatively, using a URL:
151
+
152
+ ```bash
153
+ python script.py --image_url https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg --prompt "a person practicing yoga at sunset" --save_output
154
+ ```
155
+
156
+ This command uses an online image and saves the generated images without using the prompt in filenames.
157
+
158
+ ## Notes
159
+
160
+ - **Configuration File**: The script assumes a `configs/model_ckpts.yaml` file exists with configurations for the required models (`lllyasviel/ControlNet`, `danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet`, `stabilityai/stable-diffusion-2-1`). Ensure this file is correctly formatted and accessible.
161
+ - **Input Requirements**: The input image (local or URL) should contain at least one person for effective pose detection.
162
+ - **Model Caching**: Global variables cache the models to improve performance for multiple inferences within the same session.
163
+ - **Device Compatibility**: The `setup_device` function determines the computation device. Ensure compatible hardware (e.g., GPU) is available for optimal performance.
164
+ - **Output Flexibility**: The script supports generating multiple images if multiple poses are detected, with each image conditioned on one pose.
165
+ - **Error Handling**: The script assumes the custom modules handle errors appropriately. Users should verify that input paths, URLs, and model configurations are valid.
166
+
167
+ ## Potential Improvements
168
+
169
+ - Add error handling for invalid inputs or missing configuration files.
170
+ - Support batch processing for multiple input images.
171
+ - Allow dynamic model selection via command-line arguments instead of hardcoded model IDs.
172
+ - Include options for adjusting pose detection sensitivity or other model-specific parameters.
173
+
174
+ ## Conclusion
175
+
176
+ This script provides a robust framework for pose-guided image generation using ControlNet and Stable Diffusion. Its modular design and command-line interface make it suitable for both one-off experiments and integration into larger workflows. By leveraging pre-trained models and customizable parameters, it enables users to generate high-quality, pose-conditioned images with minimal setup.
docs/scripts/download_ckpts_doc.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Download Model Checkpoints
2
+
3
+ This script downloads model checkpoints from the Hugging Face Hub based on configurations specified in a YAML file.
4
+
5
+ ## Functionality
6
+ - **Load Configuration**: Reads a YAML configuration file to get model details.
7
+ - **Download Model**: Downloads files for specified models from the Hugging Face Hub to a local directory.
8
+ - Checks for a valid `local_dir` in the configuration; skips download if `local_dir` is null.
9
+ - Creates the local directory if it doesn't exist.
10
+ - Supports `allow` and `deny` patterns to filter files:
11
+ - If `allow` patterns are specified, only those files are downloaded.
12
+ - If no `allow` patterns are provided, all files are downloaded except those matching `deny` patterns.
13
+ - Uses `hf_hub_download` from the `huggingface_hub` library with symlinks disabled.
14
+
15
+ ## Command-Line Arguments
16
+ - `--config_path`: Path to the YAML configuration file (defaults to `configs/model_ckpts.yaml`).
17
+
18
+ ## Dependencies
19
+ - `argparse`: For parsing command-line arguments.
20
+ - `os`: For directory creation.
21
+ - `yaml`: For reading the configuration file.
22
+ - `huggingface_hub`: For downloading files from the Hugging Face Hub.
23
+
24
+ ## Usage
25
+ Run the script with:
26
+ ```bash
27
+ python scripts/download_ckpts.py --config_path <path_to_yaml>
28
+ ```
29
+ The script processes each model in the configuration file, printing the model ID and local directory for each.
docs/scripts/download_datasets_doc.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Download Datasets
2
+
3
+ This script downloads datasets from Hugging Face using configuration details specified in a YAML file.
4
+
5
+ ## Functionality
6
+ - **Load Configuration**: Reads dataset details from a YAML configuration file.
7
+ - **Download Dataset**: Downloads datasets from Hugging Face if the platform is specified as 'HuggingFace' in the configuration.
8
+ - **Command-Line Argument**: Accepts a path to the configuration file via the `--config_path` argument (defaults to `configs/datasets_info.yaml`).
9
+ - **Dataset Information**: Extracts dataset name and local storage directory from the configuration, splits the dataset name into user and model hub components, and saves the dataset to the specified directory.
10
+ - **Verification**: Prints dataset details, including user name, model hub name, storage location, and dataset information for confirmation.
11
+ - **Platform Check**: Only processes datasets from Hugging Face; unsupported platforms are flagged with a message.
12
+
13
+ ## Usage
14
+ Run the script with the command:
15
+ `python script_name.py --config_path path/to/config.yaml`
16
+
17
+ The configuration file should contain:
18
+ - `dataset_name`: Format as `user_name/model_hub_name`.
19
+ - `local_dir`: Directory to save the dataset.
20
+ - `platform`: Must be set to `HuggingFace` for the script to process.
docs/training/training_doc.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ControlNet Training Documentation
2
+
3
+ This document outlines the process for training a ControlNet model using the provided Python scripts (`train.py` and `train_controlnet.py`). The scripts facilitate training a ControlNet model integrated with a Stable Diffusion pipeline for conditional image generation. Below, we describe the training process and provide a detailed table of the command-line arguments used to configure the training.
4
+
5
+ ## Overview
6
+
7
+ The training process involves two main scripts:
8
+ 1. **`train.py`**: A wrapper script that executes `train_controlnet.py` with the provided command-line arguments.
9
+ 2. **`train_controlnet.py`**: The core script that handles the training of the ControlNet model, including dataset preparation, model initialization, training loop, and validation.
10
+
11
+ ### Training Workflow
12
+ 1. **Argument Parsing**: The script parses command-line arguments to configure the training process, such as model paths, dataset details, and hyperparameters.
13
+ 2. **Dataset Preparation**: Loads and preprocesses the dataset (either from HuggingFace Hub or a local directory) with transformations for images and captions.
14
+ 3. **Model Initialization**: Loads pretrained models (e.g., Stable Diffusion, VAE, UNet, text encoder) and initializes or loads ControlNet weights.
15
+ 4. **Training Loop**: Trains the ControlNet model using the Accelerate library for distributed training, with support for mixed precision, gradient checkpointing, and learning rate scheduling.
16
+ 5. **Validation**: Periodically validates the model by generating images using validation prompts and images, logging results to TensorBoard or Weights & Biases.
17
+ 6. **Checkpointing and Saving**: Saves model checkpoints during training and the final model to the output directory. Optionally pushes the model to the HuggingFace Hub.
18
+ 7. **Model Card Creation**: Generates a model card with training details and example images for documentation.
19
+
20
+ ## Command-Line Arguments
21
+
22
+ The following table describes the command-line arguments available in `train_controlnet.py` for configuring the training process:
23
+
24
+ | Argument | Type | Default | Description |
25
+ |----------|------|---------|-------------|
26
+ | `--pretrained_model_name_or_path` | `str` | None | Path to pretrained model or model identifier from huggingface.co/models. Required. |
27
+ | `--controlnet_model_name_or_path` | `str` | None | Path to pretrained ControlNet model or model identifier. If not specified, ControlNet weights are initialized from UNet. |
28
+ | `--revision` | `str` | None | Revision of pretrained model identifier from huggingface.co/models. |
29
+ | `--variant` | `str` | None | Variant of the model files (e.g., 'fp16'). |
30
+ | `--tokenizer_name` | `str` | None | Pretrained tokenizer name or path if different from model_name. |
31
+ | `--output_dir` | `str` | "controlnet-model" | Directory where model predictions and checkpoints are saved. |
32
+ | `--cache_dir` | `str` | None | Directory for storing downloaded models and datasets. |
33
+ | `--seed` | `int` | None | Seed for reproducible training. |
34
+ | `--resolution` | `int` | 512 | Resolution for input images (must be divisible by 8). |
35
+ | `--train_batch_size` | `int` | 4 | Batch size per device for the training dataloader. |
36
+ | `--num_train_epochs` | `int` | 1 | Number of training epochs. |
37
+ | `--max_train_steps` | `int` | None | Total number of training steps. Overrides `num_train_epochs` if provided. |
38
+ | `--checkpointing_steps` | `int` | 500 | Save a checkpoint every X updates. |
39
+ | `--checkpoints_total_limit` | `int` | None | Maximum number of checkpoints to store. |
40
+ | `--resume_from_checkpoint` | `str` | None | Resume training from a previous checkpoint path or "latest". |
41
+ | `--gradient_accumulation_steps` | `int` | 1 | Number of update steps to accumulate before a backward pass. |
42
+ | `--gradient_checkpointing` | `flag` | False | Enable gradient checkpointing to save memory at the cost of slower backward passes. |
43
+ | `--learning_rate` | `float` | 5e-6 | Initial learning rate after warmup. |
44
+ | `--scale_lr` | `flag` | False | Scale learning rate by number of GPUs, gradient accumulation steps, and batch size. |
45
+ | `--lr_scheduler` | `str` | "constant" | Learning rate scheduler type: ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]. |
46
+ | `--lr_warmup_steps` | `int` | 500 | Number of steps for learning rate warmup. |
47
+ | `--lr_num_cycles` | `int` | 1 | Number of hard resets for cosine_with_restarts scheduler. |
48
+ | `--lr_power` | `float` | 1.0 | Power factor for polynomial scheduler. |
49
+ | `--use_8bit_adam` | `flag` | False | Use 8-bit Adam optimizer from bitsandbytes for lower memory usage. |
50
+ | `--dataloader_num_workers` | `int` | 0 | Number of subprocesses for data loading (0 means main process). |
51
+ | `--adam_beta1` | `float` | 0.9 | Beta1 parameter for Adam optimizer. |
52
+ | `--adam_beta2` | `float` | 0.999 | Beta2 parameter for Adam optimizer. |
53
+ | `--adam_weight_decay` | `float` | 1e-2 | Weight decay for Adam optimizer. |
54
+ | `--adam_epsilon` | `float` | 1e-08 | Epsilon value for Adam optimizer. |
55
+ | `--max_grad_norm` | `float` | 1.0 | Maximum gradient norm for clipping. |
56
+ | `--push_to_hub` | `flag` | False | Push the model to the HuggingFace Hub. |
57
+ | `--hub_token` | `str` | None | Token for pushing to the HuggingFace Hub. |
58
+ | `--hub_model_id` | `str` | None | Repository name for syncing with `output_dir`. |
59
+ | `--logging_dir` | `str` | "logs" | TensorBoard log directory. |
60
+ | `--allow_tf32` | `flag` | False | Allow TF32 on Ampere GPUs for faster training. |
61
+ | `--report_to` | `str` | "tensorboard" | Integration for logging: ["tensorboard", "wandb", "comet_ml", "all"]. |
62
+ | `--mixed_precision` | `str` | None | Mixed precision training: ["no", "fp16", "bf16"]. |
63
+ | `--enable_xformers_memory_efficient_attention` | `flag` | False | Enable xformers for memory-efficient attention. |
64
+ | `--set_grads_to_none` | `flag` | False | Set gradients to None instead of zero to save memory. |
65
+ | `--dataset_name` | `str` | None | Name of the dataset from HuggingFace Hub or local path. |
66
+ | `--dataset_config_name` | `str` | None | Dataset configuration name. |
67
+ | `--train_data_dir` | `str` | None | Directory containing training data with `metadata.jsonl`. |
68
+ | `--image_column` | `str` | "image" | Dataset column for target images. |
69
+ | `--conditioning_image_column` | `str` | "conditioning_image" | Dataset column for ControlNet conditioning images. |
70
+ | `--caption_column` | `str` | "text" | Dataset column for captions. |
71
+ | `--max_train_samples` | `int` | None | Truncate training examples to this number for debugging or quicker training. |
72
+ | `--proportion_empty_prompts` | `float` | 0 | Proportion of prompts to replace with empty strings (0 to 1). |
73
+ | `--validation_prompt` | `str` | None | Prompts for validation, evaluated every `validation_steps`. |
74
+ | `--validation_image` | `str` | None | Paths to ControlNet conditioning images for validation. |
75
+ | `--num_validation_images` | `int` | 4 | Number of images generated per validation prompt-image pair. |
76
+ | `--validation_steps` | `int` | 100 | Run validation every X steps. |
77
+ | `--tracker_project_name` | `str` | "train_controlnet" | Project name for Accelerator trackers. |
78
+
79
+ ## Usage Example
80
+
81
+ To train a ControlNet model, run the following command:
82
+
83
+ ```bash
84
+ python src/controlnet_image_generator/train.py \
85
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" \
86
+ --dataset_name="huggingface/controlnet-dataset" \
87
+ --output_dir="controlnet_output" \
88
+ --resolution=512 \
89
+ --train_batch_size=4 \
90
+ --num_train_epochs=3 \
91
+ --learning_rate=1e-5 \
92
+ --validation_prompt="A cat sitting on a chair" \
93
+ --validation_image="path/to/conditioning_image.png" \
94
+ --push_to_hub \
95
+ --hub_model_id="your-username/controlnet-model"
96
+ ```
97
+
98
+ This command trains a ControlNet model using the Stable Diffusion 2.1 pretrained model, a specified dataset, and logs results to the HuggingFace Hub.
99
+
100
+ ## Notes
101
+ - Ensure the dataset contains columns for target images, conditioning images, and captions as specified by `image_column`, `conditioning_image_column`, and `caption_column`.
102
+ - The resolution must be divisible by 8 to ensure compatibility with the VAE and ControlNet encoder.
103
+ - Mixed precision training (`fp16` or `bf16`) can reduce memory usage but requires compatible hardware.
104
+ - Validation images and prompts must be provided in matching quantities or as single values to be reused.
105
+
106
+ For further details, refer to the source scripts or the HuggingFace Diffusers documentation.
notebooks/SD-2.1-Openpose-ControlNet.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ huggingface-hub>=0.33.1
2
+ bitsandbytes>=0.46.0
3
+ diffusers>=0.34.0
4
+ peft>=0.17.0
5
+ controlnet-aux>=0.0.10
6
+ accelerate>=1.7.0
7
+ gradio>=5.39.0
requirements/requirements_compatible.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ huggingface-hub==0.34.1
2
+ bitsandbytes==0.46.0
3
+ diffusers==0.34.0
4
+ peft==0.17.0
5
+ controlnet-aux==0.0.10
6
+ accelerate==1.7.0
7
+ gradio==5.39.0
scripts/download_ckpts.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import yaml
4
+ from huggingface_hub import hf_hub_download, list_repo_files
5
+
6
+ def load_config(config_path):
7
+ with open(config_path, 'r') as file:
8
+ return yaml.safe_load(file)
9
+
10
+ def download_model(model_config):
11
+ model_id = model_config["model_id"]
12
+ local_dir = model_config["local_dir"]
13
+
14
+ if local_dir is None:
15
+ print(f"Skipping download for {model_id}: local_dir is null")
16
+ return
17
+
18
+ os.makedirs(local_dir, exist_ok=True)
19
+
20
+ allow_patterns = model_config.get("allow", [])
21
+ deny_patterns = model_config.get("deny", [])
22
+
23
+ if allow_patterns:
24
+ for file in allow_patterns:
25
+ hf_hub_download(
26
+ repo_id=model_id,
27
+ filename=file,
28
+ local_dir=local_dir,
29
+ local_dir_use_symlinks=False
30
+ )
31
+ else:
32
+ print(f"No allow patterns specified for {model_id}. Attempting to download all files except those in deny list.")
33
+ repo_files = list_repo_files(repo_id=model_id)
34
+ for file in repo_files:
35
+ if not any(deny_pattern in file for deny_pattern in deny_patterns):
36
+ hf_hub_download(
37
+ repo_id=model_id,
38
+ filename=file,
39
+ local_dir=local_dir,
40
+ local_dir_use_symlinks=False
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ parser = argparse.ArgumentParser(description="Download model checkpoints from Hugging Face Hub")
45
+ parser.add_argument(
46
+ "--config_path",
47
+ type=str,
48
+ default="configs/model_ckpts.yaml",
49
+ help="Path to the configuration YAML file"
50
+ )
51
+
52
+ args = parser.parse_args()
53
+
54
+ config = load_config(args.config_path)
55
+
56
+ for model_config in config:
57
+ print(f"Processing {model_config['model_id']} (local_dir: {model_config['local_dir']})")
58
+ download_model(model_config)
scripts/download_datasets.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ from datasets import load_dataset
4
+
5
+
6
+ def load_config(config_path):
7
+ with open(config_path, 'r') as file:
8
+ return yaml.safe_load(file)
9
+
10
+
11
+ def download_huggingface_dataset(config):
12
+ # Get dataset details from config
13
+ dataset_name = config['dataset_name']
14
+ local_dir = config['local_dir']
15
+
16
+ # Split dataset name into user_name and model_hub_name
17
+ user_name, model_hub_name = dataset_name.split('/')
18
+
19
+ # Login using e.g. `huggingface-cli login` to access this dataset
20
+ ds = load_dataset(dataset_name, cache_dir=local_dir)
21
+
22
+ # Print information for verification
23
+ print(f"User Name: {user_name}")
24
+ print(f"Model Hub Name: {model_hub_name}")
25
+ print(f"Dataset saved to: {local_dir}")
26
+ print(f"Dataset info: {ds}")
27
+
28
+
29
+ if __name__ == "__main__":
30
+ # Set up argument parser
31
+ parser = argparse.ArgumentParser(description="Download dataset from Hugging Face")
32
+ parser.add_argument('--config_path',
33
+ type=str,
34
+ default='configs/datasets_info.yaml',
35
+ help='Path to the dataset configuration YAML file')
36
+
37
+ args = parser.parse_args()
38
+
39
+ # Load configuration from YAML file
40
+ configs = load_config(args.config_path)
41
+
42
+ # Iterate through the list of configurations
43
+ for config in configs:
44
+ # Download dataset if platform is HuggingFace
45
+ if config['platform'] == 'HuggingFace':
46
+ download_huggingface_dataset(config)
47
+ else:
48
+ print(f"Unsupported platform: {config['platform']}")
scripts/setup_third_party.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import argparse
5
+
6
+ def setup_diffusers(target_dir):
7
+ # Define paths
8
+ diffusers_dir = os.path.join(target_dir, "diffusers")
9
+
10
+ # Create third_party directory if it doesn't exist
11
+ os.makedirs(target_dir, exist_ok=True)
12
+
13
+ # Check if diffusers already exists in third_party
14
+ if os.path.exists(diffusers_dir):
15
+ print(f"Diffusers already exists in {target_dir}. Skipping clone.")
16
+ return
17
+
18
+ # Clone diffusers repository
19
+ subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers"],
20
+ cwd=target_dir, check=True)
21
+
22
+ # Change to diffusers directory and install
23
+ original_dir = os.getcwd()
24
+ os.chdir(diffusers_dir)
25
+ try:
26
+ subprocess.run(["pip", "install", "-e", "."], check=True)
27
+ finally:
28
+ os.chdir(original_dir)
29
+
30
+ print(f"Diffusers successfully cloned and installed to {diffusers_dir}")
31
+
32
+ if __name__ == "__main__":
33
+ parser = argparse.ArgumentParser(description="Setup diffusers in a specified directory.")
34
+ parser.add_argument("--target-dir", type=str, default="src/third_party",
35
+ help="Target directory to clone diffusers into (default: src)")
36
+
37
+ args = parser.parse_args()
38
+ setup_diffusers(args.target_dir)