radames commited on
Commit
d056e0b
1 Parent(s): d3237c9

add contronet canny

Browse files
app-controlnet.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import traceback
5
+ from pydantic import BaseModel
6
+
7
+ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import StreamingResponse, JSONResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+
12
+ from diffusers import AutoencoderTiny, ControlNetModel
13
+ from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet
14
+ from compel import Compel
15
+ import torch
16
+
17
+ from canny_gpu import SobelOperator
18
+ # from controlnet_aux import OpenposeDetector
19
+ # import cv2
20
+
21
+ try:
22
+ import intel_extension_for_pytorch as ipex
23
+ except:
24
+ pass
25
+ from PIL import Image
26
+ import numpy as np
27
+ import gradio as gr
28
+ import io
29
+ import uuid
30
+ import os
31
+ import time
32
+ import psutil
33
+
34
+
35
+ MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
36
+ TIMEOUT = float(os.environ.get("TIMEOUT", 0))
37
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
38
+ WIDTH = 512
39
+ HEIGHT = 512
40
+ # disable tiny autoencoder for better quality speed tradeoff
41
+ USE_TINY_AUTOENCODER = True
42
+
43
+ # check if MPS is available OSX only M1/M2/M3 chips
44
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
45
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
46
+ device = torch.device(
47
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
48
+ )
49
+
50
+ # change to torch.float16 to save GPU memory
51
+ torch_dtype = torch.float16
52
+
53
+ print(f"TIMEOUT: {TIMEOUT}")
54
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
55
+ print(f"MAX_QUEUE_SIZE: {MAX_QUEUE_SIZE}")
56
+ print(f"device: {device}")
57
+
58
+ if mps_available:
59
+ device = torch.device("mps")
60
+ device = "cpu"
61
+ torch_dtype = torch.float32
62
+
63
+ controlnet_canny = ControlNetModel.from_pretrained(
64
+ "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch_dtype
65
+ ).to(device)
66
+
67
+ canny_torch = SobelOperator(device=device)
68
+ # controlnet_pose = ControlNetModel.from_pretrained(
69
+ # "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch_dtype
70
+ # ).to(device)
71
+ # controlnet_depth = ControlNetModel.from_pretrained(
72
+ # "lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch_dtype
73
+ # ).to(device)
74
+
75
+
76
+ # pose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
77
+
78
+ if SAFETY_CHECKER == "True":
79
+ pipe = LatentConsistencyModelPipeline_controlnet.from_pretrained(
80
+ "SimianLuo/LCM_Dreamshaper_v7",
81
+ controlnet=controlnet_canny,
82
+ scheduler=None,
83
+ )
84
+ else:
85
+ pipe = LatentConsistencyModelPipeline_controlnet.from_pretrained(
86
+ "SimianLuo/LCM_Dreamshaper_v7",
87
+ safety_checker=None,
88
+ controlnet=controlnet_canny,
89
+ scheduler=None,
90
+ )
91
+
92
+ if USE_TINY_AUTOENCODER:
93
+ pipe.vae = AutoencoderTiny.from_pretrained(
94
+ "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
95
+ )
96
+ pipe.set_progress_bar_config(disable=True)
97
+ pipe.to(device=device, dtype=torch_dtype).to(device)
98
+ pipe.unet.to(memory_format=torch.channels_last)
99
+
100
+ if psutil.virtual_memory().total < 64 * 1024**3:
101
+ pipe.enable_attention_slicing()
102
+
103
+ # if not mps_available and not xpu_available:
104
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
105
+ # pipe(prompt="warmup", image=[Image.new("RGB", (512, 512))])
106
+
107
+ compel_proc = Compel(
108
+ tokenizer=pipe.tokenizer,
109
+ text_encoder=pipe.text_encoder,
110
+ truncate_long_prompts=False,
111
+ )
112
+ user_queue_map = {}
113
+
114
+
115
+ class InputParams(BaseModel):
116
+ seed: int = 2159232
117
+ prompt: str
118
+ guidance_scale: float = 8.0
119
+ strength: float = 0.5
120
+ steps: int = 4
121
+ lcm_steps: int = 50
122
+ width: int = WIDTH
123
+ height: int = HEIGHT
124
+ controlnet_scale: float = 0.8
125
+ controlnet_start: float = 0.0
126
+ controlnet_end: float = 1.0
127
+ canny_low_threshold: float = 0.31
128
+ canny_high_threshold: float = 0.78
129
+
130
+ def predict(
131
+ input_image: Image.Image, params: InputParams, prompt_embeds: torch.Tensor = None
132
+ ):
133
+ generator = torch.manual_seed(params.seed)
134
+
135
+ control_image = canny_torch(input_image, params.canny_low_threshold, params.canny_high_threshold)
136
+ print(params.canny_low_threshold, params.canny_high_threshold)
137
+ results = pipe(
138
+ control_image=control_image,
139
+ prompt_embeds=prompt_embeds,
140
+ generator=generator,
141
+ image=input_image,
142
+ strength=params.strength,
143
+ num_inference_steps=params.steps,
144
+ guidance_scale=params.guidance_scale,
145
+ width=params.width,
146
+ height=params.height,
147
+ lcm_origin_steps=params.lcm_steps,
148
+ output_type="pil",
149
+ controlnet_conditioning_scale=params.controlnet_scale,
150
+ control_guidance_start=params.controlnet_start,
151
+ control_guidance_end=params.controlnet_end,
152
+ )
153
+ nsfw_content_detected = (
154
+ results.nsfw_content_detected[0]
155
+ if "nsfw_content_detected" in results
156
+ else False
157
+ )
158
+ if nsfw_content_detected:
159
+ return None
160
+ return results.images[0]
161
+
162
+
163
+ app = FastAPI()
164
+ app.add_middleware(
165
+ CORSMiddleware,
166
+ allow_origins=["*"],
167
+ allow_credentials=True,
168
+ allow_methods=["*"],
169
+ allow_headers=["*"],
170
+ )
171
+
172
+
173
+ @app.websocket("/ws")
174
+ async def websocket_endpoint(websocket: WebSocket):
175
+ await websocket.accept()
176
+ if MAX_QUEUE_SIZE > 0 and len(user_queue_map) >= MAX_QUEUE_SIZE:
177
+ print("Server is full")
178
+ await websocket.send_json({"status": "error", "message": "Server is full"})
179
+ await websocket.close()
180
+ return
181
+
182
+ try:
183
+ uid = str(uuid.uuid4())
184
+ print(f"New user connected: {uid}")
185
+ await websocket.send_json(
186
+ {"status": "success", "message": "Connected", "userId": uid}
187
+ )
188
+ user_queue_map[uid] = {"queue": asyncio.Queue()}
189
+ await websocket.send_json(
190
+ {"status": "start", "message": "Start Streaming", "userId": uid}
191
+ )
192
+ await handle_websocket_data(websocket, uid)
193
+ except WebSocketDisconnect as e:
194
+ logging.error(f"WebSocket Error: {e}, {uid}")
195
+ traceback.print_exc()
196
+ finally:
197
+ print(f"User disconnected: {uid}")
198
+ queue_value = user_queue_map.pop(uid, None)
199
+ queue = queue_value.get("queue", None)
200
+ if queue:
201
+ while not queue.empty():
202
+ try:
203
+ queue.get_nowait()
204
+ except asyncio.QueueEmpty:
205
+ continue
206
+
207
+
208
+ @app.get("/queue_size")
209
+ async def get_queue_size():
210
+ queue_size = len(user_queue_map)
211
+ return JSONResponse({"queue_size": queue_size})
212
+
213
+
214
+ @app.get("/stream/{user_id}")
215
+ async def stream(user_id: uuid.UUID):
216
+ uid = str(user_id)
217
+ try:
218
+ user_queue = user_queue_map[uid]
219
+ queue = user_queue["queue"]
220
+
221
+ async def generate():
222
+ last_prompt: str = None
223
+ prompt_embeds: torch.Tensor = None
224
+ while True:
225
+ data = await queue.get()
226
+ input_image = data["image"]
227
+ params = data["params"]
228
+ if input_image is None:
229
+ continue
230
+ # avoid recalculate prompt embeds
231
+ if last_prompt != params.prompt:
232
+ print("new prompt")
233
+ prompt_embeds = compel_proc(params.prompt)
234
+ last_prompt = params.prompt
235
+
236
+ image = predict(
237
+ input_image,
238
+ params,
239
+ prompt_embeds,
240
+ )
241
+ if image is None:
242
+ continue
243
+ frame_data = io.BytesIO()
244
+ image.save(frame_data, format="JPEG")
245
+ frame_data = frame_data.getvalue()
246
+ if frame_data is not None and len(frame_data) > 0:
247
+ yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
248
+
249
+ await asyncio.sleep(1.0 / 120.0)
250
+
251
+ return StreamingResponse(
252
+ generate(), media_type="multipart/x-mixed-replace;boundary=frame"
253
+ )
254
+ except Exception as e:
255
+ logging.error(f"Streaming Error: {e}, {user_queue_map}")
256
+ traceback.print_exc()
257
+ return HTTPException(status_code=404, detail="User not found")
258
+
259
+
260
+ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
261
+ uid = str(user_id)
262
+ user_queue = user_queue_map[uid]
263
+ queue = user_queue["queue"]
264
+ if not queue:
265
+ return HTTPException(status_code=404, detail="User not found")
266
+ last_time = time.time()
267
+ try:
268
+ while True:
269
+ data = await websocket.receive_bytes()
270
+ params = await websocket.receive_json()
271
+ params = InputParams(**params)
272
+ pil_image = Image.open(io.BytesIO(data))
273
+
274
+ while not queue.empty():
275
+ try:
276
+ queue.get_nowait()
277
+ except asyncio.QueueEmpty:
278
+ continue
279
+ await queue.put({"image": pil_image, "params": params})
280
+ if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
281
+ await websocket.send_json(
282
+ {
283
+ "status": "timeout",
284
+ "message": "Your session has ended",
285
+ "userId": uid,
286
+ }
287
+ )
288
+ await websocket.close()
289
+ return
290
+
291
+ except Exception as e:
292
+ logging.error(f"Error: {e}")
293
+ traceback.print_exc()
294
+
295
+
296
+ app.mount("/", StaticFiles(directory="controlnet", html=True), name="public")
canny_gpu.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.transforms import ToTensor, ToPILImage
4
+ from PIL import Image
5
+
6
+ class SobelOperator(nn.Module):
7
+ def __init__(self, device="cuda"):
8
+ super(SobelOperator, self).__init__()
9
+ self.device = device
10
+ self.edge_conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
11
+ self.device
12
+ )
13
+ self.edge_conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(
14
+ self.device
15
+ )
16
+
17
+ sobel_kernel_x = torch.tensor(
18
+ [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=self.device
19
+ )
20
+ sobel_kernel_y = torch.tensor(
21
+ [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]], device=self.device
22
+ )
23
+
24
+ self.edge_conv_x.weight = nn.Parameter(sobel_kernel_x.view((1, 1, 3, 3)))
25
+ self.edge_conv_y.weight = nn.Parameter(sobel_kernel_y.view((1, 1, 3, 3)))
26
+
27
+ @torch.no_grad()
28
+ def forward(self, image: Image.Image, low_threshold: float, high_threshold: float):
29
+ # Convert PIL image to PyTorch tensor
30
+ image_gray = image.convert("L")
31
+ image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
32
+
33
+ # Compute gradients
34
+ edge_x = self.edge_conv_x(image_tensor)
35
+ edge_y = self.edge_conv_y(image_tensor)
36
+ edge = torch.sqrt(edge_x**2 + edge_y**2)
37
+
38
+ # Apply thresholding
39
+ edge = edge / edge.max() # Normalize to 0-1
40
+ edge[edge >= high_threshold] = 1.0
41
+ edge[edge <= low_threshold] = 0.0
42
+
43
+ # Convert the result back to a PIL image
44
+ return ToPILImage()(edge.squeeze(0).cpu())
controlnet/index.html ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html>
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <title>Real-Time Latent Consistency Model ControlNet</title>
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
+ <script
9
+ src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
10
+ <script src="https://cdn.jsdelivr.net/npm/piexifjs@1.0.6/piexif.min.js"></script>
11
+ <script src="https://cdn.tailwindcss.com"></script>
12
+ <style type="text/tailwindcss">
13
+ .button {
14
+ @apply bg-gray-700 hover:bg-gray-800 text-white font-normal p-2 rounded disabled:bg-gray-300 dark:disabled:bg-gray-700 disabled:cursor-not-allowed dark:disabled:text-black
15
+ }
16
+ </style>
17
+ <script type="module">
18
+ // you can change the size of the input image to 768x768 if you have a powerful GPU
19
+ const getValue = (id) => document.querySelector(`${id}`).value;
20
+ const startBtn = document.querySelector("#start");
21
+ const stopBtn = document.querySelector("#stop");
22
+ const videoEl = document.querySelector("#webcam");
23
+ const imageEl = document.querySelector("#player");
24
+ const queueSizeEl = document.querySelector("#queue_size");
25
+ const errorEl = document.querySelector("#error");
26
+ const snapBtn = document.querySelector("#snap");
27
+ const webcamsEl = document.querySelector("#webcams");
28
+
29
+ function LCMLive(webcamVideo, liveImage) {
30
+ let websocket;
31
+
32
+ async function start() {
33
+ return new Promise((resolve, reject) => {
34
+ const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
35
+ }:${window.location.host}/ws`;
36
+
37
+ const socket = new WebSocket(websocketURL);
38
+ socket.onopen = () => {
39
+ console.log("Connected to websocket");
40
+ };
41
+ socket.onclose = () => {
42
+ console.log("Disconnected from websocket");
43
+ stop();
44
+ resolve({ "status": "disconnected" });
45
+ };
46
+ socket.onerror = (err) => {
47
+ console.error(err);
48
+ reject(err);
49
+ };
50
+ socket.onmessage = (event) => {
51
+ const data = JSON.parse(event.data);
52
+ switch (data.status) {
53
+ case "success":
54
+ break;
55
+ case "start":
56
+ const userId = data.userId;
57
+ initVideoStream(userId);
58
+ break;
59
+ case "timeout":
60
+ stop();
61
+ resolve({ "status": "timeout" });
62
+ case "error":
63
+ stop();
64
+ reject(data.message);
65
+
66
+ }
67
+ };
68
+ websocket = socket;
69
+ })
70
+ }
71
+ function switchCamera() {
72
+ const constraints = {
73
+ audio: false,
74
+ video: { width: 1024, height: 768, deviceId: mediaDevices[webcamsEl.value].deviceId }
75
+ };
76
+ navigator.mediaDevices
77
+ .getUserMedia(constraints)
78
+ .then((mediaStream) => {
79
+ webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
80
+ webcamVideo.srcObject = mediaStream;
81
+ webcamVideo.onloadedmetadata = () => {
82
+ webcamVideo.play();
83
+ webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
84
+ };
85
+ })
86
+ .catch((err) => {
87
+ console.error(`${err.name}: ${err.message}`);
88
+ });
89
+ }
90
+
91
+ async function videoTimeUpdateHandler() {
92
+ const dimension = getValue("input[name=dimension]:checked");
93
+ const [WIDTH, HEIGHT] = JSON.parse(dimension);
94
+
95
+ const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
96
+ const videoW = webcamVideo.videoWidth;
97
+ const videoH = webcamVideo.videoHeight;
98
+ const aspectRatio = WIDTH / HEIGHT;
99
+
100
+ const ctx = canvas.getContext("2d");
101
+ ctx.drawImage(webcamVideo, videoW / 2 - videoH * aspectRatio / 2, 0, videoH * aspectRatio, videoH, 0, 0, WIDTH, HEIGHT)
102
+ const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
103
+ websocket.send(blob);
104
+ websocket.send(JSON.stringify({
105
+ "seed": getValue("#seed"),
106
+ "prompt": getValue("#prompt"),
107
+ "guidance_scale": getValue("#guidance-scale"),
108
+ "strength": getValue("#strength"),
109
+ "steps": getValue("#steps"),
110
+ "lcm_steps": getValue("#lcm_steps"),
111
+ "width": WIDTH,
112
+ "height": HEIGHT,
113
+ "controlnet_scale": getValue("#controlnet_scale"),
114
+ "controlnet_start": getValue("#controlnet_start"),
115
+ "controlnet_end": getValue("#controlnet_end"),
116
+ "canny_low_threshold": getValue("#canny_low_threshold"),
117
+ "canny_high_threshold": getValue("#canny_high_threshold"),
118
+ }));
119
+ }
120
+ let mediaDevices = [];
121
+ async function initVideoStream(userId) {
122
+ liveImage.src = `/stream/${userId}`;
123
+ await navigator.mediaDevices.enumerateDevices()
124
+ .then(devices => {
125
+ const cameras = devices.filter(device => device.kind === 'videoinput');
126
+ mediaDevices = cameras;
127
+ webcamsEl.innerHTML = "";
128
+ cameras.forEach((camera, index) => {
129
+ const option = document.createElement("option");
130
+ option.value = index;
131
+ option.innerText = camera.label;
132
+ webcamsEl.appendChild(option);
133
+ option.selected = index === 0;
134
+ });
135
+ webcamsEl.addEventListener("change", switchCamera);
136
+ })
137
+ .catch(err => {
138
+ console.error(err);
139
+ });
140
+ const constraints = {
141
+ audio: false,
142
+ video: { width: 1024, height: 768, deviceId: mediaDevices[0].deviceId }
143
+ };
144
+ navigator.mediaDevices
145
+ .getUserMedia(constraints)
146
+ .then((mediaStream) => {
147
+ webcamVideo.srcObject = mediaStream;
148
+ webcamVideo.onloadedmetadata = () => {
149
+ webcamVideo.play();
150
+ webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
151
+ };
152
+ })
153
+ .catch((err) => {
154
+ console.error(`${err.name}: ${err.message}`);
155
+ });
156
+ }
157
+
158
+
159
+ async function stop() {
160
+ websocket.close();
161
+ navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
162
+ mediaStream.getTracks().forEach((track) => track.stop());
163
+ });
164
+ webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
165
+ webcamsEl.removeEventListener("change", switchCamera);
166
+ webcamVideo.srcObject = null;
167
+ }
168
+ return {
169
+ start,
170
+ stop
171
+ }
172
+ }
173
+ function toggleMessage(type) {
174
+ errorEl.hidden = false;
175
+ errorEl.scrollIntoView();
176
+ switch (type) {
177
+ case "error":
178
+ errorEl.innerText = "To many users are using the same GPU, please try again later.";
179
+ errorEl.classList.toggle("bg-red-300", "text-red-900");
180
+ break;
181
+ case "success":
182
+ errorEl.innerText = "Your session has ended, please start a new one.";
183
+ errorEl.classList.toggle("bg-green-300", "text-green-900");
184
+ break;
185
+ }
186
+ setTimeout(() => {
187
+ errorEl.hidden = true;
188
+ }, 2000);
189
+ }
190
+ function snapImage() {
191
+ try {
192
+ const zeroth = {};
193
+ const exif = {};
194
+ const gps = {};
195
+ zeroth[piexif.ImageIFD.Make] = "LCM Image-to-Image ControNet";
196
+ zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${getValue("#prompt")} | seed: ${getValue("#seed")} | guidance_scale: ${getValue("#guidance-scale")} | strength: ${getValue("#strength")} | controlnet_start: ${getValue("#controlnet_start")} | controlnet_end: ${getValue("#controlnet_end")} | lcm_steps: ${getValue("#lcm_steps")} | steps: ${getValue("#steps")}`;
197
+ zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model";
198
+ exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString();
199
+
200
+ const exifObj = { "0th": zeroth, "Exif": exif, "GPS": gps };
201
+ const exifBytes = piexif.dump(exifObj);
202
+
203
+ const canvas = document.createElement("canvas");
204
+ canvas.width = imageEl.naturalWidth;
205
+ canvas.height = imageEl.naturalHeight;
206
+ const ctx = canvas.getContext("2d");
207
+ ctx.drawImage(imageEl, 0, 0);
208
+ const dataURL = canvas.toDataURL("image/jpeg");
209
+ const withExif = piexif.insert(exifBytes, dataURL);
210
+
211
+ const a = document.createElement("a");
212
+ a.href = withExif;
213
+ a.download = `lcm_txt_2_img${Date.now()}.png`;
214
+ a.click();
215
+ } catch (err) {
216
+ console.log(err);
217
+ }
218
+ }
219
+
220
+
221
+ const lcmLive = LCMLive(videoEl, imageEl);
222
+ startBtn.addEventListener("click", async () => {
223
+ try {
224
+ startBtn.disabled = true;
225
+ snapBtn.disabled = false;
226
+ const res = await lcmLive.start();
227
+ startBtn.disabled = false;
228
+ if (res.status === "timeout")
229
+ toggleMessage("success")
230
+ } catch (err) {
231
+ console.log(err);
232
+ toggleMessage("error")
233
+ startBtn.disabled = false;
234
+ }
235
+ });
236
+ stopBtn.addEventListener("click", () => {
237
+ lcmLive.stop();
238
+ });
239
+ window.addEventListener("beforeunload", () => {
240
+ lcmLive.stop();
241
+ });
242
+ snapBtn.addEventListener("click", snapImage);
243
+ setInterval(() =>
244
+ fetch("/queue_size")
245
+ .then((res) => res.json())
246
+ .then((data) => {
247
+ queueSizeEl.innerText = data.queue_size;
248
+ })
249
+ .catch((err) => {
250
+ console.log(err);
251
+ })
252
+ , 5000);
253
+ </script>
254
+ </head>
255
+
256
+ <body class="text-black dark:bg-gray-900 dark:text-white">
257
+ <div class="fixed right-2 top-2 p-4 font-bold text-sm rounded-lg max-w-xs text-center" id="error">
258
+ </div>
259
+ <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
260
+ <article class="text-center max-w-xl mx-auto">
261
+ <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
262
+ <h2 class="text-2xl font-bold mb-4">ControlNet</h2>
263
+ <p class="text-sm">
264
+ This demo showcases
265
+ <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
266
+ class="text-blue-500 underline hover:no-underline">LCM</a> Image to Image pipeline
267
+ using
268
+ <a href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
269
+ target="_blank" class="text-blue-500 underline hover:no-underline">Diffusers</a> with a MJPEG
270
+ stream server.
271
+ </p>
272
+ <p class="text-sm">
273
+ There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
274
+ real-time performance. Maximum queue size is 4. <a
275
+ href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
276
+ target="_blank" class="text-blue-500 underline hover:no-underline">Duplicate</a> and run it on your
277
+ own GPU.
278
+ </p>
279
+ </article>
280
+ <div>
281
+ <h2 class="font-medium">Prompt</h2>
282
+ <p class="text-sm text-gray-500">
283
+ Change the prompt to generate different images, accepts <a
284
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md" target="_blank"
285
+ class="text-blue-500 underline hover:no-underline">Compel</a> syntax.
286
+ </p>
287
+ <div class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
288
+ <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1 outline-none dark:text-black"
289
+ title="Prompt, this is an example, feel free to modify"
290
+ placeholder="Add your prompt here...">Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5, cinematic, masterpiece</textarea>
291
+ </div>
292
+ </div>
293
+ <div class="">
294
+ <details>
295
+ <summary class="font-medium cursor-pointer">Advanced Options</summary>
296
+ <div class="grid grid-cols-3 sm:grid-cols-6 items-center gap-3 py-3">
297
+ <label for="webcams" class="text-sm font-medium">Camera Options: </label>
298
+ <select id="webcams" class="text-sm border-2 border-gray-500 rounded-md font-light dark:text-black">
299
+ </select>
300
+ <div></div>
301
+ <label class="text-sm font-medium " for="steps">Inference Steps
302
+ </label>
303
+ <input type="range" id="steps" name="steps" min="1" max="20" value="4"
304
+ oninput="this.nextElementSibling.value = Number(this.value)">
305
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
306
+ 4</output>
307
+ <!-- -->
308
+ <label class="text-sm font-medium" for="lcm_steps">LCM Inference Steps
309
+ </label>
310
+ <input type="range" id="lcm_steps" name="lcm_steps" min="2" max="60" value="50"
311
+ oninput="this.nextElementSibling.value = Number(this.value)">
312
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
313
+ 50</output>
314
+ <!-- -->
315
+ <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
316
+ </label>
317
+ <input type="range" id="guidance-scale" name="guidance-scale" min="0" max="30" step="0.001"
318
+ value="8.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
319
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
320
+ 8.0</output>
321
+ <!-- -->
322
+ <label class="text-sm font-medium" for="strength">Strength</label>
323
+ <input type="range" id="strength" name="strength" min="0.1" max="1" step="0.001" value="0.50"
324
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
325
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
326
+ 0.5</output>
327
+ <!-- -->
328
+ <label class="text-sm font-medium" for="controlnet_scale">ControlNet Condition Scale</label>
329
+ <input type="range" id="controlnet_scale" name="controlnet_scale" min="0.0" max="1" step="0.001"
330
+ value="0.80" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
331
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
332
+ 0.8</output>
333
+ <!-- -->
334
+ <label class="text-sm font-medium" for="controlnet_start">ControlNet Guidance Start</label>
335
+ <input type="range" id="controlnet_start" name="controlnet_start" min="0.0" max="1.0" step="0.001"
336
+ value="0.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
337
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
338
+ 0.0</output>
339
+ <!-- -->
340
+ <label class="text-sm font-medium" for="controlnet_end">ControlNet Guidance End</label>
341
+ <input type="range" id="controlnet_end" name="controlnet_end" min="0.0" max="1.0" step="0.001"
342
+ value="1.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
343
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
344
+ 1.0</output>
345
+ <!-- -->
346
+ <label class="text-sm font-medium" for="canny_low_threshold">Canny Low Threshold</label>
347
+ <input type="range" id="canny_low_threshold" name="canny_low_threshold" min="0.0" max="1.0"
348
+ step="0.001" value="0.2"
349
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
350
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
351
+ 0.2</output>
352
+ <!-- -->
353
+ <label class="text-sm font-medium" for="canny_high_threshold">Canny High Threshold</label>
354
+ <input type="range" id="canny_high_threshold" name="canny_high_threshold" min="0.0" max="1.0"
355
+ step="0.001" value="0.8"
356
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
357
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
358
+ 0.8</output>
359
+ <!-- -->
360
+ <label class="text-sm font-medium" for="seed">Seed</label>
361
+ <input type="number" id="seed" name="seed" value="299792458"
362
+ class="font-light border border-gray-700 text-right rounded-md p-2 dark:text-black">
363
+ <button
364
+ onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
365
+ class="button">
366
+ Rand
367
+ </button>
368
+ <!-- -->
369
+ <!-- -->
370
+ <label class="text-sm font-medium" for="dimension">Image Dimensions</label>
371
+ <div class="col-span-2 flex gap-2">
372
+ <div class="flex gap-1">
373
+ <input type="radio" id="dimension512" name="dimension" value="[512,512]" checked
374
+ class="cursor-pointer">
375
+ <label for="dimension512" class="text-sm cursor-pointer">512x512</label>
376
+ </div>
377
+ <div class="flex gap-1">
378
+ <input type="radio" id="dimension768" name="dimension" value="[768,768]"
379
+ lass="cursor-pointer">
380
+ <label for="dimension768" class="text-sm cursor-pointer">768x768</label>
381
+ </div>
382
+ </div>
383
+ <!-- -->
384
+ </div>
385
+ </details>
386
+ </div>
387
+ <div class="flex gap-3">
388
+ <button id="start" class="button">
389
+ Start
390
+ </button>
391
+ <button id="stop" class="button">
392
+ Stop
393
+ </button>
394
+ <button id="snap" disabled class="button ml-auto">
395
+ Snapshot
396
+ </button>
397
+ </div>
398
+ <div class="relative rounded-lg border border-slate-300 overflow-hidden">
399
+ <img id="player" class="w-full aspect-square rounded-lg"
400
+ src="">
401
+ <div class="absolute top-0 left-0 w-1/4 aspect-square">
402
+ <video id="webcam" class="w-full aspect-square relative z-10 object-cover" playsinline autoplay muted
403
+ loop></video>
404
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 448" width="100"
405
+ class="w-full p-4 absolute top-0 opacity-20 z-0">
406
+ <path fill="currentColor"
407
+ d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z" />
408
+ </svg>
409
+ </div>
410
+ </div>
411
+ </main>
412
+ </body>
413
+
414
+ </html>
controlnet/tailwind.config.js ADDED
File without changes
latent_consistency_controlnet.py ADDED
@@ -0,0 +1,1094 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
25
+
26
+ from diffusers import (
27
+ AutoencoderKL,
28
+ AutoencoderTiny,
29
+ ConfigMixin,
30
+ DiffusionPipeline,
31
+ SchedulerMixin,
32
+ UNet2DConditionModel,
33
+ ControlNetModel,
34
+ logging,
35
+ )
36
+ from diffusers.configuration_utils import register_to_config
37
+ from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
38
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
39
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
40
+ StableDiffusionSafetyChecker,
41
+ )
42
+ from diffusers.utils import BaseOutput
43
+
44
+ from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
45
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
46
+
47
+
48
+ import PIL.Image
49
+
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
54
+ _optional_components = ["scheduler"]
55
+
56
+ def __init__(
57
+ self,
58
+ vae: AutoencoderKL,
59
+ text_encoder: CLIPTextModel,
60
+ tokenizer: CLIPTokenizer,
61
+ controlnet: Union[
62
+ ControlNetModel,
63
+ List[ControlNetModel],
64
+ Tuple[ControlNetModel],
65
+ MultiControlNetModel,
66
+ ],
67
+ unet: UNet2DConditionModel,
68
+ scheduler: "LCMScheduler",
69
+ safety_checker: StableDiffusionSafetyChecker,
70
+ feature_extractor: CLIPImageProcessor,
71
+ requires_safety_checker: bool = True,
72
+ ):
73
+ super().__init__()
74
+
75
+ scheduler = (
76
+ scheduler
77
+ if scheduler is not None
78
+ else LCMScheduler_X(
79
+ beta_start=0.00085,
80
+ beta_end=0.0120,
81
+ beta_schedule="scaled_linear",
82
+ prediction_type="epsilon",
83
+ )
84
+ )
85
+
86
+ self.register_modules(
87
+ vae=vae,
88
+ text_encoder=text_encoder,
89
+ tokenizer=tokenizer,
90
+ unet=unet,
91
+ controlnet=controlnet,
92
+ scheduler=scheduler,
93
+ safety_checker=safety_checker,
94
+ feature_extractor=feature_extractor,
95
+ )
96
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
97
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
98
+ self.control_image_processor = VaeImageProcessor(
99
+ vae_scale_factor=self.vae_scale_factor,
100
+ do_convert_rgb=True,
101
+ do_normalize=False,
102
+ )
103
+
104
+ def _encode_prompt(
105
+ self,
106
+ prompt,
107
+ device,
108
+ num_images_per_prompt,
109
+ prompt_embeds: None,
110
+ ):
111
+ r"""
112
+ Encodes the prompt into text encoder hidden states.
113
+ Args:
114
+ prompt (`str` or `List[str]`, *optional*):
115
+ prompt to be encoded
116
+ device: (`torch.device`):
117
+ torch device
118
+ num_images_per_prompt (`int`):
119
+ number of images that should be generated per prompt
120
+ prompt_embeds (`torch.FloatTensor`, *optional*):
121
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
122
+ provided, text embeddings will be generated from `prompt` input argument.
123
+ """
124
+
125
+ if prompt is not None and isinstance(prompt, str):
126
+ pass
127
+ elif prompt is not None and isinstance(prompt, list):
128
+ len(prompt)
129
+ else:
130
+ prompt_embeds.shape[0]
131
+
132
+ if prompt_embeds is None:
133
+ text_inputs = self.tokenizer(
134
+ prompt,
135
+ padding="max_length",
136
+ max_length=self.tokenizer.model_max_length,
137
+ truncation=True,
138
+ return_tensors="pt",
139
+ )
140
+ text_input_ids = text_inputs.input_ids
141
+ untruncated_ids = self.tokenizer(
142
+ prompt, padding="longest", return_tensors="pt"
143
+ ).input_ids
144
+
145
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
146
+ -1
147
+ ] and not torch.equal(text_input_ids, untruncated_ids):
148
+ removed_text = self.tokenizer.batch_decode(
149
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
150
+ )
151
+ logger.warning(
152
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
153
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
154
+ )
155
+
156
+ if (
157
+ hasattr(self.text_encoder.config, "use_attention_mask")
158
+ and self.text_encoder.config.use_attention_mask
159
+ ):
160
+ attention_mask = text_inputs.attention_mask.to(device)
161
+ else:
162
+ attention_mask = None
163
+
164
+ prompt_embeds = self.text_encoder(
165
+ text_input_ids.to(device),
166
+ attention_mask=attention_mask,
167
+ )
168
+ prompt_embeds = prompt_embeds[0]
169
+
170
+ if self.text_encoder is not None:
171
+ prompt_embeds_dtype = self.text_encoder.dtype
172
+ elif self.unet is not None:
173
+ prompt_embeds_dtype = self.unet.dtype
174
+ else:
175
+ prompt_embeds_dtype = prompt_embeds.dtype
176
+
177
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
178
+
179
+ bs_embed, seq_len, _ = prompt_embeds.shape
180
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
181
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
182
+ prompt_embeds = prompt_embeds.view(
183
+ bs_embed * num_images_per_prompt, seq_len, -1
184
+ )
185
+
186
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
187
+ return prompt_embeds
188
+
189
+ def run_safety_checker(self, image, device, dtype):
190
+ if self.safety_checker is None:
191
+ has_nsfw_concept = None
192
+ else:
193
+ if torch.is_tensor(image):
194
+ feature_extractor_input = self.image_processor.postprocess(
195
+ image, output_type="pil"
196
+ )
197
+ else:
198
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
199
+ safety_checker_input = self.feature_extractor(
200
+ feature_extractor_input, return_tensors="pt"
201
+ ).to(device)
202
+ image, has_nsfw_concept = self.safety_checker(
203
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
204
+ )
205
+ return image, has_nsfw_concept
206
+
207
+ def prepare_control_image(
208
+ self,
209
+ image,
210
+ width,
211
+ height,
212
+ batch_size,
213
+ num_images_per_prompt,
214
+ device,
215
+ dtype,
216
+ do_classifier_free_guidance=False,
217
+ guess_mode=False,
218
+ ):
219
+ image = self.control_image_processor.preprocess(
220
+ image, height=height, width=width
221
+ ).to(dtype=dtype)
222
+ image_batch_size = image.shape[0]
223
+
224
+ if image_batch_size == 1:
225
+ repeat_by = batch_size
226
+ else:
227
+ # image batch size is the same as prompt batch size
228
+ repeat_by = num_images_per_prompt
229
+
230
+ image = image.repeat_interleave(repeat_by, dim=0)
231
+
232
+ image = image.to(device=device, dtype=dtype)
233
+
234
+ if do_classifier_free_guidance and not guess_mode:
235
+ image = torch.cat([image] * 2)
236
+
237
+ return image
238
+
239
+ def prepare_latents(
240
+ self,
241
+ image,
242
+ timestep,
243
+ batch_size,
244
+ num_channels_latents,
245
+ height,
246
+ width,
247
+ dtype,
248
+ device,
249
+ latents=None,
250
+ generator=None,
251
+ ):
252
+ shape = (
253
+ batch_size,
254
+ num_channels_latents,
255
+ height // self.vae_scale_factor,
256
+ width // self.vae_scale_factor,
257
+ )
258
+
259
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
260
+ raise ValueError(
261
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
262
+ )
263
+
264
+ image = image.to(device=device, dtype=dtype)
265
+
266
+ # batch_size = batch_size * num_images_per_prompt
267
+
268
+ if image.shape[1] == 4:
269
+ init_latents = image
270
+
271
+ else:
272
+ if isinstance(generator, list) and len(generator) != batch_size:
273
+ raise ValueError(
274
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
275
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
276
+ )
277
+
278
+ elif isinstance(generator, list):
279
+ if isinstance(self.vae, AutoencoderTiny):
280
+ init_latents = [
281
+ self.vae.encode(image[i : i + 1]).latents
282
+ for i in range(batch_size)
283
+ ]
284
+ else:
285
+ init_latents = [
286
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i])
287
+ for i in range(batch_size)
288
+ ]
289
+ init_latents = torch.cat(init_latents, dim=0)
290
+ else:
291
+ if isinstance(self.vae, AutoencoderTiny):
292
+ init_latents = self.vae.encode(image).latents
293
+ else:
294
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
295
+
296
+ init_latents = self.vae.config.scaling_factor * init_latents
297
+
298
+ if (
299
+ batch_size > init_latents.shape[0]
300
+ and batch_size % init_latents.shape[0] == 0
301
+ ):
302
+ # expand init_latents for batch_size
303
+ deprecation_message = (
304
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
305
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
306
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
307
+ " your script to pass as many initial images as text prompts to suppress this warning."
308
+ )
309
+ # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
310
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
311
+ init_latents = torch.cat(
312
+ [init_latents] * additional_image_per_prompt, dim=0
313
+ )
314
+ elif (
315
+ batch_size > init_latents.shape[0]
316
+ and batch_size % init_latents.shape[0] != 0
317
+ ):
318
+ raise ValueError(
319
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
320
+ )
321
+ else:
322
+ init_latents = torch.cat([init_latents], dim=0)
323
+
324
+ shape = init_latents.shape
325
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
326
+
327
+ # get latents
328
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
329
+ latents = init_latents
330
+
331
+ return latents
332
+
333
+ if latents is None:
334
+ latents = torch.randn(shape, dtype=dtype).to(device)
335
+ else:
336
+ latents = latents.to(device)
337
+ # scale the initial noise by the standard deviation required by the scheduler
338
+ latents = latents * self.scheduler.init_noise_sigma
339
+ return latents
340
+
341
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
342
+ """
343
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
344
+ Args:
345
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
346
+ embedding_dim: int: dimension of the embeddings to generate
347
+ dtype: data type of the generated embeddings
348
+ Returns:
349
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
350
+ """
351
+ assert len(w.shape) == 1
352
+ w = w * 1000.0
353
+
354
+ half_dim = embedding_dim // 2
355
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
356
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
357
+ emb = w.to(dtype)[:, None] * emb[None, :]
358
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
359
+ if embedding_dim % 2 == 1: # zero pad
360
+ emb = torch.nn.functional.pad(emb, (0, 1))
361
+ assert emb.shape == (w.shape[0], embedding_dim)
362
+ return emb
363
+
364
+ def get_timesteps(self, num_inference_steps, strength, device):
365
+ # get the original timestep using init_timestep
366
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
367
+
368
+ t_start = max(num_inference_steps - init_timestep, 0)
369
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
370
+
371
+ return timesteps, num_inference_steps - t_start
372
+
373
+ @torch.no_grad()
374
+ def __call__(
375
+ self,
376
+ prompt: Union[str, List[str]] = None,
377
+ image: PipelineImageInput = None,
378
+ control_image: PipelineImageInput = None,
379
+ strength: float = 0.8,
380
+ height: Optional[int] = 768,
381
+ width: Optional[int] = 768,
382
+ guidance_scale: float = 7.5,
383
+ num_images_per_prompt: Optional[int] = 1,
384
+ latents: Optional[torch.FloatTensor] = None,
385
+ generator: Optional[torch.Generator] = None,
386
+ num_inference_steps: int = 4,
387
+ lcm_origin_steps: int = 50,
388
+ prompt_embeds: Optional[torch.FloatTensor] = None,
389
+ output_type: Optional[str] = "pil",
390
+ return_dict: bool = True,
391
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
392
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
393
+ guess_mode: bool = True,
394
+ control_guidance_start: Union[float, List[float]] = 0.0,
395
+ control_guidance_end: Union[float, List[float]] = 1.0,
396
+ ):
397
+ controlnet = (
398
+ self.controlnet._orig_mod
399
+ if is_compiled_module(self.controlnet)
400
+ else self.controlnet
401
+ )
402
+ # 0. Default height and width to unet
403
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
404
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
405
+ if not isinstance(control_guidance_start, list) and isinstance(
406
+ control_guidance_end, list
407
+ ):
408
+ control_guidance_start = len(control_guidance_end) * [
409
+ control_guidance_start
410
+ ]
411
+ elif not isinstance(control_guidance_end, list) and isinstance(
412
+ control_guidance_start, list
413
+ ):
414
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
415
+ elif not isinstance(control_guidance_start, list) and not isinstance(
416
+ control_guidance_end, list
417
+ ):
418
+ mult = (
419
+ len(controlnet.nets)
420
+ if isinstance(controlnet, MultiControlNetModel)
421
+ else 1
422
+ )
423
+ control_guidance_start, control_guidance_end = mult * [
424
+ control_guidance_start
425
+ ], mult * [control_guidance_end]
426
+ # 2. Define call parameters
427
+ if prompt is not None and isinstance(prompt, str):
428
+ batch_size = 1
429
+ elif prompt is not None and isinstance(prompt, list):
430
+ batch_size = len(prompt)
431
+ else:
432
+ batch_size = prompt_embeds.shape[0]
433
+
434
+ device = self._execution_device
435
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
436
+ global_pool_conditions = (
437
+ controlnet.config.global_pool_conditions
438
+ if isinstance(controlnet, ControlNetModel)
439
+ else controlnet.nets[0].config.global_pool_conditions
440
+ )
441
+ guess_mode = guess_mode or global_pool_conditions
442
+ # 3. Encode input prompt
443
+ prompt_embeds = self._encode_prompt(
444
+ prompt,
445
+ device,
446
+ num_images_per_prompt,
447
+ prompt_embeds=prompt_embeds,
448
+ )
449
+
450
+ # 3.5 encode image
451
+ image = self.image_processor.preprocess(image)
452
+
453
+ if isinstance(controlnet, ControlNetModel):
454
+ control_image = self.prepare_control_image(
455
+ image=control_image,
456
+ width=width,
457
+ height=height,
458
+ batch_size=batch_size * num_images_per_prompt,
459
+ num_images_per_prompt=num_images_per_prompt,
460
+ device=device,
461
+ dtype=controlnet.dtype,
462
+ guess_mode=guess_mode,
463
+ )
464
+ elif isinstance(controlnet, MultiControlNetModel):
465
+ control_images = []
466
+
467
+ for control_image_ in control_image:
468
+ control_image_ = self.prepare_control_image(
469
+ image=control_image_,
470
+ width=width,
471
+ height=height,
472
+ batch_size=batch_size * num_images_per_prompt,
473
+ num_images_per_prompt=num_images_per_prompt,
474
+ device=device,
475
+ dtype=controlnet.dtype,
476
+ do_classifier_free_guidance=do_classifier_free_guidance,
477
+ guess_mode=guess_mode,
478
+ )
479
+
480
+ control_images.append(control_image_)
481
+
482
+ control_image = control_images
483
+ else:
484
+ assert False
485
+
486
+ # 4. Prepare timesteps
487
+ self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)
488
+ # timesteps = self.scheduler.timesteps
489
+ # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
490
+ timesteps = self.scheduler.timesteps
491
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
492
+
493
+ print("timesteps: ", timesteps)
494
+
495
+ # 5. Prepare latent variable
496
+ num_channels_latents = self.unet.config.in_channels
497
+ latents = self.prepare_latents(
498
+ image,
499
+ latent_timestep,
500
+ batch_size * num_images_per_prompt,
501
+ num_channels_latents,
502
+ height,
503
+ width,
504
+ prompt_embeds.dtype,
505
+ device,
506
+ latents,
507
+ )
508
+ bs = batch_size * num_images_per_prompt
509
+
510
+ # 6. Get Guidance Scale Embedding
511
+ w = torch.tensor(guidance_scale).repeat(bs)
512
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(
513
+ device=device, dtype=latents.dtype
514
+ )
515
+ controlnet_keep = []
516
+ for i in range(len(timesteps)):
517
+ keeps = [
518
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
519
+ for s, e in zip(control_guidance_start, control_guidance_end)
520
+ ]
521
+ controlnet_keep.append(
522
+ keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
523
+ )
524
+ # 7. LCM MultiStep Sampling Loop:
525
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
526
+ for i, t in enumerate(timesteps):
527
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
528
+ latents = latents.to(prompt_embeds.dtype)
529
+ if guess_mode:
530
+ # Infer ControlNet only for the conditional batch.
531
+ control_model_input = latents
532
+ control_model_input = self.scheduler.scale_model_input(
533
+ control_model_input, ts
534
+ )
535
+ controlnet_prompt_embeds = prompt_embeds
536
+ else:
537
+ control_model_input = latents
538
+ controlnet_prompt_embeds = prompt_embeds
539
+ if isinstance(controlnet_keep[i], list):
540
+ cond_scale = [
541
+ c * s
542
+ for c, s in zip(
543
+ controlnet_conditioning_scale, controlnet_keep[i]
544
+ )
545
+ ]
546
+ else:
547
+ controlnet_cond_scale = controlnet_conditioning_scale
548
+ if isinstance(controlnet_cond_scale, list):
549
+ controlnet_cond_scale = controlnet_cond_scale[0]
550
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
551
+
552
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
553
+ control_model_input,
554
+ ts,
555
+ encoder_hidden_states=controlnet_prompt_embeds,
556
+ controlnet_cond=control_image,
557
+ conditioning_scale=cond_scale,
558
+ guess_mode=guess_mode,
559
+ return_dict=False,
560
+ )
561
+ # model prediction (v-prediction, eps, x)
562
+ model_pred = self.unet(
563
+ latents,
564
+ ts,
565
+ timestep_cond=w_embedding,
566
+ encoder_hidden_states=prompt_embeds,
567
+ cross_attention_kwargs=cross_attention_kwargs,
568
+ down_block_additional_residuals=down_block_res_samples,
569
+ mid_block_additional_residual=mid_block_res_sample,
570
+ return_dict=False,
571
+ )[0]
572
+
573
+ # compute the previous noisy sample x_t -> x_t-1
574
+ latents, denoised = self.scheduler.step(
575
+ model_pred, i, t, latents, return_dict=False
576
+ )
577
+
578
+ # # call the callback, if provided
579
+ # if i == len(timesteps) - 1:
580
+ progress_bar.update()
581
+
582
+ denoised = denoised.to(prompt_embeds.dtype)
583
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
584
+ self.unet.to("cpu")
585
+ self.controlnet.to("cpu")
586
+ torch.cuda.empty_cache()
587
+ if not output_type == "latent":
588
+ image = self.vae.decode(
589
+ denoised / self.vae.config.scaling_factor, return_dict=False
590
+ )[0]
591
+ image, has_nsfw_concept = self.run_safety_checker(
592
+ image, device, prompt_embeds.dtype
593
+ )
594
+ else:
595
+ image = denoised
596
+ has_nsfw_concept = None
597
+
598
+ if has_nsfw_concept is None:
599
+ do_denormalize = [True] * image.shape[0]
600
+ else:
601
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
602
+
603
+ image = self.image_processor.postprocess(
604
+ image, output_type=output_type, do_denormalize=do_denormalize
605
+ )
606
+
607
+ if not return_dict:
608
+ return (image, has_nsfw_concept)
609
+
610
+ return StableDiffusionPipelineOutput(
611
+ images=image, nsfw_content_detected=has_nsfw_concept
612
+ )
613
+
614
+
615
+ @dataclass
616
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
617
+ class LCMSchedulerOutput(BaseOutput):
618
+ """
619
+ Output class for the scheduler's `step` function output.
620
+ Args:
621
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
622
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
623
+ denoising loop.
624
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
625
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
626
+ `pred_original_sample` can be used to preview progress or for guidance.
627
+ """
628
+
629
+ prev_sample: torch.FloatTensor
630
+ denoised: Optional[torch.FloatTensor] = None
631
+
632
+
633
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
634
+ def betas_for_alpha_bar(
635
+ num_diffusion_timesteps,
636
+ max_beta=0.999,
637
+ alpha_transform_type="cosine",
638
+ ):
639
+ """
640
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
641
+ (1-beta) over time from t = [0,1].
642
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
643
+ to that part of the diffusion process.
644
+ Args:
645
+ num_diffusion_timesteps (`int`): the number of betas to produce.
646
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
647
+ prevent singularities.
648
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
649
+ Choose from `cosine` or `exp`
650
+ Returns:
651
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
652
+ """
653
+ if alpha_transform_type == "cosine":
654
+
655
+ def alpha_bar_fn(t):
656
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
657
+
658
+ elif alpha_transform_type == "exp":
659
+
660
+ def alpha_bar_fn(t):
661
+ return math.exp(t * -12.0)
662
+
663
+ else:
664
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
665
+
666
+ betas = []
667
+ for i in range(num_diffusion_timesteps):
668
+ t1 = i / num_diffusion_timesteps
669
+ t2 = (i + 1) / num_diffusion_timesteps
670
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
671
+ return torch.tensor(betas, dtype=torch.float32)
672
+
673
+
674
+ def rescale_zero_terminal_snr(betas):
675
+ """
676
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
677
+ Args:
678
+ betas (`torch.FloatTensor`):
679
+ the betas that the scheduler is being initialized with.
680
+ Returns:
681
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
682
+ """
683
+ # Convert betas to alphas_bar_sqrt
684
+ alphas = 1.0 - betas
685
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
686
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
687
+
688
+ # Store old values.
689
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
690
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
691
+
692
+ # Shift so the last timestep is zero.
693
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
694
+
695
+ # Scale so the first timestep is back to the old value.
696
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
697
+
698
+ # Convert alphas_bar_sqrt to betas
699
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
700
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
701
+ alphas = torch.cat([alphas_bar[0:1], alphas])
702
+ betas = 1 - alphas
703
+
704
+ return betas
705
+
706
+
707
+ class LCMScheduler_X(SchedulerMixin, ConfigMixin):
708
+ """
709
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
710
+ non-Markovian guidance.
711
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
712
+ methods the library implements for all schedulers such as loading and saving.
713
+ Args:
714
+ num_train_timesteps (`int`, defaults to 1000):
715
+ The number of diffusion steps to train the model.
716
+ beta_start (`float`, defaults to 0.0001):
717
+ The starting `beta` value of inference.
718
+ beta_end (`float`, defaults to 0.02):
719
+ The final `beta` value.
720
+ beta_schedule (`str`, defaults to `"linear"`):
721
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
722
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
723
+ trained_betas (`np.ndarray`, *optional*):
724
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
725
+ clip_sample (`bool`, defaults to `True`):
726
+ Clip the predicted sample for numerical stability.
727
+ clip_sample_range (`float`, defaults to 1.0):
728
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
729
+ set_alpha_to_one (`bool`, defaults to `True`):
730
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
731
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
732
+ otherwise it uses the alpha value at step 0.
733
+ steps_offset (`int`, defaults to 0):
734
+ An offset added to the inference steps. You can use a combination of `offset=1` and
735
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
736
+ Diffusion.
737
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
738
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
739
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
740
+ Video](https://imagen.research.google/video/paper.pdf) paper).
741
+ thresholding (`bool`, defaults to `False`):
742
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
743
+ as Stable Diffusion.
744
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
745
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
746
+ sample_max_value (`float`, defaults to 1.0):
747
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
748
+ timestep_spacing (`str`, defaults to `"leading"`):
749
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
750
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
751
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
752
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
753
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
754
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
755
+ """
756
+
757
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
758
+ order = 1
759
+
760
+ @register_to_config
761
+ def __init__(
762
+ self,
763
+ num_train_timesteps: int = 1000,
764
+ beta_start: float = 0.0001,
765
+ beta_end: float = 0.02,
766
+ beta_schedule: str = "linear",
767
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
768
+ clip_sample: bool = True,
769
+ set_alpha_to_one: bool = True,
770
+ steps_offset: int = 0,
771
+ prediction_type: str = "epsilon",
772
+ thresholding: bool = False,
773
+ dynamic_thresholding_ratio: float = 0.995,
774
+ clip_sample_range: float = 1.0,
775
+ sample_max_value: float = 1.0,
776
+ timestep_spacing: str = "leading",
777
+ rescale_betas_zero_snr: bool = False,
778
+ ):
779
+ if trained_betas is not None:
780
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
781
+ elif beta_schedule == "linear":
782
+ self.betas = torch.linspace(
783
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
784
+ )
785
+ elif beta_schedule == "scaled_linear":
786
+ # this schedule is very specific to the latent diffusion model.
787
+ self.betas = (
788
+ torch.linspace(
789
+ beta_start**0.5,
790
+ beta_end**0.5,
791
+ num_train_timesteps,
792
+ dtype=torch.float32,
793
+ )
794
+ ** 2
795
+ )
796
+ elif beta_schedule == "squaredcos_cap_v2":
797
+ # Glide cosine schedule
798
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
799
+ else:
800
+ raise NotImplementedError(
801
+ f"{beta_schedule} does is not implemented for {self.__class__}"
802
+ )
803
+
804
+ # Rescale for zero SNR
805
+ if rescale_betas_zero_snr:
806
+ self.betas = rescale_zero_terminal_snr(self.betas)
807
+
808
+ self.alphas = 1.0 - self.betas
809
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
810
+
811
+ # At every step in ddim, we are looking into the previous alphas_cumprod
812
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
813
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
814
+ # whether we use the final alpha of the "non-previous" one.
815
+ self.final_alpha_cumprod = (
816
+ torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
817
+ )
818
+
819
+ # standard deviation of the initial noise distribution
820
+ self.init_noise_sigma = 1.0
821
+
822
+ # setable values
823
+ self.num_inference_steps = None
824
+ self.timesteps = torch.from_numpy(
825
+ np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)
826
+ )
827
+
828
+ def scale_model_input(
829
+ self, sample: torch.FloatTensor, timestep: Optional[int] = None
830
+ ) -> torch.FloatTensor:
831
+ """
832
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
833
+ current timestep.
834
+ Args:
835
+ sample (`torch.FloatTensor`):
836
+ The input sample.
837
+ timestep (`int`, *optional*):
838
+ The current timestep in the diffusion chain.
839
+ Returns:
840
+ `torch.FloatTensor`:
841
+ A scaled input sample.
842
+ """
843
+ return sample
844
+
845
+ def _get_variance(self, timestep, prev_timestep):
846
+ alpha_prod_t = self.alphas_cumprod[timestep]
847
+ alpha_prod_t_prev = (
848
+ self.alphas_cumprod[prev_timestep]
849
+ if prev_timestep >= 0
850
+ else self.final_alpha_cumprod
851
+ )
852
+ beta_prod_t = 1 - alpha_prod_t
853
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
854
+
855
+ variance = (beta_prod_t_prev / beta_prod_t) * (
856
+ 1 - alpha_prod_t / alpha_prod_t_prev
857
+ )
858
+
859
+ return variance
860
+
861
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
862
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
863
+ """
864
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
865
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
866
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
867
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
868
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
869
+ https://arxiv.org/abs/2205.11487
870
+ """
871
+ dtype = sample.dtype
872
+ batch_size, channels, height, width = sample.shape
873
+
874
+ if dtype not in (torch.float32, torch.float64):
875
+ sample = (
876
+ sample.float()
877
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
878
+
879
+ # Flatten sample for doing quantile calculation along each image
880
+ sample = sample.reshape(batch_size, channels * height * width)
881
+
882
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
883
+
884
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
885
+ s = torch.clamp(
886
+ s, min=1, max=self.config.sample_max_value
887
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
888
+
889
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
890
+ sample = (
891
+ torch.clamp(sample, -s, s) / s
892
+ ) # "we threshold xt0 to the range [-s, s] and then divide by s"
893
+
894
+ sample = sample.reshape(batch_size, channels, height, width)
895
+ sample = sample.to(dtype)
896
+
897
+ return sample
898
+
899
+ def set_timesteps(
900
+ self,
901
+ stength,
902
+ num_inference_steps: int,
903
+ lcm_origin_steps: int,
904
+ device: Union[str, torch.device] = None,
905
+ ):
906
+ """
907
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
908
+ Args:
909
+ num_inference_steps (`int`):
910
+ The number of diffusion steps used when generating samples with a pre-trained model.
911
+ """
912
+
913
+ if num_inference_steps > self.config.num_train_timesteps:
914
+ raise ValueError(
915
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
916
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
917
+ f" maximal {self.config.num_train_timesteps} timesteps."
918
+ )
919
+
920
+ self.num_inference_steps = num_inference_steps
921
+
922
+ # LCM Timesteps Setting: # Linear Spacing
923
+ c = self.config.num_train_timesteps // lcm_origin_steps
924
+ lcm_origin_timesteps = (
925
+ np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1
926
+ ) # LCM Training Steps Schedule
927
+ skipping_step = max(len(lcm_origin_timesteps) // num_inference_steps, 1)
928
+ timesteps = lcm_origin_timesteps[::-skipping_step][
929
+ :num_inference_steps
930
+ ] # LCM Inference Steps Schedule
931
+
932
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
933
+
934
+ def get_scalings_for_boundary_condition_discrete(self, t):
935
+ self.sigma_data = 0.5 # Default: 0.5
936
+
937
+ # By dividing 0.1: This is almost a delta function at t=0.
938
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
939
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
940
+ return c_skip, c_out
941
+
942
+ def step(
943
+ self,
944
+ model_output: torch.FloatTensor,
945
+ timeindex: int,
946
+ timestep: int,
947
+ sample: torch.FloatTensor,
948
+ eta: float = 0.0,
949
+ use_clipped_model_output: bool = False,
950
+ generator=None,
951
+ variance_noise: Optional[torch.FloatTensor] = None,
952
+ return_dict: bool = True,
953
+ ) -> Union[LCMSchedulerOutput, Tuple]:
954
+ """
955
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
956
+ process from the learned model outputs (most often the predicted noise).
957
+ Args:
958
+ model_output (`torch.FloatTensor`):
959
+ The direct output from learned diffusion model.
960
+ timestep (`float`):
961
+ The current discrete timestep in the diffusion chain.
962
+ sample (`torch.FloatTensor`):
963
+ A current instance of a sample created by the diffusion process.
964
+ eta (`float`):
965
+ The weight of noise for added noise in diffusion step.
966
+ use_clipped_model_output (`bool`, defaults to `False`):
967
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
968
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
969
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
970
+ `use_clipped_model_output` has no effect.
971
+ generator (`torch.Generator`, *optional*):
972
+ A random number generator.
973
+ variance_noise (`torch.FloatTensor`):
974
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
975
+ itself. Useful for methods such as [`CycleDiffusion`].
976
+ return_dict (`bool`, *optional*, defaults to `True`):
977
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
978
+ Returns:
979
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
980
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
981
+ tuple is returned where the first element is the sample tensor.
982
+ """
983
+ if self.num_inference_steps is None:
984
+ raise ValueError(
985
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
986
+ )
987
+
988
+ # 1. get previous step value
989
+ prev_timeindex = timeindex + 1
990
+ if prev_timeindex < len(self.timesteps):
991
+ prev_timestep = self.timesteps[prev_timeindex]
992
+ else:
993
+ prev_timestep = timestep
994
+
995
+ # 2. compute alphas, betas
996
+ alpha_prod_t = self.alphas_cumprod[timestep]
997
+ alpha_prod_t_prev = (
998
+ self.alphas_cumprod[prev_timestep]
999
+ if prev_timestep >= 0
1000
+ else self.final_alpha_cumprod
1001
+ )
1002
+
1003
+ beta_prod_t = 1 - alpha_prod_t
1004
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
1005
+
1006
+ # 3. Get scalings for boundary conditions
1007
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
1008
+
1009
+ # 4. Different Parameterization:
1010
+ parameterization = self.config.prediction_type
1011
+
1012
+ if parameterization == "epsilon": # noise-prediction
1013
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
1014
+
1015
+ elif parameterization == "sample": # x-prediction
1016
+ pred_x0 = model_output
1017
+
1018
+ elif parameterization == "v_prediction": # v-prediction
1019
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
1020
+
1021
+ # 4. Denoise model output using boundary conditions
1022
+ denoised = c_out * pred_x0 + c_skip * sample
1023
+
1024
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
1025
+ # Noise is not used for one-step sampling.
1026
+ if len(self.timesteps) > 1:
1027
+ noise = torch.randn(model_output.shape).to(model_output.device)
1028
+ prev_sample = (
1029
+ alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
1030
+ )
1031
+ else:
1032
+ prev_sample = denoised
1033
+
1034
+ if not return_dict:
1035
+ return (prev_sample, denoised)
1036
+
1037
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
1038
+
1039
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
1040
+ def add_noise(
1041
+ self,
1042
+ original_samples: torch.FloatTensor,
1043
+ noise: torch.FloatTensor,
1044
+ timesteps: torch.IntTensor,
1045
+ ) -> torch.FloatTensor:
1046
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1047
+ alphas_cumprod = self.alphas_cumprod.to(
1048
+ device=original_samples.device, dtype=original_samples.dtype
1049
+ )
1050
+ timesteps = timesteps.to(original_samples.device)
1051
+
1052
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
1053
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1054
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
1055
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1056
+
1057
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
1058
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1059
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
1060
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1061
+
1062
+ noisy_samples = (
1063
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
1064
+ )
1065
+ return noisy_samples
1066
+
1067
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
1068
+ def get_velocity(
1069
+ self,
1070
+ sample: torch.FloatTensor,
1071
+ noise: torch.FloatTensor,
1072
+ timesteps: torch.IntTensor,
1073
+ ) -> torch.FloatTensor:
1074
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
1075
+ alphas_cumprod = self.alphas_cumprod.to(
1076
+ device=sample.device, dtype=sample.dtype
1077
+ )
1078
+ timesteps = timesteps.to(sample.device)
1079
+
1080
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
1081
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1082
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
1083
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1084
+
1085
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
1086
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1087
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
1088
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1089
+
1090
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
1091
+ return velocity
1092
+
1093
+ def __len__(self):
1094
+ return self.config.num_train_timesteps
requirements.txt CHANGED
@@ -7,4 +7,5 @@ fastapi==0.104.0
7
  uvicorn==0.23.2
8
  Pillow==10.1.0
9
  accelerate==0.24.0
10
- compel==2.0.2
 
 
7
  uvicorn==0.23.2
8
  Pillow==10.1.0
9
  accelerate==0.24.0
10
+ compel==2.0.2
11
+ controlnet-aux==0.0.7