File size: 5,122 Bytes
1976a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# inference handler for lightning ai

import re
import os
import logging
# import json
from pydantic import BaseModel
from typing import Any, Dict, Optional, TYPE_CHECKING
from dataclasses import dataclass
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

import lightning as L
from lightning.app.components.serve import PythonServer, Text
from lightning.app import BuildConfig


class _DefaultInputData(BaseModel):
    prompt: str

class _DefaultOutputData(BaseModel):
    img_data: str
    parameters: str


@dataclass
class CustomBuildConfig(BuildConfig):
    def build_commands(self):
        dir_path = "/content/"
        model_path = os.path.join(dir_path, "models/Stable-diffusion")
        model_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors"
        download_cmd = "wget -P {} {}".format(str(model_path), model_url)
        vae_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/VAE/vae-ft-mse-840000-ema-pruned.ckpt"
        vae_path = os.path.join(dir_path, "models/VAE")
        down2 = "wget -P {} {}".format(str(vae_path), vae_url)
        lora_url1 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/koreanDollLikeness_v10.safetensors"
        lora_url2 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/taiwanDollLikeness_v10.safetensors"
        lora_path = os.path.join(dir_path, "models/Lora")
        down3 = "wget -P {} {}".format(str(lora_path), lora_url1)
        down4 = "wget -P {} {}".format(str(lora_path), lora_url2)
        # https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo
        cmd1 = "pip3 install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117"
        cmd2 = "pip3 install torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117"
        cmd_31 = "sudo apt-get update"
        cmd3 = "sudo apt-get install libgl1-mesa-glx"
        cmd4 = "sudo apt-get install libglib2.0-0"
        return [download_cmd, down2, down3, down4, cmd1, cmd2, cmd_31, cmd3, cmd4]


class PyTorchServer(PythonServer):
    def __init__(
        self,
        input_type: type = _DefaultInputData,
        output_type: type = _DefaultOutputData,
        **kwargs: Any,
        ):
        super().__init__(input_type=input_type, output_type=output_type, **kwargs)

        # Use the custom build config
        self.cloud_build_config = CustomBuildConfig()
    def setup(self):
        # need to install dependancies first to import packages
        import torch
        # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
        if ".dev" in torch.__version__ or "+git" in torch.__version__:
            torch.__long_version__ = torch.__version__
            torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)

        from handler import initialize
        initialize()

    def predict(self, request):
        from modules.api.api import encode_pil_to_base64
        from modules import shared
        from modules.processing import StableDiffusionProcessingTxt2Img, process_images
        args = {
            "do_not_save_samples": True,
            "do_not_save_grid": True,
            "outpath_samples": "/content/desktop",
            "prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
            "negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed",
            "sampler_name": "DPM++ SDE Karras",
            "steps": 20, # 25
            "cfg_scale": 8,
            "width": 512,
            "height": 768,
            "seed": -1,
        }
        print("&&&&&&&&&&&&&&&&&&&&&&&&",request)
        if request.prompt:
            prompt = request.prompt
            print("get prompt from request: ", prompt)
            args["prompt"] = prompt
        p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
        processed = process_images(p)
        single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8')
        return {
            "img_data": single_image_b64,
            "parameters": processed.images[0].info.get('parameters', ""),
        }


component = PyTorchServer(
   cloud_compute=L.CloudCompute('gpu', disk_size=20, idle_timeout=30)
)
# lightning run app app.py --cloud
app = L.LightningApp(component)