saassa commited on
Commit
45b10c9
1 Parent(s): 53654df

Create TES UPDATE TO THIS

Browse files
Files changed (1) hide show
  1. TES UPDATE TO THIS +144 -0
TES UPDATE TO THIS ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %%writefile handler.py
2
+ from typing import Dict, List, Any
3
+ import base64
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, AutoencoderKL, StableDiffusionXLControlNetPipeline, AutoPipelineForText2Image
7
+ import torch
8
+ from diffusers.utils import load_image
9
+
10
+ import numpy as np
11
+ import cv2
12
+ import controlnet_hinter
13
+
14
+ # ADDED AUTO PIPE
15
+ # set device
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ if device.type != 'cuda':
18
+ raise ValueError("need to run on GPU")
19
+ # set mixed precision dtype
20
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
21
+
22
+ # controlnet mapping for controlnet id and control hinter
23
+ CONTROLNET_MAPPING = {
24
+ "canny_edge": {
25
+ "model_id": "lllyasviel/sd-controlnet-canny",
26
+ "hinter": controlnet_hinter.hint_canny
27
+ },
28
+ "pose": {
29
+ "model_id": "lllyasviel/sd-controlnet-openpose",
30
+ "hinter": controlnet_hinter.hint_openpose
31
+ },
32
+ "depth": {
33
+ "model_id": "lllyasviel/sd-controlnet-depth",
34
+ "hinter": controlnet_hinter.hint_depth
35
+ },
36
+ "scribble": {
37
+ "model_id": "lllyasviel/sd-controlnet-scribble",
38
+ "hinter": controlnet_hinter.hint_scribble,
39
+ },
40
+ "segmentation": {
41
+ "model_id": "lllyasviel/sd-controlnet-seg",
42
+ "hinter": controlnet_hinter.hint_segmentation,
43
+ },
44
+ "normal": {
45
+ "model_id": "lllyasviel/sd-controlnet-normal",
46
+ "hinter": controlnet_hinter.hint_normal,
47
+ },
48
+ "hed": {
49
+ "model_id": "lllyasviel/sd-controlnet-hed",
50
+ "hinter": controlnet_hinter.hint_hed,
51
+ },
52
+ "hough": {
53
+ "model_id": "lllyasviel/sd-controlnet-mlsd",
54
+ "hinter": controlnet_hinter.hint_hough,
55
+ }
56
+ }
57
+
58
+
59
+ class EndpointHandler():
60
+ def __init__(self, path=""):
61
+ # define default controlnet id and load controlnet
62
+ self.control_type = "normal"
63
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
64
+
65
+ # Load StableDiffusionControlNetPipeline
66
+ self.stable_diffusion_id = "stablediffusionapi/disney-pixar-cartoon"
67
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
68
+ controlnet=self.controlnet,
69
+ torch_dtype=dtype,
70
+ safety_checker=None).to(device)
71
+
72
+ # Define Generator with seed
73
+ # COMMENTED self.generator = torch.Generator(device="cpu").manual_seed(3)
74
+
75
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
76
+ """
77
+ :param data: A dictionary contains `inputs` and optional `image` field.
78
+ :return: A dictionary with `image` field contains image in base64.
79
+ """
80
+ prompt = data.pop("inputs", None)
81
+ image = data.pop("image", None)
82
+ controlnet_type = data.pop("controlnet_type", None)
83
+ stablediffusion_id = data.pop("stablediffusionid", None) # Get the stablediffusionid from the request data
84
+
85
+ if stablediffusion_id is not None and stablediffusion_id != self.stable_diffusion_id:
86
+ # Change the Stable Diffusion model to the new model ID
87
+ self.stable_diffusion_id = stablediffusion_id
88
+ # Reinitialize the pipeline with the new model ID
89
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
90
+ self.stable_diffusion_id,
91
+ controlnet=self.controlnet,
92
+ torch_dtype=dtype,
93
+ safety_checker=None
94
+ ).to(device)
95
+
96
+ # Check if neither prompt nor image is provided
97
+ if prompt is None and image is None:
98
+ return {"error": "Please provide a prompt and base64 encoded image."}
99
+
100
+ # Check if a new controlnet is provided
101
+ if controlnet_type is not None and controlnet_type != self.control_type:
102
+ print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
103
+ self.control_type = controlnet_type
104
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
105
+ torch_dtype=dtype).to(device)
106
+ self.pipe.controlnet = self.controlnet
107
+
108
+ # hyperparameters
109
+ negative_prompt = data.pop("negative_prompt", None)
110
+ num_inference_steps = data.pop("num_inference_steps", 150)
111
+ guidance_scale = data.pop("guidance_scale", 5)
112
+ negative_prompt = data.pop("negative_prompt", None)
113
+ height = data.pop("height", None)
114
+ width = data.pop("width", None)
115
+ controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
116
+
117
+ # process image
118
+ image = self.decode_base64_image(image)
119
+ control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
120
+
121
+ # run inference pipeline
122
+ out = self.pipe(
123
+ prompt=prompt,
124
+ negative_prompt=negative_prompt,
125
+ image=control_image,
126
+ num_inference_steps=num_inference_steps,
127
+ guidance_scale=guidance_scale,
128
+ num_images_per_prompt=1,
129
+ height=height,
130
+ width=width,
131
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
132
+ guess_mode=True,
133
+ )
134
+
135
+ # generator=self.generator COMMENTED from self.pipe
136
+ # return the first generated PIL image
137
+ return out.images[0]
138
+
139
+ # helper to decode input image
140
+ def decode_base64_image(self, image_string):
141
+ base64_image = base64.b64decode(image_string)
142
+ buffer = BytesIO(base64_image)
143
+ image = Image.open(buffer)
144
+ return image