from __future__ import annotations import logging import os import time import modules from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse from modules import shared from modules.call_queue import wrap_gradio_gpu_call from PIL import Image, ImageOps from starlette.concurrency import iterate_in_threadpool from .config import LOGGER_NAME, NAME_SCRIPT_LOOPBACK, NAME_SCRIPT_UPSCALE from .script_hack import get_script_info, get_scripts_metadata, process_script_args from .structs import ( ConfigResponse, ImageResponse, Img2ImgRequest, Txt2ImgRequest, UpscaleRequest, UpscaleResponse, ) from .utils import ( b64_to_img, bytewise_xor, get_encrypt_key, get_sampler_index, get_upscaler_index, img_to_b64, load_config, merge_default_config, parse_prompt, prepare_backend, prepare_mask, save_img, sddebz_highres_fix, ) router = APIRouter() log = logging.getLogger(LOGGER_NAME) # NOTE: how to run a script # - get scripts_txt2img/scripts_img2img from modules.scripts # - construct array args, where 0th element is selected script # - refer to script.args_from & script.args_to to figure out which elements in # array args to populate # # The way scripts are handled is they are loaded one by one, append to a list of # scripts, which each script taking up "slots" in the input args array. # So the more scripts, the longer array args would be for the last script. # NOTE: where to draw the line on what is done by the backend vs the frontend? # TODO: Create separate Outpainting route, add img2img structs to Upscale route # - yes I know its highly inconsistent what should be a route or not, but to prevent # incredibly hacky workarounds on the frontend for script calling, it should be # done by the backend, which has better access to the script information. # - Upscale tab UI: # - Upscaler dropdown + 0.5x downscale checkbox + SD upscale checkbox # - SD upscale checkbox hides 0.5x downscale checkbox, renames upscaler dropdown # - to prescaler, and shows modified img2img UI (ofc uses its own cfg namespace) # - Outpaint tab UI: # - modified img2img UI with own cfg namespace # - try and hijack more control (Pixel to expand per direction instead of all directions) # - self-sketch mode: basically sketch + inpaint but the inpaint mask is auto-calculated # - option to select poor man, mk 2 or self-sketch # TODO: Consider using pipeline directly instead of Gradio API for less surprises & better control @router.get("/config", response_model=ConfigResponse) async def get_state(): """Get information about backend API. Returns config from `krita_config.yaml`, other metadata, the path to the rendered image and image mask, etc. Returns: Dict: information. """ opt = load_config().plugin prepare_backend(opt) sample_path = os.path.abspath(opt.sample_path) return { **opt.dict(), "sample_path": sample_path, "upscalers": [upscaler.name for upscaler in shared.sd_upscalers], "samplers": [sampler.name for sampler in modules.sd_samplers.samplers], "samplers_img2img": [ sampler.name for sampler in modules.sd_samplers.samplers_for_img2img ], "scripts_txt2img": get_scripts_metadata(False), "scripts_img2img": get_scripts_metadata(True), "face_restorers": [model.name() for model in shared.face_restorers], "sd_models": modules.sd_models.checkpoint_tiles(), # yes internal API has spelling error } @router.post("/txt2img", response_model=ImageResponse) def f_txt2img(req: Txt2ImgRequest): """Post request for Txt2Img. Args: req (Txt2ImgRequest): Request. Returns: Dict: Outputs and info. """ log.info(f"txt2img:\n{req}") opt = load_config().txt2img req = merge_default_config(req, opt) prepare_backend(req) script_ind, script, meta = get_script_info(req.script, False) args = process_script_args(script_ind, script, meta, req.script_args) width, height = sddebz_highres_fix( req.base_size, req.max_size, req.orig_width, req.orig_height ) output = wrap_gradio_gpu_call(modules.txt2img.txt2img)( parse_prompt(req.prompt), # prompt parse_prompt(req.negative_prompt), # negative_prompt "None", # prompt_style: saved prompt styles (unsupported) "None", # prompt_style2: saved prompt styles (unsupported) req.steps, # steps get_sampler_index(req.sampler_name), # sampler_index req.restore_faces, # restore_faces req.tiling, # tiling req.batch_count, # n_iter req.batch_size, # batch_size req.cfg_scale, # cfg_scale req.seed, # seed req.subseed, # subseed req.subseed_strength, # subseed_strength req.seed_resize_from_h, # seed_resize_from_h req.seed_resize_from_w, # seed_resize_from_w req.seed_enable_extras, # seed_enable_extras height, # height width, # width req.highres_fix, # enable_hr: high res fix req.denoising_strength, # denoising_strength: only applicable if high res fix in use 0, # hr_scale (overrided by hr_resize_x/y) req.upscaler_name, # hr_upscaler: upscaler to use for highres fix 0, # hr_second_pass_steps: 0 uses same num of steps as generation to refine details req.orig_width, # hr_resize_x req.orig_height, # hr_resize_y *args, ) images = output[0] info = output[1] if images is None or len(images) < 1: log.warning("Interrupted!") return {"outputs": [], "info": info} if shared.opts.return_grid: if not req.include_grid and len(images) > 1 and script_ind == 0: images = images[1:] if not script or (width == images[0].width and height == images[0].height): log.info( f"img size: {images[0].width}x{images[0].height}, target: {req.orig_width}x{req.orig_height}" ) images = [ modules.images.resize_image(0, image, req.orig_width, req.orig_height) for image in images ] # save images for debugging/logging purposes if req.save_samples: output_paths = [ save_img(image, opt.sample_path, filename=f"{int(time.time())}_{i}.png") for i, image in enumerate(images) ] log.info(f"saved: {output_paths}") images = [img_to_b64(image) for image in images] log.info(f"output sizes: {[len(i) for i in images]}") log.info(f"finished txt2img!") return {"outputs": images, "info": info} @router.post("/img2img", response_model=ImageResponse) def f_img2img(req: Img2ImgRequest): """Post request for Img2Img. Args: req (Img2ImgRequest): Request. Returns: Dict: Outputs and info. """ log.info(f"img2img:\n{req.dict(exclude={'src_img', 'mask_img'})}") opt = load_config().img2img req = merge_default_config(req, opt) prepare_backend(req) script_ind, script, meta = get_script_info(req.script, True) args = process_script_args(script_ind, script, meta, req.script_args) image = b64_to_img(req.src_img) mask = ( prepare_mask(b64_to_img(req.mask_img)) if req.mode == 1 and req.mask_img is not None else None ) orig_width, orig_height = image.size if script and script.title() == NAME_SCRIPT_UPSCALE: # in SD upscale mode, width & height determines tile size width = height = req.base_size else: width, height = sddebz_highres_fix( req.base_size, req.max_size, orig_width, orig_height ) # NOTE: # - image & mask repeated due to Gradio API have separate tabs for each mode... # - mask is used only in inpaint mode # - mask_mode determines whethere init_img_with_mask or init_img_inpaint is used, # I dont know why # - new color sketch functionality in webUI is irrelevant so None is used for their options. # - the internal code for img2img is confusing and duplicative... output = wrap_gradio_gpu_call(modules.img2img.img2img)( req.mode, # mode parse_prompt(req.prompt), # prompt parse_prompt(req.negative_prompt), # negative_prompt "None", # prompt_style: saved prompt styles (unsupported) "None", # prompt_style2: saved prompt styles (unsupported) image, # init_img {"image": image, "mask": mask}, # init_img_with_mask None, # init_img_with_mask_orig # only used by webUI color sketch if init_img_with_mask isn't dict image, # init_img_inpaint mask, # init_mask_inpaint # using 1 for uploaded mask mode; processing done by prepare_mask to ensure its correct 1, # mask_mode: internally checks if equal 0. 1 enables alpha mask (remove erased parts) req.steps, # steps get_sampler_index(req.sampler_name), # sampler_index 0, # req.mask_blur, # mask_blur None, # mask_alpha # only used by webUI color sketch if init_img_with_mask isn't dict req.inpainting_fill, # inpainting_fill req.restore_faces, # restore_faces req.tiling, # tiling req.batch_count, # n_iter req.batch_size, # batch_size req.cfg_scale, # cfg_scale req.denoising_strength, # denoising_strength req.seed, # seed req.subseed, # subseed req.subseed_strength, # subseed_strength req.seed_resize_from_h, # seed_resize_from_h req.seed_resize_from_w, # seed_resize_from_w req.seed_enable_extras, # seed_enable_extras height, # height width, # width req.resize_mode, # resize_mode False, # req.inpaint_full_res, # inpaint_full_res 0, # req.inpaint_full_res_padding, # inpaint_full_res_padding req.invert_mask, # inpainting_mask_invert "", # img2img_batch_input_dir (unspported) "", # img2img_batch_output_dir (unspported) *args, ) images = output[0] info = output[1] if images is None or len(images) < 1: log.warning("Interrupted!") return {"outputs": [], "info": info} if shared.opts.return_grid: if not req.include_grid and len(images) > 1 and script_ind == 0: images = images[1:] # This is a workaround. if script and script.title() == NAME_SCRIPT_LOOPBACK and len(images) > 1: images = images[1:] # NOTE: this is a dumb assumption: # if size of image is different from size given to pipeline (after sbbedz fix) # then it must be intentional (i.e. SD Upscale/outpaint) so dont scale back if not script or (width == images[0].width and height == images[0].height): log.info( f"img Size: {images[0].width}x{images[0].height}, target: {orig_width}x{orig_height}" ) images = [ modules.images.resize_image(0, image, orig_width, orig_height) for image in images ] if req.mode == 1: def apply_mask(img): """Mask inpaint using original mask, including alpha.""" r, g, b = img.split() # img2img/inpaint gives rgb image a = ImageOps.invert(mask) if req.invert_mask else mask return Image.merge("RGBA", (r, g, b, a)) images = [apply_mask(x) for x in images] # save images for debugging/logging purposes if req.save_samples: output_paths = [ save_img(image, opt.sample_path, filename=f"{int(time.time())}_{i}.png") for i, image in enumerate(images) ] log.info(f"saved: {output_paths}") images = [img_to_b64(image) for image in images] log.info(f"output sizes: {[len(i) for i in images]}") log.info(f"finished img2img!") return {"outputs": images, "info": info} @router.post("/upscale", response_model=UpscaleResponse) def f_upscale(req: UpscaleRequest): """Post request for upscaling. Args: req (UpscaleRequest): Request. Returns: Dict: Output. """ log.info(f"upscale:\n{req.dict(exclude={'src_img'})}") opt = load_config().upscale req = merge_default_config(req, opt) prepare_backend(req) image = b64_to_img(req.src_img).convert("RGB") orig_width, orig_height = image.size upscaler_index = get_upscaler_index(req.upscaler_name) upscaler = shared.sd_upscalers[upscaler_index] if upscaler.name == "None": log.info(f"No upscaler selected, will do nothing") return if req.downscale_first: image = modules.images.resize_image(0, image, orig_width // 2, orig_height // 2) image = upscaler.scaler.upscale(image, upscaler.scale, upscaler.data_path) if req.save_samples: output_path = save_img( image, opt.sample_path, filename=f"{int(time.time())}.png" ) log.info(f"saved: {output_path}") output = img_to_b64(image) log.info(f"output size: {len(output)}") log.info("finished upscale!") return {"output": output} async def app_encryption_middleware(req: Request, call_next): """Used to decrypt/encrypt HTTP request body.""" is_encrypted = "X-Encrypted-Body" in req.headers # only supported method now is XOR assert not is_encrypted or req.headers["X-Encrypted-Body"] == "XOR" if is_encrypted: key = get_encrypt_key() assert key is not None, "Unable to decrypt request without key." body = await req.body() body = bytewise_xor(body, key) # NOTE: FastAPI refuses to work with requests that have already been consumed idk why async def receive(): return dict(type="http.request", body=body, more_body=False) req = Request(req.scope, receive, req._send) res: StreamingResponse = await call_next(req) if is_encrypted: res.headers["X-Encrypted-Body"] = req.headers["X-Encrypted-Body"] body = [bytewise_xor(chunk, key) async for chunk in res.body_iterator] res.body_iterator = iterate_in_threadpool(iter(body)) return res