Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update to the original Space
Browse files- README.md +1 -1
- app.py +83 -58
- gradio_canny2image.py β app_canny.py +50 -33
- gradio_depth2image.py β app_depth.py +40 -22
- gradio_fake_scribble2image.py β app_fake_scribble.py +37 -22
- gradio_hed2image.py β app_hed.py +37 -22
- gradio_hough2image.py β app_hough.py +41 -25
- gradio_normal2image.py β app_normal.py +41 -23
- gradio_pose2image.py β app_pose.py +43 -22
- gradio_scribble2image.py β app_scribble.py +36 -22
- gradio_scribble2image_interactive.py β app_scribble_interactive.py +36 -22
- gradio_seg2image.py β app_seg.py +40 -22
- model.py +564 -698
- requirements.txt +3 -2
    	
        README.md
    CHANGED
    
    | @@ -4,7 +4,7 @@ emoji: π» | |
| 4 | 
             
            colorFrom: pink
         | 
| 5 | 
             
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 3. | 
| 8 | 
             
            python_version: 3.10.9
         | 
| 9 | 
             
            app_file: app.py
         | 
| 10 | 
             
            pinned: false
         | 
|  | |
| 4 | 
             
            colorFrom: pink
         | 
| 5 | 
             
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 3.20.0
         | 
| 8 | 
             
            python_version: 3.10.9
         | 
| 9 | 
             
            app_file: app.py
         | 
| 10 | 
             
            pinned: false
         | 
    	
        app.py
    CHANGED
    
    | @@ -30,92 +30,117 @@ for name in names: | |
| 30 | 
             
                    continue
         | 
| 31 | 
             
                subprocess.run(shlex.split(command), cwd='ControlNet/annotator/ckpts/')
         | 
| 32 |  | 
| 33 | 
            -
            from  | 
| 34 | 
            -
            from  | 
| 35 | 
            -
            from  | 
| 36 | 
            -
            from  | 
| 37 | 
            -
            from  | 
| 38 | 
            -
            from  | 
| 39 | 
            -
            from  | 
| 40 | 
            -
            from  | 
| 41 | 
            -
            from  | 
| 42 | 
             
                create_demo as create_demo_scribble_interactive
         | 
| 43 | 
            -
            from  | 
| 44 | 
            -
            from model import  | 
| 45 | 
            -
                               DEFAULT_BASE_MODEL_URL, Model)
         | 
| 46 |  | 
| 47 | 
            -
             | 
| 48 | 
            -
            DESCRIPTION = '''# [ControlNet](https://github.com/lllyasviel/ControlNet)
         | 
| 49 | 
            -
             | 
| 50 | 
            -
            This Space is a modified version of [this Space](https://huggingface.co/spaces/hysts/ControlNet).
         | 
| 51 | 
            -
            The original Space uses [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) as the base model, but [Anything v4.0](https://huggingface.co/andite/anything-v4.0) is used in this Space.
         | 
| 52 | 
            -
            '''
         | 
| 53 |  | 
| 54 | 
             
            SPACE_ID = os.getenv('SPACE_ID')
         | 
| 55 | 
            -
            ALLOW_CHANGING_BASE_MODEL = SPACE_ID != 'hysts/ControlNet | 
| 56 | 
            -
             | 
| 57 | 
            -
            if not ALLOW_CHANGING_BASE_MODEL:
         | 
| 58 | 
            -
                DESCRIPTION += 'In this Space, the base model is not allowed to be changed so as not to slow down the demo, but it can be changed if you duplicate the Space.'
         | 
| 59 |  | 
| 60 | 
             
            if SPACE_ID is not None:
         | 
| 61 | 
            -
                DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings | 
| 62 | 
            -
            <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">
         | 
| 63 | 
            -
            <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
         | 
| 64 | 
            -
            <p/>
         | 
| 65 | 
             
            '''
         | 
| 66 |  | 
| 67 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 68 |  | 
| 69 | 
             
            with gr.Blocks(css='style.css') as demo:
         | 
| 70 | 
             
                gr.Markdown(DESCRIPTION)
         | 
| 71 | 
            -
             | 
| 72 | 
             
                with gr.Tabs():
         | 
| 73 | 
             
                    with gr.TabItem('Canny'):
         | 
| 74 | 
            -
                        create_demo_canny(model.process_canny, | 
|  | |
|  | |
| 75 | 
             
                    with gr.TabItem('Hough'):
         | 
| 76 | 
            -
                        create_demo_hough(model.process_hough, | 
|  | |
|  | |
| 77 | 
             
                    with gr.TabItem('HED'):
         | 
| 78 | 
            -
                        create_demo_hed(model.process_hed, | 
|  | |
|  | |
| 79 | 
             
                    with gr.TabItem('Scribble'):
         | 
| 80 | 
            -
                        create_demo_scribble(model.process_scribble, | 
|  | |
|  | |
| 81 | 
             
                    with gr.TabItem('Scribble Interactive'):
         | 
| 82 | 
             
                        create_demo_scribble_interactive(
         | 
| 83 | 
            -
                            model.process_scribble_interactive, | 
|  | |
|  | |
| 84 | 
             
                    with gr.TabItem('Fake Scribble'):
         | 
| 85 | 
             
                        create_demo_fake_scribble(model.process_fake_scribble,
         | 
| 86 | 
            -
                                                  max_images=MAX_IMAGES | 
|  | |
| 87 | 
             
                    with gr.TabItem('Pose'):
         | 
| 88 | 
            -
                        create_demo_pose(model.process_pose, | 
|  | |
|  | |
| 89 | 
             
                    with gr.TabItem('Segmentation'):
         | 
| 90 | 
            -
                        create_demo_seg(model.process_seg, | 
|  | |
|  | |
| 91 | 
             
                    with gr.TabItem('Depth'):
         | 
| 92 | 
            -
                        create_demo_depth(model.process_depth, | 
|  | |
|  | |
| 93 | 
             
                    with gr.TabItem('Normal map'):
         | 
| 94 | 
            -
                        create_demo_normal(model.process_normal, | 
|  | |
|  | |
| 95 |  | 
| 96 | 
             
                with gr.Accordion(label='Base model', open=False):
         | 
| 97 | 
            -
                    current_base_model = gr.Text(label='Current base model',
         | 
| 98 | 
            -
                                                 value=DEFAULT_BASE_MODEL_URL)
         | 
| 99 | 
             
                    with gr.Row():
         | 
| 100 | 
            -
                         | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
                             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 114 | 
             
                change_base_model_button.click(fn=model.set_base_model,
         | 
| 115 | 
            -
                                               inputs= | 
| 116 | 
            -
                                                   base_model_repo,
         | 
| 117 | 
            -
                                                   base_model_filename,
         | 
| 118 | 
            -
                                               ],
         | 
| 119 | 
             
                                               outputs=current_base_model)
         | 
| 120 |  | 
| 121 | 
             
            demo.queue(api_open=False).launch()
         | 
|  | |
| 30 | 
             
                    continue
         | 
| 31 | 
             
                subprocess.run(shlex.split(command), cwd='ControlNet/annotator/ckpts/')
         | 
| 32 |  | 
| 33 | 
            +
            from app_canny import create_demo as create_demo_canny
         | 
| 34 | 
            +
            from app_depth import create_demo as create_demo_depth
         | 
| 35 | 
            +
            from app_fake_scribble import create_demo as create_demo_fake_scribble
         | 
| 36 | 
            +
            from app_hed import create_demo as create_demo_hed
         | 
| 37 | 
            +
            from app_hough import create_demo as create_demo_hough
         | 
| 38 | 
            +
            from app_normal import create_demo as create_demo_normal
         | 
| 39 | 
            +
            from app_pose import create_demo as create_demo_pose
         | 
| 40 | 
            +
            from app_scribble import create_demo as create_demo_scribble
         | 
| 41 | 
            +
            from app_scribble_interactive import \
         | 
| 42 | 
             
                create_demo as create_demo_scribble_interactive
         | 
| 43 | 
            +
            from app_seg import create_demo as create_demo_seg
         | 
| 44 | 
            +
            from model import Model, download_all_controlnet_weights
         | 
|  | |
| 45 |  | 
| 46 | 
            +
            DESCRIPTION = '# [ControlNet](https://github.com/lllyasviel/ControlNet)'
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 |  | 
| 48 | 
             
            SPACE_ID = os.getenv('SPACE_ID')
         | 
| 49 | 
            +
            ALLOW_CHANGING_BASE_MODEL = SPACE_ID != 'hysts/ControlNet'
         | 
|  | |
|  | |
|  | |
| 50 |  | 
| 51 | 
             
            if SPACE_ID is not None:
         | 
| 52 | 
            +
                DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>
         | 
|  | |
|  | |
|  | |
| 53 | 
             
            '''
         | 
| 54 |  | 
| 55 | 
            +
            MAX_IMAGES = int(os.getenv('MAX_IMAGES', '3'))
         | 
| 56 | 
            +
            DEFAULT_NUM_IMAGES = min(MAX_IMAGES, int(os.getenv('DEFAULT_NUM_IMAGES', '1')))
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            if os.getenv('SYSTEM') == 'spaces':
         | 
| 59 | 
            +
                download_all_controlnet_weights()
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            DEFAULT_MODEL_ID = os.getenv('DEFAULT_MODEL_ID',
         | 
| 62 | 
            +
                                         'runwayml/stable-diffusion-v1-5')
         | 
| 63 | 
            +
            model = Model(base_model_id=DEFAULT_MODEL_ID, task_name='canny')
         | 
| 64 |  | 
| 65 | 
             
            with gr.Blocks(css='style.css') as demo:
         | 
| 66 | 
             
                gr.Markdown(DESCRIPTION)
         | 
|  | |
| 67 | 
             
                with gr.Tabs():
         | 
| 68 | 
             
                    with gr.TabItem('Canny'):
         | 
| 69 | 
            +
                        create_demo_canny(model.process_canny,
         | 
| 70 | 
            +
                                          max_images=MAX_IMAGES,
         | 
| 71 | 
            +
                                          default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 72 | 
             
                    with gr.TabItem('Hough'):
         | 
| 73 | 
            +
                        create_demo_hough(model.process_hough,
         | 
| 74 | 
            +
                                          max_images=MAX_IMAGES,
         | 
| 75 | 
            +
                                          default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 76 | 
             
                    with gr.TabItem('HED'):
         | 
| 77 | 
            +
                        create_demo_hed(model.process_hed,
         | 
| 78 | 
            +
                                        max_images=MAX_IMAGES,
         | 
| 79 | 
            +
                                        default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 80 | 
             
                    with gr.TabItem('Scribble'):
         | 
| 81 | 
            +
                        create_demo_scribble(model.process_scribble,
         | 
| 82 | 
            +
                                             max_images=MAX_IMAGES,
         | 
| 83 | 
            +
                                             default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 84 | 
             
                    with gr.TabItem('Scribble Interactive'):
         | 
| 85 | 
             
                        create_demo_scribble_interactive(
         | 
| 86 | 
            +
                            model.process_scribble_interactive,
         | 
| 87 | 
            +
                            max_images=MAX_IMAGES,
         | 
| 88 | 
            +
                            default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 89 | 
             
                    with gr.TabItem('Fake Scribble'):
         | 
| 90 | 
             
                        create_demo_fake_scribble(model.process_fake_scribble,
         | 
| 91 | 
            +
                                                  max_images=MAX_IMAGES,
         | 
| 92 | 
            +
                                                  default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 93 | 
             
                    with gr.TabItem('Pose'):
         | 
| 94 | 
            +
                        create_demo_pose(model.process_pose,
         | 
| 95 | 
            +
                                         max_images=MAX_IMAGES,
         | 
| 96 | 
            +
                                         default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 97 | 
             
                    with gr.TabItem('Segmentation'):
         | 
| 98 | 
            +
                        create_demo_seg(model.process_seg,
         | 
| 99 | 
            +
                                        max_images=MAX_IMAGES,
         | 
| 100 | 
            +
                                        default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 101 | 
             
                    with gr.TabItem('Depth'):
         | 
| 102 | 
            +
                        create_demo_depth(model.process_depth,
         | 
| 103 | 
            +
                                          max_images=MAX_IMAGES,
         | 
| 104 | 
            +
                                          default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 105 | 
             
                    with gr.TabItem('Normal map'):
         | 
| 106 | 
            +
                        create_demo_normal(model.process_normal,
         | 
| 107 | 
            +
                                           max_images=MAX_IMAGES,
         | 
| 108 | 
            +
                                           default_num_images=DEFAULT_NUM_IMAGES)
         | 
| 109 |  | 
| 110 | 
             
                with gr.Accordion(label='Base model', open=False):
         | 
|  | |
|  | |
| 111 | 
             
                    with gr.Row():
         | 
| 112 | 
            +
                        with gr.Column():
         | 
| 113 | 
            +
                            current_base_model = gr.Text(label='Current base model')
         | 
| 114 | 
            +
                        with gr.Column(scale=0.3):
         | 
| 115 | 
            +
                            check_base_model_button = gr.Button('Check current base model')
         | 
| 116 | 
            +
                    with gr.Row():
         | 
| 117 | 
            +
                        with gr.Column():
         | 
| 118 | 
            +
                            new_base_model_id = gr.Text(
         | 
| 119 | 
            +
                                label='New base model',
         | 
| 120 | 
            +
                                max_lines=1,
         | 
| 121 | 
            +
                                placeholder='runwayml/stable-diffusion-v1-5',
         | 
| 122 | 
            +
                                info=
         | 
| 123 | 
            +
                                'The base model must be compatible with Stable Diffusion v1.5.',
         | 
| 124 | 
            +
                                interactive=ALLOW_CHANGING_BASE_MODEL)
         | 
| 125 | 
            +
                        with gr.Column(scale=0.3):
         | 
| 126 | 
            +
                            change_base_model_button = gr.Button('Change base model')
         | 
| 127 | 
            +
                    if not ALLOW_CHANGING_BASE_MODEL:
         | 
| 128 | 
            +
                        gr.Markdown(
         | 
| 129 | 
            +
                            '''The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'''
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                gr.Markdown(
         | 
| 133 | 
            +
                    '[Space using Anything-v4.0 as base model](https://huggingface.co/spaces/hysts/ControlNet-with-other-models)'
         | 
| 134 | 
            +
                )
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                check_base_model_button.click(fn=lambda: model.base_model_id,
         | 
| 137 | 
            +
                                              outputs=current_base_model,
         | 
| 138 | 
            +
                                              queue=False)
         | 
| 139 | 
            +
                new_base_model_id.submit(fn=model.set_base_model,
         | 
| 140 | 
            +
                                         inputs=new_base_model_id,
         | 
| 141 | 
            +
                                         outputs=current_base_model)
         | 
| 142 | 
             
                change_base_model_button.click(fn=model.set_base_model,
         | 
| 143 | 
            +
                                               inputs=new_base_model_id,
         | 
|  | |
|  | |
|  | |
| 144 | 
             
                                               outputs=current_base_model)
         | 
| 145 |  | 
| 146 | 
             
            demo.queue(api_open=False).launch()
         | 
    	
        gradio_canny2image.py β app_canny.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
         | 
| @@ -16,39 +16,40 @@ def create_demo(process, max_images=12): | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| 23 | 
             
                                                             maximum=768,
         | 
| 24 | 
             
                                                             value=512,
         | 
| 25 | 
             
                                                             step=256)
         | 
| 26 | 
            -
                                 | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
|  | |
|  | |
| 46 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 47 | 
             
                                                 minimum=-1,
         | 
| 48 | 
             
                                                 maximum=2147483647,
         | 
| 49 | 
             
                                                 step=1,
         | 
| 50 | 
             
                                                 randomize=True)
         | 
| 51 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 52 | 
             
                                a_prompt = gr.Textbox(
         | 
| 53 | 
             
                                    label='Added Prompt',
         | 
| 54 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -58,17 +59,33 @@ def create_demo(process, max_images=12): | |
| 58 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 59 | 
             
                                )
         | 
| 60 | 
             
                        with gr.Column():
         | 
| 61 | 
            -
                             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
                     | 
| 66 | 
            -
                        input_image, | 
| 67 | 
            -
                         | 
| 68 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 69 | 
             
                    ]
         | 
|  | |
| 70 | 
             
                    run_button.click(fn=process,
         | 
| 71 | 
            -
                                     inputs= | 
| 72 | 
            -
                                     outputs= | 
| 73 | 
             
                                     api_name='canny')
         | 
| 74 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
         | 
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            +
                                                        value=default_num_images,
         | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| 23 | 
             
                                                             maximum=768,
         | 
| 24 | 
             
                                                             value=512,
         | 
| 25 | 
             
                                                             step=256)
         | 
| 26 | 
            +
                                canny_low_threshold = gr.Slider(
         | 
| 27 | 
            +
                                    label='Canny low threshold',
         | 
| 28 | 
            +
                                    minimum=1,
         | 
| 29 | 
            +
                                    maximum=255,
         | 
| 30 | 
            +
                                    value=100,
         | 
| 31 | 
            +
                                    step=1)
         | 
| 32 | 
            +
                                canny_high_threshold = gr.Slider(
         | 
| 33 | 
            +
                                    label='Canny high threshold',
         | 
| 34 | 
            +
                                    minimum=1,
         | 
| 35 | 
            +
                                    maximum=255,
         | 
| 36 | 
            +
                                    value=200,
         | 
| 37 | 
            +
                                    step=1)
         | 
| 38 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 39 | 
            +
                                                      minimum=1,
         | 
| 40 | 
            +
                                                      maximum=100,
         | 
| 41 | 
            +
                                                      value=20,
         | 
| 42 | 
            +
                                                      step=1)
         | 
| 43 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 44 | 
            +
                                                           minimum=0.1,
         | 
| 45 | 
            +
                                                           maximum=30.0,
         | 
| 46 | 
            +
                                                           value=9.0,
         | 
| 47 | 
            +
                                                           step=0.1)
         | 
| 48 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 49 | 
             
                                                 minimum=-1,
         | 
| 50 | 
             
                                                 maximum=2147483647,
         | 
| 51 | 
             
                                                 step=1,
         | 
| 52 | 
             
                                                 randomize=True)
         | 
|  | |
| 53 | 
             
                                a_prompt = gr.Textbox(
         | 
| 54 | 
             
                                    label='Added Prompt',
         | 
| 55 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 59 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 60 | 
             
                                )
         | 
| 61 | 
             
                        with gr.Column():
         | 
| 62 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 63 | 
            +
                                                show_label=False,
         | 
| 64 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 65 | 
            +
                                                                         height='auto')
         | 
| 66 | 
            +
                    inputs = [
         | 
| 67 | 
            +
                        input_image,
         | 
| 68 | 
            +
                        prompt,
         | 
| 69 | 
            +
                        a_prompt,
         | 
| 70 | 
            +
                        n_prompt,
         | 
| 71 | 
            +
                        num_samples,
         | 
| 72 | 
            +
                        image_resolution,
         | 
| 73 | 
            +
                        num_steps,
         | 
| 74 | 
            +
                        guidance_scale,
         | 
| 75 | 
            +
                        seed,
         | 
| 76 | 
            +
                        canny_low_threshold,
         | 
| 77 | 
            +
                        canny_high_threshold,
         | 
| 78 | 
             
                    ]
         | 
| 79 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 80 | 
             
                    run_button.click(fn=process,
         | 
| 81 | 
            +
                                     inputs=inputs,
         | 
| 82 | 
            +
                                     outputs=result,
         | 
| 83 | 
             
                                     api_name='canny')
         | 
| 84 | 
             
                return demo
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            if __name__ == '__main__':
         | 
| 88 | 
            +
                from model import Model
         | 
| 89 | 
            +
                model = Model()
         | 
| 90 | 
            +
                demo = create_demo(model.process_canny)
         | 
| 91 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_depth2image.py β app_depth.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Depth Maps')
         | 
| @@ -13,10 +13,12 @@ def create_demo(process, max_images=12): | |
| 13 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 14 | 
             
                            run_button = gr.Button(label='Run')
         | 
| 15 | 
             
                            with gr.Accordion('Advanced options', open=False):
         | 
|  | |
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| @@ -28,22 +30,21 @@ def create_demo(process, max_images=12): | |
| 28 | 
             
                                                              maximum=1024,
         | 
| 29 | 
             
                                                              value=384,
         | 
| 30 | 
             
                                                              step=1)
         | 
| 31 | 
            -
                                 | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
                                 | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 42 | 
             
                                                 minimum=-1,
         | 
| 43 | 
             
                                                 maximum=2147483647,
         | 
| 44 | 
             
                                                 step=1,
         | 
| 45 | 
             
                                                 randomize=True)
         | 
