File size: 15,307 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import base64
import io
import os
import re
import json
from PIL import Image
import gradio as gr
from modules.paths import data_path
from modules import shared, gr_tempdir, script_callbacks, images


type_of_gr_update = type(gr.update())
paste_fields = {}
registered_param_bindings = []
debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: PASTE')


class ParamBinding:
    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
        self.paste_button = paste_button
        self.tabname = tabname
        self.source_text_component = source_text_component
        self.source_image_component = source_image_component
        self.source_tabname = source_tabname
        self.override_settings_component = override_settings_component
        self.paste_field_names = paste_field_names or []
        debug(f'ParamBinding: {vars(self)}')


def reset():
    paste_fields.clear()


def quote(text):
    if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
        return text
    return json.dumps(text, ensure_ascii=False)


def unquote(text):
    if len(text) == 0 or text[0] != '"' or text[-1] != '"':
        return text
    try:
        return json.loads(text)
    except Exception:
        return text


def image_from_url_text(filedata):
    if filedata is None:
        return None
    if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
        filedata = filedata[0]
    if type(filedata) == dict and filedata.get("is_file", False):
        filename = filedata["name"]
        is_in_right_dir = gr_tempdir.check_tmp_file(shared.demo, filename)
        if is_in_right_dir:
            filename = filename.rsplit('?', 1)[0]
            if not os.path.exists(filename):
                shared.log.error(f'Image file not found: {filename}')
                image = Image.new('RGB', (512, 512))
                image.info['parameters'] = f'Image file not found: {filename}'
                return image
            image = Image.open(filename)
            geninfo, _items = images.read_info_from_image(image)
            image.info['parameters'] = geninfo
            return image
        else:
            shared.log.warning(f'File access denied: {filename}')
            return None
    if type(filedata) == list:
        if len(filedata) == 0:
            return None
        filedata = filedata[0]
    if type(filedata) == dict:
        shared.log.warning('Incorrect filedata received')
        return None
    if filedata.startswith("data:image/png;base64,"):
        filedata = filedata[len("data:image/png;base64,"):]
    if filedata.startswith("data:image/webp;base64,"):
        filedata = filedata[len("data:image/webp;base64,"):]
    if filedata.startswith("data:image/jpeg;base64,"):
        filedata = filedata[len("data:image/jpeg;base64,"):]
    filedata = base64.decodebytes(filedata.encode('utf-8'))
    image = Image.open(io.BytesIO(filedata))
    images.read_info_from_image(image)
    return image


def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
    paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
    # backwards compatibility for existing extensions
    import modules.ui
    if tabname == 'txt2img':
        modules.ui.txt2img_paste_fields = fields
    elif tabname == 'img2img':
        modules.ui.img2img_paste_fields = fields


def create_buttons(tabs_list):
    buttons = {}
    for tab in tabs_list:
        name = tab
        if name == 'txt2img':
            name = 'Text'
        elif name == 'img2img':
            name = 'Image'
        elif name == 'inpaint':
            name = 'Inpaint'
        elif name == 'extras':
            name = 'Process'
        elif name == 'control':
            name = 'Control'
        buttons[tab] = gr.Button(f"โž  {name}", elem_id=f"{tab}_tab")
    return buttons


def bind_buttons(buttons, send_image, send_generate_info):
    """old function for backwards compatibility; do not use this, use register_paste_params_button"""
    for tabname, button in buttons.items():
        source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
        source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
        bindings = ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)
        register_paste_params_button(bindings)


def register_paste_params_button(binding: ParamBinding):
    registered_param_bindings.append(binding)


