File size: 2,702 Bytes
42ae52a
 
 
 
 
 
 
 
 
 
030e485
 
 
 
 
 
 
 
 
42ae52a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gc
from typing import List, Optional, Dict, Any

import torch
from pydantic import BaseModel
from PIL import Image
from diffusers.schedulers import *
from controlnet_aux.processor import Processor


class ControlNetReq(BaseModel):
    controlnets: List[str] # ["canny", "tile", "depth"]
    control_images: List[Image.Image]
    controlnet_conditioning_scale: List[float]
    
    class Config:
        arbitrary_types_allowed=True


class BaseReq(BaseModel):
    model: str = ""
    prompt: str = ""
    fast_generation: Optional[bool] = True
    loras: Optional[list] = []
    resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
    scheduler: Optional[str] = "euler_fl"
    height: int = 1024
    width: int = 1024
    num_images_per_prompt: int = 1
    num_inference_steps: int = 8
    guidance_scale: float = 3.5
    seed: Optional[int] = 0
    refiner: bool = False
    vae: bool = True
    controlnet_config: Optional[ControlNetReq] = None
    custom_addons: Optional[Dict[Any, Any]] = None
    
    class Config:
        arbitrary_types_allowed=True


class BaseImg2ImgReq(BaseReq):
    image: Image.Image
    strength: float = 1.0
    
    class Config:
        arbitrary_types_allowed=True


class BaseInpaintReq(BaseImg2ImgReq):
    mask_image: Image.Image
    
    class Config:
        arbitrary_types_allowed=True


def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
    for image in images:
        if resize_mode == "resize_only":
            image = image.resize((width, height))
        elif resize_mode == "crop_and_resize":
            image = image.crop((0, 0, width, height))
        elif resize_mode == "resize_and_fill":
            image = image.resize((width, height), Image.Resampling.LANCZOS)

    return images


def get_controlnet_images(controlnet_config: ControlNetReq, height: int, width: int, resize_mode: str):
    response_images = []
    control_images = resize_images(controlnet_config.control_images, height, width, resize_mode)
    for controlnet, image in zip(controlnet_config.controlnets, control_images):
        if controlnet == "canny":
            processor = Processor('canny')
        elif controlnet == "depth":
            processor = Processor('depth_midas')
        elif controlnet == "pose":
            processor = Processor('openpose_full')
        else:
            raise ValueError(f"Invalid Controlnet: {controlnet}")
    
        response_images.append(processor(image, to_pil=True))
    
    return response_images


def cleanup(pipeline, loras = None):
    if loras:
        pipeline.unload_lora_weights()
    gc.collect()
    torch.cuda.empty_cache()