| 46 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 47 | 
             
                                a_prompt = gr.Textbox(
         | 
| 48 | 
             
                                    label='Added Prompt',
         | 
| 49 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -53,16 +54,33 @@ def create_demo(process, max_images=12): | |
| 53 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 54 | 
             
                                )
         | 
| 55 | 
             
                        with gr.Column():
         | 
| 56 | 
            -
                             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                     | 
| 61 | 
            -
                        input_image, | 
| 62 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 63 | 
             
                    ]
         | 
|  | |
| 64 | 
             
                    run_button.click(fn=process,
         | 
| 65 | 
            -
                                     inputs= | 
| 66 | 
            -
                                     outputs= | 
| 67 | 
             
                                     api_name='depth')
         | 
| 68 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Depth Maps')
         | 
|  | |
| 13 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 14 | 
             
                            run_button = gr.Button(label='Run')
         | 
| 15 | 
             
                            with gr.Accordion('Advanced options', open=False):
         | 
| 16 | 
            +
                                is_depth_image = gr.Checkbox(label='Is depth image',
         | 
| 17 | 
            +
                                                             value=False)
         | 
| 18 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 19 | 
             
                                                        minimum=1,
         | 
| 20 | 
             
                                                        maximum=max_images,
         | 
| 21 | 
            +
                                                        value=default_num_images,
         | 
| 22 | 
             
                                                        step=1)
         | 
| 23 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 24 | 
             
                                                             minimum=256,
         | 
|  | |
| 30 | 
             
                                                              maximum=1024,
         | 
| 31 | 
             
                                                              value=384,
         | 
| 32 | 
             
                                                              step=1)
         | 
| 33 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 34 | 
            +
                                                      minimum=1,
         | 
| 35 | 
            +
                                                      maximum=100,
         | 
| 36 | 
            +
                                                      value=20,
         | 
| 37 | 
            +
                                                      step=1)
         | 
| 38 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 39 | 
            +
                                                           minimum=0.1,
         | 
| 40 | 
            +
                                                           maximum=30.0,
         | 
| 41 | 
            +
                                                           value=9.0,
         | 
| 42 | 
            +
                                                           step=0.1)
         | 
| 43 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 44 | 
             
                                                 minimum=-1,
         | 
| 45 | 
             
                                                 maximum=2147483647,
         | 
| 46 | 
             
                                                 step=1,
         | 
| 47 | 
             
                                                 randomize=True)
         | 
|  | |
| 48 | 
             
                                a_prompt = gr.Textbox(
         | 
| 49 | 
             
                                    label='Added Prompt',
         | 
| 50 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 54 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 55 | 
             
                                )
         | 
| 56 | 
             
                        with gr.Column():
         | 
| 57 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 58 | 
            +
                                                show_label=False,
         | 
| 59 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 60 | 
            +
                                                                         height='auto')
         | 
| 61 | 
            +
                    inputs = [
         | 
| 62 | 
            +
                        input_image,
         | 
| 63 | 
            +
                        prompt,
         | 
| 64 | 
            +
                        a_prompt,
         | 
| 65 | 
            +
                        n_prompt,
         | 
| 66 | 
            +
                        num_samples,
         | 
| 67 | 
            +
                        image_resolution,
         | 
| 68 | 
            +
                        detect_resolution,
         | 
| 69 | 
            +
                        num_steps,
         | 
| 70 | 
            +
                        guidance_scale,
         | 
| 71 | 
            +
                        seed,
         | 
| 72 | 
            +
                        is_depth_image,
         | 
| 73 | 
             
                    ]
         | 
| 74 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 75 | 
             
                    run_button.click(fn=process,
         | 
| 76 | 
            +
                                     inputs=inputs,
         | 
| 77 | 
            +
                                     outputs=result,
         | 
| 78 | 
             
                                     api_name='depth')
         | 
| 79 | 
             
                return demo
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            if __name__ == '__main__':
         | 
| 83 | 
            +
                from model import Model
         | 
| 84 | 
            +
                model = Model()
         | 
| 85 | 
            +
                demo = create_demo(model.process_depth)
         | 
| 86 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_fake_scribble2image.py β app_fake_scribble.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Fake Scribble Maps')
         | 
| @@ -16,7 +16,7 @@ def create_demo(process, max_images=12): | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| @@ -28,22 +28,21 @@ def create_demo(process, max_images=12): | |
| 28 | 
             
                                                              maximum=1024,
         | 
| 29 | 
             
                                                              value=512,
         | 
| 30 | 
             
                                                              step=1)
         | 
| 31 | 
            -
                                 | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
                                 | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 42 | 
             
                                                 minimum=-1,
         | 
| 43 | 
             
                                                 maximum=2147483647,
         | 
| 44 | 
             
                                                 step=1,
         | 
| 45 | 
             
                                                 randomize=True)
         | 
| 46 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 47 | 
             
                                a_prompt = gr.Textbox(
         | 
| 48 | 
             
                                    label='Added Prompt',
         | 
| 49 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -53,16 +52,32 @@ def create_demo(process, max_images=12): | |
| 53 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 54 | 
             
                                )
         | 
| 55 | 
             
                        with gr.Column():
         | 
| 56 | 
            -
                             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                     | 
| 61 | 
            -
                        input_image, | 
| 62 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 63 | 
             
                    ]
         | 
|  | |
| 64 | 
             
                    run_button.click(fn=process,
         | 
| 65 | 
            -
                                     inputs= | 
| 66 | 
            -
                                     outputs= | 
| 67 | 
             
                                     api_name='fake_scribble')
         | 
| 68 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Fake Scribble Maps')
         | 
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            +
                                                        value=default_num_images,
         | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
|  | |
| 28 | 
             
                                                              maximum=1024,
         | 
| 29 | 
             
                                                              value=512,
         | 
| 30 | 
             
                                                              step=1)
         | 
| 31 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 32 | 
            +
                                                      minimum=1,
         | 
| 33 | 
            +
                                                      maximum=100,
         | 
| 34 | 
            +
                                                      value=20,
         | 
| 35 | 
            +
                                                      step=1)
         | 
| 36 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 37 | 
            +
                                                           minimum=0.1,
         | 
| 38 | 
            +
                                                           maximum=30.0,
         | 
| 39 | 
            +
                                                           value=9.0,
         | 
| 40 | 
            +
                                                           step=0.1)
         | 
| 41 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 42 | 
             
                                                 minimum=-1,
         | 
| 43 | 
             
                                                 maximum=2147483647,
         | 
| 44 | 
             
                                                 step=1,
         | 
| 45 | 
             
                                                 randomize=True)
         | 
|  | |
| 46 | 
             
                                a_prompt = gr.Textbox(
         | 
| 47 | 
             
                                    label='Added Prompt',
         | 
| 48 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 52 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 53 | 
             
                                )
         | 
| 54 | 
             
                        with gr.Column():
         | 
| 55 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 56 | 
            +
                                                show_label=False,
         | 
| 57 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 58 | 
            +
                                                                         height='auto')
         | 
| 59 | 
            +
                    inputs = [
         | 
| 60 | 
            +
                        input_image,
         | 
| 61 | 
            +
                        prompt,
         | 
| 62 | 
            +
                        a_prompt,
         | 
| 63 | 
            +
                        n_prompt,
         | 
| 64 | 
            +
                        num_samples,
         | 
| 65 | 
            +
                        image_resolution,
         | 
| 66 | 
            +
                        detect_resolution,
         | 
| 67 | 
            +
                        num_steps,
         | 
| 68 | 
            +
                        guidance_scale,
         | 
| 69 | 
            +
                        seed,
         | 
| 70 | 
             
                    ]
         | 
| 71 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 72 | 
             
                    run_button.click(fn=process,
         | 
| 73 | 
            +
                                     inputs=inputs,
         | 
| 74 | 
            +
                                     outputs=result,
         | 
| 75 | 
             
                                     api_name='fake_scribble')
         | 
| 76 | 
             
                return demo
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            if __name__ == '__main__':
         | 
| 80 | 
            +
                from model import Model
         | 
| 81 | 
            +
                model = Model()
         | 
| 82 | 
            +
                demo = create_demo(model.process_fake_scribble)
         | 
| 83 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_hed2image.py β app_hed.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with HED Maps')
         | 
| @@ -16,7 +16,7 @@ def create_demo(process, max_images=12): | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| @@ -28,22 +28,21 @@ def create_demo(process, max_images=12): | |
| 28 | 
             
                                                              maximum=1024,
         | 
| 29 | 
             
                                                              value=512,
         | 
| 30 | 
             
                                                              step=1)
         | 
| 31 | 
            -
                                 | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
                                 | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 42 | 
             
                                                 minimum=-1,
         | 
| 43 | 
             
                                                 maximum=2147483647,
         | 
| 44 | 
             
                                                 step=1,
         | 
| 45 | 
             
                                                 randomize=True)
         | 
| 46 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 47 | 
             
                                a_prompt = gr.Textbox(
         | 
| 48 | 
             
                                    label='Added Prompt',
         | 
| 49 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -53,16 +52,32 @@ def create_demo(process, max_images=12): | |
| 53 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 54 | 
             
                                )
         | 
| 55 | 
             
                        with gr.Column():
         | 
| 56 | 
            -
                             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                     | 
| 61 | 
            -
                        input_image, | 
| 62 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 63 | 
             
                    ]
         | 
|  | |
| 64 | 
             
                    run_button.click(fn=process,
         | 
| 65 | 
            -
                                     inputs= | 
| 66 | 
            -
                                     outputs= | 
| 67 | 
             
                                     api_name='hed')
         | 
| 68 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with HED Maps')
         | 
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            +
                                                        value=default_num_images,
         | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
|  | |
| 28 | 
             
                                                              maximum=1024,
         | 
| 29 | 
             
                                                              value=512,
         | 
| 30 | 
             
                                                              step=1)
         | 
| 31 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 32 | 
            +
                                                      minimum=1,
         | 
| 33 | 
            +
                                                      maximum=100,
         | 
| 34 | 
            +
                                                      value=20,
         | 
| 35 | 
            +
                                                      step=1)
         | 
| 36 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 37 | 
            +
                                                           minimum=0.1,
         | 
| 38 | 
            +
                                                           maximum=30.0,
         | 
| 39 | 
            +
                                                           value=9.0,
         | 
| 40 | 
            +
                                                           step=0.1)
         | 
| 41 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 42 | 
             
                                                 minimum=-1,
         | 
| 43 | 
             
                                                 maximum=2147483647,
         | 
| 44 | 
             
                                                 step=1,
         | 
| 45 | 
             
                                                 randomize=True)
         | 
|  | |
| 46 | 
             
                                a_prompt = gr.Textbox(
         | 
| 47 | 
             
                                    label='Added Prompt',
         | 
| 48 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 52 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 53 | 
             
                                )
         | 
| 54 | 
             
                        with gr.Column():
         | 
| 55 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 56 | 
            +
                                                show_label=False,
         | 
| 57 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 58 | 
            +
                                                                         height='auto')
         | 
| 59 | 
            +
                    inputs = [
         | 
| 60 | 
            +
                        input_image,
         | 
| 61 | 
            +
                        prompt,
         | 
| 62 | 
            +
                        a_prompt,
         | 
| 63 | 
            +
                        n_prompt,
         | 
| 64 | 
            +
                        num_samples,
         | 
| 65 | 
            +
                        image_resolution,
         | 
| 66 | 
            +
                        detect_resolution,
         | 
| 67 | 
            +
                        num_steps,
         | 
| 68 | 
            +
                        guidance_scale,
         | 
| 69 | 
            +
                        seed,
         | 
| 70 | 
             
                    ]
         | 
| 71 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 72 | 
             
                    run_button.click(fn=process,
         | 
| 73 | 
            +
                                     inputs=inputs,
         | 
| 74 | 
            +
                                     outputs=result,
         | 
| 75 | 
             
                                     api_name='hed')
         | 
| 76 | 
             
                return demo
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            if __name__ == '__main__':
         | 
| 80 | 
            +
                from model import Model
         | 
| 81 | 
            +
                model = Model()
         | 
| 82 | 
            +
                demo = create_demo(model.process_hed)
         | 
| 83 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_hough2image.py β app_hough.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Hough Line Maps')
         | 
| @@ -16,7 +16,7 @@ def create_demo(process, max_images=12): | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| @@ -28,34 +28,33 @@ def create_demo(process, max_images=12): | |
| 28 | 
             
                                                              maximum=1024,
         | 
| 29 | 
             
                                                              value=512,
         | 
| 30 | 
             
                                                              step=1)
         | 
| 31 | 
            -
                                 | 
| 32 | 
             
                                    label='Hough value threshold (MLSD)',
         | 
| 33 | 
             
                                    minimum=0.01,
         | 
| 34 | 
             
                                    maximum=2.0,
         | 
| 35 | 
             
                                    value=0.1,
         | 
| 36 | 
             
                                    step=0.01)
         | 
| 37 | 
            -
                                 | 
| 38 | 
             
                                    label='Hough distance threshold (MLSD)',
         | 
| 39 | 
             
                                    minimum=0.01,
         | 
| 40 | 
             
                                    maximum=20.0,
         | 
| 41 | 
             
                                    value=0.1,
         | 
| 42 | 
             
                                    step=0.01)
         | 
