diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..0ce3ee943e9e53560cb8db88440c8f9b30003666 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png filter=lfs diff=lfs merge=lfs -text +apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png filter=lfs diff=lfs merge=lfs -text +apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg filter=lfs diff=lfs merge=lfs -text +apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png filter=lfs diff=lfs merge=lfs -text +apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg filter=lfs diff=lfs merge=lfs -text +apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png filter=lfs diff=lfs merge=lfs -text +assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png filter=lfs diff=lfs merge=lfs -text +assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png filter=lfs diff=lfs merge=lfs -text +assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png filter=lfs diff=lfs merge=lfs -text +assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png filter=lfs diff=lfs merge=lfs -text +tests/test_data/a_man_is_doing_yoga_in_a_serene_park_0.png filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..8f14f9d876e19bb68bb9d09215203eca58b90f92 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Danh Tran + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/apps/gradio_app.py b/apps/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca104006d7af1362a590cff9bd9c827a842e0b6 --- /dev/null +++ b/apps/gradio_app.py @@ -0,0 +1,187 @@ +import os +import subprocess +import gradio as gr +import random +from gradio_app.inference import run_inference +from gradio_app.examples import load_examples, select_example +from gradio_app.project_info import ( + NAME, + CONTENT_DESCRIPTION, + CONTENT_IN_1, + CONTENT_OUT_1 +) + +def run_setup_script(): + setup_script = os.path.join(os.path.dirname(__file__), "gradio_app", "setup_scripts.py") + try: + result = subprocess.run(["python", setup_script], capture_output=True, text=True, check=True) + return result.stdout + except subprocess.CalledProcessError as e: + print(f"Setup script failed with error: {e.stderr}") + return f"Setup script failed: {e.stderr}" + +def stop_app(): + """Function to stop the Gradio app.""" + try: + gr.Interface.close_all() # Attempt to close all running Gradio interfaces + return "Application stopped successfully." + except Exception as e: + return f"Error stopping application: {str(e)}" + +def create_gui(): + try: + custom_css = open("apps/gradio_app/static/style.css").read() + except FileNotFoundError: + print("Error: style.css not found at gradio_app/static/style.css") + custom_css = "" # Fallback to empty CSS if file is missing + + with gr.Blocks(css=custom_css) as demo: + gr.Markdown(NAME) + gr.HTML(CONTENT_DESCRIPTION) + gr.HTML(CONTENT_IN_1) + + with gr.Row(): + with gr.Column(scale=2): + input_image = gr.Image(type="filepath", label="Input Image") + prompt = gr.Textbox( + label="Prompt", + value="a man is doing yoga" + ) + negative_prompt = gr.Textbox( + label="Negative Prompt", + value="monochrome, lowres, bad anatomy, worst quality, low quality" + ) + + with gr.Row(): + width = gr.Slider( + minimum=256, + maximum=1024, + value=512, + step=64, + label="Width" + ) + height = gr.Slider( + minimum=256, + maximum=1024, + value=512, + step=64, + label="Height" + ) + + with gr.Accordion("Advanced Settings", open=False): + num_steps = gr.Slider( + minimum=1, + maximum=100, + value=30, + step=1, + label="Number of Inference Steps" + ) + use_random_seed = gr.Checkbox(label="Use Random Seed", value=False) + seed = gr.Slider( + minimum=0, + maximum=2**32 - 1, + value=42, + step=1, + label="Random Seed", + visible=True + ) + + guidance_scale = gr.Slider( + minimum=1.0, + maximum=20.0, + value=7.5, + step=0.1, + label="Guidance Scale" + ) + controlnet_conditioning_scale = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + label="ControlNet Conditioning Scale" + ) + + with gr.Column(scale=3): + output_images = gr.Image(label="Generated Images") + output_message = gr.Textbox(label="Status") + + submit_button = gr.Button("Generate Images", elem_classes="submit-btn") + stop_button = gr.Button("Stop Application", elem_classes="stop-btn") + + def update_seed_visibility(use_random): + return gr.update(visible=not use_random) + + use_random_seed.change( + fn=update_seed_visibility, + inputs=use_random_seed, + outputs=seed + ) + + # Load examples + examples_data = load_examples(os.path.join("apps", "gradio_app", + "assets", "examples", "Stable-Diffusion-2.1-Openpose-ControlNet")) + examples_component = gr.Examples( + examples=examples_data, + inputs=[ + input_image, + prompt, + negative_prompt, + output_images, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + use_random_seed + ], + outputs=[ + input_image, + prompt, + negative_prompt, + output_images, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + use_random_seed, + output_message + ], + fn=select_example, + cache_examples=False, + label="Examples: Yoga Poses" + ) + + submit_button.click( + fn=run_inference, + inputs=[ + input_image, + prompt, + negative_prompt, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + use_random_seed, + ], + outputs=[output_images, output_message] + ) + + stop_button.click( + fn=stop_app, + inputs=[], + outputs=[output_message] + ) + + gr.HTML(CONTENT_OUT_1) + + return demo + +if __name__ == "__main__": + run_setup_script() + demo = create_gui() + demo.launch(share=True) \ No newline at end of file diff --git a/apps/gradio_app/__init__.py b/apps/gradio_app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png new file mode 100644 index 0000000000000000000000000000000000000000..9e8f4643b712061b6d5db7286b6a32fa8cf68f87 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dc2b7efb61afd2d6ceda1b32ec9792a5b07f3ac3d7a96d7acdd2102ddb957b7 +size 367280 diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json new file mode 100644 index 0000000000000000000000000000000000000000..fa4aa5b94bc127d1e6998d0ba4c22eb309438965 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json @@ -0,0 +1,12 @@ +{ + "input_image": "yoga.jpg", + "output_image": "a_man_is_doing_yoga_in_a_serene_park_0.png", + "prompt": "A man is doing yoga in a serene park.", + "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face", + "num_steps": 50, + "seed": 100, + "width": 512, + "height": 512, + "guidance_scale": 5.5, + "controlnet_conditioning_scale": 0.6 +} diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c5f35baf0515a9b66e6388d5c44bb337d6f9366c Binary files /dev/null and b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg differ diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png new file mode 100644 index 0000000000000000000000000000000000000000..4676e42e8654a5401e9cd5c53812e0cda57f3d61 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e83cc3b007c2303e276b3ac60a8fa930877e584e3534f12e1441ec83ed9e9fd +size 1112085 diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b8f49878237f190f7a871a49c1e3f55ab642145a --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json @@ -0,0 +1,12 @@ +{ + "input_image": "ride_bike.jpg", + "output_image": "a_man_is_galloping_on_a_horse_0.png", + "prompt": "A man is galloping on a horse.", + "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face", + "num_steps": 100, + "seed": 56, + "width": 1080, + "height": 720, + "guidance_scale": 9.5, + "controlnet_conditioning_scale": 0.5 +} diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a9c7f76606bc8d2707e6ba30dbacf31ef31bf096 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76310cad16fcf71097c9660d46a95ced0992d48bd92469e83fd25ee59f015998 +size 163547 diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png new file mode 100644 index 0000000000000000000000000000000000000000..d02e377e0572ba785480967ae2c874f949c27578 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a048958e0ed28806ecb7c9834f91b07a464b73cd641fa19b03f39ff542986530 +size 1267884 diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json new file mode 100644 index 0000000000000000000000000000000000000000..bbd04e660f98c1f4ae56dd8b145298555adecb99 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json @@ -0,0 +1,12 @@ +{ + "input_image": "tennis.jpg", + "output_image": "a_woman_is_holding_a_baseball_bat_in_her_hand_0.png", + "prompt": "A woman is holding a baseball bat in her hand.", + "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face", + "num_steps": 100, + "seed": 765, + "width": 990, + "height": 720, + "guidance_scale": 6.5, + "controlnet_conditioning_scale": 0.7 +} diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f1ad5e1632ea54f5e509793ce506c69aa419666 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:259845edb5c365bccb33f9207630d829bb5a839e72bf7d0326f11ae4862694fa +size 5611831 diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png new file mode 100644 index 0000000000000000000000000000000000000000..d5c40b562fb7091444d2762646ac5d87e1997863 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:deaa70aba05ab58ea0f9bd16512c6dcc7e0951559037779063045b7c035342f8 +size 440696 diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json new file mode 100644 index 0000000000000000000000000000000000000000..07e73e90d221ce5bd71925bb50bfbc09261e57f7 --- /dev/null +++ b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json @@ -0,0 +1,12 @@ +{ + "input_image": "man_and_sword.jpg", + "output_image": "a_woman_raises_a_katana_0.png", + "prompt": "A woman raises a katana.", + "negative_prompt": "body elongated, fragmentation, many hands, ugly, deformed face", + "num_steps": 50, + "seed": 78, + "width": 540, + "height": 512, + "guidance_scale": 6.5, + "controlnet_conditioning_scale": 0.8 +} diff --git a/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4212460e2267e6b61768f268e215e7a288247fa Binary files /dev/null and b/apps/gradio_app/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg differ diff --git a/apps/gradio_app/examples.py b/apps/gradio_app/examples.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7eb9d967ef0eda6df52d31e3bc356d58f77cdf --- /dev/null +++ b/apps/gradio_app/examples.py @@ -0,0 +1,99 @@ +import os +import json +from PIL import Image +import gradio as gr + +def load_examples(examples_base_path=os.path.join("apps", "gradio_app", + "assets", "examples", "Stable-Diffusion-2.1-Openpose-ControlNet")): + + """Load example configurations and input images from the Stable-Diffusion-2.1-Openpose-ControlNet directory.""" + examples = [] + + # Iterate through example folders (e.g., '1', '2', '3', '4') + for folder in os.listdir(examples_base_path): + folder_path = os.path.join(examples_base_path, folder) + config_path = os.path.join(folder_path, "config.json") + + if os.path.exists(config_path): + try: + with open(config_path, 'r') as f: + config = json.load(f) + + # Extract configuration fields + input_filename = config["input_image"] + output_filename = config["output_image"] + prompt = config.get("prompt", "a man is doing yoga") + negative_prompt = config.get("negative_prompt", "monochrome, lowres, bad anatomy, worst quality, low quality") + num_steps = config.get("num_steps", 30) + seed = config.get("seed", 42) + width = config.get("width", 512) + height = config.get("height", 512) + guidance_scale = config.get("guidance_scale", 7.5) + controlnet_conditioning_scale = config.get("controlnet_conditioning_scale", 1.0) + + # Construct absolute path for input image + input_image_path = os.path.join(folder_path, input_filename) + output_image_path = os.path.join(folder_path, output_filename) + # Check if input image exists + if os.path.exists(input_image_path): + input_image_data = Image.open(input_image_path) + output_image_data = Image.open(output_image_path) + # Append example data in the order expected by Gradio inputs + examples.append([ + input_image_data, # Input image + prompt, + negative_prompt, + output_image_data, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + False # use_random_seed, hardcoded as per original gr.Examples + ]) + else: + print(f"Input image not found at {input_image_path}") + + except json.JSONDecodeError as e: + print(f"Error decoding JSON from {config_path}: {str(e)}") + except Exception as e: + print(f"Error processing example in {folder_path}: {str(e)}") + + return examples + +def select_example(evt: gr.SelectData, examples_data): + """Handle selection of an example to populate Gradio inputs.""" + example_index = evt.index + # Extract example data + # input_image_data, prompt, negative_prompt, output_image_data, num_steps, seed, width, height, guidance_scale, controlnet_conditioning_scale, use_random_seed = examples_data[example_index] + ( + input_image_data, + prompt, + negative_prompt, + output_image_data, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + use_random_seed, + ) = examples_data[example_index] + + + # Return values to update Gradio interface inputs and output message + return ( + input_image_data, # Input image + prompt, # Prompt + negative_prompt, # Negative prompt + output_image_data, # Output image + num_steps, # Number of inference steps + seed, # Random seed + width, # Width + height, # Height + guidance_scale, # Guidance scale + controlnet_conditioning_scale, # ControlNet conditioning scale + use_random_seed, # Use random seed + f"Loaded example {example_index + 1} with prompt: {prompt}" # Output message + ) \ No newline at end of file diff --git a/apps/gradio_app/inference.py b/apps/gradio_app/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..78ac0f15035354fce30413dd283622f7428fd646 --- /dev/null +++ b/apps/gradio_app/inference.py @@ -0,0 +1,45 @@ +import random +import os +import sys + +# Add the project root directory to the Python path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + +from src.controlnet_image_generator.infer import infer + + +def run_inference( + input_image, + prompt, + negative_prompt, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + use_random_seed=False, +): + config_path = "configs/model_ckpts.yaml" + + if use_random_seed: + seed = random.randint(0, 2 ** 32) + + try: + result = infer( + config_path=config_path, + input_image=input_image, + image_url=None, + prompt=prompt, + negative_prompt=negative_prompt, + num_steps=num_steps, + seed=seed, + width=width, + height=height, + guidance_scale=guidance_scale, + controlnet_conditioning_scale=float(controlnet_conditioning_scale), + ) + result = list(result)[0] + return result, "Inference completed successfully" + except Exception as e: + return [], f"Error during inference: {str(e)}" \ No newline at end of file diff --git a/apps/gradio_app/project_info.py b/apps/gradio_app/project_info.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb0a8e926b69403efa94cc14f3aaf9fa82698cf --- /dev/null +++ b/apps/gradio_app/project_info.py @@ -0,0 +1,37 @@ +NAME = """ +# ControlNet Image Generator 🖌️ +""".strip() + +CONTENT_DESCRIPTION = """ +

