ControlNet-endpoint / handler.py
Anwar786's picture
Update handler.py
2a4315a verified
raw history blame
No virus
4.75 kB
from typing import List, Dict, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
import controlnet_hinter
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("Need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
# controlnet mapping for depth controlnet
CONTROLNET_MAPPING = {
"depth": {
"model_id": "lllyasviel/sd-controlnet-depth",
"hinter": controlnet_hinter.hint_depth
}
}
class EndpointHandler():
def __init__(self, path=""):
# define default controlnet id and load controlnet
self.control_type = "depth"
self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
# Load StableDiffusionControlNetPipeline
self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=None).to(device)
# Define Generator with seed
self.generator = torch.Generator(device="cpu").manual_seed(3)
def __call__(self, data: Any) -> Dict[str, str]:
# Example JSON payload for testing
example_payload = {
"prompt": "a beautiful landscape",
"negative_prompt": "blur",
"width": 1024,
"height": 1024,
"steps": 30,
"cfg_scale": 7,
"alwayson_scripts": {
"controlnet": {
"args": [
{
"enabled": True,
"input_image": "image in base64",
"model": "control_sd15_depth [fef5e48e]",
"control_mode": "Balanced"
}
]
}
}
}
# Extract parameters from the payload
prompt = data.get("prompt", None)
negative_prompt = data.get("negative_prompt", None)
width = data.get("width", None)
height = data.get("height", None)
num_inference_steps = data.get("steps", 30)
guidance_scale = data.get("cfg_scale", 7)
# Extract controlnet configuration from payload
controlnet_config = data.get("alwayson_scripts", {}).get("controlnet", {}).get("args", [{}])[0]
# Run stable diffusion process
out = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
height=height,
width=width,
controlnet_conditioning_scale=1.0,
generator=self.generator,
)
# Get the generated image
generated_image = out.images[0]
# Process with controlnet if enabled
if controlnet_config.get("enabled", False):
input_image_base64 = controlnet_config.get("input_image", "")
input_image = self.decode_base64_image(input_image_base64)
controlnet_model = controlnet_config.get("model", "")
controlnet_control_mode = controlnet_config.get("control_mode", "")
processed_image = self.process_with_controlnet(generated_image, input_image, controlnet_model, controlnet_control_mode)
else:
processed_image = generated_image
# Return the final processed image as base64
return {"image": self.encode_base64_image(processed_image)}
def process_with_controlnet(self, generated_image, input_image, model, control_mode):
# Simulated controlnet processing (replace with actual implementation)
# Here, we're just using the input_image as-is. Replace this with your controlnet logic.
return input_image
def encode_base64_image(self, image):
# Encode the PIL Image to base64
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image