| 43 | 
            -
                                 | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
                                 | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 54 | 
             
                                                 minimum=-1,
         | 
| 55 | 
             
                                                 maximum=2147483647,
         | 
| 56 | 
             
                                                 step=1,
         | 
| 57 | 
             
                                                 randomize=True)
         | 
| 58 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 59 | 
             
                                a_prompt = gr.Textbox(
         | 
| 60 | 
             
                                    label='Added Prompt',
         | 
| 61 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -65,17 +64,34 @@ def create_demo(process, max_images=12): | |
| 65 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 66 | 
             
                                )
         | 
| 67 | 
             
                        with gr.Column():
         | 
| 68 | 
            -
                             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
                     | 
| 73 | 
            -
                        input_image, | 
| 74 | 
            -
                         | 
| 75 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 76 | 
             
                    ]
         | 
|  | |
| 77 | 
             
                    run_button.click(fn=process,
         | 
| 78 | 
            -
                                     inputs= | 
| 79 | 
            -
                                     outputs= | 
| 80 | 
             
                                     api_name='hough')
         | 
| 81 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Hough Line Maps')
         | 
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            +
                                                        value=default_num_images,
         | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
|  | |
| 28 | 
             
                                                              maximum=1024,
         | 
| 29 | 
             
                                                              value=512,
         | 
| 30 | 
             
                                                              step=1)
         | 
| 31 | 
            +
                                mlsd_value_threshold = gr.Slider(
         | 
| 32 | 
             
                                    label='Hough value threshold (MLSD)',
         | 
| 33 | 
             
                                    minimum=0.01,
         | 
| 34 | 
             
                                    maximum=2.0,
         | 
| 35 | 
             
                                    value=0.1,
         | 
| 36 | 
             
                                    step=0.01)
         | 
| 37 | 
            +
                                mlsd_distance_threshold = gr.Slider(
         | 
| 38 | 
             
                                    label='Hough distance threshold (MLSD)',
         | 
| 39 | 
             
                                    minimum=0.01,
         | 
| 40 | 
             
                                    maximum=20.0,
         | 
| 41 | 
             
                                    value=0.1,
         | 
| 42 | 
             
                                    step=0.01)
         | 
| 43 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 44 | 
            +
                                                      minimum=1,
         | 
| 45 | 
            +
                                                      maximum=100,
         | 
| 46 | 
            +
                                                      value=20,
         | 
| 47 | 
            +
                                                      step=1)
         | 
| 48 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 49 | 
            +
                                                           minimum=0.1,
         | 
| 50 | 
            +
                                                           maximum=30.0,
         | 
| 51 | 
            +
                                                           value=9.0,
         | 
| 52 | 
            +
                                                           step=0.1)
         | 
| 53 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 54 | 
             
                                                 minimum=-1,
         | 
| 55 | 
             
                                                 maximum=2147483647,
         | 
| 56 | 
             
                                                 step=1,
         | 
| 57 | 
             
                                                 randomize=True)
         | 
|  | |
| 58 | 
             
                                a_prompt = gr.Textbox(
         | 
| 59 | 
             
                                    label='Added Prompt',
         | 
| 60 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 64 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 65 | 
             
                                )
         | 
| 66 | 
             
                        with gr.Column():
         | 
| 67 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 68 | 
            +
                                                show_label=False,
         | 
| 69 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 70 | 
            +
                                                                         height='auto')
         | 
| 71 | 
            +
                    inputs = [
         | 
| 72 | 
            +
                        input_image,
         | 
| 73 | 
            +
                        prompt,
         | 
| 74 | 
            +
                        a_prompt,
         | 
| 75 | 
            +
                        n_prompt,
         | 
| 76 | 
            +
                        num_samples,
         | 
| 77 | 
            +
                        image_resolution,
         | 
| 78 | 
            +
                        detect_resolution,
         | 
| 79 | 
            +
                        num_steps,
         | 
| 80 | 
            +
                        guidance_scale,
         | 
| 81 | 
            +
                        seed,
         | 
| 82 | 
            +
                        mlsd_value_threshold,
         | 
| 83 | 
            +
                        mlsd_distance_threshold,
         | 
| 84 | 
             
                    ]
         | 
| 85 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 86 | 
             
                    run_button.click(fn=process,
         | 
| 87 | 
            +
                                     inputs=inputs,
         | 
| 88 | 
            +
                                     outputs=result,
         | 
| 89 | 
             
                                     api_name='hough')
         | 
| 90 | 
             
                return demo
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            if __name__ == '__main__':
         | 
| 94 | 
            +
                from model import Model
         | 
| 95 | 
            +
                model = Model()
         | 
| 96 | 
            +
                demo = create_demo(model.process_hough)
         | 
| 97 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_normal2image.py β app_normal.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Normal Maps')
         | 
| @@ -13,10 +13,12 @@ def create_demo(process, max_images=12): | |
| 13 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 14 | 
             
                            run_button = gr.Button(label='Run')
         | 
| 15 | 
             
                            with gr.Accordion('Advanced options', open=False):
         | 
|  | |
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| @@ -34,22 +36,21 @@ def create_demo(process, max_images=12): | |
| 34 | 
             
                                    maximum=1.0,
         | 
| 35 | 
             
                                    value=0.4,
         | 
| 36 | 
             
                                    step=0.01)
         | 
| 37 | 
            -
                                 | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
                                 | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 48 | 
             
                                                 minimum=-1,
         | 
| 49 | 
             
                                                 maximum=2147483647,
         | 
| 50 | 
             
                                                 step=1,
         | 
| 51 | 
             
                                                 randomize=True)
         | 
| 52 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 53 | 
             
                                a_prompt = gr.Textbox(
         | 
| 54 | 
             
                                    label='Added Prompt',
         | 
| 55 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -59,17 +60,34 @@ def create_demo(process, max_images=12): | |
| 59 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 60 | 
             
                                )
         | 
| 61 | 
             
                        with gr.Column():
         | 
| 62 | 
            -
                             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
             | 
| 66 | 
            -
                     | 
| 67 | 
            -
                        input_image, | 
| 68 | 
            -
                         | 
| 69 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 70 | 
             
                    ]
         | 
|  | |
| 71 | 
             
                    run_button.click(fn=process,
         | 
| 72 | 
            -
                                     inputs= | 
| 73 | 
            -
                                     outputs= | 
| 74 | 
             
                                     api_name='normal')
         | 
| 75 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Normal Maps')
         | 
|  | |
| 13 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 14 | 
             
                            run_button = gr.Button(label='Run')
         | 
| 15 | 
             
                            with gr.Accordion('Advanced options', open=False):
         | 
| 16 | 
            +
                                is_normal_image = gr.Checkbox(label='Is normal image',
         | 
| 17 | 
            +
                                                              value=False)
         | 
| 18 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 19 | 
             
                                                        minimum=1,
         | 
| 20 | 
             
                                                        maximum=max_images,
         | 
| 21 | 
            +
                                                        value=default_num_images,
         | 
| 22 | 
             
                                                        step=1)
         | 
| 23 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 24 | 
             
                                                             minimum=256,
         | 
|  | |
| 36 | 
             
                                    maximum=1.0,
         | 
| 37 | 
             
                                    value=0.4,
         | 
| 38 | 
             
                                    step=0.01)
         | 
| 39 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 40 | 
            +
                                                      minimum=1,
         | 
| 41 | 
            +
                                                      maximum=100,
         | 
| 42 | 
            +
                                                      value=20,
         | 
| 43 | 
            +
                                                      step=1)
         | 
| 44 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 45 | 
            +
                                                           minimum=0.1,
         | 
| 46 | 
            +
                                                           maximum=30.0,
         | 
| 47 | 
            +
                                                           value=9.0,
         | 
| 48 | 
            +
                                                           step=0.1)
         | 
| 49 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 50 | 
             
                                                 minimum=-1,
         | 
| 51 | 
             
                                                 maximum=2147483647,
         | 
| 52 | 
             
                                                 step=1,
         | 
| 53 | 
             
                                                 randomize=True)
         | 
|  | |
| 54 | 
             
                                a_prompt = gr.Textbox(
         | 
| 55 | 
             
                                    label='Added Prompt',
         | 
| 56 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 60 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 61 | 
             
                                )
         | 
| 62 | 
             
                        with gr.Column():
         | 
| 63 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 64 | 
            +
                                                show_label=False,
         | 
| 65 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 66 | 
            +
                                                                         height='auto')
         | 
| 67 | 
            +
                    inputs = [
         | 
| 68 | 
            +
                        input_image,
         | 
| 69 | 
            +
                        prompt,
         | 
| 70 | 
            +
                        a_prompt,
         | 
| 71 | 
            +
                        n_prompt,
         | 
| 72 | 
            +
                        num_samples,
         | 
| 73 | 
            +
                        image_resolution,
         | 
| 74 | 
            +
                        detect_resolution,
         | 
| 75 | 
            +
                        num_steps,
         | 
| 76 | 
            +
                        guidance_scale,
         | 
| 77 | 
            +
                        seed,
         | 
| 78 | 
            +
                        bg_threshold,
         | 
| 79 | 
            +
                        is_normal_image,
         | 
| 80 | 
             
                    ]
         | 
| 81 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 82 | 
             
                    run_button.click(fn=process,
         | 
| 83 | 
            +
                                     inputs=inputs,
         | 
| 84 | 
            +
                                     outputs=result,
         | 
| 85 | 
             
                                     api_name='normal')
         | 
| 86 | 
             
                return demo
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            if __name__ == '__main__':
         | 
| 90 | 
            +
                from model import Model
         | 
| 91 | 
            +
                model = Model()
         | 
| 92 | 
            +
                demo = create_demo(model.process_normal)
         | 
| 93 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_pose2image.py β app_pose.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Human Pose')
         | 
| @@ -13,10 +13,15 @@ def create_demo(process, max_images=12): | |
| 13 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 14 | 
             
                            run_button = gr.Button(label='Run')
         | 
| 15 | 
             
                            with gr.Accordion('Advanced options', open=False):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| @@ -28,22 +33,21 @@ def create_demo(process, max_images=12): | |
| 28 | 
             
                                                              maximum=1024,
         | 
| 29 | 
             
                                                              value=512,
         | 
| 30 | 
             
                                                              step=1)
         | 
| 31 | 
            -
                                 | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
                                 | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 42 | 
             
                                                 minimum=-1,
         | 
| 43 | 
             
                                                 maximum=2147483647,
         | 
| 44 | 
             
                                                 step=1,
         | 
| 45 | 
             
                                                 randomize=True)
         | 
| 46 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 47 | 
             
                                a_prompt = gr.Textbox(
         | 
| 48 | 
             
                                    label='Added Prompt',
         | 
| 49 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -53,16 +57,33 @@ def create_demo(process, max_images=12): | |
| 53 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 54 | 
             
                                )
         | 
| 55 | 
             
                        with gr.Column():
         | 
| 56 | 
            -
                             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                     | 
| 61 | 
            -
                        input_image, | 
| 62 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 63 | 
             
                    ]
         | 
|  | |
| 64 | 
             
                    run_button.click(fn=process,
         | 
| 65 | 
            -
                                     inputs= | 
| 66 | 
            -
                                     outputs= | 
| 67 | 
             
                                     api_name='pose')
         | 
| 68 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Human Pose')
         | 
|  | |
| 13 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 14 | 
             
                            run_button = gr.Button(label='Run')
         | 
| 15 | 
             
                            with gr.Accordion('Advanced options', open=False):
         | 
| 16 | 
            +
                                is_pose_image = gr.Checkbox(label='Is pose image',
         | 
| 17 | 
            +
                                                            value=False)
         | 
| 18 | 
            +
                                gr.Markdown(
         | 
| 19 | 
            +
                                    'You can use [PoseMaker2](https://huggingface.co/spaces/jonigata/PoseMaker2) to create pose images.'
         | 
| 20 | 
            +
                                )
         | 
| 21 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 22 | 
             
                                                        minimum=1,
         | 
| 23 | 
             
                                                        maximum=max_images,
         | 
| 24 | 
            +
                                                        value=default_num_images,
         | 
| 25 | 
             
                                                        step=1)
         | 
| 26 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 27 | 
             
                                                             minimum=256,
         | 
|  | |
| 33 | 
             
                                                              maximum=1024,
         | 
| 34 | 
             
                                                              value=512,
         | 
| 35 | 
             
                                                              step=1)
         | 
| 36 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 37 | 
            +
                                                      minimum=1,
         | 
| 38 | 
            +
                                                      maximum=100,
         | 
| 39 | 
            +
                                                      value=20,
         | 
| 40 | 
            +
                                                      step=1)
         | 
| 41 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 42 | 
            +
                                                           minimum=0.1,
         | 
| 43 | 
            +
                                                           maximum=30.0,
         | 
| 44 | 
            +
                                                           value=9.0,
         | 
| 45 | 
            +
                                                           step=0.1)
         | 
| 46 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 47 | 
             
                                                 minimum=-1,
         | 
| 48 | 
             
                                                 maximum=2147483647,
         | 
| 49 | 
             
                                                 step=1,
         | 
| 50 | 
             
                                                 randomize=True)
         | 
|  | |
| 51 | 
             
                                a_prompt = gr.Textbox(
         | 
| 52 | 
             
                                    label='Added Prompt',
         | 
| 53 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 57 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 58 | 
             
                                )
         | 
| 59 | 
             
                        with gr.Column():
         | 
| 60 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 61 | 
            +
                                                show_label=False,
         | 
| 62 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 63 | 
            +
                                                                         height='auto')
         | 
| 64 | 
            +
                    inputs = [
         | 
| 65 | 
            +
                        input_image,
         | 
| 66 | 
            +
                        prompt,
         | 
| 67 | 
            +
                        a_prompt,
         | 
| 68 | 
            +
                        n_prompt,
         | 
| 69 | 
            +
                        num_samples,
         | 
| 70 | 
            +
                        image_resolution,
         | 
| 71 | 
            +
                        detect_resolution,
         | 
| 72 | 
            +
                        num_steps,
         | 
| 73 | 
            +
                        guidance_scale,
         | 
| 74 | 
            +
                        seed,
         | 
| 75 | 
            +
                        is_pose_image,
         | 
| 76 | 
             
                    ]
         | 
| 77 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 78 | 
             
                    run_button.click(fn=process,
         | 
| 79 | 
            +
                                     inputs=inputs,
         | 
| 80 | 
            +
                                     outputs=result,
         | 
| 81 | 
             
                                     api_name='pose')
         | 
| 82 | 
             
                return demo
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            if __name__ == '__main__':
         | 
| 86 | 
            +
                from model import Model
         | 
| 87 | 
            +
                model = Model()
         | 
| 88 | 
            +
                demo = create_demo(model.process_pose)
         | 
| 89 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_scribble2image.py β app_scribble.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Scribble Maps')
         | 
| @@ -16,29 +16,28 @@ def create_demo(process, max_images=12): | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| 23 | 
             
                                                             maximum=768,
         | 
| 24 | 
             
                                                             value=512,
         | 
| 25 | 
             
                                                             step=256)
         | 
| 26 | 
            -
                                 | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
                                 | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 37 | 
             
                                                 minimum=-1,
         | 
| 38 | 
             
                                                 maximum=2147483647,
         | 
| 39 | 
             
                                                 step=1,
         | 
| 40 | 
             
                                                 randomize=True)
         | 
| 41 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 42 | 
             
                                a_prompt = gr.Textbox(
         | 
| 43 | 
             
                                    label='Added Prompt',
         | 
| 44 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -48,16 +47,31 @@ def create_demo(process, max_images=12): | |
| 48 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 49 | 
             
                                )
         | 
| 50 | 
             
                        with gr.Column():
         | 
| 51 | 
            -
                             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
                     | 
| 56 | 
            -
                        input_image, | 
| 57 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 58 | 
             
                    ]
         | 
|  | |
| 59 | 
             
                    run_button.click(fn=process,
         | 
| 60 | 
            -
                                     inputs= | 
| 61 | 
            -
                                     outputs= | 
| 62 | 
             
                                     api_name='scribble')
         | 
| 63 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Scribble Maps')
         | 
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            +
                                                        value=default_num_images,
         | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| 23 | 
             
                                                             maximum=768,
         | 
| 24 | 
             
                                                             value=512,
         | 
| 25 | 
             
                                                             step=256)
         | 
| 26 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 27 | 
            +
                                                      minimum=1,
         | 
| 28 | 
            +
                                                      maximum=100,
         | 
| 29 | 
            +
                                                      value=20,
         | 
| 30 | 
            +
                                                      step=1)
         | 
| 31 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 32 | 
            +
                                                           minimum=0.1,
         | 
| 33 | 
            +
                                                           maximum=30.0,
         | 
| 34 | 
            +
                                                           value=9.0,
         | 
| 35 | 
            +
                                                           step=0.1)
         | 
| 36 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 37 | 
             
                                                 minimum=-1,
         | 
| 38 | 
             
                                                 maximum=2147483647,
         | 
| 39 | 
             
                                                 step=1,
         | 