ControlNet ⚡️ boosts Stable Diffusion with sharp, innovative image generation control 🖌️

+""".strip() + +# CONTENT_IN_1 = """ +# Transforms low-res anime images into sharp, vibrant HD visuals, enhancing textures and details for artwork and games. +# """.strip() + +CONTENT_IN_1 = """ +

+ For more information, you can check out my GitHub repository and HuggingFace Model Hub:
+ Source code: + + GitHub Repo + , + Model Hub: + + HuggingFace Model + . +

+""".strip() + +CONTENT_OUT_1 = """ +
+

+ This project is built using code from + + Built on Real-ESRGAN + . +

+
+""".strip() + diff --git a/apps/gradio_app/setup_scripts.py b/apps/gradio_app/setup_scripts.py new file mode 100644 index 0000000000000000000000000000000000000000..013a7213fd17420d393bde4b92aa8858a2a167ce --- /dev/null +++ b/apps/gradio_app/setup_scripts.py @@ -0,0 +1,59 @@ +import subprocess +import sys +import os + +def run_script(script_path, args=None): + """ + Run a Python script using subprocess with optional arguments and handle errors. + Returns True if successful, False otherwise. + """ + if not os.path.isfile(script_path): + print(f"Script not found: {script_path}") + return False + + try: + command = [sys.executable, script_path] + if args: + command.extend(args) + result = subprocess.run( + command, + check=True, + text=True, + capture_output=True + ) + print(f"Successfully executed {script_path}") + print(result.stdout) + return True + except subprocess.CalledProcessError as e: + print(f"Error executing {script_path}:") + print(e.stderr) + return False + except Exception as e: + print(f"Unexpected error executing {script_path}: {str(e)}") + return False + +def main(): + """ + Main function to execute download_ckpts.py with proper error handling. + """ + scripts_dir = "scripts" + scripts = [ + { + "path": os.path.join(scripts_dir, "download_ckpts.py"), + "args": [] # Empty list for args to avoid NoneType issues + } + ] + + for script in scripts: + script_path = script["path"] + args = script.get("args", []) # Safely get args with default empty list + print(f"Starting execution of {script_path}{' with args: ' + ' '.join(args) if args else ''}\n") + + if not run_script(script_path, args): + print(f"Stopping execution due to error in {script_path}") + sys.exit(1) + + print(f"Completed execution of {script_path}\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/apps/gradio_app/static/style.css b/apps/gradio_app/static/style.css new file mode 100644 index 0000000000000000000000000000000000000000..68f0d9f4ec0d617b5e76adea56472de5800f4656 --- /dev/null +++ b/apps/gradio_app/static/style.css @@ -0,0 +1,574 @@ +/* @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap'); */ +/* ─── palette ───────────────────────────────────────────── */ +body, .gradio-container { + font-family: 'Inter', sans-serif; + background: #FFFBF7; + color: #0F172A; +} +a { + color: #F97316; + text-decoration: none; + font-weight: 600; +} +a:hover { color: #C2410C; } +/* ─── headline ──────────────────────────────────────────── */ +#titlebar { + text-align: center; + margin-top: 2.4rem; + margin-bottom: .9rem; +} +/* ─── card look ─────────────────────────────────────────── */ +.gr-block, +.gr-box, +.gr-row, +#cite-wrapper { + border: 1px solid #F8C89B; + border-radius: 10px; + background: #fff; + box-shadow: 0 3px 6px rgba(0, 0, 0, .05); +} +.gr-gallery-item { background: #fff; } +/* ─── controls / inputs ─────────────────────────────────── */ +.gr-button-primary, +#copy-btn { + background: linear-gradient(90deg, #F97316 0%, #C2410C 100%); + border: none; + color: #fff; + border-radius: 6px; + font-weight: 600; + transition: transform .12s ease, box-shadow .12s ease; +} +.gr-button-primary:hover, +#copy-btn:hover { + transform: translateY(-2px); + box-shadow: 0 4px 12px rgba(249, 115, 22, .35); +} +.gr-dropdown input { + border: 1px solid #F9731699; +} +.preview img, +.preview canvas { object-fit: contain !important; } +/* ─── hero section ─────────────────────────────────────── */ +#hero-wrapper { text-align: center; } +#hero-badge { + display: inline-block; + padding: .85rem 1.2rem; + border-radius: 8px; + background: #FFEAD2; + border: 1px solid #F9731655; + font-size: .95rem; + font-weight: 600; + margin-bottom: .5rem; +} +#hero-links { + font-size: .95rem; + font-weight: 600; + margin-bottom: 1.6rem; +} +#hero-links img { + height: 22px; + vertical-align: middle; + margin-left: .55rem; +} +/* ─── score area ───────────────────────────────────────── */ +#score-area { + text-align: center; +} +.title-container { + display: flex; + align-items: center; + gap: 12px; + justify-content: center; + margin-bottom: 10px; + text-align: center; +} +.match-badge { + display: inline-block; + padding: .35rem .9rem; + border-radius: 9999px; + font-weight: 600; + font-size: 1.25rem; +} +/* ─── citation card ────────────────────────────────────── */ +#cite-wrapper { + position: relative; + padding: .9rem 1rem; + margin-top: 2rem; +} +#cite-wrapper code { + font-family: SFMono-Regular, Consolas, monospace; + font-size: .84rem; + white-space: pre-wrap; + color: #0F172A; +} +#copy-btn { + position: absolute; + top: .55rem; + right: .6rem; + padding: .18rem .7rem; + font-size: .72rem; + line-height: 1; +} +/* ─── dark mode ────────────────────────────────────── */ +.dark body, +.dark .gradio-container { + background-color: #332a22; + color: #e5e7eb; +} +.dark .gr-block, +.dark .gr-box, +.dark .gr-row { + background-color: #332a22; + border: 1px solid #4b5563; +} +.dark .gr-dropdown input { + background-color: #332a22; + color: #f1f5f9; + border: 1px solid #F97316aa; +} +.dark #hero-badge { + background: #334155; + border: 1px solid #F9731655; + color: #fefefe; +} +.dark #cite-wrapper { + background-color: #473f38; +} +.dark #bibtex { + color: #f8fafc !important; +} +.dark .card { + background-color: #473f38; +} +/* ─── switch logo for light/dark theme ─────────────── */ +.logo-dark { display: none; } +.dark .logo-light { display: none; } +.dark .logo-dark { display: inline; } + +/* https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap */ + +/* cyrillic-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 400; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2'); + unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F; +} +/* cyrillic */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 400; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2'); + unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116; +} +/* greek-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 400; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2'); + unicode-range: U+1F00-1FFF; +} +/* greek */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 400; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2'); + unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF; +} +/* vietnamese */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 400; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2'); + unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB; +} +/* latin-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 400; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2'); + unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF; +} +/* latin */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 400; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2'); + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; +} +/* cyrillic-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 500; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2'); + unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F; +} +/* cyrillic */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 500; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2'); + unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116; +} +/* greek-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 500; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2'); + unicode-range: U+1F00-1FFF; +} +/* greek */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 500; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2'); + unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF; +} +/* vietnamese */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 500; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2'); + unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB; +} +/* latin-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 500; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2'); + unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF; +} +/* latin */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 500; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2'); + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; +} +/* cyrillic-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 600; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2'); + unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F; +} +/* cyrillic */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 600; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2'); + unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116; +} +/* greek-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 600; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2'); + unicode-range: U+1F00-1FFF; +} +/* greek */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 600; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2'); + unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF; +} +/* vietnamese */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 600; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2'); + unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB; +} +/* latin-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 600; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2'); + unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF; +} +/* latin */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 600; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2'); + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; +} +/* cyrillic-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 700; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2'); + unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F; +} +/* cyrillic */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 700; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2'); + unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116; +} +/* greek-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 700; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2'); + unicode-range: U+1F00-1FFF; +} +/* greek */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 700; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2'); + unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF; +} +/* vietnamese */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 700; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2'); + unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB; +} +/* latin-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 700; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2'); + unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF; +} +/* latin */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 700; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2'); + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; +} +/* cyrillic-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 800; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2JL7SUc.woff2) format('woff2'); + unicode-range: U+0460-052F, U+1C80-1C8A, U+20B4, U+2DE0-2DFF, U+A640-A69F, U+FE2E-FE2F; +} +/* cyrillic */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 800; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa0ZL7SUc.woff2) format('woff2'); + unicode-range: U+0301, U+0400-045F, U+0490-0491, U+04B0-04B1, U+2116; +} +/* greek-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 800; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2ZL7SUc.woff2) format('woff2'); + unicode-range: U+1F00-1FFF; +} +/* greek */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 800; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1pL7SUc.woff2) format('woff2'); + unicode-range: U+0370-0377, U+037A-037F, U+0384-038A, U+038C, U+038E-03A1, U+03A3-03FF; +} +/* vietnamese */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 800; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa2pL7SUc.woff2) format('woff2'); + unicode-range: U+0102-0103, U+0110-0111, U+0128-0129, U+0168-0169, U+01A0-01A1, U+01AF-01B0, U+0300-0301, U+0303-0304, U+0308-0309, U+0323, U+0329, U+1EA0-1EF9, U+20AB; +} +/* latin-ext */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 800; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa25L7SUc.woff2) format('woff2'); + unicode-range: U+0100-02BA, U+02BD-02C5, U+02C7-02CC, U+02CE-02D7, U+02DD-02FF, U+0304, U+0308, U+0329, U+1D00-1DBF, U+1E00-1E9F, U+1EF2-1EFF, U+2020, U+20A0-20AB, U+20AD-20C0, U+2113, U+2C60-2C7F, U+A720-A7FF; +} +/* latin */ +@font-face { + font-family: 'Inter'; + font-style: normal; + font-weight: 800; + font-display: swap; + src: url(https://fonts.gstatic.com/s/inter/v19/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2) format('woff2'); + unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; +} + +/* title_css */ +#title { + font-size: 2.6rem; + font-weight: 800; + margin: 0; + line-height: 1.25; + color: #0F172A; +} +/* brand class is passed in title parameter */ +#title .brand { + background: linear-gradient(90deg, #F97316 0%, #C2410C 90%); + -webkit-background-clip: text; + color: transparent; +} +.dark #title { + color: #f8fafc; +} +.title-container { + display: flex; + align-items: center; + gap: 12px; + justify-content: center; + margin-bottom: 10px; + text-align: center; +} + +/* Dark Mode */ +@media (prefers-color-scheme: dark) { + body { @extend .dark; } +} +/* Smaller size for input image */ +.input-image img { + max-width: 300px; + height: auto; +} +/* Larger size for output image */ +.output-image img { + max-width: 500px; + height: auto; +} + +/* Add styling for warning message */ +.warning-message { + color: red; + font-size: 14px; + margin-top: 5px; + display: block; +} +#warning-text { + min-height: 20px; /* Ensure space for warning */ +} +/*Components for Gradio App*/ +.quote-container { + border-left: 5px solid #007bff; + padding-left: 15px; + margin-bottom: 15px; + font-style: italic; +} +.attribution p { + margin: 10px 0; +} +.badge { + display: inline-block; + border-radius: 4px; + text-decoration: none; + font-size: 14px; + transition: background-color 0.3s; +} +.badge:hover { + background-color: #0056b3; +} +.badge img { + vertical-align: middle; + margin-right: 5px; +} +.source { + font-size: 14px; +} + +/* Start- Stop Buttons */ +.submit-btn { + background-color: #f97316; /* Green background */ + color: white; + font-weight: bold; + padding: 8px 16px; + border-radius: 6px; + border: none; + cursor: pointer; + transition: background-color 0.3s ease; +} + +.submit-btn:hover { + background-color: #f97416de; /* Darker green on hover */ +} + +.stop-btn { + background-color: grey; /* Red background */ + color: white; + font-weight: 600; + padding: 8px 16px; + border-radius: 6px; + border: none; + cursor: pointer; + transition: background-color 0.3s ease; +} + +.stop-btn:hover { + background-color: rgba(128, 128, 128, 0.858); /* Darker red on hover */ +} \ No newline at end of file diff --git a/apps/old-gradio_app.py b/apps/old-gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..df3493b06ca46bcdeeed11a7568fc0d8711a1439 --- /dev/null +++ b/apps/old-gradio_app.py @@ -0,0 +1,177 @@ +import os +import sys +import subprocess +import gradio as gr +import torch +import random + +# Add the project root directory to the Python path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from src.controlnet_image_generator.infer import infer + +def run_setup_script(): + setup_script = os.path.join(os.path.dirname(__file__), "gradio_app", "setup_scripts.py") + try: + result = subprocess.run(["python", setup_script], capture_output=True, text=True, check=True) + return result.stdout + except subprocess.CalledProcessError as e: + print(f"Setup script failed with error: {e.stderr}") + return f"Setup script failed: {e.stderr}" + +def run_inference( + input_image, + prompt, + negative_prompt, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + use_random_seed=False, +): + config_path = "configs/model_ckpts.yaml" + + if use_random_seed: + seed = random.randint(0, 2 ** 32) + + try: + result = infer( + config_path=config_path, + input_image=input_image, + image_url=None, + prompt=prompt, + negative_prompt=negative_prompt, + num_steps=num_steps, + seed=seed, + width=width, + height=height, + guidance_scale=guidance_scale, + controlnet_conditioning_scale=float(controlnet_conditioning_scale), + ) + result = list(result)[0] + return result, "Inference completed successfully" + except Exception as e: + return [], f"Error during inference: {str(e)}" + +def stop_app(): + """Function to stop the Gradio app.""" + try: + gr.Interface.close_all() # Attempt to close all running Gradio interfaces + return "Application stopped successfully." + except Exception as e: + return f"Error stopping application: {str(e)}" + +def create_gui(): + cuscustom_css = open("apps/gradio_app/static/style.css").read() + with gr.Blocks(css=cuscustom_css) as demo: + gr.Markdown("# ControlNet Image Generation with Pose Detection") + + with gr.Row(): + with gr.Column(): + input_image = gr.Image(type="filepath", label="Input Image") + prompt = gr.Textbox( + label="Prompt", + value="a man is doing yoga" + ) + negative_prompt = gr.Textbox( + label="Negative Prompt", + value="monochrome, lowres, bad anatomy, worst quality, low quality" + ) + + with gr.Row(): + width = gr.Slider( + minimum=256, + maximum=1024, + value=512, + step=64, + label="Width" + ) + height = gr.Slider( + minimum=256, + maximum=1024, + value=512, + step=64, + label="Height" + ) + + with gr.Accordion("Advanced Settings", open=False): + num_steps = gr.Slider( + minimum=1, + maximum=100, + value=30, + step=1, + label="Number of Inference Steps" + ) + use_random_seed = gr.Checkbox(label="Use Random Seed", value=False) + seed = gr.Slider( + minimum=0, + maximum=2**32, + value=42, + step=1, + label="Random Seed", + visible=True + ) + + guidance_scale = gr.Slider( + minimum=1.0, + maximum=20.0, + value=7.5, + step=0.1, + label="Guidance Scale" + ) + controlnet_conditioning_scale = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + label="ControlNet Conditioning Scale" + ) + + with gr.Column(): + output_images = gr.Image(label="Generated Images") + output_message = gr.Textbox(label="Status") + + # with gr.Row(): + submit_button = gr.Button("Generate Images", elem_classes="submit-btn") + stop_button = gr.Button("Stop Application", elem_classes="stop-btn") + + def update_seed_visibility(use_random): + return gr.update(visible=not use_random) + + use_random_seed.change( + fn=update_seed_visibility, + inputs=use_random_seed, + outputs=seed + ) + + submit_button.click( + fn=run_inference, + inputs=[ + input_image, + prompt, + negative_prompt, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + use_random_seed, + ], + outputs=[output_images, output_message] + ) + + stop_button.click( + fn=stop_app, + inputs=[], + outputs=[output_message] + ) + + return demo + +if __name__ == "__main__": + run_setup_script() + demo = create_gui() + demo.launch(share=True) \ No newline at end of file diff --git a/assets/.gitkeep b/assets/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png new file mode 100644 index 0000000000000000000000000000000000000000..9e8f4643b712061b6d5db7286b6a32fa8cf68f87 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/a_man_is_doing_yoga_in_a_serene_park_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dc2b7efb61afd2d6ceda1b32ec9792a5b07f3ac3d7a96d7acdd2102ddb957b7 +size 367280 diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json new file mode 100644 index 0000000000000000000000000000000000000000..fa4aa5b94bc127d1e6998d0ba4c22eb309438965 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/config.json @@ -0,0 +1,12 @@ +{ + "input_image": "yoga.jpg", + "output_image": "a_man_is_doing_yoga_in_a_serene_park_0.png", + "prompt": "A man is doing yoga in a serene park.", + "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face", + "num_steps": 50, + "seed": 100, + "width": 512, + "height": 512, + "guidance_scale": 5.5, + "controlnet_conditioning_scale": 0.6 +} diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c5f35baf0515a9b66e6388d5c44bb337d6f9366c Binary files /dev/null and b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/1/yoga.jpg differ diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png new file mode 100644 index 0000000000000000000000000000000000000000..4676e42e8654a5401e9cd5c53812e0cda57f3d61 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/a_man_is_galloping_on_a_horse_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e83cc3b007c2303e276b3ac60a8fa930877e584e3534f12e1441ec83ed9e9fd +size 1112085 diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b8f49878237f190f7a871a49c1e3f55ab642145a --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/config.json @@ -0,0 +1,12 @@ +{ + "input_image": "ride_bike.jpg", + "output_image": "a_man_is_galloping_on_a_horse_0.png", + "prompt": "A man is galloping on a horse.", + "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face", + "num_steps": 100, + "seed": 56, + "width": 1080, + "height": 720, + "guidance_scale": 9.5, + "controlnet_conditioning_scale": 0.5 +} diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a9c7f76606bc8d2707e6ba30dbacf31ef31bf096 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/2/ride_bike.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76310cad16fcf71097c9660d46a95ced0992d48bd92469e83fd25ee59f015998 +size 163547 diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png new file mode 100644 index 0000000000000000000000000000000000000000..d02e377e0572ba785480967ae2c874f949c27578 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/a_woman_is_holding_a_baseball_bat_in_her_hand_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a048958e0ed28806ecb7c9834f91b07a464b73cd641fa19b03f39ff542986530 +size 1267884 diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json new file mode 100644 index 0000000000000000000000000000000000000000..bbd04e660f98c1f4ae56dd8b145298555adecb99 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/config.json @@ -0,0 +1,12 @@ +{ + "input_image": "tennis.jpg", + "output_image": "a_woman_is_holding_a_baseball_bat_in_her_hand_0.png", + "prompt": "A woman is holding a baseball bat in her hand.", + "negative_prompt": "monochrome, lowres, bad anatomy, ugly, deformed face", + "num_steps": 100, + "seed": 765, + "width": 990, + "height": 720, + "guidance_scale": 6.5, + "controlnet_conditioning_scale": 0.7 +} diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f1ad5e1632ea54f5e509793ce506c69aa419666 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/3/tennis.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:259845edb5c365bccb33f9207630d829bb5a839e72bf7d0326f11ae4862694fa +size 5611831 diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png new file mode 100644 index 0000000000000000000000000000000000000000..d5c40b562fb7091444d2762646ac5d87e1997863 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/a_woman_raises_a_katana_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:deaa70aba05ab58ea0f9bd16512c6dcc7e0951559037779063045b7c035342f8 +size 440696 diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json new file mode 100644 index 0000000000000000000000000000000000000000..07e73e90d221ce5bd71925bb50bfbc09261e57f7 --- /dev/null +++ b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/config.json @@ -0,0 +1,12 @@ +{ + "input_image": "man_and_sword.jpg", + "output_image": "a_woman_raises_a_katana_0.png", + "prompt": "A woman raises a katana.", + "negative_prompt": "body elongated, fragmentation, many hands, ugly, deformed face", + "num_steps": 50, + "seed": 78, + "width": 540, + "height": 512, + "guidance_scale": 6.5, + "controlnet_conditioning_scale": 0.8 +} diff --git a/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4212460e2267e6b61768f268e215e7a288247fa Binary files /dev/null and b/assets/examples/Stable-Diffusion-2.1-Openpose-ControlNet/4/man_and_sword.jpg differ diff --git a/ckpts/.gitignore b/ckpts/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a3a0c8b5f48c0260a4cb43aa577f9b18896ee280 --- /dev/null +++ b/ckpts/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/configs/.gitkeep b/configs/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/datasets_info.yaml b/configs/datasets_info.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8574c745713e008958073a00dfb402415c8d2d5a --- /dev/null +++ b/configs/datasets_info.yaml @@ -0,0 +1,3 @@ +- dataset_name: "HighCWu/open_pose_controlnet_subset" + local_dir: "HighCWu-open_pose_controlnet_subset" + platform: "HuggingFace" \ No newline at end of file diff --git a/configs/model_ckpts.yaml b/configs/model_ckpts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..03f5524bf4bfb18b5af5cfeb7d33100b7f9873f6 --- /dev/null +++ b/configs/model_ckpts.yaml @@ -0,0 +1,16 @@ +- model_id: "danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet" + local_dir: "ckpts/Stable-Diffusion-2.1-Openpose-ControlNet" + allow: + - diffusion_pytorch_model.safetensors + - config.json + +- model_id: "stabilityai/stable-diffusion-2-1" + local_dir: "ckpts/stable-diffusion-2-1" + deny: + - v2-1_768-ema-pruned.ckpt + - v2-1_768-ema-pruned.safetensors + - v2-1_768-nonema-pruned.ckpt + - v2-1_768-nonema-pruned.safetensors + +- model_id: "lllyasviel/ControlNet" + local_dir: null diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a3a0c8b5f48c0260a4cb43aa577f9b18896ee280 --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/docs/inference/inference_doc.md b/docs/inference/inference_doc.md new file mode 100644 index 0000000000000000000000000000000000000000..64501efa103f950de250b2d3b012089e84b23c12 --- /dev/null +++ b/docs/inference/inference_doc.md @@ -0,0 +1,176 @@ +# ControlNet Image Generation with Pose Detection + +This document provides a comprehensive overview of a Python script designed for image generation using ControlNet with pose detection, integrated with the Stable Diffusion model. The script processes an input image to detect human poses and generates new images based on a text prompt, guided by the detected poses. + +## Purpose + +The script enables users to generate images that adhere to specific poses extracted from an input image, combining the power of ControlNet for pose conditioning with Stable Diffusion for high-quality image synthesis. It is particularly useful for applications requiring pose-guided image generation, such as creating stylized images of people in specific poses (e.g., yoga, dancing) based on a reference image. + +## Dependencies + +The script relies on the following Python libraries and custom modules: + +- **Standard Libraries**: + - `torch`: For tensor operations and deep learning model handling. + - `argparse`: For parsing command-line arguments. + - `os`: For file and directory operations. + - `sys`: For modifying the Python path to include the project root. + +- **Custom Modules** (assumed to be part of the project structure): + - `inference.config_loader`: + - `load_config`: Loads model configurations from a YAML file. + - `find_config_by_model_id`: Retrieves specific model configurations by ID. + - `inference.model_initializer`: + - `initialize_controlnet`: Initializes the ControlNet model. + - `initialize_pipeline`: Initializes the Stable Diffusion pipeline. + - `initialize_controlnet_detector`: Initializes the pose detection model. + - `inference.device_manager`: + - `setup_device`: Configures the computation device (e.g., CPU or GPU). + - `inference.image_processor`: + - `load_input_image`: Loads the input image from a local path or URL. + - `detect_poses`: Detects human poses in the input image. + - `inference.image_generator`: + - `generate_images`: Generates images using the pipeline and pose conditions. + - `save_images`: Saves generated images to the specified directory. + +## Script Structure + +The script is organized into the following components: + +1. **Imports and Path Setup**: + - Imports necessary libraries and adds the project root directory to the Python path for accessing custom modules. + - Ensures the script can locate custom modules regardless of the execution context. + +2. **Global Variables**: + - Defines three global variables to cache initialized models: + - `controlnet_detector`: For pose detection. + - `controlnet`: For pose-guided conditioning. + - `pipe`: The Stable Diffusion pipeline. + - These variables persist across multiple calls to the `infer` function to avoid redundant model initialization. + +3. **Main Function: `infer`**: + - The core function that orchestrates the image generation process. + - Takes configurable parameters for input, model settings, and output options. + +4. **Command-Line Interface**: + - Uses `argparse` to provide a user-friendly interface for running the script with customizable parameters. + +## Main Function: `infer` + +The `infer` function handles the end-to-end process of loading models, processing input images, detecting poses, generating images, and optionally saving the results. + +### Parameters + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `config_path` | `str` | Path to the configuration YAML file. | `"configs/model_ckpts.yaml"` | +| `input_image` | `str` | Path to the local input image. Mutually exclusive with `image_url`. | `None` | +| `image_url` | `str` | URL of the input image. Mutually exclusive with `input_image`. | `None` | +| `prompt` | `str` | Text prompt for image generation. | `"a man is doing yoga"` | +| `negative_prompt` | `str` | Negative prompt to avoid undesired features. | `"monochrome, lowres, bad anatomy, worst quality, low quality"` | +| `num_steps` | `int` | Number of inference steps. | `20` | +| `seed` | `int` | Random seed for reproducibility. | `2` | +| `width` | `int` | Width of the generated image (pixels). | `512` | +| `height` | `int` | Height of the generated image (pixels). | `512` | +| `guidance_scale` | `float` | Guidance scale for prompt adherence. | `7.5` | +| `controlnet_conditioning_scale` | `float` | ControlNet conditioning scale for pose influence. | `1.0` | +| `output_dir` | `str` | Directory to save generated images. | `tests/test_data` | +| `use_prompt_as_output_name` | `bool` | Use prompt in output filenames. | `False` | +| `save_output` | `bool` | Save generated images to `output_dir`. | `False` | + +### Workflow + +1. **Configuration Loading**: + - Loads model configurations from `config_path` using `load_config`. + - Retrieves specific configurations for: + - Pose detection model (`lllyasviel/ControlNet`). + - ControlNet model (`danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet`). + - Stable Diffusion pipeline (`stabilityai/stable-diffusion-2-1`). + +2. **Model Initialization**: + - Checks if `controlnet_detector`, `controlnet`, or `pipe` are `None`. + - If `None`, initializes them using the respective configurations to avoid redundant loading. + +3. **Device Setup**: + - Configures the computation device (e.g., CPU or GPU) for the pipeline using `setup_device`. + +4. **Image Processing**: + - Loads the input image from either `input_image` or `image_url` using `load_input_image`. + - Detects poses in the input image using `detect_poses` with the `controlnet_detector`. + +5. **Image Generation**: + - Creates a list of random number generators seeded with `seed + i` for each detected pose. + - Generates images using `generate_images`, passing: + - The pipeline (`pipe`). + - Repeated prompts and negative prompts for each pose. + - Detected poses as conditioning inputs. + - Generators for reproducibility. + - Parameters like `num_steps`, `guidance_scale`, `controlnet_conditioning_scale`, `width`, and `height`. + +6. **Output Handling**: + - If `save_output` is `True`, saves the generated images to `output_dir` using `save_images`. + - If `use_prompt_as_output_name` is `True`, incorporates the prompt into the output filenames. + - Returns the list of generated images. + +## Command-Line Interface + +The script includes a command-line interface using `argparse` for flexible execution. + +### Arguments Table + +| Argument | Type | Default Value | Description | +|----------|------|---------------|-------------| +| `--input_image` | `str` | `tests/test_data/yoga1.jpg` | Path to the local input image. Mutually exclusive with `--image_url`. | +| `--image_url` | `str` | `None` | URL of the input image (e.g., `https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg`). Mutually exclusive with `--input_image`. | +| `--config_path` | `str` | `configs/model_ckpts.yaml` | Path to the configuration YAML file for model settings. | +| `--prompt` | `str` | `"a man is doing yoga"` | Text prompt for image generation. | +| `--negative_prompt` | `str` | `"monochrome, lowres, bad anatomy, worst quality, low quality"` | Negative prompt to avoid undesired features in generated images. | +| `--num_steps` | `int` | `20` | Number of inference steps for image generation. | +| `--seed` | `int` | `2` | Random seed for reproducible generation. | +| `--width` | `int` | `512` | Width of the generated image in pixels. | +| `--height` | `int` | `512` | Height of the generated image in pixels. | +| `--guidance_scale` | `float` | `7.5` | Guidance scale for prompt adherence during generation. | +| `--controlnet_conditioning_scale` | `float` | `1.0` | ControlNet conditioning scale to balance pose influence. | +| `--output_dir` | `str` | `tests/test_data` | Directory to save generated images. | +| `--use_prompt_as_output_name` | Flag | `False` | If set, incorporates the prompt into output image filenames. | +| `--save_output` | Flag | `False` | If set, saves generated images to the specified output directory. | + +### Example Usage + +```bash +python script.py --input_image tests/test_data/yoga1.jpg --prompt "a woman doing yoga in a park" --num_steps 30 --guidance_scale 8.0 --save_output --use_prompt_as_output_name +``` + +This command: +- Uses the local image `tests/test_data/yoga1.jpg` as input. +- Generates images with the prompt `"a woman doing yoga in a park"`. +- Runs for 30 inference steps with a guidance scale of 8.0. +- Saves the output images to `tests/test_data`, with filenames including the prompt. + +Alternatively, using a URL: + +```bash +python script.py --image_url https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg --prompt "a person practicing yoga at sunset" --save_output +``` + +This command uses an online image and saves the generated images without using the prompt in filenames. + +## Notes + +- **Configuration File**: The script assumes a `configs/model_ckpts.yaml` file exists with configurations for the required models (`lllyasviel/ControlNet`, `danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet`, `stabilityai/stable-diffusion-2-1`). Ensure this file is correctly formatted and accessible. +- **Input Requirements**: The input image (local or URL) should contain at least one person for effective pose detection. +- **Model Caching**: Global variables cache the models to improve performance for multiple inferences within the same session. +- **Device Compatibility**: The `setup_device` function determines the computation device. Ensure compatible hardware (e.g., GPU) is available for optimal performance. +- **Output Flexibility**: The script supports generating multiple images if multiple poses are detected, with each image conditioned on one pose. +- **Error Handling**: The script assumes the custom modules handle errors appropriately. Users should verify that input paths, URLs, and model configurations are valid. + +## Potential Improvements + +- Add error handling for invalid inputs or missing configuration files. +- Support batch processing for multiple input images. +- Allow dynamic model selection via command-line arguments instead of hardcoded model IDs. +- Include options for adjusting pose detection sensitivity or other model-specific parameters. + +## Conclusion + +This script provides a robust framework for pose-guided image generation using ControlNet and Stable Diffusion. Its modular design and command-line interface make it suitable for both one-off experiments and integration into larger workflows. By leveraging pre-trained models and customizable parameters, it enables users to generate high-quality, pose-conditioned images with minimal setup. \ No newline at end of file diff --git a/docs/scripts/download_ckpts_doc.md b/docs/scripts/download_ckpts_doc.md new file mode 100644 index 0000000000000000000000000000000000000000..705c9521d853c9ca57f2f661c7de293ac8927820 --- /dev/null +++ b/docs/scripts/download_ckpts_doc.md @@ -0,0 +1,29 @@ +# Download Model Checkpoints + +This script downloads model checkpoints from the Hugging Face Hub based on configurations specified in a YAML file. + +## Functionality +- **Load Configuration**: Reads a YAML configuration file to get model details. +- **Download Model**: Downloads files for specified models from the Hugging Face Hub to a local directory. + - Checks for a valid `local_dir` in the configuration; skips download if `local_dir` is null. + - Creates the local directory if it doesn't exist. + - Supports `allow` and `deny` patterns to filter files: + - If `allow` patterns are specified, only those files are downloaded. + - If no `allow` patterns are provided, all files are downloaded except those matching `deny` patterns. + - Uses `hf_hub_download` from the `huggingface_hub` library with symlinks disabled. + +## Command-Line Arguments +- `--config_path`: Path to the YAML configuration file (defaults to `configs/model_ckpts.yaml`). + +## Dependencies +- `argparse`: For parsing command-line arguments. +- `os`: For directory creation. +- `yaml`: For reading the configuration file. +- `huggingface_hub`: For downloading files from the Hugging Face Hub. + +## Usage +Run the script with: +```bash +python scripts/download_ckpts.py --config_path +``` +The script processes each model in the configuration file, printing the model ID and local directory for each. \ No newline at end of file diff --git a/docs/scripts/download_datasets_doc.md b/docs/scripts/download_datasets_doc.md new file mode 100644 index 0000000000000000000000000000000000000000..babb1a09710ac8a998280daa677e3803d37145cc --- /dev/null +++ b/docs/scripts/download_datasets_doc.md @@ -0,0 +1,20 @@ +# Download Datasets + +This script downloads datasets from Hugging Face using configuration details specified in a YAML file. + +## Functionality +- **Load Configuration**: Reads dataset details from a YAML configuration file. +- **Download Dataset**: Downloads datasets from Hugging Face if the platform is specified as 'HuggingFace' in the configuration. +- **Command-Line Argument**: Accepts a path to the configuration file via the `--config_path` argument (defaults to `configs/datasets_info.yaml`). +- **Dataset Information**: Extracts dataset name and local storage directory from the configuration, splits the dataset name into user and model hub components, and saves the dataset to the specified directory. +- **Verification**: Prints dataset details, including user name, model hub name, storage location, and dataset information for confirmation. +- **Platform Check**: Only processes datasets from Hugging Face; unsupported platforms are flagged with a message. + +## Usage +Run the script with the command: +`python script_name.py --config_path path/to/config.yaml` + +The configuration file should contain: +- `dataset_name`: Format as `user_name/model_hub_name`. +- `local_dir`: Directory to save the dataset. +- `platform`: Must be set to `HuggingFace` for the script to process. \ No newline at end of file diff --git a/docs/training/training_doc.md b/docs/training/training_doc.md new file mode 100644 index 0000000000000000000000000000000000000000..1a15ee4f1b77d5b0b2e1f21030b68a3b2018068a --- /dev/null +++ b/docs/training/training_doc.md @@ -0,0 +1,106 @@ +# ControlNet Training Documentation + +This document outlines the process for training a ControlNet model using the provided Python scripts (`train.py` and `train_controlnet.py`). The scripts facilitate training a ControlNet model integrated with a Stable Diffusion pipeline for conditional image generation. Below, we describe the training process and provide a detailed table of the command-line arguments used to configure the training. + +## Overview + +The training process involves two main scripts: +1. **`train.py`**: A wrapper script that executes `train_controlnet.py` with the provided command-line arguments. +2. **`train_controlnet.py`**: The core script that handles the training of the ControlNet model, including dataset preparation, model initialization, training loop, and validation. + +### Training Workflow +1. **Argument Parsing**: The script parses command-line arguments to configure the training process, such as model paths, dataset details, and hyperparameters. +2. **Dataset Preparation**: Loads and preprocesses the dataset (either from HuggingFace Hub or a local directory) with transformations for images and captions. +3. **Model Initialization**: Loads pretrained models (e.g., Stable Diffusion, VAE, UNet, text encoder) and initializes or loads ControlNet weights. +4. **Training Loop**: Trains the ControlNet model using the Accelerate library for distributed training, with support for mixed precision, gradient checkpointing, and learning rate scheduling. +5. **Validation**: Periodically validates the model by generating images using validation prompts and images, logging results to TensorBoard or Weights & Biases. +6. **Checkpointing and Saving**: Saves model checkpoints during training and the final model to the output directory. Optionally pushes the model to the HuggingFace Hub. +7. **Model Card Creation**: Generates a model card with training details and example images for documentation. + +## Command-Line Arguments + +The following table describes the command-line arguments available in `train_controlnet.py` for configuring the training process: + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--pretrained_model_name_or_path` | `str` | None | Path to pretrained model or model identifier from huggingface.co/models. Required. | +| `--controlnet_model_name_or_path` | `str` | None | Path to pretrained ControlNet model or model identifier. If not specified, ControlNet weights are initialized from UNet. | +| `--revision` | `str` | None | Revision of pretrained model identifier from huggingface.co/models. | +| `--variant` | `str` | None | Variant of the model files (e.g., 'fp16'). | +| `--tokenizer_name` | `str` | None | Pretrained tokenizer name or path if different from model_name. | +| `--output_dir` | `str` | "controlnet-model" | Directory where model predictions and checkpoints are saved. | +| `--cache_dir` | `str` | None | Directory for storing downloaded models and datasets. | +| `--seed` | `int` | None | Seed for reproducible training. | +| `--resolution` | `int` | 512 | Resolution for input images (must be divisible by 8). | +| `--train_batch_size` | `int` | 4 | Batch size per device for the training dataloader. | +| `--num_train_epochs` | `int` | 1 | Number of training epochs. | +| `--max_train_steps` | `int` | None | Total number of training steps. Overrides `num_train_epochs` if provided. | +| `--checkpointing_steps` | `int` | 500 | Save a checkpoint every X updates. | +| `--checkpoints_total_limit` | `int` | None | Maximum number of checkpoints to store. | +| `--resume_from_checkpoint` | `str` | None | Resume training from a previous checkpoint path or "latest". | +| `--gradient_accumulation_steps` | `int` | 1 | Number of update steps to accumulate before a backward pass. | +| `--gradient_checkpointing` | `flag` | False | Enable gradient checkpointing to save memory at the cost of slower backward passes. | +| `--learning_rate` | `float` | 5e-6 | Initial learning rate after warmup. | +| `--scale_lr` | `flag` | False | Scale learning rate by number of GPUs, gradient accumulation steps, and batch size. | +| `--lr_scheduler` | `str` | "constant" | Learning rate scheduler type: ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]. | +| `--lr_warmup_steps` | `int` | 500 | Number of steps for learning rate warmup. | +| `--lr_num_cycles` | `int` | 1 | Number of hard resets for cosine_with_restarts scheduler. | +| `--lr_power` | `float` | 1.0 | Power factor for polynomial scheduler. | +| `--use_8bit_adam` | `flag` | False | Use 8-bit Adam optimizer from bitsandbytes for lower memory usage. | +| `--dataloader_num_workers` | `int` | 0 | Number of subprocesses for data loading (0 means main process). | +| `--adam_beta1` | `float` | 0.9 | Beta1 parameter for Adam optimizer. | +| `--adam_beta2` | `float` | 0.999 | Beta2 parameter for Adam optimizer. | +| `--adam_weight_decay` | `float` | 1e-2 | Weight decay for Adam optimizer. | +| `--adam_epsilon` | `float` | 1e-08 | Epsilon value for Adam optimizer. | +| `--max_grad_norm` | `float` | 1.0 | Maximum gradient norm for clipping. | +| `--push_to_hub` | `flag` | False | Push the model to the HuggingFace Hub. | +| `--hub_token` | `str` | None | Token for pushing to the HuggingFace Hub. | +| `--hub_model_id` | `str` | None | Repository name for syncing with `output_dir`. | +| `--logging_dir` | `str` | "logs" | TensorBoard log directory. | +| `--allow_tf32` | `flag` | False | Allow TF32 on Ampere GPUs for faster training. | +| `--report_to` | `str` | "tensorboard" | Integration for logging: ["tensorboard", "wandb", "comet_ml", "all"]. | +| `--mixed_precision` | `str` | None | Mixed precision training: ["no", "fp16", "bf16"]. | +| `--enable_xformers_memory_efficient_attention` | `flag` | False | Enable xformers for memory-efficient attention. | +| `--set_grads_to_none` | `flag` | False | Set gradients to None instead of zero to save memory. | +| `--dataset_name` | `str` | None | Name of the dataset from HuggingFace Hub or local path. | +| `--dataset_config_name` | `str` | None | Dataset configuration name. | +| `--train_data_dir` | `str` | None | Directory containing training data with `metadata.jsonl`. | +| `--image_column` | `str` | "image" | Dataset column for target images. | +| `--conditioning_image_column` | `str` | "conditioning_image" | Dataset column for ControlNet conditioning images. | +| `--caption_column` | `str` | "text" | Dataset column for captions. | +| `--max_train_samples` | `int` | None | Truncate training examples to this number for debugging or quicker training. | +| `--proportion_empty_prompts` | `float` | 0 | Proportion of prompts to replace with empty strings (0 to 1). | +| `--validation_prompt` | `str` | None | Prompts for validation, evaluated every `validation_steps`. | +| `--validation_image` | `str` | None | Paths to ControlNet conditioning images for validation. | +| `--num_validation_images` | `int` | 4 | Number of images generated per validation prompt-image pair. | +| `--validation_steps` | `int` | 100 | Run validation every X steps. | +| `--tracker_project_name` | `str` | "train_controlnet" | Project name for Accelerator trackers. | + +## Usage Example + +To train a ControlNet model, run the following command: + +```bash +python src/controlnet_image_generator/train.py \ + --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" \ + --dataset_name="huggingface/controlnet-dataset" \ + --output_dir="controlnet_output" \ + --resolution=512 \ + --train_batch_size=4 \ + --num_train_epochs=3 \ + --learning_rate=1e-5 \ + --validation_prompt="A cat sitting on a chair" \ + --validation_image="path/to/conditioning_image.png" \ + --push_to_hub \ + --hub_model_id="your-username/controlnet-model" +``` + +This command trains a ControlNet model using the Stable Diffusion 2.1 pretrained model, a specified dataset, and logs results to the HuggingFace Hub. + +## Notes +- Ensure the dataset contains columns for target images, conditioning images, and captions as specified by `image_column`, `conditioning_image_column`, and `caption_column`. +- The resolution must be divisible by 8 to ensure compatibility with the VAE and ControlNet encoder. +- Mixed precision training (`fp16` or `bf16`) can reduce memory usage but requires compatible hardware. +- Validation images and prompts must be provided in matching quantities or as single values to be reused. + +For further details, refer to the source scripts or the HuggingFace Diffusers documentation. \ No newline at end of file diff --git a/notebooks/SD-2.1-Openpose-ControlNet.ipynb b/notebooks/SD-2.1-Openpose-ControlNet.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d174f5fd9d98c317de82b4d18cb4c1845c13ac0a --- /dev/null +++ b/notebooks/SD-2.1-Openpose-ControlNet.ipynb @@ -0,0 +1,1723 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup Environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "if os.path.basename(os.getcwd()) == \"notebooks\":\n", + " os.chdir(\"..\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2025-07-08T09:44:37.937172Z", + "iopub.status.busy": "2025-07-08T09:44:37.936893Z", + "iopub.status.idle": "2025-07-08T09:49:21.948303Z", + "shell.execute_reply": "2025-07-08T09:49:21.947471Z", + "shell.execute_reply.started": "2025-07-08T09:44:37.937153Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.8/410.8 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m72.9/72.9 MB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0mm00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:04\u001b[0mm\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0mm\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCloning into 'diffusers'...\n", + "remote: Enumerating objects: 98145, done.\u001b[K\n", + "remote: Counting objects: 100% (222/222), done.\u001b[K\n", + "remote: Compressing objects: 100% (161/161), done.\u001b[K\n", + "remote: Total 98145 (delta 107), reused 91 (delta 52), pack-reused 97923 (from 3)\u001b[K\n", + "Receiving objects: 100% (98145/98145), 73.31 MiB | 22.71 MiB/s, done.\n", + "Resolving deltas: 100% (72392/72392), done.\n", + "/kaggle/working/diffusers\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n", + " Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Building editable for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "# !pip install -q git+https://github.com/huggingface/diffusers.git\n", + "!pip install -q peft==0.15.0 bitsandbytes\n", + "!git clone https://github.com/huggingface/diffusers\n", + "%cd diffusers\n", + "!pip install -e . -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "# !pip install git+https://github.com/huggingface/diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2025-07-08T09:49:21.950216Z", + "iopub.status.busy": "2025-07-08T09:49:21.949956Z", + "iopub.status.idle": "2025-07-08T09:49:21.955584Z", + "shell.execute_reply": "2025-07-08T09:49:21.954816Z", + "shell.execute_reply.started": "2025-07-08T09:49:21.950192Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/kaggle/working/diffusers/examples/controlnet\n" + ] + } + ], + "source": [ + "%cd examples/controlnet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.status.idle": "2025-07-08T12:05:34.302241Z", + "shell.execute_reply": "2025-07-08T12:05:34.301446Z", + "shell.execute_reply.started": "2025-07-08T09:50:31.884323Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Steps: 96%|████▊| 18000/18750 [52:01<1:18:01, 6.24s/it, loss=0.301, lr=0.0002]Configuration saved in ./ckpts/Stable-Diffusion-2.1-Openpose-ControlNet/checkpoint-18000/controlnet/config.json\n", + "Model weights saved in ./ckpts/Stable-Diffusion-2.1-Openpose-ControlNet/checkpoint-18000/controlnet/diffusion_pytorch_model.safetensors\n", + "Steps: 99%|████▉| 18500/18750 [1:44:08<26:02, 6.25s/it, loss=0.338, lr=0.0002]Configuration saved in ./ckpts/Stable-Diffusion-2.1-Openpose-ControlNet/checkpoint-18500/controlnet/config.json\n", + "Model weights saved in ./ckpts/Stable-Diffusion-2.1-Openpose-ControlNet/checkpoint-18500/controlnet/diffusion_pytorch_model.safetensors\n", + "Steps: 100%|█████| 18750/18750 [2:10:12<00:00, 6.26s/it, loss=0.416, lr=0.0002]Configuration saved in ./ckpts/Stable-Diffusion-2.1-Openpose-ControlNet/config.json\n", + "Model weights saved in ./ckpts/Stable-Diffusion-2.1-Openpose-ControlNet/diffusion_pytorch_model.safetensors\n", + "Steps: 100%|█████| 18750/18750 [2:10:16<00:00, 6.25s/it, loss=0.416, lr=0.0002]\n" + ] + } + ], + "source": [ + "!accelerate launch train_controlnet.py \\\n", + " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-2-1\" \\\n", + " --resume_from_checkpoint \"./ckpts/Stable-Diffusion-2.1-Openpose-ControlNet/checkpoint-17500\" \\\n", + " --output_dir=\"./ckpts/Stable-Diffusion-2.1-Openpose-ControlNet\" \\\n", + " --dataset_name=\"HighCWu/open_pose_controlnet_subset\" \\\n", + " --resolution=512 \\\n", + " --learning_rate=2e-4 \\\n", + " --train_batch_size=4 \\\n", + " --gradient_accumulation_steps=2 \\\n", + " --gradient_checkpointing \\\n", + " --use_8bit_adam \\\n", + " --num_train_epochs=50 \\\n", + " --mixed_precision \"fp16\" \\\n", + " --checkpoints_total_limit=2 \\\n", + " --checkpointing_steps=500 \\\n", + " --validation_steps=100\n", + " # --image_column \\\n", + " # --conditioning_image_column \\\n", + " # --caption_column \\\n", + " # --max_train_steps=10000\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-07-08T12:15:50.516724Z", + "iopub.status.busy": "2025-07-08T12:15:50.515955Z", + "iopub.status.idle": "2025-07-08T12:17:09.002224Z", + "shell.execute_reply": "2025-07-08T12:17:09.001257Z", + "shell.execute_reply.started": "2025-07-08T12:15:50.516697Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m290.4/290.4 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m29.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m82.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25h" + ] + } + ], + "source": [ + "!pip install -q controlnet-aux" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2025-07-08T12:17:09.004232Z", + "iopub.status.busy": "2025-07-08T12:17:09.003985Z", + "iopub.status.idle": "2025-07-08T12:17:09.009019Z", + "shell.execute_reply": "2025-07-08T12:17:09.008057Z", + "shell.execute_reply.started": "2025-07-08T12:17:09.004203Z" + }, + "trusted": true + }, + "outputs": [], + "source": [ + "import cv2\n", + "from PIL import Image\n", + "import numpy as np\n", + "# from diffusers.utils import load_image\n", + "from PIL import Image\n", + "import PIL\n", + "import requests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-07-08T12:17:09.010551Z", + "iopub.status.busy": "2025-07-08T12:17:09.010007Z", + "iopub.status.idle": "2025-07-08T12:17:43.288449Z", + "shell.execute_reply": "2025-07-08T12:17:43.287691Z", + "shell.execute_reply.started": "2025-07-08T12:17:09.010526Z" + }, + "trusted": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b294b69e673b4cbf84e080d236bf8158", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 63 files: 0%| | 0/63 [00:00 PIL.Image.Image:\n", + " \"\"\"\n", + " Loads `image` to a PIL Image.\n", + "\n", + " Args:\n", + " image (`str` or `PIL.Image.Image`):\n", + " The image to convert to the PIL Image format.\n", + " convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):\n", + " A conversion method to apply to the image after loading it. When set to `None` the image will be converted\n", + " \"RGB\".\n", + "\n", + " Returns:\n", + " `PIL.Image.Image`:\n", + " A PIL Image.\n", + " \"\"\"\n", + " if isinstance(image, str):\n", + " if image.startswith(\"http://\") or image.startswith(\"https://\"):\n", + " image = PIL.Image.open(requests.get(image, stream=True, timeout=200).raw)\n", + " elif os.path.isfile(image):\n", + " image = PIL.Image.open(image)\n", + " else:\n", + " raise ValueError(\n", + " f\"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path.\"\n", + " )\n", + " elif isinstance(image, PIL.Image.Image):\n", + " image = image\n", + " else:\n", + " raise ValueError(\n", + " \"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image.\"\n", + " )\n", + "\n", + " image = PIL.ImageOps.exif_transpose(image)\n", + "\n", + " if convert_method is not None:\n", + " image = convert_method(image)\n", + " else:\n", + " image = image.convert(\"RGB\")\n", + "\n", + " return image" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2025-07-08T12:17:43.297604Z", + "iopub.status.busy": "2025-07-08T12:17:43.297363Z", + "iopub.status.idle": "2025-07-08T12:17:51.415654Z", + "shell.execute_reply": "2025-07-08T12:17:51.414909Z", + "shell.execute_reply.started": "2025-07-08T12:17:43.297580Z" + }, + "trusted": true + }, + "outputs": [], + "source": [ + "def image_grid(imgs, rows, cols):\n", + " assert len(imgs) == rows * cols\n", + "\n", + " w, h = imgs[0].size\n", + " grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n", + " grid_w, grid_h = grid.size\n", + "\n", + " for i, img in enumerate(imgs):\n", + " grid.paste(img, box=(i % cols * w, i // cols * h))\n", + " return grid\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.status.busy": "2025-07-08T12:13:08.387428Z", + "iopub.status.idle": "2025-07-08T12:13:08.387805Z", + "shell.execute_reply": "2025-07-08T12:13:08.387633Z", + "shell.execute_reply.started": "2025-07-08T12:13:08.387617Z" + }, + "trusted": true + }, + "outputs": [], + "source": [ + "urls = \"yoga1.jpeg\", \"yoga2.jpeg\", \"yoga3.jpeg\", \"yoga4.jpeg\"\n", + "imgs = [\n", + " load_image(\"https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/\" + url) \n", + " for url in urls\n", + "]\n", + "\n", + "image_grid(imgs, 2, 2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2025-07-08T12:17:51.417162Z", + "iopub.status.busy": "2025-07-08T12:17:51.416851Z", + "iopub.status.idle": "2025-07-08T12:18:15.234196Z", + "shell.execute_reply": "2025-07-08T12:18:15.233323Z", + "shell.execute_reply.started": "2025-07-08T12:17:51.417138Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.11/dist-packages/controlnet_aux/mediapipe_face/mediapipe_face_common.py:7: UserWarning: The module 'mediapipe' is not installed. The package will have limited functionality. Please install it using the command: pip install 'mediapipe'\n", + " warnings.warn(\n", + "/usr/local/lib/python3.11/dist-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n", + "/usr/local/lib/python3.11/dist-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", + "/usr/local/lib/python3.11/dist-packages/controlnet_aux/segment_anything/modeling/tiny_vit_sam.py:654: UserWarning: Overwriting tiny_vit_5m_224 in registry with controlnet_aux.segment_anything.modeling.tiny_vit_sam.tiny_vit_5m_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.\n", + " return register_model(fn_wrapper)\n", + "/usr/local/lib/python3.11/dist-packages/controlnet_aux/segment_anything/modeling/tiny_vit_sam.py:654: UserWarning: Overwriting tiny_vit_11m_224 in registry with controlnet_aux.segment_anything.modeling.tiny_vit_sam.tiny_vit_11m_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.\n", + " return register_model(fn_wrapper)\n", + "/usr/local/lib/python3.11/dist-packages/controlnet_aux/segment_anything/modeling/tiny_vit_sam.py:654: UserWarning: Overwriting tiny_vit_21m_224 in registry with controlnet_aux.segment_anything.modeling.tiny_vit_sam.tiny_vit_21m_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.\n", + " return register_model(fn_wrapper)\n", + "/usr/local/lib/python3.11/dist-packages/controlnet_aux/segment_anything/modeling/tiny_vit_sam.py:654: UserWarning: Overwriting tiny_vit_21m_384 in registry with controlnet_aux.segment_anything.modeling.tiny_vit_sam.tiny_vit_21m_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.\n", + " return register_model(fn_wrapper)\n", + "/usr/local/lib/python3.11/dist-packages/controlnet_aux/segment_anything/modeling/tiny_vit_sam.py:654: UserWarning: Overwriting tiny_vit_21m_512 in registry with controlnet_aux.segment_anything.modeling.tiny_vit_sam.tiny_vit_21m_512. This is because the name being registered conflicts with an existing name. Please check if this is not expected.\n", + " return register_model(fn_wrapper)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8fe6d8083b7d4535af03ef53d62ba556", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "body_pose_model.pth: 0%| | 0.00/209M [00:00" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from controlnet_aux import OpenposeDetector\n", + "\n", + "model = OpenposeDetector.from_pretrained(\"lllyasviel/ControlNet\")\n", + "\n", + "poses = [model(img) for img in imgs]\n", + "image_grid(poses, 2, 2)\n", + "# poses" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2025-07-08T12:18:15.235687Z", + "iopub.status.busy": "2025-07-08T12:18:15.235222Z", + "iopub.status.idle": "2025-07-08T12:19:10.969606Z", + "shell.execute_reply": "2025-07-08T12:19:10.968652Z", + "shell.execute_reply.started": "2025-07-08T12:18:15.235659Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-07-08 12:18:18.767595: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "E0000 00:00:1751977098.965819 35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1751977099.020166 35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b3a5530ab5704dddbee38d77619485ef", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model_index.json: 0%| | 0.00/537 [00:00" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "generator = [torch.Generator(device=\"cpu\").manual_seed(2) for i in range(4)]\n", + "prompt = \"a man is doing yoga\"\n", + "output = pipe(\n", + " [prompt] * 4,\n", + " poses,\n", + " negative_prompt=[\"monochrome, lowres, bad anatomy, worst quality, low quality\"] * 4,\n", + " generator=generator,\n", + " num_inference_steps=20,\n", + ")\n", + "image_grid(output.images, 2, 2)\n" + ] + } + ], + "metadata": { + "kaggle": { + "accelerator": "gpu", + "dataSources": [], + "dockerImageVersionId": 31041, + "isGpuEnabled": true, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7dd4fcb30e57403fb5e2260c4ec770863b08a26d --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,7 @@ +huggingface-hub>=0.33.1 +bitsandbytes>=0.46.0 +diffusers>=0.34.0 +peft>=0.17.0 +controlnet-aux>=0.0.10 +accelerate>=1.7.0 +gradio>=5.39.0 \ No newline at end of file diff --git a/requirements/requirements_compatible.txt b/requirements/requirements_compatible.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c0c8167d481aafa03a998b34b14881a6476e751 --- /dev/null +++ b/requirements/requirements_compatible.txt @@ -0,0 +1,7 @@ +huggingface-hub==0.34.1 +bitsandbytes==0.46.0 +diffusers==0.34.0 +peft==0.17.0 +controlnet-aux==0.0.10 +accelerate==1.7.0 +gradio==5.39.0 \ No newline at end of file diff --git a/scripts/download_ckpts.py b/scripts/download_ckpts.py new file mode 100644 index 0000000000000000000000000000000000000000..3a333f900b1ff5367cb45f1706a10057b26e4df6 --- /dev/null +++ b/scripts/download_ckpts.py @@ -0,0 +1,58 @@ +import argparse +import os +import yaml +from huggingface_hub import hf_hub_download, list_repo_files + +def load_config(config_path): + with open(config_path, 'r') as file: + return yaml.safe_load(file) + +def download_model(model_config): + model_id = model_config["model_id"] + local_dir = model_config["local_dir"] + + if local_dir is None: + print(f"Skipping download for {model_id}: local_dir is null") + return + + os.makedirs(local_dir, exist_ok=True) + + allow_patterns = model_config.get("allow", []) + deny_patterns = model_config.get("deny", []) + + if allow_patterns: + for file in allow_patterns: + hf_hub_download( + repo_id=model_id, + filename=file, + local_dir=local_dir, + local_dir_use_symlinks=False + ) + else: + print(f"No allow patterns specified for {model_id}. Attempting to download all files except those in deny list.") + repo_files = list_repo_files(repo_id=model_id) + for file in repo_files: + if not any(deny_pattern in file for deny_pattern in deny_patterns): + hf_hub_download( + repo_id=model_id, + filename=file, + local_dir=local_dir, + local_dir_use_symlinks=False + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download model checkpoints from Hugging Face Hub") + parser.add_argument( + "--config_path", + type=str, + default="configs/model_ckpts.yaml", + help="Path to the configuration YAML file" + ) + + args = parser.parse_args() + + config = load_config(args.config_path) + + for model_config in config: + print(f"Processing {model_config['model_id']} (local_dir: {model_config['local_dir']})") + download_model(model_config) \ No newline at end of file diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..30fbeabc19184a534cafdcd8b7ce15984e5ee536 --- /dev/null +++ b/scripts/download_datasets.py @@ -0,0 +1,48 @@ +import argparse +import yaml +from datasets import load_dataset + + +def load_config(config_path): + with open(config_path, 'r') as file: + return yaml.safe_load(file) + + +def download_huggingface_dataset(config): + # Get dataset details from config + dataset_name = config['dataset_name'] + local_dir = config['local_dir'] + + # Split dataset name into user_name and model_hub_name + user_name, model_hub_name = dataset_name.split('/') + + # Login using e.g. `huggingface-cli login` to access this dataset + ds = load_dataset(dataset_name, cache_dir=local_dir) + + # Print information for verification + print(f"User Name: {user_name}") + print(f"Model Hub Name: {model_hub_name}") + print(f"Dataset saved to: {local_dir}") + print(f"Dataset info: {ds}") + + +if __name__ == "__main__": + # Set up argument parser + parser = argparse.ArgumentParser(description="Download dataset from Hugging Face") + parser.add_argument('--config_path', + type=str, + default='configs/datasets_info.yaml', + help='Path to the dataset configuration YAML file') + + args = parser.parse_args() + + # Load configuration from YAML file + configs = load_config(args.config_path) + + # Iterate through the list of configurations + for config in configs: + # Download dataset if platform is HuggingFace + if config['platform'] == 'HuggingFace': + download_huggingface_dataset(config) + else: + print(f"Unsupported platform: {config['platform']}") \ No newline at end of file diff --git a/scripts/setup_third_party.py b/scripts/setup_third_party.py new file mode 100644 index 0000000000000000000000000000000000000000..c4df051919ba821f4b8e7fb09566c5c9bf06aab7 --- /dev/null +++ b/scripts/setup_third_party.py @@ -0,0 +1,38 @@ +import os +import shutil +import subprocess +import argparse + +def setup_diffusers(target_dir): + # Define paths + diffusers_dir = os.path.join(target_dir, "diffusers") + + # Create third_party directory if it doesn't exist + os.makedirs(target_dir, exist_ok=True) + + # Check if diffusers already exists in third_party + if os.path.exists(diffusers_dir): + print(f"Diffusers already exists in {target_dir}. Skipping clone.") + return + + # Clone diffusers repository + subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers"], + cwd=target_dir, check=True) + + # Change to diffusers directory and install + original_dir = os.getcwd() + os.chdir(diffusers_dir) + try: + subprocess.run(["pip", "install", "-e", "."], check=True) + finally: + os.chdir(original_dir) + + print(f"Diffusers successfully cloned and installed to {diffusers_dir}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Setup diffusers in a specified directory.") + parser.add_argument("--target-dir", type=str, default="src/third_party", + help="Target directory to clone diffusers into (default: src)") + + args = parser.parse_args() + setup_diffusers(args.target_dir) diff --git a/src/controlnet_image_generator/__init__.py b/src/controlnet_image_generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/controlnet_image_generator/infer.py b/src/controlnet_image_generator/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c95fffda1725f7b699b412c997957f99c61996d --- /dev/null +++ b/src/controlnet_image_generator/infer.py @@ -0,0 +1,135 @@ +import torch +import argparse +import os +import sys +# Add the project root directory to the Python path +sys.path.append(os.path.abspath(os.path.dirname(__file__))) + +from inference.config_loader import load_config, find_config_by_model_id +from inference.model_initializer import ( + initialize_controlnet, + initialize_pipeline, + initialize_controlnet_detector +) +from inference.device_manager import setup_device +from inference.image_processor import load_input_image, detect_poses +from inference.image_generator import generate_images, save_images + +# Global variables to store models +global controlnet_detector, controlnet, pipe +controlnet_detector = None +controlnet = None +pipe = None + +def infer( + config_path, + input_image, + image_url, + prompt, + negative_prompt, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + output_dir=None, + use_prompt_as_output_name=None, + save_output=False +): + global controlnet_detector, controlnet, pipe + + # Load configuration + configs = load_config(config_path) + + # Initialize models only if they are not already loaded + if controlnet_detector is None or controlnet is None or pipe is None: + controlnet_detector_config = find_config_by_model_id(configs, "lllyasviel/ControlNet") + controlnet_config = find_config_by_model_id(configs, + "danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet") + pipeline_config = find_config_by_model_id(configs, + "stabilityai/stable-diffusion-2-1") + + controlnet_detector = initialize_controlnet_detector(controlnet_detector_config) + controlnet = initialize_controlnet(controlnet_config) + pipe = initialize_pipeline(controlnet, pipeline_config) + + # Setup device + device = setup_device(pipe) + + # Load and process image + demo_image = load_input_image(input_image, image_url) + poses = detect_poses(controlnet_detector, demo_image) + + # Generate images + generators = [torch.Generator(device="cpu").manual_seed(seed + i) for i in range(len(poses))] + output_images = generate_images( + pipe, + [prompt] * len(generators), + poses, + generators, + [negative_prompt] * len(generators), + num_steps, + guidance_scale, + controlnet_conditioning_scale, + width, + height + ) + + # Save images if required + if save_output: + save_images(output_images, output_dir, prompt, use_prompt_as_output_name) + + return output_images + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ControlNet image generation with pose detection") + image_group = parser.add_mutually_exclusive_group(required=True) + image_group.add_argument("--input_image", type=str, default=None, + help="Path to local input image (default: tests/test_data/yoga1.jpg)") + image_group.add_argument("--image_url", type=str, default=None, + help="URL of input image (e.g., https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg)") + + parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml", + help="Path to configuration YAML file") + parser.add_argument("--prompt", type=str, default="a man is doing yoga", + help="Text prompt for image generation") + parser.add_argument("--negative_prompt", type=str, + default="monochrome, lowres, bad anatomy, worst quality, low quality", + help="Negative prompt for image generation") + parser.add_argument("--num_steps", type=int, default=20, + help="Number of inference steps") + parser.add_argument("--seed", type=int, default=2, + help="Random seed for generation") + parser.add_argument("--width", type=int, default=512, + help="Width of the generated image") + parser.add_argument("--height", type=int, default=512, + help="Height of the generated image") + parser.add_argument("--guidance_scale", type=float, default=7.5, + help="Guidance scale for prompt adherence") + parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0, + help="ControlNet conditioning scale") + parser.add_argument("--output_dir", type=str, default="tests/test_data", + help="Directory to save generated images") + parser.add_argument("--use_prompt_as_output_name", action="store_true", + help="Use prompt as part of output image filename") + parser.add_argument("--save_output", action="store_true", + help="Save generated images to output directory") + + args = parser.parse_args() + infer( + config_path=args.config_path, + input_image=args.input_image, + image_url=args.image_url, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + num_steps=args.num_steps, + seed=args.seed, + width=args.width, + height=args.height, + guidance_scale=args.guidance_scale, + controlnet_conditioning_scale=args.controlnet_conditioning_scale, + output_dir=args.output_dir, + use_prompt_as_output_name=args.use_prompt_as_output_name, + save_output=args.save_output + ) \ No newline at end of file diff --git a/src/controlnet_image_generator/inference/__init__.py b/src/controlnet_image_generator/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/controlnet_image_generator/inference/config_loader.py b/src/controlnet_image_generator/inference/config_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9aa48409a7b9b8e41faa7bdeb1b07b144a6a74 --- /dev/null +++ b/src/controlnet_image_generator/inference/config_loader.py @@ -0,0 +1,14 @@ +import yaml + +def load_config(config_path): + try: + with open(config_path, 'r') as file: + return yaml.safe_load(file) + except Exception as e: + raise ValueError(f"Error loading config file: {e}") + +def find_config_by_model_id(configs, model_id): + for config in configs: + if config['model_id'] == model_id: + return config + raise ValueError(f"No configuration found for model_id: {model_id}") \ No newline at end of file diff --git a/src/controlnet_image_generator/inference/device_manager.py b/src/controlnet_image_generator/inference/device_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..231e1f10e6703227812a3c47266e268c3b346541 --- /dev/null +++ b/src/controlnet_image_generator/inference/device_manager.py @@ -0,0 +1,8 @@ +import torch + +def setup_device(pipe): + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": + pipe.enable_model_cpu_offload() + pipe.to(device) + return device \ No newline at end of file diff --git a/src/controlnet_image_generator/inference/image_generator.py b/src/controlnet_image_generator/inference/image_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..138842b4d0472810a8eeb4ee8d6b6f0a7385b761 --- /dev/null +++ b/src/controlnet_image_generator/inference/image_generator.py @@ -0,0 +1,28 @@ +import torch +import os +import re +import uuid +from tqdm import tqdm + +def generate_images(pipe, prompts, pose_images, generators, negative_prompts, num_steps, guidance_scale, controlnet_conditioning_scale, width, height): + return pipe( + prompts, + pose_images, + negative_prompt=negative_prompts, + generator=generators, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + controlnet_conditioning_scale=controlnet_conditioning_scale, + width=width, + height=height + ).images + +def save_images(images, output_dir, prompt, use_prompt_as_output_name, index_offset=0): + os.makedirs(output_dir, exist_ok=True) + for i, img in enumerate(tqdm(images, desc="Saving images")): + if use_prompt_as_output_name: + sanitized_prompt = re.sub(r'[^\w\s-]', '', prompt).replace(' ', '_').lower() + filename = f"{sanitized_prompt}_{i + index_offset}.png" + else: + filename = f"{uuid.uuid4()}_{i + index_offset}.png" + img.save(os.path.join(output_dir, filename)) \ No newline at end of file diff --git a/src/controlnet_image_generator/inference/image_processor.py b/src/controlnet_image_generator/inference/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb87213bed11ff3e07f2b642facc99b299102cb --- /dev/null +++ b/src/controlnet_image_generator/inference/image_processor.py @@ -0,0 +1,16 @@ +from PIL import Image +from utils.download import load_image + +def load_input_image(input_image_path=None, image_url=None): + try: + if input_image_path: + return Image.open(input_image_path).convert("RGB") + elif image_url: + return load_image(image_url) + else: + raise ValueError("Either input_image or image_url must be provided") + except Exception as e: + raise ValueError(f"Error loading image: {e}") + +def detect_poses(controlnet_detector, image): + return [controlnet_detector(image)] \ No newline at end of file diff --git a/src/controlnet_image_generator/inference/model_initializer.py b/src/controlnet_image_generator/inference/model_initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..b88a8fa53ed8d482541294c5db9e4ab268e4aeca --- /dev/null +++ b/src/controlnet_image_generator/inference/model_initializer.py @@ -0,0 +1,29 @@ +import torch +from controlnet_aux import OpenposeDetector +from diffusers import ( + StableDiffusionControlNetPipeline, + ControlNetModel, + UniPCMultistepScheduler +) + +def initialize_controlnet(config): + model_id = config['model_id'] + local_dir = config.get('local_dir', model_id) + return ControlNetModel.from_pretrained( + local_dir if local_dir != model_id else model_id, + torch_dtype=torch.float16 + ) + +def initialize_pipeline(controlnet, config): + model_id = config['model_id'] + local_dir = config.get('local_dir', model_id) + pipe = StableDiffusionControlNetPipeline.from_pretrained( + local_dir if local_dir != model_id else model_id, + controlnet=controlnet, + torch_dtype=torch.float16 + ) + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + return pipe + +def initialize_controlnet_detector(config): + return OpenposeDetector.from_pretrained(config['model_id']) \ No newline at end of file diff --git a/src/controlnet_image_generator/old-infer.py b/src/controlnet_image_generator/old-infer.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f1552f1673f76a2318a416a0a1db9f77fbbd53 --- /dev/null +++ b/src/controlnet_image_generator/old-infer.py @@ -0,0 +1,102 @@ +import cv2 +import torch +from PIL import Image +import numpy as np +import yaml +import argparse +from controlnet_aux import OpenposeDetector +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler +from utils.download import load_image +from utils.plot import image_grid + +def load_config(config_path): + with open(config_path, 'r') as file: + return yaml.safe_load(file) + +def initialize_controlnet(config): + model_id = config['model_id'] + local_dir = config.get('local_dir', model_id) + return ControlNetModel.from_pretrained( + local_dir if local_dir != model_id else model_id, + torch_dtype=torch.float16 + ) + +def initialize_pipeline(controlnet, config): + model_id = config['model_id'] + local_dir = config.get('local_dir', model_id) + pipe = StableDiffusionControlNetPipeline.from_pretrained( + local_dir if local_dir != model_id else model_id, + controlnet=controlnet, + torch_dtype=torch.float16 + ) + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + return pipe + +def setup_device(pipe): + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": + pipe.enable_model_cpu_offload() + pipe.to(device) + return device + +def generate_images(pipe, prompts, pose_images, generators, negative_prompts, num_steps): + return pipe( + prompts, + pose_images, + negative_prompt=negative_prompts, + generator=generators, + num_inference_steps=num_steps + ).images + +def infer(args): + # Load configuration + configs = load_config(args.config_path) + + # Initialize models + controlnet_detector = OpenposeDetector.from_pretrained( + configs[2]['model_id'] # lllyasviel/ControlNet + ) + controlnet = initialize_controlnet(configs[0]) + pipe = initialize_pipeline(controlnet, configs[1]) + + # Setup device + device = setup_device(pipe) + + # Load and process image + demo_image = load_image(args.image_url) + poses = [controlnet_detector(demo_image)] + + # Generate images + generators = [torch.Generator(device="cpu").manual_seed(args.seed) for _ in range(len(poses))] + + output_images = generate_images( + pipe, + [args.prompt] * len(generators), + poses, + generators, + [args.negative_prompt] * len(generators), + args.num_steps + ) + + # Display results + # image_grid(output_images, 2, 2) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ControlNet image generation with pose detection") + parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml", + help="Path to configuration YAML file") + parser.add_argument("--image_url", type=str, + default="https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg", + help="URL of input image") + parser.add_argument("--prompt", type=str, default="a man is doing yoga", + help="Text prompt for image generation") + parser.add_argument("--negative_prompt", type=str, + default="monochrome, lowres, bad anatomy, worst quality, low quality", + help="Negative prompt for image generation") + parser.add_argument("--num_steps", type=int, default=20, + help="Number of inference steps") + parser.add_argument("--seed", type=int, default=2, + help="Random seed for generation") + # return parser.parse_args() + args = parser.parse_args() + infer(args) \ No newline at end of file diff --git a/src/controlnet_image_generator/old2-infer.py b/src/controlnet_image_generator/old2-infer.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4d47ac07d555daa22718f7d7fb9dfa895b1c7d --- /dev/null +++ b/src/controlnet_image_generator/old2-infer.py @@ -0,0 +1,189 @@ +import cv2 +import torch +from PIL import Image +import numpy as np +import yaml +import argparse +from controlnet_aux import OpenposeDetector +from diffusers import ( + StableDiffusionControlNetPipeline, + ControlNetModel, + UniPCMultistepScheduler +) + +from utils.download import load_image +from utils.plot import image_grid +import os +from tqdm import tqdm +import re +import uuid + +def load_config(config_path): + try: + with open(config_path, 'r') as file: + return yaml.safe_load(file) + except Exception as e: + raise ValueError(f"Error loading config file: {e}") + +def initialize_controlnet(config): + model_id = config['model_id'] + local_dir = config.get('local_dir', model_id) + return ControlNetModel.from_pretrained( + local_dir if local_dir != model_id else model_id, + torch_dtype=torch.float16 + ) + +def initialize_pipeline(controlnet, config): + model_id = config['model_id'] + local_dir = config.get('local_dir', model_id) + pipe = StableDiffusionControlNetPipeline.from_pretrained( + local_dir if local_dir != model_id else model_id, + controlnet=controlnet, + torch_dtype=torch.float16 + ) + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + return pipe + +def setup_device(pipe): + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": + pipe.enable_model_cpu_offload() + pipe.to(device) + return device + +def generate_images(pipe, prompts, pose_images, generators, negative_prompts, num_steps, guidance_scale, controlnet_conditioning_scale, width, height): + return pipe( + prompts, + pose_images, + negative_prompt=negative_prompts, + generator=generators, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + controlnet_conditioning_scale=controlnet_conditioning_scale, + width=width, + height=height + ).images + +def infer(args): + # Load configuration + configs = load_config(args.config_path) + + # Initialize models + controlnet_detector = OpenposeDetector.from_pretrained( + configs[2]['model_id'] # lllyasviel/ControlNet + ) + controlnet = initialize_controlnet(configs[0]) + pipe = initialize_pipeline(controlnet, configs[1]) + + # Setup device + device = setup_device(pipe) + + # Load and process image + try: + if args.input_image: + demo_image = Image.open(args.input_image).convert("RGB") + elif args.image_url: + demo_image = load_image(args.image_url) + else: + raise ValueError("Either --input_image or --image_url must be provided") + except Exception as e: + raise ValueError(f"Error loading image: {e}") + + poses = [controlnet_detector(demo_image)] + + # Generate images + generators = [torch.Generator(device="cpu").manual_seed(args.seed + i) for i in range(len(poses))] + + output_images = generate_images( + pipe, + [args.prompt] * len(generators), + poses, + generators, + [args.negative_prompt] * len(generators), + args.num_steps, + args.guidance_scale, + args.controlnet_conditioning_scale, + args.width, + args.height + ) + + # Save images if save_output is True + if args.save_output: + os.makedirs(args.output_dir, exist_ok=True) + for i, img in enumerate(tqdm(output_images, desc="Saving images")): + if args.use_prompt_as_output_name: + # Sanitize prompt for filename (replace spaces and special characters) + sanitized_prompt = re.sub(r'[^\w\s-]', '', args.prompt).replace(' ', '_').lower() + filename = f"{sanitized_prompt}_{i}.png" + else: + # Use UUID for filename + filename = f"{uuid.uuid4()}_{i}.png" + img.save(os.path.join(args.output_dir, filename)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ControlNet image generation with pose detection") + # Create mutually exclusive group for input_image and image_url + image_group = parser.add_mutually_exclusive_group(required=True) + image_group.add_argument("--input_image", type=str, default=None, + help="Path to local input image (default: tests/test_data/yoga1.jpg)") + image_group.add_argument("--image_url", type=str, default=None, + help="URL of input image (e.g., https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg)") + + parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml", + help="Path to configuration YAML file") + parser.add_argument("--prompt", type=str, default="a man is doing yoga", + help="Text prompt for image generation") + parser.add_argument("--negative_prompt", type=str, + default="monochrome, lowres, bad anatomy, worst quality, low quality", + help="Negative prompt for image generation") + parser.add_argument("--num_steps", type=int, default=20, + help="Number of inference steps") + parser.add_argument("--seed", type=int, default=2, + help="Random seed for generation") + parser.add_argument("--width", type=int, default=512, + help="Width of the generated image") + parser.add_argument("--height", type=int, default=512, + help="Height of the generated image") + parser.add_argument("--guidance_scale", type=float, default=7.5, + help="Guidance scale for prompt adherence") + parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0, + help="ControlNet conditioning scale") + parser.add_argument("--output_dir", type=str, default="tests/test_data", + help="Directory to save generated images") + parser.add_argument("--use_prompt_as_output_name", action="store_true", + help="Use prompt as part of output image filename") + parser.add_argument("--save_output", action="store_true artr", + help="Save generated images to output directory") + + args = parser.parse_args() + infer(args) + +# Using image_url +# python script.py \ +# --config_path configs/model_ckpts.yaml \ +# --image_url https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg \ +# --prompt "a man is doing yoga in a serene park" \ +# --negative_prompt "monochrome, lowres, bad anatomy" \ +# --num_steps 30 \ +# --seed 42 \ +# --width 512 \ +# --height 512 \ +# --guidance_scale 7.5 \ +# --controlnet_conditioning_scale 0.8 \ +# --output_dir "tests/test_data" \ +# --save_output + +# Using input_image +# python script.py \ +# --config_path configs/model_ckpts.yaml \ +# --input_image "tests/test_data/yoga1.jpg" \ +# --prompt "a man is doing yoga in a serene park" \ +# --negative_prompt "monochrome, lowres, bad anatomy" \ +# --num_steps 30 \ +# --seed 42 \ +# --width 512 \ +# --height 512 \ +# --guidance_scale 7.5 \ +# --controlnet_conditioning_scale 0.8 \ +# --output_dir "tests/test_data" \ +# --save_output \ No newline at end of file diff --git a/src/controlnet_image_generator/old3-infer.py b/src/controlnet_image_generator/old3-infer.py new file mode 100644 index 0000000000000000000000000000000000000000..39cb043abd2dab07796976d23f3c4797d214880f --- /dev/null +++ b/src/controlnet_image_generator/old3-infer.py @@ -0,0 +1,119 @@ +import torch +import argparse +from inference.config_loader import load_config, find_config_by_model_id +from inference.model_initializer import ( + initialize_controlnet, + initialize_pipeline, + initialize_controlnet_detector +) +from inference.device_manager import setup_device +from inference.image_processor import load_input_image, detect_poses +from inference.image_generator import generate_images, save_images + +def infer( + config_path, + input_image, + image_url, + prompt, + negative_prompt, + num_steps, + seed, + width, + height, + guidance_scale, + controlnet_conditioning_scale, + output_dir, + use_prompt_as_output_name, + save_output +): + # Load configuration + configs = load_config(config_path) + + # Initialize models + controlnet_detector_config = find_config_by_model_id(configs, "lllyasviel/ControlNet") + controlnet_config = find_config_by_model_id(configs, + "danhtran2mind/Stable-Diffusion-2.1-Openpose-ControlNet") + pipeline_config = find_config_by_model_id(configs, + "stabilityai/stable-diffusion-2-1") + + controlnet_detector = initialize_controlnet_detector(controlnet_detector_config) + controlnet = initialize_controlnet(controlnet_config) + pipe = initialize_pipeline(controlnet, pipeline_config) + + # Setup device + device = setup_device(pipe) + + # Load and process image + demo_image = load_input_image(input_image, image_url) + poses = detect_poses(controlnet_detector, demo_image) + + # Generate images + generators = [torch.Generator(device="cpu").manual_seed(seed + i) for i in range(len(poses))] + output_images = generate_images( + pipe, + [prompt] * len(generators), + poses, + generators, + [negative_prompt] * len(generators), + num_steps, + guidance_scale, + controlnet_conditioning_scale, + width, + height + ) + + # Save images if required + if save_output: + save_images(output_images, output_dir, prompt, use_prompt_as_output_name) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ControlNet image generation with pose detection") + image_group = parser.add_mutually_exclusive_group(required=True) + image_group.add_argument("--input_image", type=str, default=None, + help="Path to local input image (default: tests/test_data/yoga1.jpg)") + image_group.add_argument("--image_url", type=str, default=None, + help="URL of input image (e.g., https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg)") + + parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml", + help="Path to configuration YAML file") + parser.add_argument("--prompt", type=str, default="a man is doing yoga", + help="Text prompt for image generation") + parser.add_argument("--negative_prompt", type=str, + default="monochrome, lowres, bad anatomy, worst quality, low quality", + help="Negative prompt for image generation") + parser.add_argument("--num_steps", type=int, default=20, + help="Number of inference steps") + parser.add_argument("--seed", type=int, default=2, + help="Random seed for generation") + parser.add_argument("--width", type=int, default=512, + help="Width of the generated image") + parser.add_argument("--height", type=int, default=512, + help="Height of the generated image") + parser.add_argument("--guidance_scale", type=float, default=7.5, + help="Guidance scale for prompt adherence") + parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0, + help="ControlNet conditioning scale") + parser.add_argument("--output_dir", type=str, default="tests/test_data", + help="Directory to save generated images") + parser.add_argument("--use_prompt_as_output_name", action="store_true", + help="Use prompt as part of output image filename") + parser.add_argument("--save_output", action="store_true", + help="Save generated images to output directory") + + args = parser.parse_args() + infer( + config_path=args.config_path, + input_image=args.input_image, + image_url=args.image_url, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + num_steps=args.num_steps, + seed=args.seed, + width=args.width, + height=args.height, + guidance_scale=args.guidance_scale, + controlnet_conditioning_scale=args.controlnet_conditioning_scale, + output_dir=args.output_dir, + use_prompt_as_output_name=args.use_prompt_as_output_name, + save_output=args.save_output + ) \ No newline at end of file diff --git a/src/controlnet_image_generator/train.py b/src/controlnet_image_generator/train.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd46432bf7a8c8d9f156fb2ced776ee82cb6a6d --- /dev/null +++ b/src/controlnet_image_generator/train.py @@ -0,0 +1,22 @@ +import os +import sys +import subprocess + +def run_controlnet_training(args): + """Run train_controlnet.py with the provided command-line arguments.""" + # Path to train_controlnet.py + controlnet_script = os.path.join(os.path.dirname(__file__), "..", + "third_party", "diffusers", "examples", "controlnet", "train_controlnet.py") + + # Construct the command: python + script path + arguments + command = [sys.executable, controlnet_script] + args + + # Run the command + try: + subprocess.run(command, check=True) + except subprocess.CalledProcessError as e: + print(f"Error running train_controlnet.py: {e}", file=sys.stderr) + sys.exit(e.returncode) + +if __name__ == "__main__": + run_controlnet_training(sys.argv[1:]) \ No newline at end of file diff --git a/src/controlnet_image_generator/utils/__init__.py b/src/controlnet_image_generator/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/controlnet_image_generator/utils/download.py b/src/controlnet_image_generator/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..3d276bfc8c9436bd30920a72cb79b70a3e22ef71 --- /dev/null +++ b/src/controlnet_image_generator/utils/download.py @@ -0,0 +1,45 @@ +import os +import PIL +import requests + +def load_image( + image, convert_method=None) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): + A conversion method to apply to the image after loading it. When set to `None` the image will be converted + "RGB". + + Returns: + `PIL.Image.Image`: + A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True, timeout=200).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." + ) + + image = PIL.ImageOps.exif_transpose(image) + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + + return image + diff --git a/src/controlnet_image_generator/utils/plot.py b/src/controlnet_image_generator/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..94bcddbf57f516a4dc78c75dee2b78f610b8983d --- /dev/null +++ b/src/controlnet_image_generator/utils/plot.py @@ -0,0 +1,13 @@ + +from PIL import Image + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid diff --git a/src/third_party/.gitkeep b/src/third_party/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/.gitkeep b/tests/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_data/a_man_is_doing_yoga_in_a_serene_park_0.png b/tests/test_data/a_man_is_doing_yoga_in_a_serene_park_0.png new file mode 100644 index 0000000000000000000000000000000000000000..9e8f4643b712061b6d5db7286b6a32fa8cf68f87 --- /dev/null +++ b/tests/test_data/a_man_is_doing_yoga_in_a_serene_park_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dc2b7efb61afd2d6ceda1b32ec9792a5b07f3ac3d7a96d7acdd2102ddb957b7 +size 367280 diff --git a/tests/test_data/yoga.jpg b/tests/test_data/yoga.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c5f35baf0515a9b66e6388d5c44bb337d6f9366c Binary files /dev/null and b/tests/test_data/yoga.jpg differ