File size: 19,632 Bytes
81f4d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# Copyright (C) 2023 Deforum LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# Contact the authors: https://deforum.github.io/

# This helper script is responsible for ControlNet/Deforum integration
# https://github.com/Mikubill/sd-webui-controlnet — controlnet repo

import os
import gradio as gr
import scripts
from PIL import Image
import numpy as np
import importlib
from modules import scripts
from .deforum_controlnet_gradio import hide_ui_by_cn_status, hide_file_textboxes, ToolButton
from .general_utils import count_files_in_folder, clean_gradio_path_strings  # TODO: do it another way
from .video_audio_utilities import vid2frames, convert_image
from .animation_key_frames import ControlNetKeys
from .load_images import load_image
from .general_utils import debug_print

cnet = None
# number of CN model tabs to show in the deforum gui
num_of_models = 5

def find_controlnet():
    global cnet
    if cnet: return cnet
    try:
        cnet = importlib.import_module('extensions.sd-webui-controlnet.scripts.external_code', 'external_code')
    except:
        try:
            cnet = importlib.import_module('extensions-builtin.sd-webui-controlnet.scripts.external_code', 'external_code')
        except:
            pass
    if cnet:
        print(f"\033[0;32m*Deforum ControlNet support: enabled*\033[0m")
        return True
    return None

def controlnet_infotext():
    return """Requires the <a style='color:SteelBlue;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet'>ControlNet</a> extension to be installed.</p>
            <p">If Deforum crashes due to CN updates, go <a style='color:Orange;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet/issues'>here</a> and report your problem.</p>
           """

def is_controlnet_enabled(controlnet_args):
    for i in range(1, num_of_models + 1):
        if getattr(controlnet_args, f'cn_{i}_enabled', False):
            return True
    return False

