Suraj Narayanan Sasikumar commited on
Commit
db7a329
1 Parent(s): 086fc01

handler for endpoint

Browse files
Files changed (1) hide show
  1. handler.py +55 -0
handler.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
4
+ from PIL import Image
5
+ import base64
6
+ from io import BytesIO
7
+
8
+
9
+ # set device
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ if device.type != "cuda":
13
+ raise ValueError("need to run on GPU")
14
+
15
+
16
+ class EndpointHandler:
17
+ def __init__(self, path=""):
18
+ # load StableDiffusionInpaintPipeline pipeline
19
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
20
+ path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
21
+ )
22
+ # use DPMSolverMultistepScheduler
23
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
24
+ self.pipe.scheduler.config
25
+ )
26
+ # move to device
27
+ self.pipe = self.pipe.to(device)
28
+
29
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
30
+ """
31
+ :param data: A dictionary contains `inputs` and optional `image` field.
32
+ :return: A dictionary with `image` field contains image in base64.
33
+ """
34
+ prompt = data.pop("inputs", data)
35
+
36
+ # hyperparamters
37
+ num_inference_steps = data.pop("num_inference_steps", 30)
38
+ guidance_scale = data.pop("guidance_scale", 8)
39
+ negative_prompt = data.pop("negative_prompt", None)
40
+ height = data.pop("height", None)
41
+ width = data.pop("width", None)
42
+
43
+ # run inference pipeline
44
+ out = self.pipe(
45
+ prompt,
46
+ num_inference_steps=num_inference_steps,
47
+ guidance_scale=guidance_scale,
48
+ num_images_per_prompt=1,
49
+ negative_prompt=negative_prompt,
50
+ height=height,
51
+ width=width,
52
+ )
53
+
54
+ # return first generate PIL image
55
+ return out.images[0]