Instruct-Pix2pix / handler.py
Zhibinhong's picture
Update handler.py
cebbecd
raw
history blame contribute delete
No virus
1.2 kB
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
import base64
from io import BytesIO
from PIL import Image
import json
class EndpointHandler():
def __init__(self, path=""):
model_id = "timbrooks/instruct-pix2pix"
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
self.pipe.to("cuda")
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
def __call__(self, data):
info=data['inputs']
image=info.pop("image",data)
prompt=info.pop("text",data)
image=base64.b64decode(image)
raw_images = Image.open(BytesIO(image)).convert('RGB')
images = self.pipe(prompt, image=raw_images, num_inference_steps=25, image_guidance_scale=1).images
img=images[0]
img.save("./1.png")
with open('./1.png','rb') as img_file:
encoded_string = base64.b64encode(img_file.read()).decode('utf-8')
return {'image':encoded_string}
if __name__=="__main__":
my_handler=EndpointHandler(path='.')