def setup_controlnet_ui_raw():
    cnet = find_controlnet()
    cn_models = cnet.get_models()
    cn_preprocessors = cnet.get_modules()

    cn_modules = cnet.get_modules_detail()
    preprocessor_sliders_config = {}

    for config_name, config_values in cn_modules.items():
        sliders = config_values.get('sliders', [])
        preprocessor_sliders_config[config_name] = sliders

    model_free_preprocessors = ["reference_only", "reference_adain", "reference_adain+attn"]
    flag_preprocessor_resolution = "Preprocessor Resolution"

    def build_sliders(module, pp):
        grs = []
        if module not in preprocessor_sliders_config:
            grs += [
                gr.update(label=flag_preprocessor_resolution, value=512, minimum=64, maximum=2048, step=1, visible=not pp, interactive=not pp),
                gr.update(visible=False, interactive=False),
                gr.update(visible=False, interactive=False),
                gr.update(visible=True)
            ]
        else:
            for slider_config in preprocessor_sliders_config[module]:
                if isinstance(slider_config, dict):
                    visible = True
                    if slider_config['name'] == flag_preprocessor_resolution:
                        visible = not pp
                    grs.append(gr.update(
                        label=slider_config['name'],
                        value=slider_config['value'],
                        minimum=slider_config['min'],
                        maximum=slider_config['max'],
                        step=slider_config['step'] if 'step' in slider_config else 1,
                        visible=visible,
                        interactive=visible))
                else:
                    grs.append(gr.update(visible=False, interactive=False))
            while len(grs) < 3:
                grs.append(gr.update(visible=False, interactive=False))
            grs.append(gr.update(visible=True))
        if module in model_free_preprocessors:
            grs += [gr.update(visible=False, value='None'), gr.update(visible=False)]
        else:
            grs += [gr.update(visible=True), gr.update(visible=True)]
        return grs

    refresh_symbol = '\U0001f504'  # 🔄
    switch_values_symbol = '\U000021C5'  # ⇅
    model_dropdowns = []
    infotext_fields = []

    def create_model_in_tab_ui(cn_id):
        with gr.Row():
            enabled = gr.Checkbox(label="Enable", value=False, interactive=True)
            pixel_perfect = gr.Checkbox(label="Pixel Perfect", value=False, visible=False, interactive=True)
            low_vram = gr.Checkbox(label="Low VRAM", value=False, visible=False, interactive=True)
            overwrite_frames = gr.Checkbox(label='Overwrite input frames', value=True, visible=False, interactive=True)
        with gr.Row(visible=False) as mod_row:
            module = gr.Dropdown(cn_preprocessors, label=f"Preprocessor", value="none", interactive=True)
            model = gr.Dropdown(cn_models, label=f"Model", value="None", interactive=True)
            refresh_models = ToolButton(value=refresh_symbol)
            refresh_models.click(refresh_all_models, model, model)
        with gr.Row(visible=False) as weight_row:
            weight = gr.Textbox(label="Weight schedule", lines=1, value='0:(1)', interactive=True)
        with gr.Row(visible=False) as start_cs_row:
            guidance_start = gr.Textbox(label="Starting Control Step schedule", lines=1, value='0:(0.0)', interactive=True)
        with gr.Row(visible=False) as end_cs_row:
            guidance_end = gr.Textbox(label="Ending Control Step schedule", lines=1, value='0:(1.0)', interactive=True)
            model_dropdowns.append(model)
        with gr.Column(visible=False) as advanced_column:
            processor_res = gr.Slider(label="Annotator resolution", value=64, minimum=64, maximum=2048, interactive=False)
            threshold_a = gr.Slider(label="Threshold A", value=64, minimum=64, maximum=1024, interactive=False)
            threshold_b = gr.Slider(label="Threshold B", value=64, minimum=64, maximum=1024, interactive=False)
        with gr.Row(visible=False) as vid_path_row:
            vid_path = gr.Textbox(value='', label="ControlNet Input Video/ Image Path", interactive=True)
        with gr.Row(visible=False) as mask_vid_path_row:  # invisible temporarily since 26-04-23 until masks are fixed
            mask_vid_path = gr.Textbox(value='', label="ControlNet Mask Video/ Image Path (*NOT WORKING, kept in UI for CN's devs testing!*)", interactive=True)
        with gr.Row(visible=False) as control_mode_row:
            control_mode = gr.Radio(choices=["Balanced", "My prompt is more important", "ControlNet is more important"], value="Balanced", label="Control Mode", interactive=True)
        with gr.Row(visible=False) as env_row:
            resize_mode = gr.Radio(choices=["Outer Fit (Shrink to Fit)", "Inner Fit (Scale to Fit)", "Just Resize"], value="Inner Fit (Scale to Fit)", label="Resize Mode", interactive=True)
        with gr.Row(visible=False) as control_loopback_row:
            loopback_mode = gr.Checkbox(label="LoopBack mode", value=False, interactive=True)
        hide_output_list = [pixel_perfect, low_vram, mod_row, module, weight_row, start_cs_row, end_cs_row, env_row, overwrite_frames, vid_path_row, control_mode_row, mask_vid_path_row,
                            control_loopback_row]  # add mask_vid_path_row when masks are working again
        for cn_output in hide_output_list:
            enabled.change(fn=hide_ui_by_cn_status, inputs=enabled, outputs=cn_output)
        module.change(build_sliders, inputs=[module, pixel_perfect], outputs=[processor_res, threshold_a, threshold_b, advanced_column, model, refresh_models])
        # hide vid/image input fields
        loopback_outs = [vid_path_row, mask_vid_path_row]
        for loopback_output in loopback_outs:
            loopback_mode.change(fn=hide_file_textboxes, inputs=loopback_mode, outputs=loopback_output)
        # handle pixel perfect ui changes
        pixel_perfect.change(build_sliders, inputs=[module, pixel_perfect], outputs=[processor_res, threshold_a, threshold_b, advanced_column, model, refresh_models])
        infotext_fields.extend([
            (module, f"ControlNet Preprocessor"),
            (model, f"ControlNet Model"),
            (weight, f"ControlNet Weight"),
        ])

        return {key: value for key, value in locals().items() if key in [
            "enabled", "pixel_perfect", "low_vram", "module", "model", "weight",
            "guidance_start", "guidance_end", "processor_res", "threshold_a", "threshold_b", "resize_mode", "control_mode",
            "overwrite_frames", "vid_path", "mask_vid_path", "loopback_mode"
        ]}

    def refresh_all_models(*inputs):
        cn_models = cnet.get_models(update=True)
        dd = inputs[0]
        selected = dd if dd in cn_models else "None"
        return gr.Dropdown.update(value=selected, choices=cn_models)

    with gr.TabItem('ControlNet'):
        gr.HTML(controlnet_infotext())
        with gr.Tabs():
            model_params = {}
            for i in range(1, num_of_models + 1):
                with gr.Tab(f"CN Model {i}"):
                    model_params[i] = create_model_in_tab_ui(i)

                    for key, value in model_params[i].items():
                        locals()[f"cn_{i}_{key}"] = value

    return locals()

