timothymhowe commited on
Commit
80b2ac3
1 Parent(s): de09dd9

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +36 -0
handler.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ import base64
6
+ from io import BytesIO
7
+
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ if device.type != 'cuda':
12
+ raise ValueError("Must run SDXL on a GPU instance.")
13
+
14
+
15
+ class EndpointHandler():
16
+
17
+
18
+ def __init__(self,path=""):
19
+ self.pipe = StableDiffusionPipeline.from_pretrained(path,torch_dtype=torch.float16)
20
+ self.pipe = self.pipe.to(device)
21
+
22
+
23
+ def __call__(self):
24
+ """
25
+ """
26
+
27
+ inputs = data.pop("inputs",data)
28
+
29
+ with autocast(device.type):
30
+ image = self.pipe(inputs,guidance_scale=9)["sample"][0]
31
+
32
+ buffer = BytesIO()
33
+ image.save(buffer, format="JPEG")
34
+ img_str = base64.b64decode(buffer.getvalue())
35
+
36
+ return {"image": img_str.decode}