| 40 | 
             
                                                 randomize=True)
         | 
|  | |
| 41 | 
             
                                a_prompt = gr.Textbox(
         | 
| 42 | 
             
                                    label='Added Prompt',
         | 
| 43 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 47 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 48 | 
             
                                )
         | 
| 49 | 
             
                        with gr.Column():
         | 
| 50 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 51 | 
            +
                                                show_label=False,
         | 
| 52 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 53 | 
            +
                                                                         height='auto')
         | 
| 54 | 
            +
                    inputs = [
         | 
| 55 | 
            +
                        input_image,
         | 
| 56 | 
            +
                        prompt,
         | 
| 57 | 
            +
                        a_prompt,
         | 
| 58 | 
            +
                        n_prompt,
         | 
| 59 | 
            +
                        num_samples,
         | 
| 60 | 
            +
                        image_resolution,
         | 
| 61 | 
            +
                        num_steps,
         | 
| 62 | 
            +
                        guidance_scale,
         | 
| 63 | 
            +
                        seed,
         | 
| 64 | 
             
                    ]
         | 
| 65 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 66 | 
             
                    run_button.click(fn=process,
         | 
| 67 | 
            +
                                     inputs=inputs,
         | 
| 68 | 
            +
                                     outputs=result,
         | 
| 69 | 
             
                                     api_name='scribble')
         | 
| 70 | 
             
                return demo
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            if __name__ == '__main__':
         | 
| 74 | 
            +
                from model import Model
         | 
| 75 | 
            +
                model = Model()
         | 
| 76 | 
            +
                demo = create_demo(model.process_scribble)
         | 
| 77 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_scribble2image_interactive.py β app_scribble_interactive.py
    RENAMED
    
    | @@ -8,7 +8,7 @@ def create_canvas(w, h): | |
| 8 | 
             
                return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
         | 
| 9 |  | 
| 10 |  | 
| 11 | 
            -
            def create_demo(process, max_images=12):
         | 
| 12 | 
             
                with gr.Blocks() as demo:
         | 
| 13 | 
             
                    with gr.Row():
         | 
| 14 | 
             
                        gr.Markdown(
         | 
| @@ -37,7 +37,7 @@ def create_demo(process, max_images=12): | |
| 37 | 
             
                            )
         | 
| 38 | 
             
                            create_button.click(fn=create_canvas,
         | 
| 39 | 
             
                                                inputs=[canvas_width, canvas_height],
         | 
| 40 | 
            -
                                                outputs= | 
| 41 | 
             
                                                queue=False)
         | 
| 42 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 43 | 
             
                            run_button = gr.Button(label='Run')
         | 
| @@ -45,29 +45,28 @@ def create_demo(process, max_images=12): | |
| 45 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 46 | 
             
                                                        minimum=1,
         | 
| 47 | 
             
                                                        maximum=max_images,
         | 
| 48 | 
            -
                                                        value= | 
| 49 | 
             
                                                        step=1)
         | 
| 50 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 51 | 
             
                                                             minimum=256,
         | 
| 52 | 
             
                                                             maximum=768,
         | 
| 53 | 
             
                                                             value=512,
         | 
| 54 | 
             
                                                             step=256)
         | 
| 55 | 
            -
                                 | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                                 | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 66 | 
             
                                                 minimum=-1,
         | 
| 67 | 
             
                                                 maximum=2147483647,
         | 
| 68 | 
             
                                                 step=1,
         | 
| 69 | 
             
                                                 randomize=True)
         | 
| 70 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 71 | 
             
                                a_prompt = gr.Textbox(
         | 
| 72 | 
             
                                    label='Added Prompt',
         | 
| 73 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -77,13 +76,28 @@ def create_demo(process, max_images=12): | |
| 77 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 78 | 
             
                                )
         | 
| 79 | 
             
                        with gr.Column():
         | 
| 80 | 
            -
                             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
                     | 
| 85 | 
            -
                        input_image, | 
| 86 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 87 | 
             
                    ]
         | 
| 88 | 
            -
                     | 
|  | |
| 89 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 8 | 
             
                return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
         | 
| 9 |  | 
| 10 |  | 
| 11 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 12 | 
             
                with gr.Blocks() as demo:
         | 
| 13 | 
             
                    with gr.Row():
         | 
| 14 | 
             
                        gr.Markdown(
         | 
|  | |
| 37 | 
             
                            )
         | 
| 38 | 
             
                            create_button.click(fn=create_canvas,
         | 
| 39 | 
             
                                                inputs=[canvas_width, canvas_height],
         | 
| 40 | 
            +
                                                outputs=input_image,
         | 
| 41 | 
             
                                                queue=False)
         | 
| 42 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 43 | 
             
                            run_button = gr.Button(label='Run')
         | 
|  | |
| 45 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 46 | 
             
                                                        minimum=1,
         | 
| 47 | 
             
                                                        maximum=max_images,
         | 
| 48 | 
            +
                                                        value=default_num_images,
         | 
| 49 | 
             
                                                        step=1)
         | 
| 50 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 51 | 
             
                                                             minimum=256,
         | 
| 52 | 
             
                                                             maximum=768,
         | 
| 53 | 
             
                                                             value=512,
         | 
| 54 | 
             
                                                             step=256)
         | 
| 55 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 56 | 
            +
                                                      minimum=1,
         | 
| 57 | 
            +
                                                      maximum=100,
         | 
| 58 | 
            +
                                                      value=20,
         | 
| 59 | 
            +
                                                      step=1)
         | 
| 60 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 61 | 
            +
                                                           minimum=0.1,
         | 
| 62 | 
            +
                                                           maximum=30.0,
         | 
| 63 | 
            +
                                                           value=9.0,
         | 
| 64 | 
            +
                                                           step=0.1)
         | 
| 65 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 66 | 
             
                                                 minimum=-1,
         | 
| 67 | 
             
                                                 maximum=2147483647,
         | 
| 68 | 
             
                                                 step=1,
         | 
| 69 | 
             
                                                 randomize=True)
         | 
|  | |
| 70 | 
             
                                a_prompt = gr.Textbox(
         | 
| 71 | 
             
                                    label='Added Prompt',
         | 
| 72 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 76 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 77 | 
             
                                )
         | 
| 78 | 
             
                        with gr.Column():
         | 
| 79 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 80 | 
            +
                                                show_label=False,
         | 
| 81 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 82 | 
            +
                                                                         height='auto')
         | 
| 83 | 
            +
                    inputs = [
         | 
| 84 | 
            +
                        input_image,
         | 
| 85 | 
            +
                        prompt,
         | 
| 86 | 
            +
                        a_prompt,
         | 
| 87 | 
            +
                        n_prompt,
         | 
| 88 | 
            +
                        num_samples,
         | 
| 89 | 
            +
                        image_resolution,
         | 
| 90 | 
            +
                        num_steps,
         | 
| 91 | 
            +
                        guidance_scale,
         | 
| 92 | 
            +
                        seed,
         | 
| 93 | 
             
                    ]
         | 
| 94 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 95 | 
            +
                    run_button.click(fn=process, inputs=inputs, outputs=result)
         | 
| 96 | 
             
                return demo
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            if __name__ == '__main__':
         | 
| 100 | 
            +
                from model import Model
         | 
| 101 | 
            +
                model = Model()
         | 
| 102 | 
            +
                demo = create_demo(model.process_scribble_interactive)
         | 
| 103 | 
            +
                demo.queue().launch()
         | 
    	
        gradio_seg2image.py β app_seg.py
    RENAMED
    
    | @@ -3,7 +3,7 @@ | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def create_demo(process, max_images=12):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Segmentation Maps')
         | 
| @@ -13,10 +13,12 @@ def create_demo(process, max_images=12): | |
| 13 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 14 | 
             
                            run_button = gr.Button(label='Run')
         | 
| 15 | 
             
                            with gr.Accordion('Advanced options', open=False):
         | 
|  | |
|  | |
| 16 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 17 | 
             
                                                        minimum=1,
         | 
| 18 | 
             
                                                        maximum=max_images,
         | 
| 19 | 
            -
                                                        value= | 
| 20 | 
             
                                                        step=1)
         | 
| 21 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 22 | 
             
                                                             minimum=256,
         | 
| @@ -29,22 +31,21 @@ def create_demo(process, max_images=12): | |
| 29 | 
             
                                    maximum=1024,
         | 
| 30 | 
             
                                    value=512,
         | 
| 31 | 
             
                                    step=1)
         | 
| 32 | 
            -
                                 | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
                                 | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 43 | 
             
                                                 minimum=-1,
         | 
| 44 | 
             
                                                 maximum=2147483647,
         | 
| 45 | 
             
                                                 step=1,
         | 
| 46 | 
             
                                                 randomize=True)
         | 
| 47 | 
            -
                                eta = gr.Number(label='eta (DDIM)', value=0.0)
         | 
| 48 | 
             
                                a_prompt = gr.Textbox(
         | 
| 49 | 
             
                                    label='Added Prompt',
         | 
| 50 | 
             
                                    value='best quality, extremely detailed')
         | 
| @@ -54,16 +55,33 @@ def create_demo(process, max_images=12): | |
| 54 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 55 | 
             
                                )
         | 
| 56 | 
             
                        with gr.Column():
         | 
| 57 | 
            -
                             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
                     | 
| 62 | 
            -
                        input_image, | 
| 63 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 64 | 
             
                    ]
         | 
|  | |
| 65 | 
             
                    run_button.click(fn=process,
         | 
| 66 | 
            -
                                     inputs= | 
| 67 | 
            -
                                     outputs= | 
| 68 | 
             
                                     api_name='seg')
         | 
| 69 | 
             
                return demo
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def create_demo(process, max_images=12, default_num_images=3):
         | 
| 7 | 
             
                with gr.Blocks() as demo:
         | 
| 8 | 
             
                    with gr.Row():
         | 
| 9 | 
             
                        gr.Markdown('## Control Stable Diffusion with Segmentation Maps')
         | 
|  | |
| 13 | 
             
                            prompt = gr.Textbox(label='Prompt')
         | 
| 14 | 
             
                            run_button = gr.Button(label='Run')
         | 
| 15 | 
             
                            with gr.Accordion('Advanced options', open=False):
         | 
| 16 | 
            +
                                is_segmentation_map = gr.Checkbox(
         | 
| 17 | 
            +
                                    label='Is segmentation map', value=False)
         | 
| 18 | 
             
                                num_samples = gr.Slider(label='Images',
         | 
| 19 | 
             
                                                        minimum=1,
         | 
| 20 | 
             
                                                        maximum=max_images,
         | 
| 21 | 
            +
                                                        value=default_num_images,
         | 
| 22 | 
             
                                                        step=1)
         | 
| 23 | 
             
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 24 | 
             
                                                             minimum=256,
         | 
|  | |
| 31 | 
             
                                    maximum=1024,
         | 
| 32 | 
             
                                    value=512,
         | 
| 33 | 
             
                                    step=1)
         | 
| 34 | 
            +
                                num_steps = gr.Slider(label='Steps',
         | 
| 35 | 
            +
                                                      minimum=1,
         | 
| 36 | 
            +
                                                      maximum=100,
         | 
| 37 | 
            +
                                                      value=20,
         | 
| 38 | 
            +
                                                      step=1)
         | 
| 39 | 
            +
                                guidance_scale = gr.Slider(label='Guidance Scale',
         | 
| 40 | 
            +
                                                           minimum=0.1,
         | 
| 41 | 
            +
                                                           maximum=30.0,
         | 
| 42 | 
            +
                                                           value=9.0,
         | 
| 43 | 
            +
                                                           step=0.1)
         | 
| 44 | 
             
                                seed = gr.Slider(label='Seed',
         | 
| 45 | 
             
                                                 minimum=-1,
         | 
| 46 | 
             
                                                 maximum=2147483647,
         | 
| 47 | 
             
                                                 step=1,
         | 
| 48 | 
             
                                                 randomize=True)
         | 
|  | |
| 49 | 
             
                                a_prompt = gr.Textbox(
         | 
| 50 | 
             
                                    label='Added Prompt',
         | 
| 51 | 
             
                                    value='best quality, extremely detailed')
         | 
|  | |
| 55 | 
             
                                    'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 56 | 
             
                                )
         | 
| 57 | 
             
                        with gr.Column():
         | 
| 58 | 
            +
                            result = gr.Gallery(label='Output',
         | 
| 59 | 
            +
                                                show_label=False,
         | 
| 60 | 
            +
                                                elem_id='gallery').style(grid=2,
         | 
| 61 | 
            +
                                                                         height='auto')
         | 
| 62 | 
            +
                    inputs = [
         | 
| 63 | 
            +
                        input_image,
         | 
| 64 | 
            +
                        prompt,
         | 
| 65 | 
            +
                        a_prompt,
         | 
| 66 | 
            +
                        n_prompt,
         | 
| 67 | 
            +
                        num_samples,
         | 
| 68 | 
            +
                        image_resolution,
         | 
| 69 | 
            +
                        detect_resolution,
         | 
| 70 | 
            +
                        num_steps,
         | 
| 71 | 
            +
                        guidance_scale,
         | 
| 72 | 
            +
                        seed,
         | 
| 73 | 
            +
                        is_segmentation_map,
         | 
| 74 | 
             
                    ]
         | 
| 75 | 
            +
                    prompt.submit(fn=process, inputs=inputs, outputs=result)
         | 
| 76 | 
             
                    run_button.click(fn=process,
         | 
| 77 | 
            +
                                     inputs=inputs,
         | 
| 78 | 
            +
                                     outputs=result,
         | 
| 79 | 
             
                                     api_name='seg')
         | 
| 80 | 
             
                return demo
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            if __name__ == '__main__':
         | 
| 84 | 
            +
                from model import Model
         | 
| 85 | 
            +
                model = Model()
         | 
| 86 | 
            +
                demo = create_demo(model.process_seg)
         | 
| 87 | 
            +
                demo.queue().launch()
         | 
    	
        model.py
    CHANGED
    
    | @@ -3,21 +3,20 @@ | |
| 3 | 
             
            from __future__ import annotations
         | 
| 4 |  | 
| 5 | 
             
            import pathlib
         | 
| 6 | 
            -
            import random
         | 
| 7 | 
            -
            import shlex
         | 
| 8 | 
            -
            import subprocess
         | 
| 9 | 
             
            import sys
         | 
| 10 |  | 
| 11 | 
             
            import cv2
         | 
| 12 | 
            -
            import einops
         | 
| 13 | 
             
            import numpy as np
         | 
|  | |
| 14 | 
             
            import torch
         | 
| 15 | 
            -
            from  | 
| 16 | 
            -
             | 
|  | |
| 17 |  | 
| 18 | 
            -
             | 
|  | |
|  | |
| 19 |  | 
| 20 | 
            -
            import config
         | 
| 21 | 
             
            from annotator.canny import apply_canny
         | 
| 22 | 
             
            from annotator.hed import apply_hed, nms
         | 
| 23 | 
             
            from annotator.midas import apply_midas
         | 
| @@ -25,733 +24,600 @@ from annotator.mlsd import apply_mlsd | |
| 25 | 
             
            from annotator.openpose import apply_openpose
         | 
| 26 | 
             
            from annotator.uniformer import apply_uniformer
         | 
| 27 | 
             
            from annotator.util import HWC3, resize_image
         | 
| 28 | 
            -
            from cldm.model import create_model, load_state_dict
         | 
| 29 | 
            -
            from ldm.models.diffusion.ddim import DDIMSampler
         | 
| 30 | 
             
            from share import *
         | 
| 31 |  | 
| 32 | 
            -
             | 
| 33 | 
            -
                'canny': ' | 
| 34 | 
            -
                'hough': ' | 
| 35 | 
            -
                'hed': ' | 
| 36 | 
            -
                'scribble': ' | 
| 37 | 
            -
                'pose': ' | 
| 38 | 
            -
                'seg': ' | 
| 39 | 
            -
                'depth': ' | 
| 40 | 
            -
                'normal': ' | 
| 41 | 
             
            }
         | 
| 42 | 
            -
            MODEL_REPO = 'webui/ControlNet-modules-safetensors'
         | 
| 43 |  | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
|  | |
| 47 |  | 
| 48 |  | 
| 49 | 
             
            class Model:
         | 
| 50 | 
             
                def __init__(self,
         | 
| 51 | 
            -
                              | 
| 52 | 
            -
                              | 
| 53 | 
            -
                    self. | 
| 54 | 
            -
                        'cuda:0' if torch.cuda.is_available() else 'cpu')
         | 
| 55 | 
            -
                    self.model = create_model(model_config_path).to(self.device)
         | 
| 56 | 
            -
                    self.ddim_sampler = DDIMSampler(self.model)
         | 
| 57 | 
             
                    self.task_name = ''
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 58 |  | 
| 59 | 
            -
             | 
| 60 | 
            -
                    self.model_dir = pathlib.Path(model_dir)
         | 
| 61 | 
            -
                    self.model_dir.mkdir(exist_ok=True, parents=True)
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                    self.download_models()
         | 
