from __future__ import annotations import inspect import logging import os import secrets from base64 import b64decode, b64encode from io import BytesIO from itertools import cycle from math import ceil import modules import yaml from modules import shared from PIL import Image from pydantic import BaseModel from .config import CONFIG_PATH, ENCRYPT_FILE, LOGGER_NAME, MainConfig log = logging.getLogger(LOGGER_NAME) def load_config(): """Load default config (including those not exposed in the API yet) from `CONFIG_PATH` in the current working directory. Will create `CONFIG_PATH` if it has yet to exist using `MainConfig` from `config.py`. Returns: MainConfig: config """ if not os.path.isfile(CONFIG_PATH): cfg = MainConfig() with open(CONFIG_PATH, "w") as f: yaml.safe_dump(cfg.dict(), f) with open(CONFIG_PATH) as file: obj = yaml.safe_load(file) return MainConfig.parse_obj(obj) def merge_default_config(config: BaseModel, default: BaseModel): """Replace unset and None fields in opt with values from default with the same field name in place. Unset fields does not include fields that are explicitly set to None but includes fields with a default value due to being unset. Args: config (BaseModel): Config object. default (BaseModel): Default to merge from. Returns: BaseModel: Modified config. """ for field in config.__fields__: if not field in config.__fields_set__ or field is None: setattr(config, field, getattr(default, field, None)) return config def prepare_backend(opt: BaseModel): """Misc configuration and preparation tasks before calling internal API. Currently includes: - Ensuring the output/input folders exist - Set the global face restorer model to the selected one - Set the global SD model to the selected one - Set the global upscaler to the selected one - Set other misc global webUI/backend settings Args: opt (BaseModel): Option/Request object """ # the `shared` module handles app state for the underlying codebase if hasattr(opt, "face_restorer"): shared.opts.face_restoration_model = opt.face_restorer shared.opts.code_former_weight = opt.codeformer_weight if hasattr(opt, "sd_model"): shared.opts.sd_model_checkpoint = opt.sd_model modules.sd_models.reload_model_weights(shared.sd_model) if hasattr(opt, "upscaler_name"): shared.opts.upscaler_for_img2img = opt.upscaler_name if hasattr(opt, "color_correct"): shared.opts.img2img_color_correction = opt.color_correct shared.opts.img2img_fix_steps = opt.do_exact_steps if hasattr(opt, "filter_nsfw"): shared.opts.filter_nsfw = opt.filter_nsfw if hasattr(opt, "inpaint_mask_weight"): shared.opts.inpainting_mask_weight = opt.inpaint_mask_weight # Ensure the output/input folders exist if hasattr(opt, "sample_path"): os.makedirs(opt.sample_path, exist_ok=True) def optional(*fields): """Decorator function used to modify a pydantic model's fields to all be optional. Alternatively, you can also pass the field names that should be made optional as arguments to the decorator. Taken from https://github.com/samuelcolvin/pydantic/issues/1223#issuecomment-775363074 """ def dec(_cls): for field in fields: _cls.__fields__[field].required = False return _cls if fields and inspect.isclass(fields[0]) and issubclass(fields[0], BaseModel): cls = fields[0] fields = cls.__fields__ return dec(cls) return dec def save_img(image: Image.Image, sample_path: str, filename: str): """Saves an image. Args: image (Image): Image to save. sample_path (str): Folder to save the image in. filename (str): Name to save the image as. Returns: str: Absolute path where the image was saved. """ path = os.path.join(sample_path, filename) image.save(path) return os.path.abspath(path) def img_to_b64(image: Image.Image): """Convert an image to base64-encoded string. Args: image (Image): Image to encode. Returns: str: Base64-encoded image. """ buf = BytesIO() image.save(buf, format="png") return b64encode(buf.getvalue()).decode("utf-8") def b64_to_img(enc: str): """Convert base64-encoded string to image. Args: enc (str): Base64-encoded image. Returns: Image: Image. """ return Image.open(BytesIO(b64decode(enc))) def sddebz_highres_fix( base_size: int, max_size: int, orig_width: int, orig_height: int ): """Calculate an appropiate image resolution given the base input size of the model and max input size allowed. The max input size is due to how Stable Diffusion currently handles resolutions larger than its base/native input size of 512, which can cause weird issues such as duplicated features in the image. Hence, it is typically better to render at a smaller appropiate resolution before using other methods to upscale to the original resolution. Setting max_size to 512, matching the base_size, imitates how the highres fix works. Stable Diffusion also messes up for resolutions smaller than 512. In which case, it is better to render at the base resolution before downscaling to the original. This method requires less user input than the builtin highres fix, which uses firstphase_width and firstphase_height. The original plugin writer, @sddebz, wrote this. I modified it to `ceil` instead of `round` to make selected region resizing easier in the plugin, and to avoid rounding to 0. Args: base_size (int): Native/base input size of the model. max_size (int): Max input size to accept. orig_width (int): Original width requested. orig_height (int): Original height requested. Returns: Tuple[int, int]: Appropiate (width, height) to use for the model. """ def rnd(r, x, z=64): """Scale dimension x with stride z while attempting to preserve aspect ratio r.""" return z * ceil(r * x / z) ratio = orig_width / orig_height # height is smaller dimension if orig_width > orig_height: width, height = rnd(ratio, base_size), base_size if width > max_size: width, height = max_size, rnd(1 / ratio, max_size) # width is smaller dimension else: width, height = base_size, rnd(1 / ratio, base_size) if height > max_size: width, height = rnd(ratio, max_size), max_size new_ratio = width / height log.info( f"img size: {orig_width}x{orig_height} -> {width}x{height}, " f"aspect ratio: {ratio:.2f} -> {new_ratio:.2f}, {100 * (new_ratio - ratio) / ratio :.2f}% change" ) return width, height def parse_prompt(val): """Parse different representations of prompt/negative prompt. Args: val (Any): Prompt to parse. Raises: SyntaxError: Value of the prompt key cannot be parsed. Returns: str: Correctly formatted prompt. """ if val is None: return "" # Below cases are meant for prompts read from the yaml config if isinstance(val, str): return val if isinstance(val, list): return ", ".join(val) if isinstance(val, dict): prompt = "" for item, weight in val.items(): if not prompt == "": prompt += " " if weight is None: prompt += f"{item}" else: prompt += f"({item}:{weight})" return prompt raise SyntaxError(f"prompt field in {CONFIG_PATH} is invalid") def get_sampler_index(sampler_name: str): """Get index of sampler by name. Args: sampler_name (str): Exact name of sampler. Raises: KeyError: Sampler cannot be found. Returns: int: Index of sampler. """ for index, sampler in enumerate(modules.sd_samplers.samplers): if sampler_name == sampler.name or sampler_name in sampler.aliases: return index raise KeyError(f"sampler not found: {sampler_name}") def get_upscaler_index(upscaler_name: str): """Get index of upscaler by name. Args: upscaler_name (str): Exact name of upscaler. Raises: KeyError: Upscaler cannot be found. Returns: int: Index of sampler. """ for index, upscaler in enumerate(shared.sd_upscalers): if upscaler.name == upscaler_name: return index raise KeyError(f"upscaler not found: {upscaler_name}") def prepare_mask(mask: Image.Image): """Prepare mask for usage. Args: mask (Image): mask. Returns: Image: The luminance mask. """ return mask.getchannel("A") def bytewise_xor(msg: bytes, key: bytes): """Used for decrypting/encrypting request/response bodies.""" return bytes(v ^ k for v, k in zip(msg, cycle(key))) def get_encrypt_key(): """Read encryption key from file.""" try: with open(ENCRYPT_FILE) as f: return f.read().strip().encode("utf-8") except: if not os.path.exists(ENCRYPT_FILE): log.warning( f"Encryption key file doesn't exist at {os.path.abspath(ENCRYPT_FILE)}." ) log.warning(f"Creating random encryption key.") with open(ENCRYPT_FILE, "w") as f: f.write(secrets.token_hex(16)) log.warning( f"Key in {ENCRYPT_FILE} is completely optional. It can be used to encrypt messages between backend & Krita and is editable." ) return get_encrypt_key() return None