OrderAndChaos commited on
Commit
71b5721
1 Parent(s): a7b0604

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +79 -0
handler.py CHANGED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
4
+ from PIL import Image
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ if device.type != 'cuda':
11
+ raise ValueError("need to run on GPU")
12
+
13
+ class EndpointHandler:
14
+ def __init__(self, path="lllyasviel/control_v11p_sd15_inpaint"):
15
+ self.controlnet = ControlNetModel.from_pretrained(path, torch_dtype=torch.float32).to(device)
16
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
17
+ "runwayml/stable-diffusion-v1-5",
18
+ controlnet=self.controlnet,
19
+ torch_dtype=torch.float32
20
+ ).to(device)
21
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
22
+ self.generator = torch.Generator(device=device)
23
+
24
+ def __call__(self, data):
25
+ # Decode the images from base64
26
+ original_image = decode_image(data["image"])
27
+ mask_image = decode_image(data["mask_image"])
28
+
29
+ num_inference_steps = data.pop("num_inference_steps", 30)
30
+ guidance_scale = data.pop("guidance_scale", 7.5)
31
+ negative_prompt = data.pop("negative_prompt", None)
32
+ controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
33
+
34
+ height = data.pop("height", None)
35
+ width = data.pop("width", None)
36
+
37
+ # Create inpainting condition
38
+ control_image = self.make_inpaint_condition(original_image, mask_image)
39
+
40
+ # Inpaint the image
41
+ output_image = self.pipe(
42
+ data["inputs"],
43
+ negative_prompt=negative_prompt,
44
+ num_inference_steps=num_inference_steps,
45
+ guidance_scale=guidance_scale,
46
+ num_images_per_prompt=1,
47
+ generator=self.generator,
48
+ image=control_image,
49
+ height=height,
50
+ width=width,
51
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
52
+ ).images[0]
53
+
54
+ # Save the output image to bytes
55
+ output_bytes = save_image_to_bytes(output_image)
56
+
57
+ return output_bytes
58
+
59
+ def make_inpaint_condition(self, image, mask):
60
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
61
+ mask = np.array(mask.convert("L"))
62
+ assert image.shape[0:1] == mask.shape[0:1], "image and image_mask must have the same image size"
63
+ image[mask < 128] = -1.0 # Set as masked pixel
64
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
65
+ image = torch.from_numpy(image).to(device)
66
+ return image
67
+
68
+
69
+ def decode_image(encoded_image):
70
+ image_bytes = base64.b64decode(encoded_image)
71
+ image = Image.open(BytesIO(image_bytes))
72
+ return image
73
+
74
+
75
+ def save_image_to_bytes(image):
76
+ output_bytes = BytesIO()
77
+ image.save(output_bytes, format="PNG")
78
+ output_bytes.seek(0)
79
+ return output_bytes.getvalue()