|
from typing import Dict, List, Any |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
import os |
|
from huggingface_hub import HfApi |
|
from pathlib import Path |
|
from diffusers.utils import load_image |
|
from PIL import Image |
|
import numpy as np |
|
from controlnet_aux import PidiNetDetector, HEDdetector |
|
from diffusers import ( |
|
ControlNetModel, |
|
StableDiffusionControlNetPipeline, |
|
UniPCMultistepScheduler, |
|
) |
|
from io import BytesIO |
|
import base64 |
|
import warnings |
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') |
|
|
|
checkpoint = "lllyasviel/control_v11p_sd15_scribble" |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model = ControlNetModel.from_pretrained(checkpoint, torch_dtype=torch.float16) |
|
self.processor = HEDdetector.from_pretrained('lllyasviel/Annotators') |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:dict:): |
|
The payload with the text prompt and generation parameters. |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
image_base64 = inputs["image_base64"] |
|
prompt = inputs["prompt"] |
|
|
|
|
|
image = Image.open(BytesIO(base64.b64decode(image_base64))) |
|
control_image = self.processor(image, scribble=True) |
|
pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=self.model, torch_dtype=torch.float16 |
|
) |
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
pipe.enable_model_cpu_offload() |
|
|
|
generator = torch.manual_seed(0) |
|
image = pipe(prompt, num_inference_steps=30, generator=generator, image=control_image).images[0] |
|
|
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
return {"result": img_str.decode()} |