philschmid HF staff commited on
Commit
fdae70a
1 Parent(s): a399857

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +130 -0
  2. requirements.txt +6 -0
handler.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import base64
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
+ import torch
7
+
8
+
9
+ import numpy as np
10
+ import cv2
11
+ import controlnet_hinter
12
+
13
+ # set device
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ if device.type != 'cuda':
16
+ raise ValueError("need to run on GPU")
17
+ # set mixed precision dtype
18
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
19
+
20
+ # controlnet mapping for controlnet id and control hinter
21
+ CONTROLNET_MAPPING = {
22
+ "canny_edge": {
23
+ "model_id": "lllyasviel/sd-controlnet-canny",
24
+ "hinter": controlnet_hinter.hint_canny
25
+ },
26
+ "pose": {
27
+ "model_id": "lllyasviel/sd-controlnet-openpose",
28
+ "hinter": controlnet_hinter.hint_openpose
29
+ },
30
+ "depth": {
31
+ "model_id": "lllyasviel/sd-controlnet-depth",
32
+ "hinter": controlnet_hinter.hint_depth
33
+ },
34
+ "scribble": {
35
+ "model_id": "lllyasviel/sd-controlnet-scribble",
36
+ "hinter": controlnet_hinter.hint_scribble,
37
+ },
38
+ "segmentation": {
39
+ "model_id": "lllyasviel/sd-controlnet-seg",
40
+ "hinter": controlnet_hinter.hint_segmentation,
41
+ },
42
+ "normal": {
43
+ "model_id": "lllyasviel/sd-controlnet-normal",
44
+ "hinter": controlnet_hinter.hint_normal,
45
+ },
46
+ "hed": {
47
+ "model_id": "lllyasviel/sd-controlnet-hed",
48
+ "hinter": controlnet_hinter.hint_hed,
49
+ },
50
+ "hough": {
51
+ "model_id": "lllyasviel/sd-controlnet-mlsd",
52
+ "hinter": controlnet_hinter.hint_hough,
53
+ }
54
+ }
55
+
56
+
57
+ class EndpointHandler():
58
+ def __init__(self, path=""):
59
+ # define default controlnet id and load controlnet
60
+ self.control_type = "normal"
61
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
62
+
63
+ # Load StableDiffusionControlNetPipeline
64
+ self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
65
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
66
+ controlnet=self.controlnet,
67
+ torch_dtype=dtype,
68
+ safety_checker=None).to(device)
69
+ # Define Generator with seed
70
+ self.generator = torch.Generator(device="cpu").manual_seed(3)
71
+
72
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
73
+ """
74
+ :param data: A dictionary contains `inputs` and optional `image` field.
75
+ :return: A dictionary with `image` field contains image in base64.
76
+ """
77
+ prompt = data.pop("prompt", None)
78
+ image = data.pop("image", None)
79
+ controlnet_type = data.pop("controlnet_type", None)
80
+
81
+ # Check if neither prompt nor image is provided
82
+ if prompt is None and encoded_image is None:
83
+ return {"error": "Please provide a prompt and base64 encoded image."}
84
+
85
+ # Check if a new controlnet is provided
86
+ if controlnet_type is not None and controlnet_type != self.control_type:
87
+ print(f"Loading {controlnet_type} controlnet...")
88
+ self.controlnet_type = controlnet_type
89
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
90
+ torch_dtype=dtype).to(device)
91
+ self.pipe.controlnet = self.controlnet
92
+
93
+
94
+ # hyperparamters
95
+ negatice_prompt = data.pop("negative_prompt", None)
96
+ num_inference_steps = data.pop("num_inference_steps", 30)
97
+ guidance_scale = data.pop("guidance_scale", 7.5)
98
+ negative_prompt = data.pop("negative_prompt", None)
99
+ height = data.pop("height", None)
100
+ width = data.pop("width", None)
101
+ controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
102
+
103
+ # process image
104
+ image = self.decode_base64_image(image)
105
+ control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
106
+
107
+ # run inference pipeline
108
+ out = self.pipe(
109
+ prompt=prompt,
110
+ negative_prompt=negative_prompt,
111
+ image=control_image,
112
+ num_inference_steps=num_inference_steps,
113
+ guidance_scale=guidance_scale,
114
+ num_images_per_prompt=1,
115
+ height=height,
116
+ width=width,
117
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
118
+ generator=self.generator
119
+ )
120
+
121
+
122
+ # return first generate PIL image
123
+ return out.images[0]
124
+
125
+ # helper to decode input image
126
+ def decode_base64_image(self, image_string):
127
+ base64_image = base64.b64decode(image_string)
128
+ buffer = BytesIO(base64_image)
129
+ image = Image.open(buffer)
130
+ return image
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ git+https://github.com/huggingface/diffusers.git
3
+ safetensors
4
+ xformers
5
+ opencv-python
6
+ controlnet_hinter==0.0.5