def connect_paste_params_buttons():
    binding: ParamBinding
    for binding in registered_param_bindings:
        if binding.tabname not in paste_fields:
            debug(f"Not not registered: tab={binding.tabname}")
            continue
        destination_image_component = paste_fields[binding.tabname]["init_img"]
        fields = paste_fields[binding.tabname]["fields"]
        override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
        destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
        destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)

        if binding.source_image_component and destination_image_component:
            if isinstance(binding.source_image_component, gr.Gallery):
                func = send_image_and_dimensions if destination_width_component else image_from_url_text
                jsfunc = "extract_image_from_gallery"
            else:
                func = send_image_and_dimensions if destination_width_component else lambda x: x
                jsfunc = None
            binding.paste_button.click(
                fn=func,
                _js=jsfunc,
                inputs=[binding.source_image_component],
                outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
                show_progress=False,
            )
        if binding.source_text_component is not None and fields is not None:
            connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
        if binding.source_tabname is not None and fields is not None:
            paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
            binding.paste_button.click(
                fn=lambda *x: x,
                inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
                outputs=[field for field, name in fields if name in paste_field_names],
            )
        binding.paste_button.click(
            fn=None,
            _js=f"switch_to_{binding.tabname}",
            inputs=[],
            outputs=[],
            show_progress=False,
        )


def send_image_and_dimensions(x):
    img = x if isinstance(x, Image.Image) else image_from_url_text(x)
    if shared.opts.send_size and isinstance(img, Image.Image):
        w = img.width
        h = img.height
    else:
        w = gr.update()
        h = gr.update()
    return img, w, h


def parse_generation_parameters(infotext):
    if not isinstance(infotext, str):
        return {}
    debug(f'Parse infotext: {infotext}')
    re_param = re.compile(r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)') # multi-word: value
    re_size = re.compile(r"^(\d+)x(\d+)$") # int x int
    sanitized = infotext.replace('prompt:', 'Prompt:').replace('negative prompt:', 'Negative prompt:').replace('Negative Prompt', 'Negative prompt') # cleanup everything in brackets so re_params can work
    sanitized = re.sub(r'<[^>]*>', lambda match: ' ' * len(match.group()), sanitized)
    sanitized = re.sub(r'\([^)]*\)', lambda match: ' ' * len(match.group()), sanitized)
    sanitized = re.sub(r'\{[^}]*\}', lambda match: ' ' * len(match.group()), sanitized)

    params = dict(re_param.findall(sanitized))
    debug(f"Parse params: {params}")
    params = { k.strip():params[k].strip() for k in params if k.lower() not in ['hashes', 'lora', 'embeddings', 'prompt', 'negative prompt']} # remove some keys
    first_param = next(iter(params)) if params else None
    params_idx = sanitized.find(f'{first_param}:') if first_param else -1
    negative_idx = infotext.find("Negative prompt:")

    prompt = infotext[:params_idx] if negative_idx == -1 else infotext[:negative_idx] # prompt can be with or without negative prompt
    negative = infotext[negative_idx:params_idx] if negative_idx >= 0 else ''

    for k, v in params.copy().items(): # avoid dict-has-changed
        if len(v) > 0 and v[0] == '"' and v[-1] == '"':
            v = unquote(v)
        m = re_size.match(v)
        if v.replace('.', '', 1).isdigit():
            params[k] = float(v) if '.' in v else int(v)
        elif v == "True":
            params[k] = True
        elif v == "False":
            params[k] = False
        elif m is not None:
            params[f"{k}-1"] = int(m.group(1))
            params[f"{k}-2"] = int(m.group(2))
        elif k == 'VAE' and v == 'TAESD':
            params["Full quality"] = False
        else:
            params[k] = v
    params["Prompt"] = prompt.replace('Prompt:', '').strip()
    params["Negative prompt"] = negative.replace('Negative prompt:', '').strip()
    debug(f"Parse: {params}")
    return params


settings_map = {}


