# inference handler for lightning ai import re import os import logging # import json from pydantic import BaseModel from typing import Any, Dict, Optional, TYPE_CHECKING from dataclasses import dataclass logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) import lightning as L from lightning.app.components.serve import PythonServer, Text from lightning.app import BuildConfig class _DefaultInputData(BaseModel): prompt: str class _DefaultOutputData(BaseModel): img_data: str parameters: str @dataclass class CustomBuildConfig(BuildConfig): def build_commands(self): dir_path = "/content/" model_path = os.path.join(dir_path, "models/Stable-diffusion") # model_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors" model_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors" download_cmd = "wget -P {} {}".format(str(model_path), model_url) vae_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/VAE/vae-ft-mse-840000-ema-pruned.ckpt" vae_path = os.path.join(dir_path, "models/VAE") down2 = "wget -P {} {}".format(str(vae_path), vae_url) lora_url1 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/koreanDollLikeness_v10.safetensors" lora_url2 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/taiwanDollLikeness_v10.safetensors" lora_path = os.path.join(dir_path, "models/Lora") down3 = "wget -P {} {}".format(str(lora_path), lora_url1) down4 = "wget -P {} {}".format(str(lora_path), lora_url2) # https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo cmd1 = "pip3 install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117" cmd2 = "pip3 install torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117" cmd_31 = "sudo apt-get update" cmd3 = "sudo apt-get install libgl1-mesa-glx" cmd4 = "sudo apt-get install libglib2.0-0" return [download_cmd, down2, down3, down4, cmd_31, cmd3, cmd4] class PyTorchServer(PythonServer): def __init__( self, input_type: type = _DefaultInputData, output_type: type = _DefaultOutputData, **kwargs: Any, ): super().__init__(input_type=input_type, output_type=output_type, **kwargs) # Use the custom build config self.cloud_build_config = CustomBuildConfig() def setup(self): # need to install dependancies first to import packages import torch # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) from handler import initialize initialize() def predict(self, request): from modules.api.api import encode_pil_to_base64 from modules import shared from modules.processing import StableDiffusionProcessingTxt2Img, process_images args = { "do_not_save_samples": True, "do_not_save_grid": True, "outpath_samples": "/content/desktop", "prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer", "negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed", "sampler_name": "DPM++ SDE Karras", "steps": 20, # 25 "cfg_scale": 8, "width": 512, "height": 768, "seed": -1, } print("&&&&&&&&&&&&&&&&&&&&&&&&",request) if request.prompt: prompt = request.prompt print("get prompt from request: ", prompt) args["prompt"] = prompt p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) processed = process_images(p) single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8') return { "img_data": single_image_b64, "parameters": processed.images[0].info.get('parameters', ""), } component = PyTorchServer( cloud_compute=L.CloudCompute('gpu', disk_size=20, idle_timeout=30) ) # lightning run app app.py --cloud app = L.LightningApp(component)