radames HF staff commited on
Commit
3e47535
1 Parent(s): ae27e5e
Dockerfile CHANGED
@@ -36,4 +36,5 @@ WORKDIR $HOME/app
36
  # Copy the current directory contents into the container at $HOME/app setting the owner to the user
37
  COPY --chown=user . $HOME/app
38
 
39
- CMD ["uvicorn", "app-img2img:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
36
  # Copy the current directory contents into the container at $HOME/app setting the owner to the user
37
  COPY --chown=user . $HOME/app
38
 
39
+ CMD ["uvicorn", "app-img2img:app", "--host", "0.0.0.0", "--port", "7860"]
40
+ # CMD ["uvicorn", "app-txt2img:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -20,12 +20,22 @@ You need CUDA and Python
20
  `SAFETY_CHECKER`: disabled if you want NSFW filter off
21
  `MAX_QUEUE_SIZE`: limit number of users on current app instance
22
 
 
23
  ```bash
24
  python -m venv venv
25
  source venv/bin/activate
26
  pip install -r requirements.txt
27
  uvicorn "app-img2img:app" --host 0.0.0.0 --port 7860 --reload
28
  ```
 
 
 
 
 
 
 
 
 
29
  or with environment variables
30
  ```bash
31
  TIMEOUT=120 SAFETY_CHECKER=True MAX_QUEUE_SIZE=4 uvicorn "app-img2img:app" --host 0.0.0.0 --port 7860 --reload
 
20
  `SAFETY_CHECKER`: disabled if you want NSFW filter off
21
  `MAX_QUEUE_SIZE`: limit number of users on current app instance
22
 
23
+ ### image to image
24
  ```bash
25
  python -m venv venv
26
  source venv/bin/activate
27
  pip install -r requirements.txt
28
  uvicorn "app-img2img:app" --host 0.0.0.0 --port 7860 --reload
29
  ```
30
+
31
+ ### text to image
32
+
33
+ ```bash
34
+ python -m venv venv
35
+ source venv/bin/activate
36
+ pip install -r requirements.txt
37
+ uvicorn "app-txt2img:app" --host 0.0.0.0 --port 7860 --reload
38
+ ```
39
  or with environment variables
40
  ```bash
41
  TIMEOUT=120 SAFETY_CHECKER=True MAX_QUEUE_SIZE=4 uvicorn "app-img2img:app" --host 0.0.0.0 --port 7860 --reload
app-img2img.py CHANGED
@@ -67,7 +67,7 @@ def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232)
67
  strength=strength,
68
  num_inference_steps=num_inference_steps,
69
  guidance_scale=guidance_scale,
70
- lcm_origin_steps=20,
71
  output_type="pil",
72
  )
73
  nsfw_content_detected = (
 
67
  strength=strength,
68
  num_inference_steps=num_inference_steps,
69
  guidance_scale=guidance_scale,
70
+ lcm_origin_steps=50,
71
  output_type="pil",
72
  )
