tested and added
Browse files- handler.py +21 -12
handler.py
CHANGED
@@ -23,31 +23,40 @@ class EndpointHandler():
|
|
23 |
self.pipe = self.pipe.to(device)
|
24 |
|
25 |
|
26 |
-
def __call__(self, data: Any) ->
|
27 |
"""
|
28 |
Args:
|
29 |
-
data (
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
"""
|
34 |
-
|
35 |
negative_prompt = data.pop("negative_prompt", None)
|
36 |
height = data.pop("height", 512)
|
37 |
width = data.pop("width", 512)
|
|
|
38 |
guidance_scale = data.pop("guidance_scale", 7.5)
|
39 |
|
40 |
-
#
|
41 |
with autocast(device.type):
|
42 |
if negative_prompt is None:
|
43 |
-
image = self.pipe(prompt
|
|
|
44 |
else:
|
45 |
-
image = self.pipe(prompt
|
|
|
46 |
|
47 |
-
#
|
48 |
buffered = BytesIO()
|
49 |
image.save(buffered, format="JPEG")
|
50 |
img_str = base64.b64encode(buffered.getvalue())
|
51 |
|
52 |
-
#
|
53 |
-
return {"image": img_str.decode()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
self.pipe = self.pipe.to(device)
|
24 |
|
25 |
|
26 |
+
def __call__(self, data: Any) -> Dict[str, str]:
|
27 |
"""
|
28 |
Args:
|
29 |
+
data (Any): Includes the input data and the parameters for the inference.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Dict[str, str]: Dictionary with the base64 encoded image.
|
33 |
"""
|
34 |
+
positive_prompt = data.pop("positive_prompt", "")
|
35 |
negative_prompt = data.pop("negative_prompt", None)
|
36 |
height = data.pop("height", 512)
|
37 |
width = data.pop("width", 512)
|
38 |
+
|
39 |
guidance_scale = data.pop("guidance_scale", 7.5)
|
40 |
|
41 |
+
# Run inference pipeline
|
42 |
with autocast(device.type):
|
43 |
if negative_prompt is None:
|
44 |
+
image = self.pipe(prompt=positive_prompt, height=height, width=width, guidance_scale=float(guidance_scale))
|
45 |
+
image = image.images[0]
|
46 |
else:
|
47 |
+
image = self.pipe(prompt=positive_prompt, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=float(guidance_scale))
|
48 |
+
image = image.images[0]
|
49 |
|
50 |
+
# Encode image as base64
|
51 |
buffered = BytesIO()
|
52 |
image.save(buffered, format="JPEG")
|
53 |
img_str = base64.b64encode(buffered.getvalue())
|
54 |
|
55 |
+
# Postprocess the prediction
|
56 |
+
return {"image": img_str.decode()}
|
57 |
+
|
58 |
+
def decode_base64_image(self, image_string):
|
59 |
+
base64_image = base64.b64decode(image_string)
|
60 |
+
buffer = BytesIO(base64_image)
|
61 |
+
image = Image.open(buffer)
|
62 |
+
return image
|