Spaces:
Runtime error
Runtime error
import json | |
import os | |
import re | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
import gradio as gr | |
from PIL import Image | |
import fooocus_version | |
import modules.config | |
import modules.sdxl_styles | |
from modules.flags import MetadataScheme, Performance, Steps | |
from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS | |
from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, calculate_sha256 | |
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' | |
re_param = re.compile(re_param_code) | |
re_imagesize = re.compile(r"^(\d+)x(\d+)$") | |
hash_cache = {} | |
def load_parameter_button_click(raw_metadata: dict | str, is_generating: bool): | |
loaded_parameter_dict = raw_metadata | |
if isinstance(raw_metadata, str): | |
loaded_parameter_dict = json.loads(raw_metadata) | |
assert isinstance(loaded_parameter_dict, dict) | |
results = [len(loaded_parameter_dict) > 0, 1] | |
get_str('prompt', 'Prompt', loaded_parameter_dict, results) | |
get_str('negative_prompt', 'Negative Prompt', loaded_parameter_dict, results) | |
get_list('styles', 'Styles', loaded_parameter_dict, results) | |
get_str('performance', 'Performance', loaded_parameter_dict, results) | |
get_steps('steps', 'Steps', loaded_parameter_dict, results) | |
get_float('overwrite_switch', 'Overwrite Switch', loaded_parameter_dict, results) | |
get_resolution('resolution', 'Resolution', loaded_parameter_dict, results) | |
get_float('guidance_scale', 'Guidance Scale', loaded_parameter_dict, results) | |
get_float('sharpness', 'Sharpness', loaded_parameter_dict, results) | |
get_adm_guidance('adm_guidance', 'ADM Guidance', loaded_parameter_dict, results) | |
get_str('refiner_swap_method', 'Refiner Swap Method', loaded_parameter_dict, results) | |
get_float('adaptive_cfg', 'CFG Mimicking from TSNR', loaded_parameter_dict, results) | |
get_str('base_model', 'Base Model', loaded_parameter_dict, results) | |
get_str('refiner_model', 'Refiner Model', loaded_parameter_dict, results) | |
get_float('refiner_switch', 'Refiner Switch', loaded_parameter_dict, results) | |
get_str('sampler', 'Sampler', loaded_parameter_dict, results) | |
get_str('scheduler', 'Scheduler', loaded_parameter_dict, results) | |
get_seed('seed', 'Seed', loaded_parameter_dict, results) | |
if is_generating: | |
results.append(gr.update()) | |
else: | |
results.append(gr.update(visible=True)) | |
results.append(gr.update(visible=False)) | |
get_freeu('freeu', 'FreeU', loaded_parameter_dict, results) | |
for i in range(modules.config.default_max_lora_number): | |
get_lora(f'lora_combined_{i + 1}', f'LoRA {i + 1}', loaded_parameter_dict, results) | |
return results | |
def get_str(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
try: | |
h = source_dict.get(key, source_dict.get(fallback, default)) | |
assert isinstance(h, str) | |
results.append(h) | |
except: | |
results.append(gr.update()) | |
def get_list(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
try: | |
h = source_dict.get(key, source_dict.get(fallback, default)) | |
h = eval(h) | |
assert isinstance(h, list) | |
results.append(h) | |
except: | |
results.append(gr.update()) | |
def get_float(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
try: | |
h = source_dict.get(key, source_dict.get(fallback, default)) | |
assert h is not None | |
h = float(h) | |
results.append(h) | |
except: | |
results.append(gr.update()) | |
def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
try: | |
h = source_dict.get(key, source_dict.get(fallback, default)) | |
assert h is not None | |
h = int(h) | |
# if not in steps or in steps and performance is not the same | |
if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', '_').casefold(): | |
results.append(h) | |
return | |
results.append(-1) | |
except: | |
results.append(-1) | |
def get_resolution(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
try: | |
h = source_dict.get(key, source_dict.get(fallback, default)) | |
width, height = eval(h) | |
formatted = modules.config.add_ratio(f'{width}*{height}') | |
if formatted in modules.config.available_aspect_ratios: | |
results.append(formatted) | |
results.append(-1) | |
results.append(-1) | |
else: | |
results.append(gr.update()) | |
results.append(int(width)) | |
results.append(int(height)) | |
except: | |
results.append(gr.update()) | |
results.append(gr.update()) | |
results.append(gr.update()) | |
def get_seed(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
try: | |
h = source_dict.get(key, source_dict.get(fallback, default)) | |
assert h is not None | |
h = int(h) | |
results.append(False) | |
results.append(h) | |
except: | |
results.append(gr.update()) | |
results.append(gr.update()) | |
def get_adm_guidance(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
try: | |
h = source_dict.get(key, source_dict.get(fallback, default)) | |
p, n, e = eval(h) | |
results.append(float(p)) | |
results.append(float(n)) | |
results.append(float(e)) | |
except: | |
results.append(gr.update()) | |
results.append(gr.update()) | |
results.append(gr.update()) | |
def get_freeu(key: str, fallback: str | None, source_dict: dict, results: list, default=None): | |
try: | |
h = source_dict.get(key, source_dict.get(fallback, default)) | |
b1, b2, s1, s2 = eval(h) | |
results.append(True) | |
results.append(float(b1)) | |
results.append(float(b2)) | |
results.append(float(s1)) | |
results.append(float(s2)) | |
except: | |
results.append(False) | |
results.append(gr.update()) | |
results.append(gr.update()) | |
results.append(gr.update()) | |
results.append(gr.update()) | |
def get_lora(key: str, fallback: str | None, source_dict: dict, results: list): | |
try: | |
n, w = source_dict.get(key, source_dict.get(fallback)).split(' : ') | |
w = float(w) | |
results.append(True) | |
results.append(n) | |
results.append(w) | |
except: | |
results.append(True) | |
results.append('None') | |
results.append(1) | |
def get_sha256(filepath): | |
global hash_cache | |
if filepath not in hash_cache: | |
hash_cache[filepath] = calculate_sha256(filepath) | |
return hash_cache[filepath] | |
def parse_meta_from_preset(preset_content): | |
assert isinstance(preset_content, dict) | |
preset_prepared = {} | |
items = preset_content | |
for settings_key, meta_key in modules.config.possible_preset_keys.items(): | |
if settings_key == "default_loras": | |
loras = getattr(modules.config, settings_key) | |
if settings_key in items: | |
loras = items[settings_key] | |
for index, lora in enumerate(loras[:5]): | |
preset_prepared[f'lora_combined_{index + 1}'] = ' : '.join(map(str, lora)) | |
elif settings_key == "default_aspect_ratio": | |
if settings_key in items and items[settings_key] is not None: | |
default_aspect_ratio = items[settings_key] | |
width, height = default_aspect_ratio.split('*') | |
else: | |
default_aspect_ratio = getattr(modules.config, settings_key) | |
width, height = default_aspect_ratio.split('×') | |
height = height[:height.index(" ")] | |
preset_prepared[meta_key] = (width, height) | |
else: | |
preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[ | |
settings_key] is not None else getattr(modules.config, settings_key) | |
if settings_key == "default_styles" or settings_key == "default_aspect_ratio": | |
preset_prepared[meta_key] = str(preset_prepared[meta_key]) | |
return preset_prepared | |
class MetadataParser(ABC): | |
def __init__(self): | |
self.raw_prompt: str = '' | |
self.full_prompt: str = '' | |
self.raw_negative_prompt: str = '' | |
self.full_negative_prompt: str = '' | |
self.steps: int = 30 | |
self.base_model_name: str = '' | |
self.base_model_hash: str = '' | |
self.refiner_model_name: str = '' | |
self.refiner_model_hash: str = '' | |
self.loras: list = [] | |
def get_scheme(self) -> MetadataScheme: | |
raise NotImplementedError | |
def parse_json(self, metadata: dict | str) -> dict: | |
raise NotImplementedError | |
def parse_string(self, metadata: dict) -> str: | |
raise NotImplementedError | |
def set_data(self, raw_prompt, full_prompt, raw_negative_prompt, full_negative_prompt, steps, base_model_name, | |
refiner_model_name, loras): | |
self.raw_prompt = raw_prompt | |
self.full_prompt = full_prompt | |
self.raw_negative_prompt = raw_negative_prompt | |
self.full_negative_prompt = full_negative_prompt | |
self.steps = steps | |
self.base_model_name = Path(base_model_name).stem | |
base_model_path = get_file_from_folder_list(base_model_name, modules.config.paths_checkpoints) | |
self.base_model_hash = get_sha256(base_model_path) | |
if refiner_model_name not in ['', 'None']: | |
self.refiner_model_name = Path(refiner_model_name).stem | |
refiner_model_path = get_file_from_folder_list(refiner_model_name, modules.config.paths_checkpoints) | |
self.refiner_model_hash = get_sha256(refiner_model_path) | |
self.loras = [] | |
for (lora_name, lora_weight) in loras: | |
if lora_name != 'None': | |
lora_path = get_file_from_folder_list(lora_name, modules.config.paths_loras) | |
lora_hash = get_sha256(lora_path) | |
self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) | |
class A1111MetadataParser(MetadataParser): | |
def get_scheme(self) -> MetadataScheme: | |
return MetadataScheme.A1111 | |
fooocus_to_a1111 = { | |
'raw_prompt': 'Raw prompt', | |
'raw_negative_prompt': 'Raw negative prompt', | |
'negative_prompt': 'Negative prompt', | |
'styles': 'Styles', | |
'performance': 'Performance', | |
'steps': 'Steps', | |
'sampler': 'Sampler', | |
'scheduler': 'Scheduler', | |
'guidance_scale': 'CFG scale', | |
'seed': 'Seed', | |
'resolution': 'Size', | |
'sharpness': 'Sharpness', | |
'adm_guidance': 'ADM Guidance', | |
'refiner_swap_method': 'Refiner Swap Method', | |
'adaptive_cfg': 'Adaptive CFG', | |
'overwrite_switch': 'Overwrite Switch', | |
'freeu': 'FreeU', | |
'base_model': 'Model', | |
'base_model_hash': 'Model hash', | |
'refiner_model': 'Refiner', | |
'refiner_model_hash': 'Refiner hash', | |
'lora_hashes': 'Lora hashes', | |
'lora_weights': 'Lora weights', | |
'created_by': 'User', | |
'version': 'Version' | |
} | |
def parse_json(self, metadata: str) -> dict: | |
metadata_prompt = '' | |
metadata_negative_prompt = '' | |
done_with_prompt = False | |
*lines, lastline = metadata.strip().split("\n") | |
if len(re_param.findall(lastline)) < 3: | |
lines.append(lastline) | |
lastline = '' | |
for line in lines: | |
line = line.strip() | |
if line.startswith(f"{self.fooocus_to_a1111['negative_prompt']}:"): | |
done_with_prompt = True | |
line = line[len(f"{self.fooocus_to_a1111['negative_prompt']}:"):].strip() | |
if done_with_prompt: | |
metadata_negative_prompt += ('' if metadata_negative_prompt == '' else "\n") + line | |
else: | |
metadata_prompt += ('' if metadata_prompt == '' else "\n") + line | |
found_styles, prompt, negative_prompt = extract_styles_from_prompt(metadata_prompt, metadata_negative_prompt) | |
data = { | |
'prompt': prompt, | |
'negative_prompt': negative_prompt | |
} | |
for k, v in re_param.findall(lastline): | |
try: | |
if v != '' and v[0] == '"' and v[-1] == '"': | |
v = unquote(v) | |
m = re_imagesize.match(v) | |
if m is not None: | |
data['resolution'] = str((m.group(1), m.group(2))) | |
else: | |
data[list(self.fooocus_to_a1111.keys())[list(self.fooocus_to_a1111.values()).index(k)]] = v | |
except Exception: | |
print(f"Error parsing \"{k}: {v}\"") | |
# workaround for multiline prompts | |
if 'raw_prompt' in data: | |
data['prompt'] = data['raw_prompt'] | |
raw_prompt = data['raw_prompt'].replace("\n", ', ') | |
if metadata_prompt != raw_prompt and modules.sdxl_styles.fooocus_expansion not in found_styles: | |
found_styles.append(modules.sdxl_styles.fooocus_expansion) | |
if 'raw_negative_prompt' in data: | |
data['negative_prompt'] = data['raw_negative_prompt'] | |
data['styles'] = str(found_styles) | |
# try to load performance based on steps, fallback for direct A1111 imports | |
if 'steps' in data and 'performance' not in data: | |
try: | |
data['performance'] = Performance[Steps(int(data['steps'])).name].value | |
except ValueError | KeyError: | |
pass | |
if 'sampler' in data: | |
data['sampler'] = data['sampler'].replace(' Karras', '') | |
# get key | |
for k, v in SAMPLERS.items(): | |
if v == data['sampler']: | |
data['sampler'] = k | |
break | |
for key in ['base_model', 'refiner_model']: | |
if key in data: | |
for filename in modules.config.model_filenames: | |
path = Path(filename) | |
if data[key] == path.stem: | |
data[key] = filename | |
break | |
if 'lora_hashes' in data: | |
lora_filenames = modules.config.lora_filenames.copy() | |
if modules.config.sdxl_lcm_lora in lora_filenames: | |
lora_filenames.remove(modules.config.sdxl_lcm_lora) | |
for li, lora in enumerate(data['lora_hashes'].split(', ')): | |
lora_name, lora_hash, lora_weight = lora.split(': ') | |
for filename in lora_filenames: | |
path = Path(filename) | |
if lora_name == path.stem: | |
data[f'lora_combined_{li + 1}'] = f'{filename} : {lora_weight}' | |
break | |
return data | |
def parse_string(self, metadata: dict) -> str: | |
data = {k: v for _, k, v in metadata} | |
width, height = eval(data['resolution']) | |
sampler = data['sampler'] | |
scheduler = data['scheduler'] | |
if sampler in SAMPLERS and SAMPLERS[sampler] != '': | |
sampler = SAMPLERS[sampler] | |
if sampler not in CIVITAI_NO_KARRAS and scheduler == 'karras': | |
sampler += f' Karras' | |
generation_params = { | |
self.fooocus_to_a1111['steps']: self.steps, | |
self.fooocus_to_a1111['sampler']: sampler, | |
self.fooocus_to_a1111['seed']: data['seed'], | |
self.fooocus_to_a1111['resolution']: f'{width}x{height}', | |
self.fooocus_to_a1111['guidance_scale']: data['guidance_scale'], | |
self.fooocus_to_a1111['sharpness']: data['sharpness'], | |
self.fooocus_to_a1111['adm_guidance']: data['adm_guidance'], | |
self.fooocus_to_a1111['base_model']: Path(data['base_model']).stem, | |
self.fooocus_to_a1111['base_model_hash']: self.base_model_hash, | |
self.fooocus_to_a1111['performance']: data['performance'], | |
self.fooocus_to_a1111['scheduler']: scheduler, | |
# workaround for multiline prompts | |
self.fooocus_to_a1111['raw_prompt']: self.raw_prompt, | |
self.fooocus_to_a1111['raw_negative_prompt']: self.raw_negative_prompt, | |
} | |
if self.refiner_model_name not in ['', 'None']: | |
generation_params |= { | |
self.fooocus_to_a1111['refiner_model']: self.refiner_model_name, | |
self.fooocus_to_a1111['refiner_model_hash']: self.refiner_model_hash | |
} | |
for key in ['adaptive_cfg', 'overwrite_switch', 'refiner_swap_method', 'freeu']: | |
if key in data: | |
generation_params[self.fooocus_to_a1111[key]] = data[key] | |
lora_hashes = [] | |
for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras): | |
# workaround for Fooocus not knowing LoRA name in LoRA metadata | |
lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}') | |
lora_hashes_string = ', '.join(lora_hashes) | |
generation_params |= { | |
self.fooocus_to_a1111['lora_hashes']: lora_hashes_string, | |
self.fooocus_to_a1111['version']: data['version'] | |
} | |
if modules.config.metadata_created_by != '': | |
generation_params[self.fooocus_to_a1111['created_by']] = modules.config.metadata_created_by | |
generation_params_text = ", ".join( | |
[k if k == v else f'{k}: {quote(v)}' for k, v in generation_params.items() if | |
v is not None]) | |
positive_prompt_resolved = ', '.join(self.full_prompt) | |
negative_prompt_resolved = ', '.join(self.full_negative_prompt) | |
negative_prompt_text = f"\nNegative prompt: {negative_prompt_resolved}" if negative_prompt_resolved else "" | |
return f"{positive_prompt_resolved}{negative_prompt_text}\n{generation_params_text}".strip() | |
class FooocusMetadataParser(MetadataParser): | |
def get_scheme(self) -> MetadataScheme: | |
return MetadataScheme.FOOOCUS | |
def parse_json(self, metadata: dict) -> dict: | |
model_filenames = modules.config.model_filenames.copy() | |
lora_filenames = modules.config.lora_filenames.copy() | |
if modules.config.sdxl_lcm_lora in lora_filenames: | |
lora_filenames.remove(modules.config.sdxl_lcm_lora) | |
for key, value in metadata.items(): | |
if value in ['', 'None']: | |
continue | |
if key in ['base_model', 'refiner_model']: | |
metadata[key] = self.replace_value_with_filename(key, value, model_filenames) | |
elif key.startswith('lora_combined_'): | |
metadata[key] = self.replace_value_with_filename(key, value, lora_filenames) | |
else: | |
continue | |
return metadata | |
def parse_string(self, metadata: list) -> str: | |
for li, (label, key, value) in enumerate(metadata): | |
# remove model folder paths from metadata | |
if key.startswith('lora_combined_'): | |
name, weight = value.split(' : ') | |
name = Path(name).stem | |
value = f'{name} : {weight}' | |
metadata[li] = (label, key, value) | |
res = {k: v for _, k, v in metadata} | |
res['full_prompt'] = self.full_prompt | |
res['full_negative_prompt'] = self.full_negative_prompt | |
res['steps'] = self.steps | |
res['base_model'] = self.base_model_name | |
res['base_model_hash'] = self.base_model_hash | |
if self.refiner_model_name not in ['', 'None']: | |
res['refiner_model'] = self.refiner_model_name | |
res['refiner_model_hash'] = self.refiner_model_hash | |
res['loras'] = self.loras | |
if modules.config.metadata_created_by != '': | |
res['created_by'] = modules.config.metadata_created_by | |
return json.dumps(dict(sorted(res.items()))) | |
def replace_value_with_filename(key, value, filenames): | |
for filename in filenames: | |
path = Path(filename) | |
if key.startswith('lora_combined_'): | |
name, weight = value.split(' : ') | |
if name == path.stem: | |
return f'{filename} : {weight}' | |
elif value == path.stem: | |
return filename | |
def get_metadata_parser(metadata_scheme: MetadataScheme) -> MetadataParser: | |
match metadata_scheme: | |
case MetadataScheme.FOOOCUS: | |
return FooocusMetadataParser() | |
case MetadataScheme.A1111: | |
return A1111MetadataParser() | |
case _: | |
raise NotImplementedError | |
def read_info_from_image(filepath) -> tuple[str | None, MetadataScheme | None]: | |
with Image.open(filepath) as image: | |
items = (image.info or {}).copy() | |
parameters = items.pop('parameters', None) | |
metadata_scheme = items.pop('fooocus_scheme', None) | |
exif = items.pop('exif', None) | |
if parameters is not None and is_json(parameters): | |
parameters = json.loads(parameters) | |
elif exif is not None: | |
exif = image.getexif() | |
# 0x9286 = UserComment | |
parameters = exif.get(0x9286, None) | |
# 0x927C = MakerNote | |
metadata_scheme = exif.get(0x927C, None) | |
if is_json(parameters): | |
parameters = json.loads(parameters) | |
try: | |
metadata_scheme = MetadataScheme(metadata_scheme) | |
except ValueError: | |
metadata_scheme = None | |
# broad fallback | |
if isinstance(parameters, dict): | |
metadata_scheme = MetadataScheme.FOOOCUS | |
if isinstance(parameters, str): | |
metadata_scheme = MetadataScheme.A1111 | |
return parameters, metadata_scheme | |
def get_exif(metadata: str | None, metadata_scheme: str): | |
exif = Image.Exif() | |
# tags see see https://github.com/python-pillow/Pillow/blob/9.2.x/src/PIL/ExifTags.py | |
# 0x9286 = UserComment | |
exif[0x9286] = metadata | |
# 0x0131 = Software | |
exif[0x0131] = 'Fooocus v' + fooocus_version.version | |
# 0x927C = MakerNote | |
exif[0x927C] = metadata_scheme | |
return exif |