from typing import List, Optional from threading import Lock from secrets import compare_digest from fastapi import FastAPI, APIRouter, Depends, Request from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.exceptions import HTTPException from modules import errors, shared, postprocessing from modules.api import models, endpoints, script, helpers, server, nvml, generate, process, control errors.install() class Api: def __init__(self, app: FastAPI, queue_lock: Lock): self.credentials = {} if shared.cmd_opts.auth: for auth in shared.cmd_opts.auth.split(","): user, password = auth.split(":") self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip() if shared.cmd_opts.auth_file: with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file: for line in file.readlines(): user, password = line.split(":") self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip() self.router = APIRouter() self.app = app self.queue_lock = queue_lock self.generate = generate.APIGenerate(queue_lock) self.process = process.APIProcess(queue_lock) self.control = control.APIControl(queue_lock) # server api self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str) self.add_api_route("/sdapi/v1/log", server.get_log_buffer, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"]) self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"]) self.add_api_route("/sdapi/v1/platform", server.get_platform, methods=["GET"]) self.add_api_route("/sdapi/v1/progress", server.get_progress, methods=["GET"], response_model=models.ResProgress) self.add_api_route("/sdapi/v1/interrupt", server.post_interrupt, methods=["POST"]) self.add_api_route("/sdapi/v1/skip", server.post_skip, methods=["POST"]) self.add_api_route("/sdapi/v1/shutdown", server.post_shutdown, methods=["POST"]) self.add_api_route("/sdapi/v1/memory", server.get_memory, methods=["GET"], response_model=models.ResMemory) self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/nvml", nvml.get_nvml, methods=["GET"], response_model=List[models.ResNVML]) # core api using locking self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img) self.add_api_route("/sdapi/v1/img2img", self.generate.post_img2img, methods=["POST"], response_model=models.ResImg2Img) self.add_api_route("/sdapi/v1/control", self.control.post_control, methods=["POST"], response_model=control.ResControl) self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ResProcessImage) self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ResProcessBatch) self.add_api_route("/sdapi/v1/preprocess", self.process.post_preprocess, methods=["POST"]) self.add_api_route("/sdapi/v1/mask", self.process.post_mask, methods=["POST"]) # api dealing with optional scripts self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts) self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript]) # enumerator api self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess]) self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask) self.add_api_route("/sdapi/v1/interrogate", endpoints.get_interrogate, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler]) self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler]) self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel]) self.add_api_route("/sdapi/v1/hypernetworks", endpoints.get_hypernetworks, methods=["GET"], response_model=List[models.ItemHypernetwork]) self.add_api_route("/sdapi/v1/face-restorers", endpoints.get_face_restorers, methods=["GET"], response_model=List[models.ItemFaceRestorer]) self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle]) self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings) self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae]) self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension]) self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork]) # functional api self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo) self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", endpoints.post_refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"]) def add_api_route(self, path: str, endpoint, **kwargs): if (shared.cmd_opts.auth or shared.cmd_opts.auth_file) and shared.cmd_opts.api_only: return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) return self.app.add_api_route(path, endpoint, **kwargs) def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())): # this is only needed for api-only since otherwise auth is handled in gradio/routes.py if credentials.username in self.credentials: if compare_digest(credentials.password, self.credentials[credentials.username]): return True raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"}) def get_session_start(self, req: Request, agent: Optional[str] = None): token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure") user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}') return {} def prepare_img_gen_request(self, request): if hasattr(request, "face") and request.face and not request.script_name and (not request.alwayson_scripts or "face" not in request.alwayson_scripts.keys()): request.script_name = "face" request.script_args = [ request.face.mode, request.face.source_images, request.face.ip_model, request.face.ip_override_sampler, request.face.ip_cache_model, request.face.ip_strength, request.face.ip_structure, request.face.id_strength, request.face.id_conditioning, request.face.id_cache, request.face.pm_trigger, request.face.pm_strength, request.face.pm_start, request.face.fs_cache ] del request.face if hasattr(request, "ip_adapter") and request.ip_adapter and request.script_name != "IP Adapter" and (not request.alwayson_scripts or "IP Adapter" not in request.alwayson_scripts.keys()): request.alwayson_scripts = {} if request.alwayson_scripts is None else request.alwayson_scripts request.alwayson_scripts["IP Adapter"] = { "args": [request.ip_adapter.adapter, request.ip_adapter.scale, request.ip_adapter.image] } del request.ip_adapter def set_upscalers(self, req: dict): reqDict = vars(req) reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def extras_single_image_api(self, req: models.ReqProcessImage): reqDict = self.set_upscalers(req) reqDict['image'] = helpers.decode_base64_to_image(reqDict['image']) with self.queue_lock: result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return models.ResProcessImage(image=helpers.encode_pil_to_base64(result[0][0]), html_info=result[1]) def extras_batch_images_api(self, req: models.ReqProcessBatch): reqDict = self.set_upscalers(req) image_list = reqDict.pop('imageList', []) image_folder = [helpers.decode_base64_to_image(x.data) for x in image_list] with self.queue_lock: result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict) return models.ResProcessBatch(images=list(map(helpers.encode_pil_to_base64, result[0])), html_info=result[1]) def launch(self): config = { "listen": shared.cmd_opts.listen, "port": shared.cmd_opts.port, "keyfile": shared.cmd_opts.tls_keyfile, "certfile": shared.cmd_opts.tls_certfile, "loop": "auto", # auto, asyncio, uvloop "http": "auto", # auto, h11, httptools } from modules.server import UvicornServer http_server = UvicornServer(self.app, **config) # from modules.server import HypercornServer # server = HypercornServer(self.app, **config) http_server.start() shared.log.info(f'API server: Uvicorn options={config}') return http_server # compatibility items decode_base64_to_image = helpers.decode_base64_to_image encode_pil_to_base64 = helpers.encode_pil_to_base64 validate_sampler_name = helpers.validate_sampler_name