Seg2Sat-endpoint / handler.py
rgres's picture
Upload folder using huggingface_hub
15e781b
raw history blame
No virus
2.24 kB
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from typing import Dict, List, Any
from io import BytesIO
from PIL import Image
import base64
import torch
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler():
def __init__(self, path=""):
self.stable_diffusion_id = "stabilityai/stable-diffusion-2-1-base"
controlnet = ControlNetModel.from_pretrained("rgres/Seg2Sat-sd-controlnet", torch_dtype=torch.float16)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
self.stable_diffusion_id, controlnet=controlnet, torch_dtype=dtype, safety_checker=None
).to(device)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs` and optional `image` field.
:return: A dictionary with `image` field contains image in base64.
"""
prompt = data.pop("prompt", None)
image = data.pop("image", None)
steps = data.pop("steps", 30)
seed = data.pop("seed", 0)
steps = int(steps)
seed = int(seed)
# Check if neither prompt nor image is provided
if prompt is None and image is None:
return {"error": "Please provide a prompt and base64 encoded image."}
# decode image
image = self.decode_base64_image(image)
self.generator = torch.Generator(device="cpu").manual_seed(seed)
# run inference pipeline
image_out = self.pipe(
prompt=prompt,
image=image,
num_inference_steps=steps,
generator=self.generator
).images[0]
# return first generate PIL image
return image_out
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image