| 64 | 
            -
                    self.set_base_model(DEFAULT_BASE_MODEL_REPO,
         | 
| 65 | 
            -
                                        DEFAULT_BASE_MODEL_FILENAME)
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                def set_base_model(self, model_id: str, filename: str) -> str:
         | 
| 68 | 
            -
                    if not model_id or not filename:
         | 
| 69 | 
            -
                        return self.base_model_url
         | 
| 70 | 
            -
                    base_model_url = hf_hub_url(model_id, filename)
         | 
| 71 | 
            -
                    if base_model_url != self.base_model_url:
         | 
| 72 | 
            -
                        self.load_base_model(base_model_url)
         | 
| 73 | 
            -
                        self.base_model_url = base_model_url
         | 
| 74 | 
            -
                    return self.base_model_url
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                def download_base_model(self, model_url: str) -> pathlib.Path:
         | 
| 77 | 
            -
                    self.model_dir.mkdir(exist_ok=True, parents=True)
         | 
| 78 | 
            -
                    model_name = model_url.split('/')[-1]
         | 
| 79 | 
            -
                    out_path = self.model_dir / model_name
         | 
| 80 | 
            -
                    if not out_path.exists():
         | 
| 81 | 
            -
                        subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
         | 
| 82 | 
            -
                    return out_path
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                def load_base_model(self, model_url: str) -> None:
         | 
| 85 | 
            -
                    model_path = self.download_base_model(model_url)
         | 
| 86 | 
            -
                    self.model.load_state_dict(load_state_dict(model_path,
         | 
| 87 | 
            -
                                                               location=self.device.type),
         | 
| 88 | 
            -
                                               strict=False)
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                def load_weight(self, task_name: str) -> None:
         | 
| 91 | 
             
                    if task_name == self.task_name:
         | 
| 92 | 
             
                        return
         | 
| 93 | 
            -
                     | 
| 94 | 
            -
                     | 
| 95 | 
            -
             | 
|  | |
|  | |
|  | |
| 96 | 
             
                    self.task_name = task_name
         | 
| 97 |  | 
| 98 | 
            -
                def  | 
| 99 | 
            -
                    if  | 
| 100 | 
            -
                         | 
| 101 | 
            -
                     | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 111 |  | 
| 112 | 
             
                @torch.inference_mode()
         | 
| 113 | 
            -
                def process_canny( | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
                     | 
| 117 | 
            -
             | 
| 118 | 
            -
                     | 
| 119 | 
            -
                     | 
| 120 | 
            -
             | 
| 121 | 
            -
                     | 
| 122 | 
            -
                     | 
| 123 | 
            -
             | 
| 124 | 
            -
                     | 
| 125 | 
            -
                     | 
| 126 | 
            -
             | 
| 127 | 
            -
             | 
| 128 | 
            -
             | 
| 129 | 
            -
                         | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
            -
                     | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 | 
            -
                         | 
| 137 | 
            -
                         | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
                         | 
| 141 | 
            -
             | 
| 142 | 
            -
                     | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
                     | 
| 148 | 
            -
             | 
| 149 | 
            -
                     | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
                         | 
| 156 | 
            -
                         | 
| 157 | 
            -
             | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
                    if config.save_memory:
         | 
| 163 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 164 | 
            -
             | 
| 165 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 166 | 
            -
                    x_samples = (
         | 
| 167 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 168 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 171 | 
            -
                    return [255 - detected_map] + results
         | 
| 172 |  | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
                                  num_samples, image_resolution, detect_resolution,
         | 
| 176 | 
            -
                                  ddim_steps, scale, seed, eta, value_threshold,
         | 
| 177 | 
            -
                                  distance_threshold):
         | 
| 178 | 
            -
                    self.load_weight('hough')
         | 
| 179 |  | 
| 180 | 
            -
                     | 
| 181 | 
            -
             | 
| 182 | 
            -
                                              value_threshold, distance_threshold)
         | 
| 183 | 
            -
                    detected_map = HWC3(detected_map)
         | 
| 184 | 
            -
                    img = resize_image(input_image, image_resolution)
         | 
| 185 | 
            -
                    H, W, C = img.shape
         | 
| 186 | 
            -
             | 
| 187 | 
            -
                    detected_map = cv2.resize(detected_map, (W, H),
         | 
| 188 | 
            -
                                              interpolation=cv2.INTER_NEAREST)
         | 
| 189 | 
            -
             | 
| 190 | 
            -
                    control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
         | 
| 191 | 
            -
                    control = torch.stack([control for _ in range(num_samples)], dim=0)
         | 
| 192 | 
            -
                    control = einops.rearrange(control, 'b h w c -> b c h w').clone()
         | 
| 193 | 
            -
             | 
| 194 | 
            -
                    if seed == -1:
         | 
| 195 | 
            -
                        seed = random.randint(0, 65535)
         | 
| 196 | 
            -
                    seed_everything(seed)
         | 
| 197 | 
            -
             | 
| 198 | 
            -
                    if config.save_memory:
         | 
| 199 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 200 | 
            -
             | 
| 201 | 
            -
                    cond = {
         | 
| 202 | 
            -
                        'c_concat': [control],
         | 
| 203 | 
            -
                        'c_crossattn': [
         | 
| 204 | 
            -
                            self.model.get_learned_conditioning(
         | 
| 205 | 
            -
                                [prompt + ', ' + a_prompt] * num_samples)
         | 
| 206 | 
            -
                        ]
         | 
| 207 | 
            -
                    }
         | 
| 208 | 
            -
                    un_cond = {
         | 
| 209 | 
            -
                        'c_concat': [control],
         | 
| 210 | 
            -
                        'c_crossattn':
         | 
| 211 | 
            -
                        [self.model.get_learned_conditioning([n_prompt] * num_samples)]
         | 
| 212 | 
            -
                    }
         | 