def setup_controlnet_ui():
    if not find_controlnet():
        gr.HTML("""<a style='target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet'>ControlNet not found. Please install it :)</a>""", elem_id='controlnet_not_found_html_msg')
        return {}

    try:
        return setup_controlnet_ui_raw()
    except Exception as e:
        print(f"'ControlNet UI setup failed with error: '{e}'!")
        gr.HTML(f"""
                Failed to setup ControlNet UI, check the reason in your commandline log. Please, downgrade your CN extension to <a style='color:Orange;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet/archive/c9340671d6d59e5a79fc404f78f747f969f87374.zip'>c9340671d6d59e5a79fc404f78f747f969f87374</a> or report the problem <a style='color:Orange;' target='_blank' href='https://github.com/Mikubill/sd-webui-controlnet/issues'>here</a>.
                """, elem_id='controlnet_not_found_html_msg')
        return {}

def controlnet_component_names():
    if not find_controlnet():
        return []

    return [f'cn_{i}_{component}' for i in range(1, num_of_models + 1) for component in [
        'overwrite_frames', 'vid_path', 'mask_vid_path', 'enabled',
        'low_vram', 'pixel_perfect',
        'module', 'model', 'weight', 'guidance_start', 'guidance_end',
        'processor_res', 'threshold_a', 'threshold_b', 'resize_mode', 'control_mode', 'loopback_mode'
    ]]

def process_with_controlnet(p, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=True, frame_idx=0):
    CnSchKeys = ControlNetKeys(anim_args, controlnet_args) if not parseq_adapter.use_parseq else parseq_adapter.cn_keys

    def read_cn_data(cn_idx):
        cn_mask_np, cn_image_np = None, None
        # Loopback mode ENABLED:
        if getattr(controlnet_args, f'cn_{cn_idx}_loopback_mode'):
            # On very first frame, check if use init enabled, and if init image is provided
            if frame_idx == 0 and args.use_init and (args.init_image is not None or args.init_image_box is not None):
                cn_image_np = load_image(args.init_image, args.init_image_box)
                # convert to uint8 for compatibility with CN
                cn_image_np = np.array(cn_image_np).astype('uint8')
            # Not first frame, use previous img (init_sample)
            elif frame_idx > 0 and root.init_sample:
                cn_image_np = np.array(root.init_sample).astype('uint8')
        else:  # loopback mode is DISABLED
            cn_inputframes = os.path.join(args.outdir, f'controlnet_{cn_idx}_inputframes')  # set input frames folder path
            if os.path.exists(cn_inputframes):
                if count_files_in_folder(cn_inputframes) == 1:
                    cn_frame_path = os.path.join(cn_inputframes, "000000000.jpg")
                    print(f'Reading ControlNet *static* base frame at {cn_frame_path}')
                else:
                    cn_frame_path = os.path.join(cn_inputframes, f"{frame_idx:09}.jpg")
                    print(f'Reading ControlNet {cn_idx} base frame #{frame_idx} at {cn_frame_path}')
                if os.path.exists(cn_frame_path):
                    cn_image_np = np.array(Image.open(cn_frame_path).convert("RGB")).astype('uint8')
            cn_maskframes = os.path.join(args.outdir, f'controlnet_{cn_idx}_maskframes')  # set mask frames folder path
            if os.path.exists(cn_maskframes):
                if count_files_in_folder(cn_maskframes) == 1:
                    cn_mask_frame_path = os.path.join(cn_inputframes, "000000000.jpg")
                    print(f'Reading ControlNet *static* mask frame at {cn_mask_frame_path}')
                else:
                    cn_mask_frame_path = os.path.join(args.outdir, f'controlnet_{cn_idx}_maskframes', f"{frame_idx:09}.jpg")
                    print(f'Reading ControlNet {cn_idx} mask frame #{frame_idx} at {cn_mask_frame_path}')
                if os.path.exists(cn_mask_frame_path):
                    cn_mask_np = np.array(Image.open(cn_mask_frame_path).convert("RGB")).astype('uint8')

        return cn_mask_np, cn_image_np

    cnet = find_controlnet()
    cn_data = [read_cn_data(i) for i in range(1, num_of_models + 1)]

    # Check if any loopback_mode is set to True
    any_loopback_mode = any(getattr(controlnet_args, f'cn_{i}_loopback_mode') for i in range(1, num_of_models + 1))

    cn_inputframes_list = [os.path.join(args.outdir, f'controlnet_{i}_inputframes') for i in range(1, num_of_models + 1)]

    if not any(os.path.exists(cn_inputframes) for cn_inputframes in cn_inputframes_list) and not any_loopback_mode:
        print(f'\033[33mNeither the base nor the masking frames for ControlNet were found. Using the regular pipeline\033[0m')

    p.scripts = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img

    def create_cnu_dict(cn_args, prefix, img_np, mask_np, frame_idx, CnSchKeys):

        keys = [
            "enabled", "module", "model", "weight", "resize_mode", "control_mode", "low_vram", "pixel_perfect",
            "processor_res", "threshold_a", "threshold_b", "guidance_start", "guidance_end"
        ]
        cnu = {k: getattr(cn_args, f"{prefix}_{k}") for k in keys}
        model_num = int(prefix.split('_')[-1])  # Extract model number from prefix (e.g., "cn_1" -> 1)
        if 1 <= model_num <= 5:
            # if in loopmode and no init image (img_np, after processing in this case) provided, disable CN unit for the very first frame. Will be enabled in the next frame automatically
            if getattr(cn_args, f"cn_{model_num}_loopback_mode") and frame_idx == 0 and img_np is None:
                cnu['enabled'] = False
            cnu['weight'] = getattr(CnSchKeys, f"cn_{model_num}_weight_schedule_series")[frame_idx]
            cnu['guidance_start'] = getattr(CnSchKeys, f"cn_{model_num}_guidance_start_schedule_series")[frame_idx]
            cnu['guidance_end'] = getattr(CnSchKeys, f"cn_{model_num}_guidance_end_schedule_series")[frame_idx]
            if cnu['enabled']:
                debug_print(f"ControlNet {model_num}: weight={cnu['weight']}, guidance_start={cnu['guidance_start']}, guidance_end={cnu['guidance_end']}")
        cnu['image'] = {'image': img_np, 'mask': mask_np} if mask_np is not None else img_np

        return cnu

    masks_np, images_np = zip(*cn_data)

    cn_units = [cnet.ControlNetUnit(**create_cnu_dict(controlnet_args, f"cn_{i + 1}", img_np, mask_np, frame_idx, CnSchKeys))
                for i, (img_np, mask_np) in enumerate(zip(images_np, masks_np))]

    p.script_args = {"enabled": True}
    cnet.update_cn_script_in_processing(p, cn_units, is_img2img=is_img2img, is_ui=False)