infotext_to_setting_name_mapping = [
    ('Backend', 'sd_backend'),
    ('Model hash', 'sd_model_checkpoint'),
    ('Refiner', 'sd_model_refiner'),
    ('VAE', 'sd_vae'),
    ('Parser', 'prompt_attention'),
    ('Color correction', 'img2img_color_correction'),
    # Samplers
    ('Sampler Eta', 'scheduler_eta'),
    ('Sampler ENSD', 'eta_noise_seed_delta'),
    ('Sampler order', 'schedulers_solver_order'),
    # Samplers diffusers
    ('Sampler beta schedule', 'schedulers_beta_schedule'),
    ('Sampler beta start', 'schedulers_beta_start'),
    ('Sampler beta end', 'schedulers_beta_end'),
    ('Sampler DPM solver', 'schedulers_dpm_solver'),
    # Samplers original
    ('Sampler brownian', 'schedulers_brownian_noise'),
    ('Sampler discard', 'schedulers_discard_penultimate'),
    ('Sampler dyn threshold', 'schedulers_use_thresholding'),
    ('Sampler karras', 'schedulers_use_karras'),
    ('Sampler low order', 'schedulers_use_loworder'),
    ('Sampler quantization', 'enable_quantization'),
    ('Sampler sigma', 'schedulers_sigma'),
    ('Sampler sigma min', 's_min'),
    ('Sampler sigma max', 's_max'),
    ('Sampler sigma churn', 's_churn'),
    ('Sampler sigma uncond', 's_min_uncond'),
    ('Sampler sigma noise', 's_noise'),
    ('Sampler sigma tmin', 's_tmin'),
    ('Sampler ENSM', 'initial_noise_multiplier'), # img2img only
    ('UniPC skip type', 'uni_pc_skip_type'),
    ('UniPC variant', 'uni_pc_variant'),
    # Token Merging
    ('Mask weight', 'inpainting_mask_weight'),
    ('Token merging ratio', 'token_merging_ratio'),
    ('ToMe', 'token_merging_ratio'),
    ('ToMe hires', 'token_merging_ratio_hr'),
    ('ToMe img2img', 'token_merging_ratio_img2img'),
]


def create_override_settings_dict(text_pairs):
    res = {}
    params = {}
    for pair in text_pairs:
        k, v = pair.split(":", maxsplit=1)
        params[k] = v.strip()
    for param_name, setting_name in infotext_to_setting_name_mapping:
        value = params.get(param_name, None)
        if value is None:
            continue
        res[setting_name] = shared.opts.cast_value(setting_name, value)
    return res


def connect_paste(button, local_paste_fields, input_comp, override_settings_component, tabname):

    def paste_func(prompt):
        if prompt is None or len(prompt.strip()) == 0 and not shared.cmd_opts.hide_ui_dir_config:
            filename = os.path.join(data_path, "params.txt")
            if os.path.exists(filename):
                with open(filename, "r", encoding="utf8") as file:
                    prompt = file.read()
                shared.log.debug(f'Paste prompt: type="params" prompt="{prompt}"')
            else:
                prompt = ''
        else:
            shared.log.debug(f'Paste prompt: type="current" prompt="{prompt}"')
        params = parse_generation_parameters(prompt)
        script_callbacks.infotext_pasted_callback(prompt, params)
        res = []
        applied = {}
        for output, key in local_paste_fields:
            if callable(key):
                v = key(params)
            else:
                v = params.get(key, None)
            if v is None:
                res.append(gr.update())
            elif isinstance(v, type_of_gr_update):
                res.append(v)
                applied[key] = v
            else:
                try:
                    valtype = type(output.value)
                    if valtype == bool and v == "False":
                        val = False
                    else:
                        val = valtype(v)
                    res.append(gr.update(value=val))
                    applied[key] = val
                except Exception:
                    res.append(gr.update())
        debug(f"Parse apply: {applied}")
        return res

    if override_settings_component is not None:
        def paste_settings(params):
            vals = {}
            for param_name, setting_name in infotext_to_setting_name_mapping:
                v = params.get(param_name, None)
                if v is None:
                    continue
                if shared.opts.disable_weights_auto_swap:
                    if setting_name == "sd_model_checkpoint" or setting_name == 'sd_model_refiner' or setting_name == 'sd_backend' or setting_name == 'sd_vae':
                        continue
                v = shared.opts.cast_value(setting_name, v)
                current_value = getattr(shared.opts, setting_name, None)
                if v == current_value:
                    continue
                if type(current_value) == str and v == os.path.splitext(current_value)[0]:
                    continue
                vals[param_name] = v
            vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
            shared.log.debug(f'Settings overrides: {vals_pairs}')
            return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
        local_paste_fields = local_paste_fields + [(override_settings_component, paste_settings)]

    button.click(
        fn=paste_func,
        inputs=[input_comp],
        outputs=[x[0] for x in local_paste_fields],
        show_progress=False,
    )
    button.click(
        fn=None,
        _js=f"recalculate_prompts_{tabname}",
        inputs=[],
        outputs=[],
        show_progress=False,
    )