| 213 | 
            -
                    shape = (4, H // 8, W // 8)
         | 
| 214 | 
            -
             | 
| 215 | 
            -
                    if config.save_memory:
         | 
| 216 | 
            -
                        self.model.low_vram_shift(is_diffusing=True)
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                    samples, intermediates = self.ddim_sampler.sample(
         | 
| 219 | 
            -
                        ddim_steps,
         | 
| 220 | 
            -
                        num_samples,
         | 
| 221 | 
            -
                        shape,
         | 
| 222 | 
            -
                        cond,
         | 
| 223 | 
            -
                        verbose=False,
         | 
| 224 | 
            -
                        eta=eta,
         | 
| 225 | 
            -
                        unconditional_guidance_scale=scale,
         | 
| 226 | 
            -
                        unconditional_conditioning=un_cond)
         | 
| 227 | 
            -
             | 
| 228 | 
            -
                    if config.save_memory:
         | 
| 229 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 230 | 
            -
             | 
| 231 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 232 | 
            -
                    x_samples = (
         | 
| 233 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 234 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 235 | 
            -
             | 
| 236 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 237 | 
            -
                    return [
         | 
| 238 | 
            -
                        255 - cv2.dilate(detected_map,
         | 
| 239 | 
            -
                                         np.ones(shape=(3, 3), dtype=np.uint8),
         | 
| 240 | 
            -
                                         iterations=1)
         | 
| 241 | 
            -
                    ] + results
         | 
| 242 |  | 
| 243 | 
             
                @torch.inference_mode()
         | 
| 244 | 
            -
                def  | 
| 245 | 
            -
             | 
| 246 | 
            -
             | 
| 247 | 
            -
                     | 
| 248 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 249 | 
             
                    input_image = HWC3(input_image)
         | 
| 250 | 
            -
                     | 
| 251 | 
            -
                     | 
| 252 | 
            -
                     | 
| 253 | 
            -
                    H, W | 
| 254 | 
            -
             | 
| 255 | 
            -
             | 
| 256 | 
            -
             | 
| 257 | 
            -
             | 
| 258 | 
            -
                    control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
         | 
| 259 | 
            -
                    control = torch.stack([control for _ in range(num_samples)], dim=0)
         | 
| 260 | 
            -
                    control = einops.rearrange(control, 'b h w c -> b c h w').clone()
         | 
| 261 | 
            -
             | 
| 262 | 
            -
                    if seed == -1:
         | 
| 263 | 
            -
                        seed = random.randint(0, 65535)
         | 
| 264 | 
            -
                    seed_everything(seed)
         | 
| 265 | 
            -
             | 
| 266 | 
            -
                    if config.save_memory:
         | 
| 267 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 268 | 
            -
             | 
| 269 | 
            -
                    cond = {
         | 
| 270 | 
            -
                        'c_concat': [control],
         | 
| 271 | 
            -
                        'c_crossattn': [
         | 
| 272 | 
            -
                            self.model.get_learned_conditioning(
         | 
| 273 | 
            -
                                [prompt + ', ' + a_prompt] * num_samples)
         | 
| 274 | 
            -
                        ]
         | 
| 275 | 
            -
                    }
         | 
| 276 | 
            -
                    un_cond = {
         | 
| 277 | 
            -
                        'c_concat': [control],
         | 
| 278 | 
            -
                        'c_crossattn':
         | 
| 279 | 
            -
                        [self.model.get_learned_conditioning([n_prompt] * num_samples)]
         | 
| 280 | 
            -
                    }
         | 
| 281 | 
            -
                    shape = (4, H // 8, W // 8)
         | 
| 282 | 
            -
             | 
| 283 | 
            -
                    if config.save_memory:
         | 
| 284 | 
            -
                        self.model.low_vram_shift(is_diffusing=True)
         | 
| 285 | 
            -
             | 
| 286 | 
            -
                    samples, intermediates = self.ddim_sampler.sample(
         | 
| 287 | 
            -
                        ddim_steps,
         | 
| 288 | 
            -
                        num_samples,
         | 
| 289 | 
            -
                        shape,
         | 
| 290 | 
            -
                        cond,
         | 
| 291 | 
            -
                        verbose=False,
         | 
| 292 | 
            -
                        eta=eta,
         | 
| 293 | 
            -
                        unconditional_guidance_scale=scale,
         | 
| 294 | 
            -
                        unconditional_conditioning=un_cond)
         | 
| 295 | 
            -
             | 
| 296 | 
            -
                    if config.save_memory:
         | 
| 297 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 298 | 
            -
             | 
| 299 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 300 | 
            -
                    x_samples = (
         | 
| 301 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 302 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 303 | 
            -
             | 
| 304 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 305 | 
            -
                    return [detected_map] + results
         | 
| 306 |  | 
| 307 | 
             
                @torch.inference_mode()
         | 
| 308 | 
            -
                def  | 
| 309 | 
            -
             | 
| 310 | 
            -
             | 
| 311 | 
            -
                     | 
| 312 | 
            -
             | 
| 313 | 
            -
                     | 
| 314 | 
            -
                     | 
| 315 | 
            -
             | 
| 316 | 
            -
                     | 
| 317 | 
            -
                     | 
| 318 | 
            -
             | 
| 319 | 
            -
                     | 
| 320 | 
            -
             | 
| 321 | 
            -
                     | 
| 322 | 
            -
             | 
| 323 | 
            -
             | 
| 324 | 
            -
                         | 
| 325 | 
            -
                     | 
| 326 | 
            -
             | 
| 327 | 
            -
                     | 
| 328 | 
            -
                        self. | 
| 329 | 
            -
             | 
| 330 | 
            -
             | 
| 331 | 
            -
                         | 
| 332 | 
            -
                         | 
| 333 | 
            -
             | 
| 334 | 
            -
             | 
| 335 | 
            -
             | 
| 336 | 
            -
                     | 
| 337 | 
            -
             | 
| 338 | 
            -
             | 
| 339 | 
            -
             | 
| 340 | 
            -
             | 
| 341 | 
            -
                     | 
| 342 | 
            -
             | 
| 343 | 
            -
             | 
| 344 | 
            -
                     | 
| 345 | 
            -
             | 
| 346 | 
            -
             | 
| 347 | 
            -
                     | 
| 348 | 
            -
                         | 
| 349 | 
            -
                        num_samples,
         | 
| 350 | 
            -
                        shape,
         | 
| 351 | 
            -
                        cond,
         | 
| 352 | 
            -
                        verbose=False,
         | 
| 353 | 
            -
                        eta=eta,
         | 
| 354 | 
            -
                        unconditional_guidance_scale=scale,
         | 
| 355 | 
            -
                        unconditional_conditioning=un_cond)
         | 
| 356 | 
            -
             | 
| 357 | 
            -
                    if config.save_memory:
         | 
| 358 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 359 | 
            -
             | 
| 360 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 361 | 
            -
                    x_samples = (
         | 
| 362 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 363 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 364 | 
            -
             | 
| 365 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 366 | 
            -
                    return [255 - detected_map] + results
         | 
| 367 |  | 
| 368 | 
             
                @torch.inference_mode()
         | 
| 369 | 
            -
                def  | 
| 370 | 
            -
             | 
| 371 | 
            -
             | 
| 372 | 
            -
                     | 
| 373 | 
            -
             | 
| 374 | 
            -
                     | 
| 375 | 
            -
             | 
| 376 | 
            -
                     | 
| 377 | 
            -
             | 
| 378 | 
            -
                     | 
| 379 | 
            -
                     | 
| 380 | 
            -
             | 
| 381 | 
            -
                     | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
| 384 | 
            -
             | 
| 385 | 
            -
                     | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 388 | 
            -
             | 
| 389 | 
            -
             | 
| 390 | 
            -
                         | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
                         | 
| 394 | 
            -
             | 
| 395 | 
            -
             | 
| 396 | 
            -
             | 
| 397 | 
            -
             | 
| 398 | 
            -
             | 
| 399 | 
            -
                     | 
| 400 | 
            -
             | 
| 401 | 
            -
             | 
| 402 | 
            -
             | 
| 403 | 
            -
             | 
| 404 | 
            -
                     | 
| 405 | 
            -
             | 
| 406 | 
            -
                     | 
| 407 | 
            -
             | 
| 408 | 
            -
             | 
| 409 | 
            -
                    samples, intermediates = self.ddim_sampler.sample(
         | 
| 410 | 
            -
                        ddim_steps,
         | 
| 411 | 
            -
                        num_samples,
         | 
| 412 | 
            -
                        shape,
         | 
| 413 | 
            -
                        cond,
         | 
| 414 | 
            -
                        verbose=False,
         | 
| 415 | 
            -
                        eta=eta,
         | 
| 416 | 
            -
                        unconditional_guidance_scale=scale,
         | 
| 417 | 
            -
                        unconditional_conditioning=un_cond)
         | 
| 418 | 
            -
             | 
| 419 | 
            -
                    if config.save_memory:
         | 
| 420 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 421 | 
            -
             | 
| 422 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 423 | 
            -
                    x_samples = (
         | 
| 424 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 425 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 426 | 
            -
             | 
| 427 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 428 | 
            -
                    return [255 - detected_map] + results
         | 
| 429 |  | 
| 430 | 
             
                @torch.inference_mode()
         | 
| 431 | 
            -
                def  | 
| 432 | 
            -
             | 
| 433 | 
            -
             | 
| 434 | 
            -
                     | 
| 435 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 436 | 
             
                    input_image = HWC3(input_image)
         | 
| 437 | 
            -
                     | 
| 438 | 
            -
                     | 
| 439 | 
            -
                     | 
| 440 | 
            -
                    H, W | 
| 441 | 
            -
             | 
| 442 | 
            -
                    detected_map = cv2.resize(detected_map, (W, H),
         | 
| 443 | 
            -
                                              interpolation=cv2.INTER_LINEAR)
         | 
| 444 | 
            -
                    detected_map = nms(detected_map, 127, 3.0)
         | 
| 445 | 
            -
                    detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
         | 
| 446 | 
            -
                    detected_map[detected_map > 4] = 255
         | 
| 447 | 
            -
                    detected_map[detected_map < 255] = 0
         | 
| 448 | 
            -
             | 
| 449 | 
            -
                    control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
         | 
| 450 | 
            -
                    control = torch.stack([control for _ in range(num_samples)], dim=0)
         | 
| 451 | 
            -
                    control = einops.rearrange(control, 'b h w c -> b c h w').clone()
         | 
| 452 | 
            -
             | 
| 453 | 
            -
                    if seed == -1:
         | 
| 454 | 
            -
                        seed = random.randint(0, 65535)
         | 
| 455 | 
            -
                    seed_everything(seed)
         | 
| 456 | 
            -
             | 
| 457 | 
            -
                    if config.save_memory:
         | 
| 458 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 459 | 
            -
             | 
| 460 | 
            -
                    cond = {
         | 
| 461 | 
            -
                        'c_concat': [control],
         | 
| 462 | 
            -
                        'c_crossattn': [
         | 
| 463 | 
            -
                            self.model.get_learned_conditioning(
         | 
| 464 | 
            -
                                [prompt + ', ' + a_prompt] * num_samples)
         | 
| 465 | 
            -
                        ]
         | 
| 466 | 
            -
                    }
         | 
| 467 | 
            -
                    un_cond = {
         | 
| 468 | 
            -
                        'c_concat': [control],
         | 
| 469 | 
            -
                        'c_crossattn':
         | 
| 470 | 
            -
                        [self.model.get_learned_conditioning([n_prompt] * num_samples)]
         | 
| 471 | 
            -
                    }
         | 
| 472 | 
            -
                    shape = (4, H // 8, W // 8)
         | 
| 473 | 
            -
             | 
| 474 | 
            -
                    if config.save_memory:
         | 
| 475 | 
            -
                        self.model.low_vram_shift(is_diffusing=True)
         | 
| 476 | 
            -
             | 
| 477 | 
            -
                    samples, intermediates = self.ddim_sampler.sample(
         | 
| 478 | 
            -
                        ddim_steps,
         | 
| 479 | 
            -
                        num_samples,
         | 
| 480 | 
            -
                        shape,
         | 
| 481 | 
            -
                        cond,
         | 
| 482 | 
            -
                        verbose=False,
         | 
| 483 | 
            -
                        eta=eta,
         | 
| 484 | 
            -
                        unconditional_guidance_scale=scale,
         | 
| 485 | 
            -
                        unconditional_conditioning=un_cond)
         | 
| 486 | 
            -
             | 
| 487 | 
            -
                    if config.save_memory:
         | 
| 488 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 489 | 
            -
             | 
| 490 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 491 | 
            -
                    x_samples = (
         | 
| 492 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 493 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 494 | 
            -
             | 
| 495 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 496 | 
            -
                    return [255 - detected_map] + results
         | 
| 497 |  | 
| 498 | 
            -
             | 
| 499 | 
            -
             | 
| 500 | 
            -
             | 
| 501 | 
            -
             | 
| 502 | 
            -
                     | 
|  | |
| 503 |  | 
| 504 | 
            -
                     | 
| 505 | 
            -
                    detected_map, _ = apply_openpose(
         | 
| 506 | 
            -
                        resize_image(input_image, detect_resolution))
         | 
| 507 | 
            -
                    detected_map = HWC3(detected_map)
         | 
| 508 | 
            -
                    img = resize_image(input_image, image_resolution)
         | 
| 509 | 
            -
                    H, W, C = img.shape
         | 
| 510 | 
            -
             | 
| 511 | 
            -
                    detected_map = cv2.resize(detected_map, (W, H),
         | 
| 512 | 
            -
                                              interpolation=cv2.INTER_NEAREST)
         | 
| 513 | 
            -
             | 
| 514 | 
            -
                    control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
         | 
| 515 | 
            -
                    control = torch.stack([control for _ in range(num_samples)], dim=0)
         | 
| 516 | 
            -
                    control = einops.rearrange(control, 'b h w c -> b c h w').clone()
         | 
| 517 | 
            -
             | 
| 518 | 
            -
                    if seed == -1:
         | 
| 519 | 
            -
                        seed = random.randint(0, 65535)
         | 
| 520 | 
            -
                    seed_everything(seed)
         | 
| 521 | 
            -
             | 
| 522 | 
            -
                    if config.save_memory:
         | 
| 523 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 524 | 
            -
             | 
| 525 | 
            -
                    cond = {
         | 
| 526 | 
            -
                        'c_concat': [control],
         | 
| 527 | 
            -
                        'c_crossattn': [
         | 
| 528 | 
            -
                            self.model.get_learned_conditioning(
         | 
| 529 | 
            -
                                [prompt + ', ' + a_prompt] * num_samples)
         | 
| 530 | 
            -
                        ]
         | 
| 531 | 
            -
                    }
         | 
| 532 | 
            -
                    un_cond = {
         | 
| 533 | 
            -
                        'c_concat': [control],
         | 
| 534 | 
            -
                        'c_crossattn':
         | 
| 535 | 
            -
                        [self.model.get_learned_conditioning([n_prompt] * num_samples)]
         | 
| 536 | 
            -
                    }
         | 
| 537 | 
            -
                    shape = (4, H // 8, W // 8)
         | 
| 538 | 
            -
             | 
| 539 | 
            -
                    if config.save_memory:
         | 
| 540 | 
            -
                        self.model.low_vram_shift(is_diffusing=True)
         | 
| 541 | 
            -
             | 
| 542 | 
            -
                    samples, intermediates = self.ddim_sampler.sample(
         | 
| 543 | 
            -
                        ddim_steps,
         | 
| 544 | 
            -
                        num_samples,
         | 
| 545 | 
            -
                        shape,
         | 
| 546 | 
            -
                        cond,
         | 
| 547 | 
            -
                        verbose=False,
         | 
| 548 | 
            -
                        eta=eta,
         | 
| 549 | 
            -
                        unconditional_guidance_scale=scale,
         | 
| 550 | 
            -
                        unconditional_conditioning=un_cond)
         | 
| 551 | 
            -
             | 
| 552 | 
            -
                    if config.save_memory:
         | 
| 553 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 554 | 
            -
             | 
| 555 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 556 | 
            -
                    x_samples = (
         | 
| 557 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 558 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 559 | 
            -
             | 
| 560 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 561 | 
            -
                    return [detected_map] + results
         | 
| 562 |  | 
| 563 | 
            -
             | 
| 564 | 
            -
             | 
| 565 | 
            -
                                image_resolution, detect_resolution, ddim_steps, scale,
         | 
| 566 | 
            -
                                seed, eta):
         | 
| 567 | 
            -
                    self.load_weight('seg')
         | 
| 568 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 569 | 
             
                    input_image = HWC3(input_image)
         | 
| 570 | 
            -
                     | 
| 571 | 
            -
                         | 
| 572 | 
            -
             | 
| 573 | 
            -
             | 
| 574 | 
            -
             | 
| 575 | 
            -
             | 
| 576 | 
            -
             | 
| 577 | 
            -
             | 
| 578 | 
            -
                     | 
| 579 | 
            -
             | 
| 580 | 
            -
             | 
| 581 | 
            -
             | 
| 582 | 
            -
             | 
| 583 | 
            -
                        seed = random.randint(0, 65535)
         | 
| 584 | 
            -
                    seed_everything(seed)
         | 
| 585 | 
            -
             | 
| 586 | 
            -
                    if config.save_memory:
         | 
| 587 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 588 | 
            -
             | 
| 589 | 
            -
                    cond = {
         | 
| 590 | 
            -
                        'c_concat': [control],
         | 
| 591 | 
            -
                        'c_crossattn': [
         | 
| 592 | 
            -
                            self.model.get_learned_conditioning(
         | 
| 593 | 
            -
                                [prompt + ', ' + a_prompt] * num_samples)
         | 
| 594 | 
            -
                        ]
         | 
| 595 | 
            -
                    }
         | 
| 596 | 
            -
                    un_cond = {
         | 
| 597 | 
            -
                        'c_concat': [control],
         | 
| 598 | 
            -
                        'c_crossattn':
         | 
| 599 | 
            -
                        [self.model.get_learned_conditioning([n_prompt] * num_samples)]
         | 
| 600 | 
            -
                    }
         | 
| 601 | 
            -
                    shape = (4, H // 8, W // 8)
         | 
| 602 | 
            -
             | 
| 603 | 
            -
                    if config.save_memory:
         | 
| 604 | 
            -
                        self.model.low_vram_shift(is_diffusing=True)
         | 
| 605 | 
            -
             | 
| 606 | 
            -
                    samples, intermediates = self.ddim_sampler.sample(
         | 
| 607 | 
            -
                        ddim_steps,
         | 
| 608 | 
            -
                        num_samples,
         | 
| 609 | 
            -
                        shape,
         | 
| 610 | 
            -
                        cond,
         | 
| 611 | 
            -
                        verbose=False,
         | 
| 612 | 
            -
                        eta=eta,
         | 
| 613 | 
            -
                        unconditional_guidance_scale=scale,
         | 
| 614 | 
            -
                        unconditional_conditioning=un_cond)
         | 
| 615 | 
            -
             | 
| 616 | 
            -
                    if config.save_memory:
         | 
| 617 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 618 | 
            -
             | 
| 619 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 620 | 
            -
                    x_samples = (
         | 
| 621 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 622 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 623 | 
            -
             | 
| 624 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 625 | 
            -
                    return [detected_map] + results
         | 
| 626 |  | 
| 627 | 
             
                @torch.inference_mode()
         | 
| 628 | 
            -
                def  | 
| 629 | 
            -
             | 
| 630 | 
            -
             | 
| 631 | 
            -
                     | 
| 632 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 633 | 
             
                    input_image = HWC3(input_image)
         | 
| 634 | 
            -
                     | 
| 635 | 
            -
                         | 
| 636 | 
            -
             | 
| 637 | 
            -
             | 
| 638 | 
            -
             | 
| 639 | 
            -
             | 
| 640 | 
            -
             | 
| 641 | 
            -
             | 
| 642 | 
            -
             | 
| 643 | 
            -
                     | 
| 644 | 
            -
             | 
| 645 | 
            -
                    control = einops.rearrange(control, 'b h w c -> b c h w').clone()
         | 
| 646 | 
            -
             | 
| 647 | 
            -
                    if seed == -1:
         | 
| 648 | 
            -
                        seed = random.randint(0, 65535)
         | 
| 649 | 
            -
                    seed_everything(seed)
         | 
| 650 | 
            -
             | 
| 651 | 
            -
                    if config.save_memory:
         | 
| 652 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 653 | 
            -
             | 
| 654 | 
            -
                    cond = {
         | 
| 655 | 
            -
                        'c_concat': [control],
         | 
| 656 | 
            -
                        'c_crossattn': [
         | 
| 657 | 
            -
                            self.model.get_learned_conditioning(
         | 
| 658 | 
            -
                                [prompt + ', ' + a_prompt] * num_samples)
         | 
| 659 | 
            -
                        ]
         | 
| 660 | 
            -
                    }
         | 
| 661 | 
            -
                    un_cond = {
         | 
| 662 | 
            -
                        'c_concat': [control],
         | 
| 663 | 
            -
                        'c_crossattn':
         | 
| 664 | 
            -
                        [self.model.get_learned_conditioning([n_prompt] * num_samples)]
         | 
| 665 | 
            -
                    }
         | 
| 666 | 
            -
                    shape = (4, H // 8, W // 8)
         | 
| 667 | 
            -
             | 
| 668 | 
            -
                    if config.save_memory:
         | 
| 669 | 
            -
                        self.model.low_vram_shift(is_diffusing=True)
         | 
| 670 | 
            -
             | 
| 671 | 
            -
                    samples, intermediates = self.ddim_sampler.sample(
         | 
| 672 | 
            -
                        ddim_steps,
         | 
| 673 | 
            -
                        num_samples,
         | 
| 674 | 
            -
                        shape,
         | 
| 675 | 
            -
                        cond,
         | 
| 676 | 
            -
                        verbose=False,
         | 
| 677 | 
            -
                        eta=eta,
         | 
| 678 | 
            -
                        unconditional_guidance_scale=scale,
         | 
| 679 | 
            -
                        unconditional_conditioning=un_cond)
         | 
| 680 | 
            -
             | 
| 681 | 
            -
                    if config.save_memory:
         | 
| 682 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 683 | 
            -
             | 
| 684 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 685 | 
            -
                    x_samples = (
         | 
| 686 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 687 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 688 | 
            -
             | 
| 689 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 690 | 
            -
                    return [detected_map] + results
         | 
| 691 |  | 
| 692 | 
             
                @torch.inference_mode()
         | 
| 693 | 
            -
                def  | 
| 694 | 
            -
             | 
| 695 | 
            -
             | 
| 696 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 697 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 698 | 
             
                    input_image = HWC3(input_image)
         | 
| 699 | 
            -
                     | 
| 700 | 
            -
             | 
| 701 | 
            -
             | 
| 702 | 
            -
             | 
| 703 | 
            -
             | 
| 704 | 
            -
             | 
| 705 | 
            -
             | 
| 706 | 
            -
             | 
| 707 | 
            -
             | 
| 708 | 
            -
             | 
| 709 | 
            -
             | 
| 710 | 
            -
             | 
| 711 | 
            -
             | 
| 712 | 
            -
             | 
| 713 | 
            -
             | 
| 714 | 
            -
             | 
| 715 | 
            -
             | 
| 716 | 
            -
                     | 
| 717 | 
            -
             | 
| 718 | 
            -
                     | 
| 719 | 
            -
             | 
| 720 | 
            -
             | 
| 721 | 
            -
                     | 
| 722 | 
            -
             | 
| 723 | 
            -
             | 
| 724 | 
            -
             | 
| 725 | 
            -
             | 
| 726 | 
            -
             | 
| 727 | 
            -
                     | 
| 728 | 
            -
             | 
| 729 | 
            -
             | 
| 730 | 
            -
                         | 
| 731 | 
            -
                         | 
| 732 | 
            -
             | 
| 733 | 
            -
             | 
| 734 | 
            -
             | 
| 735 | 
            -
                     | 
| 736 | 
            -
             | 
| 737 | 
            -
             | 
| 738 | 
            -
             | 
| 739 | 
            -
                         | 
| 740 | 
            -
                         | 
| 741 | 
            -
                         | 
| 742 | 
            -
                         | 
| 743 | 
            -
                         | 
| 744 | 
            -
                         | 
| 745 | 
            -
             | 
| 746 | 
            -
             | 
| 747 | 
            -
             | 
| 748 | 
            -
                    if config.save_memory:
         | 
| 749 | 
            -
                        self.model.low_vram_shift(is_diffusing=False)
         | 
| 750 | 
            -
             | 
| 751 | 
            -
                    x_samples = self.model.decode_first_stage(samples)
         | 
| 752 | 
            -
                    x_samples = (
         | 
| 753 | 
            -
                        einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
         | 
| 754 | 
            -
                        127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
         | 
| 755 | 
            -
             | 
| 756 | 
            -
                    results = [x_samples[i] for i in range(num_samples)]
         | 
| 757 | 
            -
                    return [detected_map] + results
         | 
|  | |
| 3 | 
             
            from __future__ import annotations
         | 
| 4 |  | 
| 5 | 
             
            import pathlib
         | 
|  | |
|  | |
|  | |
| 6 | 
             
            import sys
         | 
| 7 |  | 
| 8 | 
             
            import cv2
         | 
|  | |
| 9 | 
             
            import numpy as np
         | 
| 10 | 
            +
            import PIL.Image
         | 
| 11 | 
             
            import torch
         | 
| 12 | 
            +
            from diffusers import (ControlNetModel, DiffusionPipeline,
         | 
| 13 | 
            +
                                   StableDiffusionControlNetPipeline,
         | 
| 14 | 
            +
                                   UniPCMultistepScheduler)
         | 
| 15 |  | 
| 16 | 
            +
            repo_dir = pathlib.Path(__file__).parent
         | 
| 17 | 
            +
            submodule_dir = repo_dir / 'ControlNet'
         | 
| 18 | 
            +
            sys.path.append(submodule_dir.as_posix())
         | 
| 19 |  | 
|  | |
| 20 | 
             
            from annotator.canny import apply_canny
         | 
| 21 | 
             
            from annotator.hed import apply_hed, nms
         | 
| 22 | 
             
            from annotator.midas import apply_midas
         | 
|  | |
| 24 | 
             
            from annotator.openpose import apply_openpose
         | 
| 25 | 
             
            from annotator.uniformer import apply_uniformer
         | 
| 26 | 
             
            from annotator.util import HWC3, resize_image
         | 
|  | |
|  | |
| 27 | 
             
            from share import *
         | 
| 28 |  | 
| 29 | 
            +
            CONTROLNET_MODEL_IDS = {
         | 
| 30 | 
            +
                'canny': 'lllyasviel/sd-controlnet-canny',
         | 
| 31 | 
            +
                'hough': 'lllyasviel/sd-controlnet-mlsd',
         | 
| 32 | 
            +
                'hed': 'lllyasviel/sd-controlnet-hed',
         | 
| 33 | 
            +
                'scribble': 'lllyasviel/sd-controlnet-scribble',
         | 
| 34 | 
            +
                'pose': 'lllyasviel/sd-controlnet-openpose',
         | 
| 35 | 
            +
                'seg': 'lllyasviel/sd-controlnet-seg',
         | 
| 36 | 
            +
                'depth': 'lllyasviel/sd-controlnet-depth',
         | 
| 37 | 
            +
                'normal': 'lllyasviel/sd-controlnet-normal',
         | 
| 38 | 
             
            }
         | 
|  | |
| 39 |  | 
| 40 | 
            +
             | 
| 41 | 
            +
            def download_all_controlnet_weights() -> None:
         | 
| 42 | 
            +
                for model_id in CONTROLNET_MODEL_IDS.values():
         | 
| 43 | 
            +
                    ControlNetModel.from_pretrained(model_id)
         | 
| 44 |  | 
| 45 |  | 
| 46 | 
             
            class Model:
         | 
| 47 | 
             
                def __init__(self,
         | 
| 48 | 
            +
                             base_model_id: str = 'runwayml/stable-diffusion-v1-5',
         | 
| 49 | 
            +
                             task_name: str = 'canny'):
         | 
| 50 | 
            +
                    self.base_model_id = ''
         | 
|  | |
|  | |
|  | |
| 51 | 
             
                    self.task_name = ''
         | 
| 52 | 
            +
                    self.pipe = self.load_pipe(base_model_id, task_name)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
         | 
| 55 | 
            +
                    if base_model_id == self.base_model_id and task_name == self.task_name:
         | 
| 56 | 
            +
                        return self.pipe
         | 
| 57 | 
            +
                    model_id = CONTROLNET_MODEL_IDS[task_name]
         | 
| 58 | 
            +
                    controlnet = ControlNetModel.from_pretrained(model_id,
         | 
| 59 | 
            +
                                                                 torch_dtype=torch.float16)
         | 
| 60 | 
            +
                    pipe = StableDiffusionControlNetPipeline.from_pretrained(
         | 
| 61 | 
            +
                        base_model_id,
         | 
| 62 | 
            +
                        safety_checker=None,
         | 
| 63 | 
            +
                        controlnet=controlnet,
         | 
| 64 | 
            +
                        torch_dtype=torch.float16)
         | 
| 65 | 
            +
                    pipe.scheduler = UniPCMultistepScheduler.from_config(
         | 
| 66 | 
            +
                        pipe.scheduler.config)
         | 
| 67 | 
            +
                    pipe.enable_xformers_memory_efficient_attention()
         | 
| 68 | 
            +
                    pipe.enable_model_cpu_offload()
         | 
| 69 | 
            +
                    self.base_model_id = base_model_id
         | 
| 70 | 
            +
                    self.task_name = task_name
         | 
| 71 | 
            +
                    return pipe
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def set_base_model(self, base_model_id: str) -> str:
         | 
| 74 | 
            +
                    self.pipe = self.load_pipe(base_model_id, self.task_name)
         | 
| 75 | 
            +
                    return self.base_model_id
         | 
| 76 |  | 
| 77 | 
            +
                def load_controlnet_weight(self, task_name: str) -> None:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 78 | 
             
                    if task_name == self.task_name:
         | 
| 79 | 
             
                        return
         | 
| 80 | 
            +
                    model_id = CONTROLNET_MODEL_IDS[task_name]
         | 
| 81 | 
            +
                    controlnet = ControlNetModel.from_pretrained(model_id,
         | 
| 82 | 
            +
                                                                 torch_dtype=torch.float16)
         | 
| 83 | 
            +
                    from accelerate import cpu_offload_with_hook
         | 
| 84 | 
            +
                    cpu_offload_with_hook(controlnet, torch.device('cuda:0'))
         | 
| 85 | 
            +
                    self.pipe.controlnet = controlnet
         | 
| 86 | 
             
                    self.task_name = task_name
         | 
| 87 |  | 
| 88 | 
            +
                def get_prompt(self, prompt: str, additional_prompt: str) -> str:
         | 
| 89 | 
            +
                    if not prompt:
         | 
| 90 | 
            +
                        prompt = additional_prompt
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        prompt = f'{prompt}, {additional_prompt}'
         | 
| 93 | 
            +
                    return prompt
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def run_pipe(
         | 
| 96 | 
            +
                    self,
         | 
| 97 | 
            +
                    prompt: str,
         | 
| 98 | 
            +
                    negative_prompt: str,
         | 
| 99 | 
            +
                    control_image: PIL.Image.Image,
         | 
| 100 | 
            +
                    num_images: int,
         | 
| 101 | 
            +
                    num_steps: int,
         | 
| 102 | 
            +
                    guidance_scale: float,
         | 
| 103 | 
            +
                    seed: int,
         | 
| 104 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 105 | 
            +
                    generator = torch.Generator().manual_seed(seed)
         | 
| 106 | 
            +
                    return self.pipe(prompt=prompt,
         | 
| 107 | 
            +
                                     negative_prompt=negative_prompt,
         | 
| 108 | 
            +
                                     guidance_scale=guidance_scale,
         | 
| 109 | 
            +
                                     num_images_per_prompt=num_images,
         | 
| 110 | 
            +
                                     num_inference_steps=num_steps,
         | 
| 111 | 
            +
                                     generator=generator,
         | 
| 112 | 
            +
                                     image=control_image).images
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                @staticmethod
         | 
| 115 | 
            +
                def preprocess_canny(
         | 
| 116 | 
            +
                    input_image: np.ndarray,
         | 
| 117 | 
            +
                    image_resolution: int,
         | 
| 118 | 
            +
                    low_threshold: int,
         | 
| 119 | 
            +
                    high_threshold: int,
         | 
| 120 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 121 | 
            +
                    image = resize_image(HWC3(input_image), image_resolution)
         | 
| 122 | 
            +
                    control_image = apply_canny(image, low_threshold, high_threshold)
         | 
| 123 | 
            +
                    control_image = HWC3(control_image)
         | 
| 124 | 
            +
                    vis_control_image = 255 - control_image
         | 
| 125 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 126 | 
            +
                        vis_control_image)
         | 
| 127 |  | 
| 128 | 
             
                @torch.inference_mode()
         | 
| 129 | 
            +
                def process_canny(
         | 
| 130 | 
            +
                    self,
         | 
| 131 | 
            +
                    input_image: np.ndarray,
         | 
| 132 | 
            +
                    prompt: str,
         | 
| 133 | 
            +
                    additional_prompt: str,
         | 
| 134 | 
            +
                    negative_prompt: str,
         | 
| 135 | 
            +
                    num_images: int,
         | 
| 136 | 
            +
                    image_resolution: int,
         | 
| 137 | 
            +
                    num_steps: int,
         | 
| 138 | 
            +
                    guidance_scale: float,
         | 
| 139 | 
            +
                    seed: int,
         | 
| 140 | 
            +
                    low_threshold: int,
         | 
| 141 | 
            +
                    high_threshold: int,
         | 
| 142 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 143 | 
            +
                    control_image, vis_control_image = self.preprocess_canny(
         | 
| 144 | 
            +
                        input_image=input_image,
         | 
| 145 | 
            +
                        image_resolution=image_resolution,
         | 
| 146 | 
            +
                        low_threshold=low_threshold,
         | 
| 147 | 
            +
                        high_threshold=high_threshold,
         | 
| 148 | 
            +
                    )
         | 
| 149 | 
            +
                    self.load_controlnet_weight('canny')
         | 
| 150 | 
            +
                    results = self.run_pipe(
         | 
| 151 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 152 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 153 | 
            +
                        control_image=control_image,
         | 
| 154 | 
            +
                        num_images=num_images,
         | 
| 155 | 
            +
                        num_steps=num_steps,
         | 
| 156 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 157 | 
            +
                        seed=seed,
         | 
| 158 | 
            +
                    )
         | 
| 159 | 
            +
                    return [vis_control_image] + results
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                @staticmethod
         | 
| 162 | 
            +
                def preprocess_hough(
         | 
| 163 | 
            +
                    input_image: np.ndarray,
         | 
| 164 | 
            +
                    image_resolution: int,
         | 
| 165 | 
            +
                    detect_resolution: int,
         | 
| 166 | 
            +
                    value_threshold: float,
         | 
| 167 | 
            +
                    distance_threshold: float,
         | 
| 168 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 169 | 
            +
                    input_image = HWC3(input_image)
         | 
| 170 | 
            +
                    control_image = apply_mlsd(
         | 
| 171 | 
            +
                        resize_image(input_image, detect_resolution), value_threshold,
         | 
| 172 | 
            +
                        distance_threshold)
         | 
| 173 | 
            +
                    control_image = HWC3(control_image)
         | 
| 174 | 
            +
                    image = resize_image(input_image, image_resolution)
         | 
| 175 | 
            +
                    H, W = image.shape[:2]
         | 
| 176 | 
            +
                    control_image = cv2.resize(control_image, (W, H),
         | 
| 177 | 
            +
                                               interpolation=cv2.INTER_NEAREST)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 178 |  | 
| 179 | 
            +
                    vis_control_image = 255 - cv2.dilate(
         | 
| 180 | 
            +
                        control_image, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
         | 
|  | |
|  | |
|  | |
|  | |
| 181 |  | 
| 182 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 183 | 
            +
                        vis_control_image)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 184 |  | 
| 185 | 
             
                @torch.inference_mode()
         | 
| 186 | 
            +
                def process_hough(
         | 
| 187 | 
            +
                    self,
         | 
| 188 | 
            +
                    input_image: np.ndarray,
         | 
| 189 | 
            +
                    prompt: str,
         | 
| 190 | 
            +
                    additional_prompt: str,
         | 
| 191 | 
            +
                    negative_prompt: str,
         | 
| 192 | 
            +
                    num_images: int,
         | 
| 193 | 
            +
                    image_resolution: int,
         | 
| 194 | 
            +
                    detect_resolution: int,
         | 
| 195 | 
            +
                    num_steps: int,
         | 
| 196 | 
            +
                    guidance_scale: float,
         | 
| 197 | 
            +
                    seed: int,
         | 
| 198 | 
            +
                    value_threshold: float,
         | 
| 199 | 
            +
                    distance_threshold: float,
         | 
| 200 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 201 | 
            +
                    control_image, vis_control_image = self.preprocess_hough(
         | 
| 202 | 
            +
                        input_image=input_image,
         | 
| 203 | 
            +
                        image_resolution=image_resolution,
         | 
| 204 | 
            +
                        detect_resolution=detect_resolution,
         | 
| 205 | 
            +
                        value_threshold=value_threshold,
         | 
| 206 | 
            +
                        distance_threshold=distance_threshold,
         | 
| 207 | 
            +
                    )
         | 
| 208 | 
            +
                    self.load_controlnet_weight('hough')
         | 
| 209 | 
            +
                    results = self.run_pipe(
         | 
| 210 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 211 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 212 | 
            +
                        control_image=control_image,
         | 
| 213 | 
            +
                        num_images=num_images,
         | 
| 214 | 
            +
                        num_steps=num_steps,
         | 
| 215 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 216 | 
            +
                        seed=seed,
         | 
| 217 | 
            +
                    )
         | 
| 218 | 
            +
                    return [vis_control_image] + results
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                @staticmethod
         | 
| 221 | 
            +
                def preprocess_hed(
         | 
| 222 | 
            +
                    input_image: np.ndarray,
         | 
| 223 | 
            +
                    image_resolution: int,
         | 
| 224 | 
            +
                    detect_resolution: int,
         | 
| 225 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 226 | 
             
                    input_image = HWC3(input_image)
         | 
| 227 | 
            +
                    control_image = apply_hed(resize_image(input_image, detect_resolution))
         | 
| 228 | 
            +
                    control_image = HWC3(control_image)
         | 
| 229 | 
            +
                    image = resize_image(input_image, image_resolution)
         | 
| 230 | 
            +
                    H, W = image.shape[:2]
         | 
| 231 | 
            +
                    control_image = cv2.resize(control_image, (W, H),
         | 
| 232 | 
            +
                                               interpolation=cv2.INTER_LINEAR)
         | 
| 233 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 234 | 
            +
                        control_image)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 235 |  | 
| 236 | 
             
                @torch.inference_mode()
         | 
| 237 | 
            +
                def process_hed(
         | 
| 238 | 
            +
                    self,
         | 
| 239 | 
            +
                    input_image: np.ndarray,
         | 
| 240 | 
            +
                    prompt: str,
         | 
| 241 | 
            +
                    additional_prompt: str,
         | 
| 242 | 
            +
                    negative_prompt: str,
         | 
| 243 | 
            +
                    num_images: int,
         | 
| 244 | 
            +
                    image_resolution: int,
         | 
| 245 | 
            +
                    detect_resolution: int,
         | 
| 246 | 
            +
                    num_steps: int,
         | 
| 247 | 
            +
                    guidance_scale: float,
         | 
| 248 | 
            +
                    seed: int,
         | 
| 249 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 250 | 
            +
                    control_image, vis_control_image = self.preprocess_hed(
         | 
| 251 | 
            +
                        input_image=input_image,
         | 
| 252 | 
            +
                        image_resolution=image_resolution,
         | 
| 253 | 
            +
                        detect_resolution=detect_resolution,
         | 
| 254 | 
            +
                    )
         | 
| 255 | 
            +
                    self.load_controlnet_weight('hed')
         | 
| 256 | 
            +
                    results = self.run_pipe(
         | 
| 257 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 258 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 259 | 
            +
                        control_image=control_image,
         | 
| 260 | 
            +
                        num_images=num_images,
         | 
| 261 | 
            +
                        num_steps=num_steps,
         | 
| 262 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 263 | 
            +
                        seed=seed,
         | 
| 264 | 
            +
                    )
         | 
| 265 | 
            +
                    return [vis_control_image] + results
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                @staticmethod
         | 
| 268 | 
            +
                def preprocess_scribble(
         | 
| 269 | 
            +
                    input_image: np.ndarray,
         | 
| 270 | 
            +
                    image_resolution: int,
         | 
| 271 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 272 | 
            +
                    image = resize_image(HWC3(input_image), image_resolution)
         | 
| 273 | 
            +
                    control_image = np.zeros_like(image, dtype=np.uint8)
         | 
| 274 | 
            +
                    control_image[np.min(image, axis=2) < 127] = 255
         | 
| 275 | 
            +
                    vis_control_image = 255 - control_image
         | 
| 276 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 277 | 
            +
                        vis_control_image)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 278 |  | 
| 279 | 
             
                @torch.inference_mode()
         | 
| 280 | 
            +
                def process_scribble(
         | 
| 281 | 
            +
                    self,
         | 
| 282 | 
            +
                    input_image: np.ndarray,
         | 
| 283 | 
            +
                    prompt: str,
         | 
| 284 | 
            +
                    additional_prompt: str,
         | 
| 285 | 
            +
                    negative_prompt: str,
         | 
| 286 | 
            +
                    num_images: int,
         | 
| 287 | 
            +
                    image_resolution: int,
         | 
| 288 | 
            +
                    num_steps: int,
         | 
| 289 | 
            +
                    guidance_scale: float,
         | 
| 290 | 
            +
                    seed: int,
         | 
| 291 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 292 | 
            +
                    control_image, vis_control_image = self.preprocess_scribble(
         | 
| 293 | 
            +
                        input_image=input_image,
         | 
| 294 | 
            +
                        image_resolution=image_resolution,
         | 
| 295 | 
            +
                    )
         | 
| 296 | 
            +
                    self.load_controlnet_weight('scribble')
         | 
| 297 | 
            +
                    results = self.run_pipe(
         | 
| 298 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 299 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 300 | 
            +
                        control_image=control_image,
         | 
| 301 | 
            +
                        num_images=num_images,
         | 
| 302 | 
            +
                        num_steps=num_steps,
         | 
| 303 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 304 | 
            +
                        seed=seed,
         | 
| 305 | 
            +
                    )
         | 
| 306 | 
            +
                    return [vis_control_image] + results
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                @staticmethod
         | 
| 309 | 
            +
                def preprocess_scribble_interactive(
         | 
| 310 | 
            +
                    input_image: np.ndarray,
         | 
| 311 | 
            +
                    image_resolution: int,
         | 
| 312 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 313 | 
            +
                    image = resize_image(HWC3(input_image['mask'][:, :, 0]),
         | 
| 314 | 
            +
                                         image_resolution)
         | 
| 315 | 
            +
                    control_image = np.zeros_like(image, dtype=np.uint8)
         | 
| 316 | 
            +
                    control_image[np.min(image, axis=2) > 127] = 255
         | 
| 317 | 
            +
                    vis_control_image = 255 - control_image
         | 
| 318 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 319 | 
            +
                        vis_control_image)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 320 |  | 
| 321 | 
             
                @torch.inference_mode()
         | 
| 322 | 
            +
                def process_scribble_interactive(
         | 
| 323 | 
            +
                    self,
         | 
| 324 | 
            +
                    input_image: np.ndarray,
         | 
| 325 | 
            +
                    prompt: str,
         | 
| 326 | 
            +
                    additional_prompt: str,
         | 
| 327 | 
            +
                    negative_prompt: str,
         | 
| 328 | 
            +
                    num_images: int,
         | 
| 329 | 
            +
                    image_resolution: int,
         | 
| 330 | 
            +
                    num_steps: int,
         | 
| 331 | 
            +
                    guidance_scale: float,
         | 
| 332 | 
            +
                    seed: int,
         | 
| 333 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 334 | 
            +
                    control_image, vis_control_image = self.preprocess_scribble_interactive(
         | 
| 335 | 
            +
                        input_image=input_image,
         | 
| 336 | 
            +
                        image_resolution=image_resolution,
         | 
| 337 | 
            +
                    )
         | 
| 338 | 
            +
                    self.load_controlnet_weight('scribble')
         | 
| 339 | 
            +
                    results = self.run_pipe(
         | 
| 340 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 341 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 342 | 
            +
                        control_image=control_image,
         | 
| 343 | 
            +
                        num_images=num_images,
         | 
| 344 | 
            +
                        num_steps=num_steps,
         | 
| 345 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 346 | 
            +
                        seed=seed,
         | 
| 347 | 
            +
                    )
         | 
| 348 | 
            +
                    return [vis_control_image] + results
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                @staticmethod
         | 
| 351 | 
            +
                def preprocess_fake_scribble(
         | 
| 352 | 
            +
                    input_image: np.ndarray,
         | 
| 353 | 
            +
                    image_resolution: int,
         | 
| 354 | 
            +
                    detect_resolution: int,
         | 
| 355 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 356 | 
             
                    input_image = HWC3(input_image)
         | 
| 357 | 
            +
                    control_image = apply_hed(resize_image(input_image, detect_resolution))
         | 
| 358 | 
            +
                    control_image = HWC3(control_image)
         | 
| 359 | 
            +
                    image = resize_image(input_image, image_resolution)
         | 
| 360 | 
            +
                    H, W = image.shape[:2]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 361 |  | 
| 362 | 
            +
                    control_image = cv2.resize(control_image, (W, H),
         | 
| 363 | 
            +
                                               interpolation=cv2.INTER_LINEAR)
         | 
| 364 | 
            +
                    control_image = nms(control_image, 127, 3.0)
         | 
| 365 | 
            +
                    control_image = cv2.GaussianBlur(control_image, (0, 0), 3.0)
         | 
| 366 | 
            +
                    control_image[control_image > 4] = 255
         | 
| 367 | 
            +
                    control_image[control_image < 255] = 0
         | 
| 368 |  | 
| 369 | 
            +
                    vis_control_image = 255 - control_image
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 370 |  | 
| 371 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 372 | 
            +
                        vis_control_image)
         | 
|  | |
|  | |
|  | |
| 373 |  | 
| 374 | 
            +
                @torch.inference_mode()
         | 
| 375 | 
            +
                def process_fake_scribble(
         | 
| 376 | 
            +
                    self,
         | 
| 377 | 
            +
                    input_image: np.ndarray,
         | 
| 378 | 
            +
                    prompt: str,
         | 
| 379 | 
            +
                    additional_prompt: str,
         | 
| 380 | 
            +
                    negative_prompt: str,
         | 
| 381 | 
            +
                    num_images: int,
         | 
| 382 | 
            +
                    image_resolution: int,
         | 
| 383 | 
            +
                    detect_resolution: int,
         | 
| 384 | 
            +
                    num_steps: int,
         | 
| 385 | 
            +
                    guidance_scale: float,
         | 
| 386 | 
            +
                    seed: int,
         | 
| 387 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 388 | 
            +
                    control_image, vis_control_image = self.preprocess_fake_scribble(
         | 
| 389 | 
            +
                        input_image=input_image,
         | 
| 390 | 
            +
                        image_resolution=image_resolution,
         | 
| 391 | 
            +
                        detect_resolution=detect_resolution,
         | 
| 392 | 
            +
                    )
         | 
| 393 | 
            +
                    self.load_controlnet_weight('scribble')
         | 
| 394 | 
            +
                    results = self.run_pipe(
         | 
| 395 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 396 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 397 | 
            +
                        control_image=control_image,
         | 
| 398 | 
            +
                        num_images=num_images,
         | 
| 399 | 
            +
                        num_steps=num_steps,
         | 
| 400 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 401 | 
            +
                        seed=seed,
         | 
| 402 | 
            +
                    )
         | 
| 403 | 
            +
                    return [vis_control_image] + results
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                @staticmethod
         | 
| 406 | 
            +
                def preprocess_pose(
         | 
| 407 | 
            +
                    input_image: np.ndarray,
         | 
| 408 | 
            +
                    image_resolution: int,
         | 
| 409 | 
            +
                    detect_resolution: int,
         | 
| 410 | 
            +
                    is_pose_image: bool,
         | 
| 411 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 412 | 
             
                    input_image = HWC3(input_image)
         | 
| 413 | 
            +
                    if not is_pose_image:
         | 
| 414 | 
            +
                        control_image, _ = apply_openpose(
         | 
| 415 | 
            +
                            resize_image(input_image, detect_resolution))
         | 
| 416 | 
            +
                        control_image = HWC3(control_image)
         | 
| 417 | 
            +
                        image = resize_image(input_image, image_resolution)
         | 
| 418 | 
            +
                        H, W = image.shape[:2]
         | 
| 419 | 
            +
                        control_image = cv2.resize(control_image, (W, H),
         | 
| 420 | 
            +
                                                   interpolation=cv2.INTER_NEAREST)
         | 
| 421 | 
            +
                    else:
         | 
| 422 | 
            +
                        control_image = input_image
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 425 | 
            +
                        control_image)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 426 |  | 
| 427 | 
             
                @torch.inference_mode()
         | 
| 428 | 
            +
                def process_pose(
         | 
| 429 | 
            +
                    self,
         | 
| 430 | 
            +
                    input_image: np.ndarray,
         | 
| 431 | 
            +
                    prompt: str,
         | 
| 432 | 
            +
                    additional_prompt: str,
         | 
| 433 | 
            +
                    negative_prompt: str,
         | 
| 434 | 
            +
                    num_images: int,
         | 
| 435 | 
            +
                    image_resolution: int,
         | 
| 436 | 
            +
                    detect_resolution: int,
         | 
| 437 | 
            +
                    num_steps: int,
         | 
| 438 | 
            +
                    guidance_scale: float,
         | 
| 439 | 
            +
                    seed: int,
         | 
| 440 | 
            +
                    is_pose_image: bool,
         | 
| 441 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 442 | 
            +
                    control_image, vis_control_image = self.preprocess_pose(
         | 
| 443 | 
            +
                        input_image=input_image,
         | 
| 444 | 
            +
                        image_resolution=image_resolution,
         | 
| 445 | 
            +
                        detect_resolution=detect_resolution,
         | 
| 446 | 
            +
                        is_pose_image=is_pose_image,
         | 
| 447 | 
            +
                    )
         | 
| 448 | 
            +
                    self.load_controlnet_weight('pose')
         | 
| 449 | 
            +
                    results = self.run_pipe(
         | 
| 450 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 451 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 452 | 
            +
                        control_image=control_image,
         | 
| 453 | 
            +
                        num_images=num_images,
         | 
| 454 | 
            +
                        num_steps=num_steps,
         | 
| 455 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 456 | 
            +
                        seed=seed,
         | 
| 457 | 
            +
                    )
         | 
| 458 | 
            +
                    return [vis_control_image] + results
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                @staticmethod
         | 
| 461 | 
            +
                def preprocess_seg(
         | 
| 462 | 
            +
                    input_image: np.ndarray,
         | 
| 463 | 
            +
                    image_resolution: int,
         | 
| 464 | 
            +
                    detect_resolution: int,
         | 
| 465 | 
            +
                    is_segmentation_map: bool,
         | 
| 466 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 467 | 
             
                    input_image = HWC3(input_image)
         | 
| 468 | 
            +
                    if not is_segmentation_map:
         | 
| 469 | 
            +
                        control_image = apply_uniformer(
         | 
| 470 | 
            +
                            resize_image(input_image, detect_resolution))
         | 
| 471 | 
            +
                        image = resize_image(input_image, image_resolution)
         | 
| 472 | 
            +
                        H, W = image.shape[:2]
         | 
| 473 | 
            +
                        control_image = cv2.resize(control_image, (W, H),
         | 
| 474 | 
            +
                                                   interpolation=cv2.INTER_NEAREST)
         | 
| 475 | 
            +
                    else:
         | 
| 476 | 
            +
                        control_image = input_image
         | 
| 477 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 478 | 
            +
                        control_image)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 479 |  | 
| 480 | 
             
                @torch.inference_mode()
         | 
| 481 | 
            +
                def process_seg(
         | 
| 482 | 
            +
                    self,
         | 
| 483 | 
            +
                    input_image: np.ndarray,
         | 
| 484 | 
            +
                    prompt: str,
         | 
| 485 | 
            +
                    additional_prompt: str,
         | 
| 486 | 
            +
                    negative_prompt: str,
         | 
| 487 | 
            +
                    num_images: int,
         | 
| 488 | 
            +
                    image_resolution: int,
         | 
| 489 | 
            +
                    detect_resolution: int,
         | 
| 490 | 
            +
                    num_steps: int,
         | 
| 491 | 
            +
                    guidance_scale: float,
         | 
| 492 | 
            +
                    seed: int,
         | 
| 493 | 
            +
                    is_segmentation_map: bool,
         | 
| 494 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 495 | 
            +
                    control_image, vis_control_image = self.preprocess_seg(
         | 
| 496 | 
            +
                        input_image=input_image,
         | 
| 497 | 
            +
                        image_resolution=image_resolution,
         | 
| 498 | 
            +
                        detect_resolution=detect_resolution,
         | 
| 499 | 
            +
                        is_segmentation_map=is_segmentation_map,
         | 
| 500 | 
            +
                    )
         | 
| 501 | 
            +
                    self.load_controlnet_weight('seg')
         | 
| 502 | 
            +
                    results = self.run_pipe(
         | 
| 503 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 504 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 505 | 
            +
                        control_image=control_image,
         | 
| 506 | 
            +
                        num_images=num_images,
         | 
| 507 | 
            +
                        num_steps=num_steps,
         | 
| 508 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 509 | 
            +
                        seed=seed,
         | 
| 510 | 
            +
                    )
         | 
| 511 | 
            +
                    return [vis_control_image] + results
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                @staticmethod
         | 
| 514 | 
            +
                def preprocess_depth(
         | 
| 515 | 
            +
                    input_image: np.ndarray,
         | 
| 516 | 
            +
                    image_resolution: int,
         | 
| 517 | 
            +
                    detect_resolution: int,
         | 
| 518 | 
            +
                    is_depth_image: bool,
         | 
| 519 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 520 | 
            +
                    input_image = HWC3(input_image)
         | 
| 521 | 
            +
                    if not is_depth_image:
         | 
| 522 | 
            +
                        control_image, _ = apply_midas(
         | 
| 523 | 
            +
                            resize_image(input_image, detect_resolution))
         | 
| 524 | 
            +
                        control_image = HWC3(control_image)
         | 
| 525 | 
            +
                        image = resize_image(input_image, image_resolution)
         | 
| 526 | 
            +
                        H, W = image.shape[:2]
         | 
| 527 | 
            +
                        control_image = cv2.resize(control_image, (W, H),
         | 
| 528 | 
            +
                                                   interpolation=cv2.INTER_LINEAR)
         | 
| 529 | 
            +
                    else:
         | 
| 530 | 
            +
                        control_image = input_image
         | 
| 531 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 532 | 
            +
                        control_image)
         | 
| 533 |  | 
| 534 | 
            +
                @torch.inference_mode()
         | 
| 535 | 
            +
                def process_depth(
         | 
| 536 | 
            +
                    self,
         | 
| 537 | 
            +
                    input_image: np.ndarray,
         | 
| 538 | 
            +
                    prompt: str,
         | 
| 539 | 
            +
                    additional_prompt: str,
         | 
| 540 | 
            +
                    negative_prompt: str,
         | 
| 541 | 
            +
                    num_images: int,
         | 
| 542 | 
            +
                    image_resolution: int,
         | 
| 543 | 
            +
                    detect_resolution: int,
         | 
| 544 | 
            +
                    num_steps: int,
         | 
| 545 | 
            +
                    guidance_scale: float,
         | 
| 546 | 
            +
                    seed: int,
         | 
| 547 | 
            +
                    is_depth_image: bool,
         | 
| 548 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 549 | 
            +
                    control_image, vis_control_image = self.preprocess_depth(
         | 
| 550 | 
            +
                        input_image=input_image,
         | 
| 551 | 
            +
                        image_resolution=image_resolution,
         | 
| 552 | 
            +
                        detect_resolution=detect_resolution,
         | 
| 553 | 
            +
                        is_depth_image=is_depth_image,
         | 
| 554 | 
            +
                    )
         | 
| 555 | 
            +
                    self.load_controlnet_weight('depth')
         | 
| 556 | 
            +
                    results = self.run_pipe(
         | 
| 557 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 558 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 559 | 
            +
                        control_image=control_image,
         | 
| 560 | 
            +
                        num_images=num_images,
         | 
| 561 | 
            +
                        num_steps=num_steps,
         | 
| 562 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 563 | 
            +
                        seed=seed,
         | 
| 564 | 
            +
                    )
         | 
| 565 | 
            +
                    return [vis_control_image] + results
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                @staticmethod
         | 
| 568 | 
            +
                def preprocess_normal(
         | 
| 569 | 
            +
                    input_image: np.ndarray,
         | 
| 570 | 
            +
                    image_resolution: int,
         | 
| 571 | 
            +
                    detect_resolution: int,
         | 
| 572 | 
            +
                    bg_threshold: float,
         | 
| 573 | 
            +
                    is_normal_image: bool,
         | 
| 574 | 
            +
                ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
         | 
| 575 | 
             
                    input_image = HWC3(input_image)
         | 
| 576 | 
            +
                    if not is_normal_image:
         | 
| 577 | 
            +
                        _, control_image = apply_midas(resize_image(
         | 
| 578 | 
            +
                            input_image, detect_resolution),
         | 
| 579 | 
            +
                                                       bg_th=bg_threshold)
         | 
| 580 | 
            +
                        control_image = HWC3(control_image)
         | 
| 581 | 
            +
                        image = resize_image(input_image, image_resolution)
         | 
| 582 | 
            +
                        H, W = image.shape[:2]
         | 
| 583 | 
            +
                        control_image = cv2.resize(control_image, (W, H),
         | 
| 584 | 
            +
                                                   interpolation=cv2.INTER_LINEAR)
         | 
| 585 | 
            +
                    else:
         | 
| 586 | 
            +
                        control_image = input_image
         | 
| 587 | 
            +
                    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
         | 
| 588 | 
            +
                        control_image)
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                @torch.inference_mode()
         | 
| 591 | 
            +
                def process_normal(
         | 
| 592 | 
            +
                    self,
         | 
| 593 | 
            +
                    input_image: np.ndarray,
         | 
| 594 | 
            +
                    prompt: str,
         | 
| 595 | 
            +
                    additional_prompt: str,
         | 
| 596 | 
            +
                    negative_prompt: str,
         | 
| 597 | 
            +
                    num_images: int,
         | 
| 598 | 
            +
                    image_resolution: int,
         | 
| 599 | 
            +
                    detect_resolution: int,
         | 
| 600 | 
            +
                    num_steps: int,
         | 
| 601 | 
            +
                    guidance_scale: float,
         | 
| 602 | 
            +
                    seed: int,
         | 
| 603 | 
            +
                    bg_threshold: float,
         | 
| 604 | 
            +
                    is_normal_image: bool,
         | 
| 605 | 
            +
                ) -> list[PIL.Image.Image]:
         | 
| 606 | 
            +
                    control_image, vis_control_image = self.preprocess_normal(
         | 
| 607 | 
            +
                        input_image=input_image,
         | 
| 608 | 
            +
                        image_resolution=image_resolution,
         | 
| 609 | 
            +
                        detect_resolution=detect_resolution,
         | 
| 610 | 
            +
                        bg_threshold=bg_threshold,
         | 
| 611 | 
            +
                        is_normal_image=is_normal_image,
         | 
| 612 | 
            +
                    )
         | 
| 613 | 
            +
                    self.load_controlnet_weight('normal')
         | 
| 614 | 
            +
                    results = self.run_pipe(
         | 
| 615 | 
            +
                        prompt=self.get_prompt(prompt, additional_prompt),
         | 
| 616 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 617 | 
            +
                        control_image=control_image,
         | 
| 618 | 
            +
                        num_images=num_images,
         | 
| 619 | 
            +
                        num_steps=num_steps,
         | 
| 620 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 621 | 
            +
                        seed=seed,
         | 
| 622 | 
            +
                    )
         | 
| 623 | 
            +
                    return [vis_control_image] + results
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,8 +1,9 @@ | |
| 1 | 
             
            addict==2.4.0
         | 
| 2 | 
             
            albumentations==1.3.0
         | 
| 3 | 
             
            einops==0.6.0
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            huggingface | 
|  | |
| 6 | 
             
            imageio==2.25.0
         | 
| 7 | 
             
            imageio-ffmpeg==0.4.8
         | 
| 8 | 
             
            kornia==0.6.9
         | 
|  | |
| 1 | 
             
            addict==2.4.0
         | 
| 2 | 
             
            albumentations==1.3.0
         | 
| 3 | 
             
            einops==0.6.0
         | 
| 4 | 
            +
            git+https://github.com/huggingface/accelerate@78151f8
         | 
| 5 | 
            +
            git+https://github.com/huggingface/diffusers@fa6d52d
         | 
| 6 | 
            +
            gradio==3.20.0
         | 
| 7 | 
             
            imageio==2.25.0
         | 
| 8 | 
             
            imageio-ffmpeg==0.4.8
         | 
| 9 | 
             
            kornia==0.6.9
         | 
