cog-webui-sd / handler.py
wolverinn
add docker file
9c8d2ad
raw
history blame
9.3 kB
import os
import sys
import time
import importlib
import signal
import re
from typing import Dict, List, Any
# from fastapi import FastAPI
# from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, ui_tempdir
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
import modules.lowvram
import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
import modules.progress
import modules.ui
from modules import modelloader
from modules.shared import cmd_opts, opts
import modules.hypernetworks.hypernetwork
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
import base64
import io
from fastapi import HTTPException
from io import BytesIO
import piexif
import piexif.helper
from PIL import PngImagePlugin,Image
def initialize():
# check_versions()
# extensions.list_extensions()
# localization.list_localizations(cmd_opts.localizations_dir)
# if cmd_opts.ui_debug_mode:
# shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
# modules.scripts.load_scripts()
# return
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
modelloader.list_builtin_upscalers()
# modules.scripts.load_scripts()
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
# modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
try:
modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
# shared.reload_hypernetworks()
# ui_extra_networks.intialize()
# ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
# ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
# ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
# extra_networks.initialize()
# extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
# if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
# try:
# if not os.path.exists(cmd_opts.tls_keyfile):
# print("Invalid path to TLS keyfile given")
# if not os.path.exists(cmd_opts.tls_certfile):
# print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
# except TypeError:
# cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
# print("TLS setup invalid, running webui without TLS")
# else:
# print("Running with TLS")
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
initialize()
self.shared = shared
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
args = {
# todo: don't output png
"outpath_samples": "C:\\Users\\wolvz\\Desktop",
"prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
"negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed",
"sampler_name": "DPM++ SDE Karras",
"steps": 20, # 25
"cfg_scale": 8,
"width": 512,
"height": 768,
"seed": -1,
}
if data["inputs"]:
if "prompt" in data["inputs"].keys():
prompt = data["inputs"]["prompt"]
print("get prompt from request: ", prompt)
args["prompt"] = prompt
p = StableDiffusionProcessingTxt2Img(sd_model=self.shared.sd_model, **args)
processed = process_images(p)
single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8')
return {
"img_data": single_image_b64,
"parameters": processed.images[0].info.get('parameters', ""),
}
def manual_hack():
initialize()
args = {
# todo: don't output res
"outpath_samples": "C:\\Users\\wolvz\\Desktop",
"prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
"negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans",
"sampler_name": "DPM++ SDE Karras",
"steps": 20, # 25
"cfg_scale": 8,
"width": 512,
"height": 768,
"seed": -1,
}
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
processed = process_images(p)
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
if opts.samples_format.lower() == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
})
if opts.samples_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
else:
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
else:
raise HTTPException(status_code=500, detail="Invalid image format")
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
if __name__ == "__main__":
# manual_hack()
handler = EndpointHandler("./")
res = handler.__call__({})
# print(res)