v1_1_scribble / handler.py
serhatderya's picture
Update handler.py
29011eb
from typing import Dict, List, Any
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import os
from huggingface_hub import HfApi
from pathlib import Path
from diffusers.utils import load_image
from PIL import Image
import numpy as np
from controlnet_aux import PidiNetDetector, HEDdetector
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
from io import BytesIO
import base64
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
checkpoint = "lllyasviel/control_v11p_sd15_scribble"
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = ControlNetModel.from_pretrained(checkpoint, torch_dtype=torch.float16)
self.processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
image_base64 = inputs["image_base64"]
prompt = inputs["prompt"]
# preprocess
image = Image.open(BytesIO(base64.b64decode(image_base64)))
control_image = self.processor(image, scribble=True)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=self.model, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
generator = torch.manual_seed(0)
image = pipe(prompt, num_inference_steps=30, generator=generator, image=control_image).images[0]
# postprocess the prediction
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return {"result": img_str.decode()}