Spaces:
Paused
Paused
| import json | |
| import re | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| import gradio as gr | |
| from PIL import Image | |
| import fooocus_version | |
| import modules.config | |
| import modules.sdxl_styles | |
| from modules.flags import MetadataScheme, Performance, Steps | |
| from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS | |
| from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, sha256 | |
| re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' | |
| re_param = re.compile(re_param_code) | |
| re_imagesize = re.compile(r"^(\d+)x(\d+)$") | |
| hash_cache = {} | |
| def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): | |
| loaded_parameter_dict = raw_metadata | |
| if isinstance(raw_metadata, str): | |
| loaded_parameter_dict = json.loads(raw_metadata) | |
| assert isinstance(loaded_parameter_dict, dict) | |
| results = [len(loaded_parameter_dict) > 0] | |
| get_image_number('image_number', 'Image Number', loaded_parameter_dict, results) | |
| get_str('prompt', 'Prompt', loaded_parameter_dict, results) | |
| get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results) | |
| get_list('styles', 'Styles', loaded_parameter_dict, results) | |
| get_str('performance', 'Performance', loaded_parameter_dict, results) | |
| get_steps('steps', 'Steps', loaded_parameter_dict, results) | |
| get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results) | |
| get_resolution('resolution', 'Resolution', loaded_parameter_dict, results) | |
| get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results) | |
| get_float('sharpness', 'Sharpness', loaded_parameter_dict, results) | |
| get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results) | |
| get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results) | |
| get_float('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results) | |
| get_str('base_model', 'Base Model', loaded_parameter_dict, results) | |
| get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results) | |
| get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) | |
| get_str('sampler', 'Sampler', loaded_parameter_dict, results) | |
| get_str('scheduler', 'Scheduler', loaded_parameter_dict, results) | |
| get_seed('seed', 'Seed', loaded_parameter_dict, results) | |
| if is_generating: | |
| results.append(gr.update()) | |
| else: | |
| results.append(gr.update(visible=True)) | |
| results.append(gr.update(visible=False)) | |
| get_freeu('freeu', 'FreeU', loaded_parameter_dict, results) | |
| for i in range(modules.config.default_max_lora_number): | |
| get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results) | |
| return results | |
| def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| assert isinstance(h, str) | |
| results.append(h) | |
| except: | |
| results.append(gr.update()) | |
| def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| h = eval(h) | |
| assert isinstance(h, list) | |
| results.append(h) | |
| except: | |
| results.append(gr.update()) | |
| def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| assert h is not None | |
| h = float(h) | |
| results.append(h) | |
| except: | |
| results.append(gr.update()) | |
| def get_image_number(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| assert h is not None | |
| h = int(h) | |
| h = min(h, modules.config.default_max_image_number) | |
| results.append(h) | |
| except: | |
| results.append(1) | |
| def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| assert h is not None | |
| h = int(h) | |
| # if not in steps or in steps and performance is not the same | |
| if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', | |
| '_').casefold(): | |
| results.append(h) | |
| return | |
| results.append(-1) | |
| except: | |
| results.append(-1) | |
| def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| width, height = eval(h) | |
| formatted = modules.config.add_ratio(f'{width}*{height}') | |
| if formatted in modules.config.available_aspect_ratios: | |
| results.append(formatted) | |
| results.append(-1) | |
| results.append(-1) | |
| else: | |
| results.append(gr.update()) | |
| results.append(int(width)) | |
| results.append(int(height)) | |
| except: | |
| results.append(gr.update()) | |
| results.append(gr.update()) | |
| results.append(gr.update()) | |
| def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| assert h is not None | |
| h = int(h) | |
| results.append(False) | |
| results.append(h) | |
| except: | |
| results.append(gr.update()) | |
| results.append(gr.update()) | |
| def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| p, n, e = eval(h) | |
| results.append(float(p)) | |
| results.append(float(n)) | |
| results.append(float(e)) | |
| except: | |
| results.append(gr.update()) | |
| results.append(gr.update()) | |
| results.append(gr.update()) | |
| def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
| try: | |
| h = source_dict.get(key, source_dict.get(fallback, default)) | |
| b1, b2, s1, s2 = eval(h) | |
| results.append(True) | |
| results.append(float(b1)) | |
| results.append(float(b2)) | |
| results.append(float(s1)) | |
| results.append(float(s2)) | |
| except: | |
| results.append(False) | |
| results.append(gr.update()) | |
| results.append(gr.update()) | |
| results.append(gr.update()) | |
| results.append(gr.update()) | |
| def get_lora(key: str, fallback: str | None, source_dict: dict, results: list): | |
| try: | |
| split_data = source_dict.get(key, source_dict.get(fallback)).split(' : ') | |
| enabled = True | |
| name = split_data[0] | |
| weight = split_data[1] | |
| if len(split_data) == 3: | |
| enabled = split_data[0] == 'True' | |
| name = split_data[1] | |
| weight = split_data[2] | |
| weight = float(weight) | |
| results.append(enabled) | |
| results.append(name) | |
| results.append(weight) | |
| except: | |
| results.append(True) | |
| results.append('None') | |
| results.append(1) | |
| def get_sha256(filepath): | |
| global hash_cache | |
| if filepath not in hash_cache: | |
| # is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors' | |
| hash_cache[filepath] = sha256(filepath) | |
| return hash_cache[filepath] | |
| def parse_meta_from_preset(preset_content): | |
| assert isinstance(preset_content, dict) | |
| preset_prepared = {} | |
| items = preset_content | |
| for settings_key, meta_key in modules.config.possible_preset_keys.items(): | |
| if settings_key == "default_loras": | |
| loras = getattr(modules.config, settings_key) | |
| if settings_key in items: | |
| loras = items[settings_key] | |
| for index, lora in enumerate(loras[:5]): | |
| preset_prepared[f'lora_combined_{index + 1}'] = ' : '.join(map(str, lora)) | |
| elif settings_key == "default_aspect_ratio": | |
| if settings_key in items and items[settings_key] is not None: | |
| default_aspect_ratio = items[settings_key] | |
| width, height = default_aspect_ratio.split('*') | |
| else: | |
| default_aspect_ratio = getattr(modules.config, settings_key) | |
| width, height = default_aspect_ratio.split('×') | |
| height = height[:height.index(" ")] | |
| preset_prepared[meta_key] = (width, height) | |
| else: | |
| preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[ | |
| settings_key] is not None else getattr(modules.config, settings_key) | |
| if settings_key == "default_styles" or settings_key == "default_aspect_ratio": | |
| preset_prepared[meta_key] = str(preset_prepared[meta_key]) | |
| return preset_prepared | |
| class MetadataParser(ABC): | |
| def __init__(self): | |
| self.raw_prompt: str = '' | |
| self.full_prompt: str = '' | |
| self.raw_negative_prompt: str = '' | |
| self.full_negative_prompt: str = '' | |
| self.steps: int = 30 | |
| self.base_model_name: str = '' | |
| self.base_model_hash: str = '' | |
| self.refiner_model_name: str = '' | |
| self.refiner_model_hash: str = '' | |
| self.loras: list = [] | |
| def get_scheme(self) -> MetadataScheme: | |
| raise NotImplementedError | |
| def parse_json(self, metadata: dict | str) -> dict: | |
| raise NotImplementedError | |
| def parse_string(self, metadata: dict) -> str: | |
| raise NotImplementedError | |
| def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, | |
| refiner_model_name, loras): | |
| self.raw_prompt = raw_prompt | |
| self.full_prompt = full_prompt | |
| self.raw_negative_prompt = raw_negative_prompt | |
| self.full_negative_prompt = full_negative_prompt | |
| self.steps = steps | |
| self.base_model_name = Path(base_model_name).stem | |
| base_model_path = get_file_from_folder_list(base_model_name, modules.config.paths_checkpoints) | |
| self.base_model_hash = get_sha256(base_model_path) | |
| if refiner_model_name not in ['', 'None']: | |
| self.refiner_model_name = Path(refiner_model_name).stem | |
| refiner_model_path = get_file_from_folder_list(refiner_model_name, modules.config.paths_checkpoints) | |
| self.refiner_model_hash = get_sha256(refiner_model_path) | |
| self.loras = [] | |
| for (lora_name, lora_weight) in loras: | |
| if lora_name != 'None': | |
| lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras) | |
| lora_hash = get_sha256(lora_path) | |
| self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) | |
| def remove_special_loras(lora_filenames): | |
| for lora_to_remove in modules.config.loras_metadata_remove: | |
| if lora_to_remove in lora_filenames: | |
| lora_filenames.remove(lora_to_remove) | |
| class A1111MetadataParser(MetadataParser): | |
| def get_scheme(self) -> MetadataScheme: | |
| return MetadataScheme.A1111 | |
| fooocus_to_a1111 = { | |
| 'raw_prompt': 'Raw prompt', | |
| 'raw_negative_prompt': 'Raw negative prompt', | |
| 'negative_prompt': 'Negative prompt', | |
| 'styles': 'Styles', | |
| 'performance': 'Performance', | |
| 'steps': 'Steps', | |
| 'sampler': 'Sampler', | |
| 'scheduler': 'Scheduler', | |
| 'guidance_scale': 'CFG scale', | |
| 'seed': 'Seed', | |
| 'resolution': 'Size', | |
| 'sharpness': 'Sharpness', | |
| 'adm_guidance': 'ADM Guidance', | |
| 'refiner_swap_method': 'Refiner Swap Method', | |
| 'adaptive_cfg': 'Adaptive CFG', | |
| 'overwrite_switch': 'Overwrite Switch', | |
| 'freeu': 'FreeU', | |
| 'base_model': 'Model', | |
| 'base_model_hash': 'Model hash', | |
| 'refiner_model': 'Refiner', | |
| 'refiner_model_hash': 'Refiner hash', | |
| 'lora_hashes': 'Lora hashes', | |
| 'lora_weights': 'Lora weights', | |
| 'created_by': 'User', | |
| 'version': 'Version' | |
| } | |
| def parse_json(self, metadata: str) -> dict: | |
| metadata_prompt = '' | |
| metadata_negative_prompt = '' | |
| done_with_prompt = False | |
| *lines, lastline = metadata.strip().split("\n") | |
| if len(re_param.findall(lastline)) < 3: | |
| lines.append(lastline) | |
| lastline = '' | |
| for line in lines: | |
| line = line.strip() | |
| if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"): | |
| done_with_prompt = True | |
| line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip() | |
| if done_with_prompt: | |
| metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line | |
| else: | |
| metadata_prompt += ('' if metadata_prompt == '' else "\n") + line | |
| found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt) | |
| data = { | |
| 'prompt': prompt, | |
| 'negative_prompt': negative_prompt | |
| } | |
| for k, v in re_param.findall(lastline): | |
| try: | |
| if v != '' and v[0] == '"' and v[-1] == '"': | |
| v = unquote(v) | |
| m = re_imagesize.match(v) | |
| if m is not None: | |
| data['resolution'] = str((m.group(1), m.group(2))) | |
| else: | |
| data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v | |
| except Exception: | |
| print(f"Error parsing \"{k}: {v}\"") | |
| # workaround for multiline prompts | |
| if 'raw_prompt' in data: | |
| data['prompt'] = data['raw_prompt'] | |
| raw_prompt = data['raw_prompt'].replace("\n", ', ') | |
| if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles: | |
| found_styles.append(modules.sdxl_styles.fooocus_expansion) | |
| if 'raw_negative_prompt' in data: | |
| data['negative_prompt'] = data['raw_negative_prompt'] | |
| data['styles'] = str(found_styles) | |
| # try to load performance based on steps, fallback for direct A1111 imports | |
| if 'steps' in data and 'performance' not in data: | |
| try: | |
| data['performance'] = Performance[Steps(int(data['steps'])).name].value | |
| except ValueError | KeyError: | |
| pass | |
| if 'sampler' in data: | |
| data['sampler'] = data['sampler'].replace(' Karras', '') | |
| # get key | |
| for k, v in SAMPLERS.items(): | |
| if v == data['sampler']: | |
| data['sampler'] = k | |
| break | |
| for key in ['base_model', 'refiner_model']: | |
| if key in data: | |
| for filename in modules.config.model_filenames: | |
| path = Path(filename) | |
| if data[key] == path.stem: | |
| data[key] = filename | |
| break | |
| lora_data = '' | |
| if 'lora_weights' in data and data['lora_weights'] != '': | |
| lora_data = data['lora_weights'] | |
| elif 'lora_hashes' in data and data['lora_hashes'] != '' and data['lora_hashes'].split(', ')[0].count(':') == 2: | |
| lora_data = data['lora_hashes'] | |
| if lora_data != '': | |
| lora_filenames = modules.config.lora_filenames.copy() | |
| self.remove_special_loras(lora_filenames) | |
| for li, lora in enumerate(lora_data.split(', ')): | |
| lora_split = lora.split(': ') | |
| lora_name = lora_split[0] | |
| lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1] | |
| for filename in lora_filenames: | |
| path = Path(filename) | |
| if lora_name == path.stem: | |
| data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}' | |
| break | |
| return data | |
| def parse_string(self, metadata: dict) -> str: | |
| data = {k: v for _, k, v in metadata} | |
| width, height = eval(data['resolution']) | |
| sampler = data['sampler'] | |
| scheduler = data['scheduler'] | |
| if sampler in SAMPLERS and SAMPLERS[sampler] != '': | |
| sampler = SAMPLERS[sampler] | |
| if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras': | |
| sampler += f' Karras' | |
| generation_params = { | |
| self.fooocus_to_a1111['steps']: self.steps, | |
| self.fooocus_to_a1111['sampler']: sampler, | |
| self.fooocus_to_a1111['seed']: data['seed'], | |
| self.fooocus_to_a1111['resolution']: f'{width}x{height}', | |
| self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'], | |
| self.fooocus_to_a1111['sharpness']: data['sharpness'], | |
| self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'], | |
| self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem, | |
| self.fooocus_to_a1111['base_model_hash']: self.base_model_hash, | |
| self.fooocus_to_a1111['performance']: data['performance'], | |
| self.fooocus_to_a1111['scheduler']: scheduler, | |
| # workaround for multiline prompts | |
| self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, | |
| self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt, | |
| } | |
| if self.refiner_model_name not in ['', 'None']: | |
| generation_params |= { | |
| self.fooocus_to_a1111['refiner_model']: self.refiner_model_name, | |
| self.fooocus_to_a1111['refiner_model_hash']: self.refiner_model_hash | |
| } | |
| for key in ['adaptive_cfg', 'overwrite_switch', 'refiner_swap_method', 'freeu']: | |
| if key in data: | |
| generation_params[self.fooocus_to_a1111[key]] = data[key] | |
| if len(self.loras) > 0: | |
| lora_hashes = [] | |
| lora_weights = [] | |
| for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras): | |
| # workaround for Fooocus not knowing LoRA name in LoRA metadata | |
| lora_hashes.append(f'{lora_name}: {lora_hash}') | |
| lora_weights.append(f'{lora_name}: {lora_weight}') | |
| lora_hashes_string = ', '.join(lora_hashes) | |
| lora_weights_string = ', '.join(lora_weights) | |
| generation_params[self.fooocus_to_a1111['lora_hashes']] = lora_hashes_string | |
| generation_params[self.fooocus_to_a1111['lora_weights']] = lora_weights_string | |
| generation_params[self.fooocus_to_a1111['version']] = data['version'] | |
| if modules.config.metadata_created_by != '': | |
| generation_params[self.fooocus_to_a1111['created_by']] = modules.config.metadata_created_by | |
| generation_params_text = ", ".join( | |
| [k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if | |
| v is not None]) | |
| positive_prompt_resolved = ', '.join(self.full_prompt) | |
| negative_prompt_resolved = ', '.join(self.full_negative_prompt) | |
| negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" | |
| return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() | |
| class FooocusMetadataParser(MetadataParser): | |
| def get_scheme(self) -> MetadataScheme: | |
| return MetadataScheme.FOOOCUS | |
| def parse_json(self, metadata: dict) -> dict: | |
| model_filenames = modules.config.model_filenames.copy() | |
| lora_filenames = modules.config.lora_filenames.copy() | |
| self.remove_special_loras(lora_filenames) | |
| for key, value in metadata.items(): | |
| if value in ['', 'None']: | |
| continue | |
| if key in ['base_model', 'refiner_model']: | |
| metadata[key] = self.replace_value_with_filename(key, value, model_filenames) | |
| elif key.startswith('lora_combined_'): | |
| metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) | |
| else: | |
| continue | |
| return metadata | |
| def parse_string(self, metadata: list) -> str: | |
| for li, (label, key, value) in enumerate(metadata): | |
| # remove model folder paths from metadata | |
| if key.startswith('lora_combined_'): | |
| name, weight = value.split(' : ') | |
| name = Path(name).stem | |
| value = f'{name} : {weight}' | |
| metadata[li] = (label, key, value) | |
| res = {k: v for _, k, v in metadata} | |
| res['full_prompt'] = self.full_prompt | |
| res['full_negative_prompt'] = self.full_negative_prompt | |
| res['steps'] = self.steps | |
| res['base_model'] = self.base_model_name | |
| res['base_model_hash'] = self.base_model_hash | |
| if self.refiner_model_name not in ['', 'None']: | |
| res['refiner_model'] = self.refiner_model_name | |
| res['refiner_model_hash'] = self.refiner_model_hash | |
| res['loras'] = self.loras | |
| if modules.config.metadata_created_by != '': | |
| res['created_by'] = modules.config.metadata_created_by | |
| return json.dumps(dict(sorted(res.items()))) | |
| def replace_value_with_filename(key, value, filenames): | |
| for filename in filenames: | |
| path = Path(filename) | |
| if key.startswith('lora_combined_'): | |
| name, weight = value.split(' : ') | |
| if name == path.stem: | |
| return f'{filename} : {weight}' | |
| elif value == path.stem: | |
| return filename | |
| def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser: | |
| match metadata_scheme: | |
| case MetadataScheme.FOOOCUS: | |
| return FooocusMetadataParser() | |
| case MetadataScheme.A1111: | |
| return A1111MetadataParser() | |
| case _: | |
| raise NotImplementedError | |
| def read_info_from_image(filepath) -> tuple[str | None, MetadataScheme | None]: | |
| with Image.open(filepath) as image: | |
| items = (image.info or {}).copy() | |
| parameters = items.pop('parameters', None) | |
| metadata_scheme = items.pop('fooocus_scheme', None) | |
| exif = items.pop('exif', None) | |
| if parameters is not None and is_json(parameters): | |
| parameters = json.loads(parameters) | |
| elif exif is not None: | |
| exif = image.getexif() | |
| # 0x9286 = UserComment | |
| parameters = exif.get(0x9286, None) | |
| # 0x927C = MakerNote | |
| metadata_scheme = exif.get(0x927C, None) | |
| if is_json(parameters): | |
| parameters = json.loads(parameters) | |
| try: | |
| metadata_scheme = MetadataScheme(metadata_scheme) | |
| except ValueError: | |
| metadata_scheme = None | |
| # broad fallback | |
| if isinstance(parameters, dict): | |
| metadata_scheme = MetadataScheme.FOOOCUS | |
| if isinstance(parameters, str): | |
| metadata_scheme = MetadataScheme.A1111 | |
| return parameters, metadata_scheme | |
| def get_exif(metadata: str | None, metadata_scheme: str): | |
| exif = Image.Exif() | |
| # tags see see https://github.com/python-pillow/Pillow/blob/9.2.x/src/PIL/ExifTags.py | |
| # 0x9286 = UserComment | |
| exif[0x9286] = metadata | |
| # 0x0131 = Software | |
| exif[0x0131] = 'Fooocus v' + fooocus_version.version | |
| # 0x927C = MakerNote | |
| exif[0x927C] = metadata_scheme | |
| return exif | |