73
  nsfw_content_detected = (
app-txt2img.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import traceback
5
+ from pydantic import BaseModel
6
+
7
+ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import StreamingResponse, JSONResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+
12
+ from diffusers import DiffusionPipeline, AutoencoderTiny
13
+ from compel import Compel
14
+ import torch
15
+ from PIL import Image
16
+ import numpy as np
17
+ import gradio as gr
18
+ import io
19
+ import uuid
20
+ import os
21
+ import time
22
+
23
+ MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
24
+ TIMEOUT = float(os.environ.get("TIMEOUT", 0))
25
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
26
+
27
+ print(f"TIMEOUT: {TIMEOUT}")
28
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
29
+ print(f"MAX_QUEUE_SIZE: {MAX_QUEUE_SIZE}")
30
+
31
+ if SAFETY_CHECKER == "True":
32
+ pipe = DiffusionPipeline.from_pretrained(
33
+ "SimianLuo/LCM_Dreamshaper_v7",
34
+ custom_pipeline="latent_consistency_txt2img.py",
35
+ custom_revision="main",
36
+ )
37
+ else:
38
+ pipe = DiffusionPipeline.from_pretrained(
39
+ "SimianLuo/LCM_Dreamshaper_v7",
40
+ safety_checker=None,
41
+ custom_pipeline="latent_consistency_txt2img.py",
42
+ custom_revision="main",
43
+ )
44
+ pipe.vae = AutoencoderTiny.from_pretrained(
45
+ "madebyollin/taesd", torch_dtype=torch.float16, use_safetensors=True
46
+ )
47
+ pipe.set_progress_bar_config(disable=True)
48
+ pipe.to(torch_device="cuda", torch_dtype=torch.float16)
49
+ pipe.unet.to(memory_format=torch.channels_last)
50
+ compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate_long_prompts=False)
51
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
52
+ user_queue_map = {}
53
+
54
+ # warmup trigger compilation
55
+ pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
56
+
57
+ def predict(prompt, guidance_scale=8.0, seed=2159232):
58
+ generator = torch.manual_seed(seed)
59
+ prompt_embeds = compel_proc(prompt)
60
+ # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
61
+ num_inference_steps = 8
62
+ results = pipe(
63
+ prompt_embeds=prompt_embeds,
64
+ generator=generator,
65
+ num_inference_steps=num_inference_steps,
66
+ guidance_scale=guidance_scale,
67
+ lcm_origin_steps=50,
68
+ output_type="pil",
69
+ )
70
+ nsfw_content_detected = (
71
+ results.nsfw_content_detected[0]
72
+ if "nsfw_content_detected" in results
73
+ else False
74
+ )
75
+ if nsfw_content_detected:
76
+ return None
77
+ return results.images[0]
78
+
79
+
80
+ app = FastAPI()
81
+ app.add_middleware(
82
+ CORSMiddleware,
83
+ allow_origins=["*"],
84
+ allow_credentials=True,
85
+ allow_methods=["*"],
86
+ allow_headers=["*"],
87
+ )
88
+
89
+
90
+ class InputParams(BaseModel):
91
+ prompt: str
92
+ seed: int
93
+ guidance_scale: float
94
+
95
+
96
+ @app.websocket("/ws")
97
+ async def websocket_endpoint(websocket: WebSocket):
98
+ await websocket.accept()
99
+ if MAX_QUEUE_SIZE > 0 and len(user_queue_map) >= MAX_QUEUE_SIZE:
100
+ print("Server is full")
101
+ await websocket.send_json({"status": "error", "message": "Server is full"})
102
+ await websocket.close()
103
+ return
104
+
105
+ try:
106
+ uid = str(uuid.uuid4())
107
+ print(f"New user connected: {uid}")
108
+ await websocket.send_json(
109
+ {"status": "success", "message": "Connected", "userId": uid}
110
+ )
111
+ user_queue_map[uid] = {
112
+ "queue": asyncio.Queue(),
113
+ }
114
+ await websocket.send_json(
115
+ {"status": "start", "message": "Start Streaming", "userId": uid}
116
+ )
117
+ await handle_websocket_data(websocket, uid)
118
+ except WebSocketDisconnect as e:
119
+ logging.error(f"WebSocket Error: {e}, {uid}")
120
+ traceback.print_exc()
121
+ finally:
122
+ print(f"User disconnected: {uid}")
123
+ queue_value = user_queue_map.pop(uid, None)
124
+ queue = queue_value.get("queue", None)
125
+ if queue:
126
+ while not queue.empty():
127
+ try:
128
+ queue.get_nowait()
129
+ except asyncio.QueueEmpty:
130
+ continue
131
+
132
+
133
+ @app.get("/queue_size")
134
+ async def get_queue_size():
135
+ queue_size = len(user_queue_map)
136
+ return JSONResponse({"queue_size": queue_size})
137
+
138
+
139
+ @app.get("/stream/{user_id}")
140
+ async def stream(user_id: uuid.UUID):
141
+ uid = str(user_id)
142
+ try:
143
+ user_queue = user_queue_map[uid]
144
+ queue = user_queue["queue"]
145
+
146
+ async def generate():
147
+ while True:
148
+ params = await queue.get()
149
+ if params is None:
150
+ continue
151
+
152
+ image = predict(params.prompt, params.guidance_scale, params.seed)
153
+ if image is None:
154
+ continue
155
+ frame_data = io.BytesIO()
156
+ image.save(frame_data, format="JPEG")
157
+ frame_data = frame_data.getvalue()
158
+ if frame_data is not None and len(frame_data) > 0:
159
+ yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
160
+
161
+ await asyncio.sleep(1.0 / 120.0)
162
+
163
+ return StreamingResponse(
164
+ generate(), media_type="multipart/x-mixed-replace;boundary=frame"
165
+ )
166
+ except Exception as e:
167
+ logging.error(f"Streaming Error: {e}, {user_queue_map}")
168
+ traceback.print_exc()
169
+ return HTTPException(status_code=404, detail="User not found")
170
+
171
+
172
+ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
173
+ uid = str(user_id)
174
+ user_queue = user_queue_map[uid]
175
+ queue = user_queue["queue"]
176
+ if not queue:
177
+ return HTTPException(status_code=404, detail="User not found")
178
+ last_time = time.time()
179
+ try:
180
+ while True:
181
+ params = await websocket.receive_json()
182
+ params = InputParams(**params)
183
+ while not queue.empty():
184
+ try:
185
+ queue.get_nowait()
186
+ except asyncio.QueueEmpty:
187
+ continue
188
+ await queue.put(params)
189
+ if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
190
+ await websocket.send_json(
191
+ {
192
+ "status": "timeout",
193
+ "message": "Your session has ended",
194
+ "userId": uid,
195
+ }
196
+ )
197
+ await websocket.close()
198
+ return
199
+
200
+ except Exception as e:
201
+ logging.error(f"Error: {e}")
202
+ traceback.print_exc()
203
+
204
+
205
+ app.mount("/", StaticFiles(directory="txt2img", html=True), name="public")
latent_consistency_txt2img.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
25
+
26
+ from diffusers import (
27
+ AutoencoderKL,
28
+ ConfigMixin,
29
+ DiffusionPipeline,
30
+ SchedulerMixin,
31
+ UNet2DConditionModel,
32
+ logging,
33
+ )
34
+ from diffusers.configuration_utils import register_to_config
35
+ from diffusers.image_processor import VaeImageProcessor
36
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
37
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
38
+ StableDiffusionSafetyChecker,
39
+ )
40
+ from diffusers.utils import BaseOutput
41
+
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+
46
+ class LatentConsistencyModelPipeline(DiffusionPipeline):
47
+ _optional_components = ["scheduler"]
48
+
49
+ def __init__(
50
+ self,
51
+ vae: AutoencoderKL,
52
+ text_encoder: CLIPTextModel,
53
+ tokenizer: CLIPTokenizer,
54
+ unet: UNet2DConditionModel,
55
+ scheduler: "LCMScheduler",
56
+ safety_checker: StableDiffusionSafetyChecker,
57
+ feature_extractor: CLIPImageProcessor,
58
+ requires_safety_checker: bool = True,
59
+ ):
60
+ super().__init__()
61
+
62
+ scheduler = (
63
+ scheduler
64
+ if scheduler is not None
65
+ else LCMScheduler(
66
+ beta_start=0.00085,
67
+ beta_end=0.0120,
68
+ beta_schedule="scaled_linear",
69
+ prediction_type="epsilon",
70
+ )
71
+ )
72
+
73
+ self.register_modules(
74
+ vae=vae,
75
+ text_encoder=text_encoder,
76
+ tokenizer=tokenizer,
77
+ unet=unet,
78
+ scheduler=scheduler,
79
+ safety_checker=safety_checker,
80
+ feature_extractor=feature_extractor,
81
+ )
82
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
83
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
84
+
85
+ def _encode_prompt(
86
+ self,
87
+ prompt,
88
+ device,
89
+ num_images_per_prompt,
90
+ prompt_embeds: None,
91
+ ):
92
+ r"""
93
+ Encodes the prompt into text encoder hidden states.
94
+ Args:
95
+ prompt (`str` or `List[str]`, *optional*):
96
+ prompt to be encoded
97
+ device: (`torch.device`):
98
+ torch device
99
+ num_images_per_prompt (`int`):
100
+ number of images that should be generated per prompt
101
+ prompt_embeds (`torch.FloatTensor`, *optional*):
102
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
103
+ provided, text embeddings will be generated from `prompt` input argument.
104
+ """
105
+
106
+ if prompt is not None and isinstance(prompt, str):
107
+ pass
108
+ elif prompt is not None and isinstance(prompt, list):
109
+ len(prompt)
110
+ else:
111
+ prompt_embeds.shape[0]
112
+
113
+ if prompt_embeds is None:
114
+ text_inputs = self.tokenizer(
115
+ prompt,
116
+ padding="max_length",
117
+ max_length=self.tokenizer.model_max_length,
118
+ truncation=True,
119
+ return_tensors="pt",
120
+ )
121
+ text_input_ids = text_inputs.input_ids
122
+ untruncated_ids = self.tokenizer(
123
+ prompt, padding="longest", return_tensors="pt"
124
+ ).input_ids
125
+
126
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
127
+ -1
128
+ ] and not torch.equal(text_input_ids, untruncated_ids):
129
+ removed_text = self.tokenizer.batch_decode(
130
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
131
+ )
132
+ logger.warning(
133
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
134
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
135
+ )
136
+
137
+ if (
138
+ hasattr(self.text_encoder.config, "use_attention_mask")
139
+ and self.text_encoder.config.use_attention_mask
140
+ ):
141
+ attention_mask = text_inputs.attention_mask.to(device)
142
+ else:
143
+ attention_mask = None
144
+
145
+ prompt_embeds = self.text_encoder(
146
+ text_input_ids.to(device),
147
+ attention_mask=attention_mask,
148
+ )
149
+ prompt_embeds = prompt_embeds[0]
150
+
151
+ if self.text_encoder is not None:
152
+ prompt_embeds_dtype = self.text_encoder.dtype
153
+ elif self.unet is not None:
154
+ prompt_embeds_dtype = self.unet.dtype
155
+ else:
156
+ prompt_embeds_dtype = prompt_embeds.dtype
157
+
158
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
159
+
160
+ bs_embed, seq_len, _ = prompt_embeds.shape
161
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
162
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
163
+ prompt_embeds = prompt_embeds.view(
164
+ bs_embed * num_images_per_prompt, seq_len, -1
165
+ )
166
+
167
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
168
+ return prompt_embeds
169
+
170
+ def run_safety_checker(self, image, device, dtype):
171
+ if self.safety_checker is None:
172
+ has_nsfw_concept = None
173
+ else:
174
+ if torch.is_tensor(image):
175
+ feature_extractor_input = self.image_processor.postprocess(
176
+ image, output_type="pil"
177
+ )
178
+ else:
179
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
180
+ safety_checker_input = self.feature_extractor(
181
+ feature_extractor_input, return_tensors="pt"
182
+ ).to(device)
183
+ image, has_nsfw_concept = self.safety_checker(
184
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
185
+ )
186
+ return image, has_nsfw_concept
187
+
188
+ def prepare_latents(
189
+ self,
190
+ batch_size,
191
+ num_channels_latents,
192
+ height,
193
+ width,
194
+ dtype,
195
+ device,
196
+ latents=None,
197
+ generator=None,
198
+ ):
199
+ shape = (
200
+ batch_size,
201
+ num_channels_latents,
202
+ height // self.vae_scale_factor,
203
+ width // self.vae_scale_factor,
204
+ )
205
+ if generator is None:
206
+ generator = torch.Generator()
207
+ generator.manual_seed(torch.randint(0, 2 ** 32, (1,)).item())
208
+
209
+ if latents is None:
210
+ latents = torch.randn(shape, dtype=dtype, generator=generator).to(device)
211
+ else:
212
+ latents = latents.to(device)
213
+ # scale the initial noise by the standard deviation required by the scheduler
214
+ latents = latents * self.scheduler.init_noise_sigma
215
+ return latents
216
+
217
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
218
+ """
219
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
220
+ Args:
221
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
222
+ embedding_dim: int: dimension of the embeddings to generate
223
+ dtype: data type of the generated embeddings
224
+ Returns:
225
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
226
+ """
227
+ assert len(w.shape) == 1
228
+ w = w * 1000.0
229
+
230
+ half_dim = embedding_dim // 2
231
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
232
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
233
+ emb = w.to(dtype)[:, None] * emb[None, :]
234
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
235
+ if embedding_dim % 2 == 1: # zero pad
236
+ emb = torch.nn.functional.pad(emb, (0, 1))
237
+ assert emb.shape == (w.shape[0], embedding_dim)
238
+ return emb
239
+
240
+ @torch.no_grad()
241
+ def __call__(
242
+ self,
243
+ prompt: Union[str, List[str]] = None,
244
+ height: Optional[int] = 768,
245
+ width: Optional[int] = 768,
246
+ guidance_scale: float = 7.5,
247
+ num_images_per_prompt: Optional[int] = 1,
248
+ latents: Optional[torch.FloatTensor] = None,
249
+ generator: Optional[torch.Generator] = None,
250
+ num_inference_steps: int = 4,
251
+ lcm_origin_steps: int = 50,
252
+ prompt_embeds: Optional[torch.FloatTensor] = None,
253
+ output_type: Optional[str] = "pil",
254
+ return_dict: bool = True,
255
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
256
+ ):
257
+ # 0. Default height and width to unet
258
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
259
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
260
+
261
+ # 2. Define call parameters
262
+ if prompt is not None and isinstance(prompt, str):
263
+ batch_size = 1
264
+ elif prompt is not None and isinstance(prompt, list):
265
+ batch_size = len(prompt)
266
+ else:
267
+ batch_size = prompt_embeds.shape[0]
268
+
269
+ device = self._execution_device
270
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
271
+
272
+ # 3. Encode input prompt
273
+ prompt_embeds = self._encode_prompt(
274
+ prompt,
275
+ device,
276
+ num_images_per_prompt,
277
+ prompt_embeds=prompt_embeds,
278
+ )
279
+
280
+ # 4. Prepare timesteps
281
+ self.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps)
282
+ timesteps = self.scheduler.timesteps
283
+
284
+ # 5. Prepare latent variable
285
+ num_channels_latents = self.unet.config.in_channels
286
+ latents = self.prepare_latents(
287
+ batch_size * num_images_per_prompt,
288
+ num_channels_latents,
289
+ height,
290
+ width,
291
+ prompt_embeds.dtype,
292
+ device,
293
+ latents,
294
+ generator
295
+ )
296
+ bs = batch_size * num_images_per_prompt
297
+
298
+ # 6. Get Guidance Scale Embedding
299
+ w = torch.tensor(guidance_scale).repeat(bs)
300
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(
301
+ device=device, dtype=latents.dtype
302
+ )
303
+
304
+ # 7. LCM MultiStep Sampling Loop:
305
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
306
+ for i, t in enumerate(timesteps):
307
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
308
+ latents = latents.to(prompt_embeds.dtype)
309
+
310
+ # model prediction (v-prediction, eps, x)
311
+ model_pred = self.unet(
312
+ latents,
313
+ ts,
314
+ timestep_cond=w_embedding,
315
+ encoder_hidden_states=prompt_embeds,
316
+ cross_attention_kwargs=cross_attention_kwargs,
317
+ return_dict=False,
318
+ )[0]
319
+
320
+ # compute the previous noisy sample x_t -> x_t-1
321
+ latents, denoised = self.scheduler.step(
322
+ model_pred, i, t, latents, return_dict=False
323
+ )
324
+
325
+ # # call the callback, if provided
326
+ # if i == len(timesteps) - 1:
327
+ progress_bar.update()
328
+
329
+ denoised = denoised.to(prompt_embeds.dtype)
330
+ if not output_type == "latent":
331
+ image = self.vae.decode(
332
+ denoised / self.vae.config.scaling_factor, return_dict=False
333
+ )[0]
334
+ image, has_nsfw_concept = self.run_safety_checker(
335
+ image, device, prompt_embeds.dtype
336
+ )
337
+ else:
338
+ image = denoised
339
+ has_nsfw_concept = None
340
+
341
+ if has_nsfw_concept is None:
342
+ do_denormalize = [True] * image.shape[0]
343
+ else:
344
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
345
+
346
+ image = self.image_processor.postprocess(
347
+ image, output_type=output_type, do_denormalize=do_denormalize
348
+ )
349
+
350
+ if not return_dict:
351
+ return (image, has_nsfw_concept)
352
+
353
+ return StableDiffusionPipelineOutput(
354
+ images=image, nsfw_content_detected=has_nsfw_concept
355
+ )
356
+
357
+
358
+ @dataclass
359
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
360
+ class LCMSchedulerOutput(BaseOutput):
361
+ """
362
+ Output class for the scheduler's `step` function output.
363
+ Args:
364
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
365
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
366
+ denoising loop.
367
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
368
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
369
+ `pred_original_sample` can be used to preview progress or for guidance.
370
+ """
371
+
372
+ prev_sample: torch.FloatTensor
373
+ denoised: Optional[torch.FloatTensor] = None
374
+
375
+
376
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
377
+ def betas_for_alpha_bar(
378
+ num_diffusion_timesteps,
379
+ max_beta=0.999,
380
+ alpha_transform_type="cosine",
381
+ ):
382
+ """
383
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
384
+ (1-beta) over time from t = [0,1].
385
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
386
+ to that part of the diffusion process.
387
+ Args:
388
+ num_diffusion_timesteps (`int`): the number of betas to produce.
389
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
390
+ prevent singularities.
391
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
392
+ Choose from `cosine` or `exp`
393
+ Returns:
394
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
395
+ """
396
+ if alpha_transform_type == "cosine":
397
+
398
+ def alpha_bar_fn(t):
399
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
400
+
401
+ elif alpha_transform_type == "exp":
402
+
403
+ def alpha_bar_fn(t):
404
+ return math.exp(t * -12.0)
405
+
406
+ else:
407
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
408
+
409
+ betas = []
410
+ for i in range(num_diffusion_timesteps):
411
+ t1 = i / num_diffusion_timesteps
412
+ t2 = (i + 1) / num_diffusion_timesteps
413
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
414
+ return torch.tensor(betas, dtype=torch.float32)
415
+
416
+
417
+ def rescale_zero_terminal_snr(betas):
418
+ """
419
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
420
+ Args:
421
+ betas (`torch.FloatTensor`):
422
+ the betas that the scheduler is being initialized with.
423
+ Returns:
424
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
425
+ """
426
+ # Convert betas to alphas_bar_sqrt
427
+ alphas = 1.0 - betas
428
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
429
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
430
+
431
+ # Store old values.
432
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
433
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
434
+
435
+ # Shift so the last timestep is zero.
436
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
437
+
438
+ # Scale so the first timestep is back to the old value.
439
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
440
+
441
+ # Convert alphas_bar_sqrt to betas
442
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
443
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
444
+ alphas = torch.cat([alphas_bar[0:1], alphas])
445
+ betas = 1 - alphas
446
+
447
+ return betas
448
+
449
+
450
+ class LCMScheduler(SchedulerMixin, ConfigMixin):
451
+ """
452
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
453
+ non-Markovian guidance.
454
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
455
+ methods the library implements for all schedulers such as loading and saving.
456
+ Args:
457
+ num_train_timesteps (`int`, defaults to 1000):
458
+ The number of diffusion steps to train the model.
459
+ beta_start (`float`, defaults to 0.0001):
460
+ The starting `beta` value of inference.
461
+ beta_end (`float`, defaults to 0.02):
462
+ The final `beta` value.
463
+ beta_schedule (`str`, defaults to `"linear"`):
464
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
465
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
466
+ trained_betas (`np.ndarray`, *optional*):
467
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
468
+ clip_sample (`bool`, defaults to `True`):
469
+ Clip the predicted sample for numerical stability.
470
+ clip_sample_range (`float`, defaults to 1.0):
471
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
472
+ set_alpha_to_one (`bool`, defaults to `True`):
473
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
474
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
475
+ otherwise it uses the alpha value at step 0.
476
+ steps_offset (`int`, defaults to 0):
477
+ An offset added to the inference steps. You can use a combination of `offset=1` and
478
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
479
+ Diffusion.
480
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
481
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
482
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
483
+ Video](https://imagen.research.google/video/paper.pdf) paper).
484
+ thresholding (`bool`, defaults to `False`):
485
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
486
+ as Stable Diffusion.
487
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
488
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
489
+ sample_max_value (`float`, defaults to 1.0):
490
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
491
+ timestep_spacing (`str`, defaults to `"leading"`):
492
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
493
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
494
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
495
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
496
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
497
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
498
+ """
499
+
500
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
501
+ order = 1
502
+
503
+ @register_to_config
504
+ def __init__(
505
+ self,
506
+ num_train_timesteps: int = 1000,
507
+ beta_start: float = 0.0001,
508
+ beta_end: float = 0.02,
509
+ beta_schedule: str = "linear",
510
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
511
+ clip_sample: bool = True,
512
+ set_alpha_to_one: bool = True,
513
+ steps_offset: int = 0,
514
+ prediction_type: str = "epsilon",
515
+ thresholding: bool = False,
516
+ dynamic_thresholding_ratio: float = 0.995,
517
+ clip_sample_range: float = 1.0,
518
+ sample_max_value: float = 1.0,
519
+ timestep_spacing: str = "leading",
520
+ rescale_betas_zero_snr: bool = False,
521
+ ):
522
+ if trained_betas is not None:
523
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
524
+ elif beta_schedule == "linear":
525
+ self.betas = torch.linspace(
526
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
527
+ )
528
+ elif beta_schedule == "scaled_linear":
529
+ # this schedule is very specific to the latent diffusion model.
530
+ self.betas = (
531
+ torch.linspace(
532
+ beta_start**0.5,
533
+ beta_end**0.5,
534
+ num_train_timesteps,
535
+ dtype=torch.float32,
536
+ )
537
+ ** 2
538
+ )
539
+ elif beta_schedule == "squaredcos_cap_v2":
540
+ # Glide cosine schedule
541
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
542
+ else:
543
+ raise NotImplementedError(
544
+ f"{beta_schedule} does is not implemented for {self.__class__}"
545
+ )
546
+
547
+ # Rescale for zero SNR
548
+ if rescale_betas_zero_snr:
549
+ self.betas = rescale_zero_terminal_snr(self.betas)
550
+
551
+ self.alphas = 1.0 - self.betas
552
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
553
+
554
+ # At every step in ddim, we are looking into the previous alphas_cumprod
555
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
556
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
557
+ # whether we use the final alpha of the "non-previous" one.
558
+ self.final_alpha_cumprod = (
559
+ torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
560
+ )
561
+
562
+ # standard deviation of the initial noise distribution
563
+ self.init_noise_sigma = 1.0
564
+
565
+ # setable values
566
+ self.num_inference_steps = None
567
+ self.timesteps = torch.from_numpy(
568
+ np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)
569
+ )
570
+
571
+ def scale_model_input(
572
+ self, sample: torch.FloatTensor, timestep: Optional[int] = None
573
+ ) -> torch.FloatTensor:
574
+ """
575
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
576
+ current timestep.
577
+ Args:
578
+ sample (`torch.FloatTensor`):
579
+ The input sample.
580
+ timestep (`int`, *optional*):
581
+ The current timestep in the diffusion chain.
582
+ Returns:
583
+ `torch.FloatTensor`:
584
+ A scaled input sample.
585
+ """
586
+ return sample
587
+
588
+ def _get_variance(self, timestep, prev_timestep):
589
+ alpha_prod_t = self.alphas_cumprod[timestep]
590
+ alpha_prod_t_prev = (
591
+ self.alphas_cumprod[prev_timestep]
592
+ if prev_timestep >= 0
593
+ else self.final_alpha_cumprod
594
+ )
595
+ beta_prod_t = 1 - alpha_prod_t
596
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
597
+
598
+ variance = (beta_prod_t_prev / beta_prod_t) * (
599
+ 1 - alpha_prod_t / alpha_prod_t_prev
600
+ )
601
+
602
+ return variance
603
+
604
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
605
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
606
+ """
607
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
608
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
609
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
610
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
611
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
612
+ https://arxiv.org/abs/2205.11487
613
+ """
614
+ dtype = sample.dtype
615
+ batch_size, channels, height, width = sample.shape
616
+
617
+ if dtype not in (torch.float32, torch.float64):
618
+ sample = (
619
+ sample.float()
620
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
621
+
622
+ # Flatten sample for doing quantile calculation along each image
623
+ sample = sample.reshape(batch_size, channels * height * width)
624
+
625
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
626
+
627
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
628
+ s = torch.clamp(
629
+ s, min=1, max=self.config.sample_max_value
630
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
631
+
632
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
633
+ sample = (
634
+ torch.clamp(sample, -s, s) / s
635
+ ) # "we threshold xt0 to the range [-s, s] and then divide by s"
636
+
637
+ sample = sample.reshape(batch_size, channels, height, width)
638
+ sample = sample.to(dtype)
639
+
640
+ return sample
641
+
642
+ def set_timesteps(
643
+ self,
644
+ num_inference_steps: int,
645
+ lcm_origin_steps: int,
646
+ device: Union[str, torch.device] = None,
647
+ ):
648
+ """
649
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
650
+ Args:
651
+ num_inference_steps (`int`):
652
+ The number of diffusion steps used when generating samples with a pre-trained model.
653
+ """
654
+
655
+ if num_inference_steps > self.config.num_train_timesteps:
656
+ raise ValueError(
657
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
658
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
659
+ f" maximal {self.config.num_train_timesteps} timesteps."
660
+ )
661
+
662
+ self.num_inference_steps = num_inference_steps
663
+
664
+ # LCM Timesteps Setting: # Linear Spacing
665
+ c = self.config.num_train_timesteps // lcm_origin_steps
666
+ lcm_origin_timesteps = (
667
+ np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1
668
+ ) # LCM Training Steps Schedule
669
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
670
+ timesteps = lcm_origin_timesteps[::-skipping_step][
671
+ :num_inference_steps
672
+ ] # LCM Inference Steps Schedule
673
+
674
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
675
+
676
+ def get_scalings_for_boundary_condition_discrete(self, t):
677
+ self.sigma_data = 0.5 # Default: 0.5
678
+
679
+ # By dividing 0.1: This is almost a delta function at t=0.
680
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
681
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
682
+ return c_skip, c_out
683
+
684
+ def step(
685
+ self,
686
+ model_output: torch.FloatTensor,
687
+ timeindex: int,
688
+ timestep: int,
689
+ sample: torch.FloatTensor,
690
+ eta: float = 0.0,
691
+ use_clipped_model_output: bool = False,
692
+ generator=None,
693
+ variance_noise: Optional[torch.FloatTensor] = None,
694
+ return_dict: bool = True,
695
+ ) -> Union[LCMSchedulerOutput, Tuple]:
696
+ """
697
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
698
+ process from the learned model outputs (most often the predicted noise).
699
+ Args:
700
+ model_output (`torch.FloatTensor`):
701
+ The direct output from learned diffusion model.
702
+ timestep (`float`):
703
+ The current discrete timestep in the diffusion chain.
704
+ sample (`torch.FloatTensor`):
705
+ A current instance of a sample created by the diffusion process.
706
+ eta (`float`):
707
+ The weight of noise for added noise in diffusion step.
708
+ use_clipped_model_output (`bool`, defaults to `False`):
709
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
710
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
711
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
712
+ `use_clipped_model_output` has no effect.
713
+ generator (`torch.Generator`, *optional*):
714
+ A random number generator.
715
+ variance_noise (`torch.FloatTensor`):
716
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
717
+ itself. Useful for methods such as [`CycleDiffusion`].
718
+ return_dict (`bool`, *optional*, defaults to `True`):
719
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
720
+ Returns:
721
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
722
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
723
+ tuple is returned where the first element is the sample tensor.
724
+ """
725
+ if self.num_inference_steps is None:
726
+ raise ValueError(
727
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
728
+ )
729
+
730
+ # 1. get previous step value
731
+ prev_timeindex = timeindex + 1
732
+ if prev_timeindex < len(self.timesteps):
733
+ prev_timestep = self.timesteps[prev_timeindex]
734
+ else:
735
+ prev_timestep = timestep
736
+
737
+ # 2. compute alphas, betas
738
+ alpha_prod_t = self.alphas_cumprod[timestep]
739
+ alpha_prod_t_prev = (
740
+ self.alphas_cumprod[prev_timestep]
741
+ if prev_timestep >= 0
742
+ else self.final_alpha_cumprod
743
+ )
744
+
745
+ beta_prod_t = 1 - alpha_prod_t
746
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
747
+
748
+ # 3. Get scalings for boundary conditions
749
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
750
+
751
+ # 4. Different Parameterization:
752
+ parameterization = self.config.prediction_type
753
+
754
+ if parameterization == "epsilon": # noise-prediction
755
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
756
+
757
+ elif parameterization == "sample": # x-prediction
758
+ pred_x0 = model_output
759
+
760
+ elif parameterization == "v_prediction": # v-prediction
761
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
762
+
763
+ # 4. Denoise model output using boundary conditions
764
+ denoised = c_out * pred_x0 + c_skip * sample
765
+
766
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
767
+ # Noise is not used for one-step sampling.
768
+ if len(self.timesteps) > 1:
769
+ noise = torch.randn(model_output.shape).to(model_output.device)
770
+ prev_sample = (
771
+ alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
772
+ )
773
+ else:
774
+ prev_sample = denoised
775
+
776
+ if not return_dict:
777
+ return (prev_sample, denoised)
778
+
779
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
780
+
781
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
782
+ def add_noise(
783
+ self,
784
+ original_samples: torch.FloatTensor,
785
+ noise: torch.FloatTensor,
786
+ timesteps: torch.IntTensor,
787
+ ) -> torch.FloatTensor:
788
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
789
+ alphas_cumprod = self.alphas_cumprod.to(
790
+ device=original_samples.device, dtype=original_samples.dtype
791
+ )
792
+ timesteps = timesteps.to(original_samples.device)
793
+
794
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
795
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
796
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
797
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
798
+
799
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
800
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
801
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
802
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
803
+
804
+ noisy_samples = (
805
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
806
+ )
807
+ return noisy_samples
808
+
809
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
810
+ def get_velocity(
811
+ self,
812
+ sample: torch.FloatTensor,
813
+ noise: torch.FloatTensor,
814
+ timesteps: torch.IntTensor,
815
+ ) -> torch.FloatTensor:
816
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
817
+ alphas_cumprod = self.alphas_cumprod.to(
818
+ device=sample.device, dtype=sample.dtype
819
+ )
820
+ timesteps = timesteps.to(sample.device)
821
+
822
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
823
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
824
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
825
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
826
+
827
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
828
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
829
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
830
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
831
+
832
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
833
+ return velocity
834
+
835
+ def __len__(self):
836
+ return self.config.num_train_timesteps
txt2img/index.html ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html>
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <title>Real-Time Latent Consistency Model</title>
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
+ <script
9
+ src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
10
+ <script src="https://cdn.tailwindcss.com"></script>
11
+ <script type="module">
12
+
13
+ const seedEl = document.querySelector("#seed");
14
+ const promptEl = document.querySelector("#prompt");
15
+ const guidanceEl = document.querySelector("#guidance-scale");
16
+ const startBtn = document.querySelector("#start");
17
+ const stopBtn = document.querySelector("#stop");
18
+ const videoEl = document.querySelector("#webcam");
19
+ const imageEl = document.querySelector("#player");
20
+ const queueSizeEl = document.querySelector("#queue_size");
21
+ const errorEl = document.querySelector("#error");
22
+
23
+ function LCMLive(promptEl, liveImage, seedEl, guidanceEl) {
24
+ let websocket;
25
+
26
+ async function start() {
27
+ return new Promise((resolve, reject) => {
28
+ const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
29
+ }:${window.location.host}/ws`;
30
+
31
+ const socket = new WebSocket(websocketURL);
32
+ socket.onopen = () => {
33
+ console.log("Connected to websocket");
34
+ };
35
+ socket.onclose = () => {
36
+ console.log("Disconnected from websocket");
37
+ stop();
38
+ resolve({ "status": "disconnected" });
39
+ };
40
+ socket.onerror = (err) => {
41
+ console.error(err);
42
+ reject(err);
43
+ };
44
+ socket.onmessage = (event) => {
45
+ const data = JSON.parse(event.data);
46
+ switch (data.status) {
47
+ case "success":
48
+ break;
49
+ case "start":
50
+ const userId = data.userId;
51
+ initPromptStream(userId);
52
+ break;
53
+ case "timeout":
54
+ stop();
55
+ resolve({ "status": "timeout" });
56
+ case "error":
57
+ stop();
58
+ reject(data.message);
59
+ }
60
+ };
61
+ websocket = socket;
62
+ })
63
+ }
64
+
65
+ async function promptUpdateStream() {
66
+ const prompt = promptEl.value;
67
+ const seed = seedEl.value;
68
+ const guidance_scale = guidanceEl.value;
69
+ websocket.send(JSON.stringify({
70
+ prompt: prompt,
71
+ seed: seed,
72
+ guidance_scale: guidance_scale,
73
+ }));
74
+ }
75
+
76
+ function initPromptStream(userId) {
77
+ liveImage.src = `/stream/${userId}`;
78
+ promptEl.addEventListener("input", promptUpdateStream);
79
+ seedEl.addEventListener("input", promptUpdateStream);
80
+ guidanceEl.addEventListener("change", promptUpdateStream);
81
+ }
82
+
83
+ async function stop() {
84
+ websocket.close();
85
+ promptEl.removeEventListener("input", promptUpdateStream);
86
+ seedEl.removeEventListener("input", promptUpdateStream);
87
+ guidanceEl.removeEventListener("change", promptUpdateStream);
88
+ }
89
+ return {
90
+ start,
91
+ stop
92
+ }
93
+ }
94
+ function toggleMessage(type) {
95
+ errorEl.hidden = false;
96
+ switch (type) {
97
+ case "error":
98
+ errorEl.innerText = "To many users are using the same GPU, please try again later.";
99
+ errorEl.classList.toggle("bg-red-300", "text-red-900");
100
+ break;
101
+ case "success":
102
+ errorEl.innerText = "Your session has ended, please start a new one.";
103
+ errorEl.classList.toggle("bg-green-300", "text-green-900");
104
+ break;
105
+ }
106
+ setTimeout(() => {
107
+ errorEl.hidden = true;
108
+ }, 2000);
109
+ }
110
+
111
+
112
+ const lcmLive = LCMLive(promptEl, imageEl, seedEl, guidanceEl);
113
+ startBtn.addEventListener("click", async () => {
114
+ try {
115
+ startBtn.disabled = true;
116
+ const res = await lcmLive.start();
117
+ startBtn.disabled = false;
118
+ if (res.status === "timeout")
119
+ toggleMessage("success")
120
+ } catch (err) {
121
+ console.log(err);
122
+ toggleMessage("error")
123
+ }
124
+ });
125
+ stopBtn.addEventListener("click", () => {
126
+ lcmLive.stop();
127
+ });
128
+ window.addEventListener("beforeunload", () => {
129
+ lcmLive.stop();
130
+ });
131
+ setInterval(() =>
132
+ fetch("/queue_size")
133
+ .then((res) => res.json())
134
+ .then((data) => {
135
+ queueSizeEl.innerText = data.queue_size;
136
+ })
137
+ .catch((err) => {
138
+ console.log(err);
139
+ })
140
+ , 5000);
141
+ </script>
142
+ </head>
143
+
144
+ <body>
145
+ <div class="fixed right-2 top-2 p-4 font-bold text-sm rounded-lg max-w-xs text-center" id="error">
146
+ </div>
147
+ <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
148
+ <article class="text-center max-w-xl mx-auto">
149
+ <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
150
+ <h2 class="text-2xl font-bold mb-4">Text to Image</h2>
151
+ <p class="text-sm">
152
+ This demo showcases
153
+ <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
154
+ class="text-blue-500 underline hover:no-underline">LCM</a> Text to Image model
155
+ using
156
+ <a href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
157
+ target="_blank" class="text-blue-500 underline hover:no-underline">Diffusers</a> with a MJPEG
158
+ stream server.
159
+ </p>
160
+ <p class="text-sm">
161
+ There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
162
+ real-time performance. Maximum queue size is 10. <a
163
+ href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
164
+ target="_blank" class="text-blue-500 underline hover:no-underline">Duplicate</a> and run it on your own GPU.
165
+ </p>
166
+ </article>
167
+ <div>
168
+ <h2 class="font-medium">Prompt</h2>
169
+ <p class="text-sm text-gray-500">
170
+ Start your session and type your prompt here, accepts
171
+ <a href="https://github.com/damian0815/compel/blob/main/doc/syntax.md" target="_blank" class="text-blue-500 underline hover:no-underline">Compel</a> syntax.
172
+ </p>
173
+ <div class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
174
+ <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1 outline-none"
175
+ title="Start your session and type your prompt here, you can see the result in real-time."
176
+ placeholder="Add your prompt here...">Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5, cinematic, masterpiece</textarea>
177
+ </div>
178
+
179
+ </div>
180
+ <div class="">
181
+ <details>
182
+ <summary class="font-medium cursor-pointer">Advanced Options</summary>
183
+ <div class="grid grid-cols-3 max-w-md items-center gap-3 py-3">
184
+ <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
185
+ </label>
186
+ <input type="range" id="guidance-scale" name="guidance-scale" min="1" max="30" step="0.001"
187
+ value="8.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
188
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
189
+ 8.0</output>
190
+ <label class="text-sm font-medium" for="seed">Seed</label>
191
+ <input type="number" id="seed" name="seed" value="299792458"
192
+ class="font-light border border-gray-700 text-right rounded-md p-2">
193
+ <button
194
+ onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); document.querySelector('#seed').dispatchEvent(new Event('input'));"
195
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
196
+ Rand
197
+ </button>
198
+ </div>
199
+ </details>
200
+ </div>
201
+ <div>
202
+ <button id="start"
203
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
204
+ Start
205
+ </button>
206
+ <button id="stop"
207
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
208
+ Stop
209
+ </button>
210
+ </div>
211
+ <div class="relative rounded-lg border border-slate-300 overflow-hidden">
212
+ <img id="player" class="w-full aspect-square rounded-lg "
213
+ src="">
214
+ </div>
215
+ </main>
216
+ </body>
217
+
218
+ </html>
txt2img/tailwind.config.js ADDED
File without changes