mrcuddle commited on
Commit
736c1ad
·
verified ·
1 Parent(s): c6a254a

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +41 -0
handler.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ from typing import Dict, Any
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from diffusers import StableDiffusionXLPipeline
8
+
9
+
10
+ # helper decoder
11
+ def decode_base64_image(image_string):
12
+ base64_image = base64.b64decode(image_string)
13
+ buffer = BytesIO(base64_image)
14
+ return Image.open(buffer)
15
+
16
+
17
+ class EndpointHandler:
18
+ def __init__(self, path=""):
19
+ self.pipe = StableDiffusionXLPipeline.from_pretrained("/repository/roses",
20
+ torch_dtype=torch.float16, revision="fp16")
21
+ self.pipe = self.pipe.to("cuda")
22
+
23
+ def __call__(self, data: Any) -> Dict[str, str]:
24
+ """
25
+ Return predict value.
26
+ :param data: A dictionary contains `inputs` and optional `image` field.
27
+ :return: A dictionary with `image` field contains image in base64.
28
+ """
29
+ prompts = data.pop("inputs", None)
30
+ encoded_image = data.pop("image", None)
31
+ init_image = None
32
+ if encoded_image:
33
+ init_image = decode_base64_image(encoded_image)
34
+ init_image.thumbnail((768, 768))
35
+
36
+ image = self.pipe(prompts, init_image=init_image).images[0]
37
+ buffered = BytesIO()
38
+ image.save(buffered, format="png")
39
+ img_str = base64.b64encode(buffered.getvalue())
40
+
41
+ return {"image": img_str.decode()}