from typing import Dict, List, Any from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch import os from huggingface_hub import HfApi from pathlib import Path from diffusers.utils import load_image from PIL import Image import numpy as np from controlnet_aux import PidiNetDetector, HEDdetector from diffusers import ( ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) from io import BytesIO import base64 import warnings warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') checkpoint = "lllyasviel/control_v11p_sd15_scribble" class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.model = ControlNetModel.from_pretrained(checkpoint, torch_dtype=torch.float16) self.processor = HEDdetector.from_pretrained('lllyasviel/Annotators') def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) image_base64 = inputs["image_base64"] prompt = inputs["prompt"] # preprocess image = Image.open(BytesIO(base64.b64decode(image_base64))) control_image = self.processor(image, scribble=True) pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=self.model, torch_dtype=torch.float16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() generator = torch.manual_seed(0) image = pipe(prompt, num_inference_steps=30, generator=generator, image=control_image).images[0] # postprocess the prediction buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) return {"result": img_str.decode()}