Ziga Kerec commited on
Commit
f67c5e5
1 Parent(s): 2f17452

pushed new endpoint

Browse files
Files changed (3) hide show
  1. __pycache__/handler.cpython-310.pyc +0 -0
  2. handler.py +63 -14
  3. test.py +3 -1
__pycache__/handler.cpython-310.pyc ADDED
Binary file (2.27 kB). View file
 
handler.py CHANGED
@@ -1,20 +1,69 @@
1
- from typing import Dict, List, Any
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  class EndpointHandler():
4
  def __init__(self, path=""):
5
- # Preload all the elements you are going to need at inference.
6
- # pseudo:
7
- # self.model= load_model(path)
8
- print("Loading model")
 
 
9
 
10
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
11
  """
12
- data args:
13
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
14
- kwargs
15
- Return:
16
- A :obj:`list` | `dict`: will be serialized and returned
17
  """
18
-
19
- # pseudo
20
- # self.model(input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline
4
+ from PIL import Image
5
+ import base64
6
+ from io import BytesIO
7
+
8
+
9
+ # set device
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ if device.type != 'cuda':
13
+ raise ValueError("need to run on GPU")
14
 
15
  class EndpointHandler():
16
  def __init__(self, path=""):
17
+ # load StableDiffusionInpaintPipeline pipeline
18
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16)
19
+ # use DPMSolverMultistepScheduler
20
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
21
+ # move to device
22
+ self.pipe = self.pipe.to(device)
23
 
24
+
25
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
26
  """
27
+ :param data: A dictionary contains `inputs` and optional `image` field.
28
+ :return: A dictionary with `image` field contains image in base64.
 
 
 
29
  """
30
+ inputs = data.pop("inputs", data)
31
+ encoded_image = data.pop("image", None)
32
+ encoded_mask_image = data.pop("mask_image", None)
33
+
34
+ # hyperparamters
35
+ num_inference_steps = data.pop("num_inference_steps", 25)
36
+ guidance_scale = data.pop("guidance_scale", 7.5)
37
+ negative_prompt = data.pop("negative_prompt", None)
38
+ height = data.pop("height", None)
39
+ width = data.pop("width", None)
40
+
41
+ # process image
42
+ if encoded_image is not None and encoded_mask_image is not None:
43
+ image = self.decode_base64_image(encoded_image)
44
+ mask_image = self.decode_base64_image(encoded_mask_image)
45
+ else:
46
+ image = None
47
+ mask_image = None
48
+
49
+ # run inference pipeline
50
+ out = self.pipe(inputs,
51
+ image=image,
52
+ mask_image=mask_image,
53
+ num_inference_steps=num_inference_steps,
54
+ guidance_scale=guidance_scale,
55
+ num_images_per_prompt=1,
56
+ negative_prompt=negative_prompt,
57
+ height=height,
58
+ width=width
59
+ )
60
+
61
+ # return first generate PIL image
62
+ return out.images[0]
63
+
64
+ # helper to decode input image
65
+ def decode_base64_image(self, image_string):
66
+ base64_image = base64.b64decode(image_string)
67
+ buffer = BytesIO(base64_image)
68
+ image = Image.open(buffer)
69
+ return image
test.py CHANGED
@@ -1,7 +1,7 @@
1
  from handler import EndpointHandler
2
 
3
  # init handler
4
- my_handler = EndpointHandler(path=".")
5
 
6
  # # prepare sample payload
7
  # non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
@@ -15,3 +15,5 @@ my_handler = EndpointHandler(path=".")
15
  # print("non_holiday_pred", non_holiday_pred)
16
  # print("holiday_payload", holiday_payload)
17
 
 
 
 
1
  from handler import EndpointHandler
2
 
3
  # init handler
4
+ my_handler = EndpointHandler(path="./")
5
 
6
  # # prepare sample payload
7
  # non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
 
15
  # print("non_holiday_pred", non_holiday_pred)
16
  # print("holiday_payload", holiday_payload)
17
 
18
+ data = {"inputs": "Hello"}
19
+ output = my_handler(data)