def process_controlnet_input_frames(args, anim_args, controlnet_args, video_path, mask_path, outdir_suffix, id):
    if (video_path or mask_path) and getattr(controlnet_args, f'cn_{id}_enabled'):
        frame_path = os.path.join(args.outdir, f'controlnet_{id}_{outdir_suffix}')
        os.makedirs(frame_path, exist_ok=True)

        accepted_image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
        if video_path and video_path.lower().endswith(accepted_image_extensions):
            convert_image(video_path, os.path.join(frame_path, '000000000.jpg'))
            print(f"Copied CN Model {id}'s single input image to inputframes folder!")
        elif mask_path and mask_path.lower().endswith(accepted_image_extensions):
            convert_image(mask_path, os.path.join(frame_path, '000000000.jpg'))
            print(f"Copied CN Model {id}'s single input image to inputframes *mask* folder!")
        else:
            print(f'Unpacking ControlNet {id} {"video mask" if mask_path else "base video"}')
            print(f"Exporting Video Frames to {frame_path}...")
            vid2frames(
                video_path=video_path or mask_path,
                video_in_frame_path=frame_path,
                n=1 if anim_args.animation_mode != 'Video Input' else anim_args.extract_nth_frame,
                overwrite=getattr(controlnet_args, f'cn_{id}_overwrite_frames'),
                extract_from_frame=0 if anim_args.animation_mode != 'Video Input' else anim_args.extract_from_frame,
                extract_to_frame=(anim_args.max_frames - 1) if anim_args.animation_mode != 'Video Input' else anim_args.extract_to_frame,
                numeric_files_output=True
            )
            print(f"Loading {anim_args.max_frames} input frames from {frame_path} and saving video frames to {args.outdir}")
            print(f'ControlNet {id} {"video mask" if mask_path else "base video"} unpacked!')

def unpack_controlnet_vids(args, anim_args, controlnet_args):
    # this func gets called from render.py once for an entire animation run -->
    # tries to trigger an extraction of CN input frames (regular + masks) from video or image
    for i in range(1, num_of_models + 1):
        # LoopBack mode is enabled, no need to extract a video or copy an init image
        if getattr(controlnet_args, f'cn_{i}_loopback_mode'):
            print(f"ControlNet #{i} is in LoopBack mode, skipping video/ image extraction stage.")
            continue
        vid_path = clean_gradio_path_strings(getattr(controlnet_args, f'cn_{i}_vid_path', None))
        mask_path = clean_gradio_path_strings(getattr(controlnet_args, f'cn_{i}_mask_vid_path', None))

        if vid_path:  # Process base video, if available
            process_controlnet_input_frames(args, anim_args, controlnet_args, vid_path, None, 'inputframes', i)

        if mask_path:  # Process mask video, if available
            process_controlnet_input_frames(args, anim_args, controlnet_args, None, mask_path, 'maskframes', i)