chill_watcher / handler.py
wolverinn
init with model
99a5bb3
raw
history blame contribute delete
No virus
8.96 kB
import os
import sys
import time
import importlib
import signal
import re
from typing import Dict, List, Any
# from fastapi import FastAPI
# from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
import modules.lowvram
import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
import modules.progress
import modules.ui
from modules import modelloader
from modules.shared import cmd_opts, opts
import modules.hypernetworks.hypernetwork
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
import base64
import io
from fastapi import HTTPException
from io import BytesIO
import piexif
import piexif.helper
from PIL import PngImagePlugin,Image
def initialize():
# check_versions()
# extensions.list_extensions()
# localization.list_localizations(cmd_opts.localizations_dir)
# if cmd_opts.ui_debug_mode:
# shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
# modules.scripts.load_scripts()
# return
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
modelloader.list_builtin_upscalers()
# modules.scripts.load_scripts()
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
# modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
try:
modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
# shared.reload_hypernetworks()
# ui_extra_networks.intialize()
# ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
# ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
# ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
# extra_networks.initialize()
# extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
# if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
# try:
# if not os.path.exists(cmd_opts.tls_keyfile):
# print("Invalid path to TLS keyfile given")
# if not os.path.exists(cmd_opts.tls_certfile):
# print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
# except TypeError:
# cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
# print("TLS setup invalid, running webui without TLS")
# else:
# print("Running with TLS")
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
signal.signal(signal.SIGINT, sigint_handler)
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
initialize()
self.shared = shared
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
args = {
"outpath_samples": "C:\\Users\\wolvz\\Desktop",
"prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
"negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans",
"sampler_name": "DPM++ SDE Karras",
"steps": 20, # 25
"cfg_scale": 8,
"width": 512,
"height": 768,
"seed": -1,
}
if "prompt" in data.keys():
args["prompt"] = data["prompt"]
p = StableDiffusionProcessingTxt2Img(sd_model=self.shared.sd_model, **args)
processed = process_images(p)
single_image_b64 = encode_pil_to_base64(processed.images[0])
return {
"img_data": single_image_b64,
}
def manual_hack():
initialize()
args = {
"outpath_samples": "C:\\Users\\wolvz\\Desktop",
"prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
"negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans",
"sampler_name": "DPM++ SDE Karras",
"steps": 20, # 25
"cfg_scale": 8,
"width": 512,
"height": 768,
"seed": -1,
}
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
processed = process_images(p)
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as err:
raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
if opts.samples_format.lower() == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
})
if opts.samples_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
else:
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
else:
raise HTTPException(status_code=500, detail="Invalid image format")
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
if __name__ == "__main__":
# manual_hack()
handler = EndpointHandler("./")
res = handler.__call__({})
# print(res)