aai / old2 /modules /helpers /common_helpers.py
mantrakp
Refactor flux_helpers.py to enable or disable Vae
37112ef
raw
history blame
2.82 kB
import gc
from typing import List, Optional, Dict, Any
import torch
from pydantic import BaseModel
from PIL import Image
from diffusers.schedulers import *
from controlnet_aux.processor import Processor
class ControlNetReq(BaseModel):
controlnets: List[str] # ["canny", "tile", "depth"]
control_images: List[Image.Image]
controlnet_conditioning_scale: List[float]
class Config:
arbitrary_types_allowed=True
class BaseReq(BaseModel):
model: str = ""
prompt: str = ""
negative_prompt: Optional[str] = ""
fast_generation: Optional[bool] = True
loras: Optional[list] = []
embeddings: Optional[list] = None
resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
scheduler: Optional[str] = "euler_fl"
height: int = 1024
width: int = 1024
num_images_per_prompt: int = 1
num_inference_steps: int = 8
clip_skip: Optional[int] = None
guidance_scale: float = 3.5
seed: Optional[int] = 0
refiner: bool = False
vae: bool = True
controlnet_config: Optional[ControlNetReq] = None
custom_addons: Optional[Dict[Any, Any]] = None
class Config:
arbitrary_types_allowed=True
class BaseImg2ImgReq(BaseReq):
image: Image.Image
strength: float = 1.0
class Config:
arbitrary_types_allowed=True
class BaseInpaintReq(BaseImg2ImgReq):
mask_image: Image.Image
class Config:
arbitrary_types_allowed=True
def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
for image in images:
if resize_mode == "resize_only":
image = image.resize((width, height))
elif resize_mode == "crop_and_resize":
image = image.crop((0, 0, width, height))
elif resize_mode == "resize_and_fill":
image = image.resize((width, height), Image.Resampling.LANCZOS)
return images
def get_controlnet_images(controlnet_config: ControlNetReq, height: int, width: int, resize_mode: str):
response_images = []
control_images = resize_images(controlnet_config.control_images, height, width, resize_mode)
for controlnet, image in zip(controlnet_config.controlnets, control_images):
if controlnet == "canny":
processor = Processor('canny')
elif controlnet == "depth":
processor = Processor('depth_midas')
elif controlnet == "pose":
processor = Processor('openpose_full')
else:
raise ValueError(f"Invalid Controlnet: {controlnet}")
response_images.append(processor(image, to_pil=True))
return response_images
def cleanup(pipeline, loras = None):
if loras:
pipeline.unload_lora_weights()
gc.collect()
torch.cuda.empty_cache()