radames commited on
Commit
ca822d3
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ venv/
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ RUN apt-get update && apt-get install --no-install-recommends -y \
8
+ build-essential \
9
+ python3.9 \
10
+ python3-pip \
11
+ git \
12
+ ffmpeg \
13
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
14
+
15
+ WORKDIR /code
16
+
17
+ COPY ./requirements.txt /code/requirements.txt
18
+
19
+ # Set up a new user named "user" with user ID 1000
20
+ RUN useradd -m -u 1000 user
21
+ # Switch to the "user" user
22
+ USER user
23
+ # Set home to the user's home directory
24
+ ENV HOME=/home/user \
25
+ PATH=/home/user/.local/bin:$PATH \
26
+ PYTHONPATH=$HOME/app \
27
+ PYTHONUNBUFFERED=1 \
28
+ SYSTEM=spaces
29
+
30
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
31
+
32
+ # Set the working directory to the user's home directory
33
+ WORKDIR $HOME/app
34
+
35
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
36
+ COPY --chown=user . $HOME/app
37
+
38
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Real-Time Latent Consistency Model
3
+ emoji: 🔥
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import traceback
5
+ from pydantic import BaseModel
6
+
7
+ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import StreamingResponse, JSONResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+
12
+ from diffusers import DiffusionPipeline
13
+ import torch
14
+ from PIL import Image
15
+ import numpy as np
16
+ import gradio as gr
17
+ import io
18
+ import uuid
19
+ import os
20
+ import time
21
+
22
+ MAX_QUEUE_SIZE = 4
23
+ TIMEOUT = float(os.environ.get("TIMEOUT", 0))
24
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
25
+
26
+
27
+ if SAFETY_CHECKER == "True":
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ "SimianLuo/LCM_Dreamshaper_v7",
30
+ custom_pipeline="latent_consistency_img2img.py",
31
+ custom_revision="main",
32
+ )
33
+ else:
34
+ pipe = DiffusionPipeline.from_pretrained(
35
+ "SimianLuo/LCM_Dreamshaper_v7",
36
+ safety_checker=None,
37
+ custom_pipeline="latent_consistency_img2img.py",
38
+ custom_revision="main",
39
+ )
40
+ pipe.to(torch_device="cuda", torch_dtype=torch.float16)
41
+ user_queue_map = {}
42
+
43
+
44
+ def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232):
45
+ generator = torch.manual_seed(seed)
46
+ # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
47
+ num_inference_steps = 4
48
+ results = pipe(
49
+ prompt=prompt,
50
+ generator=generator,
51
+ image=input_image,
52
+ strength=strength,
53
+ num_inference_steps=num_inference_steps,
54
+ guidance_scale=guidance_scale,
55
+ lcm_origin_steps=20,
56
+ output_type="pil",
57
+ )
58
+ nsfw_content_detected = (
59
+ results.nsfw_content_detected[0]
60
+ if "nsfw_content_detected" in results
61
+ else False
62
+ )
63
+ if nsfw_content_detected:
64
+ return None
65
+ return results.images[0]
66
+
67
+
68
+ app = FastAPI()
69
+ app.add_middleware(
70
+ CORSMiddleware,
71
+ allow_origins=["*"],
72
+ allow_credentials=True,
73
+ allow_methods=["*"],
74
+ allow_headers=["*"],
75
+ )
76
+
77
+
78
+ class InputParams(BaseModel):
79
+ seed: int
80
+ prompt: str
81
+ strength: float
82
+ guidance_scale: float
83
+
84
+
85
+ @app.websocket("/ws")
86
+ async def websocket_endpoint(websocket: WebSocket):
87
+ await websocket.accept()
88
+ if len(user_queue_map) >= MAX_QUEUE_SIZE:
89
+ print("Server is full")
90
+ await websocket.send_json({"status": "error", "message": "Server is full"})
91
+ await websocket.close()
92
+ return
93
+
94
+ try:
95
+ uid = str(uuid.uuid4())
96
+ print(f"New user connected: {uid}")
97
+ await websocket.send_json(
98
+ {"status": "success", "message": "Connected", "userId": uid}
99
+ )
100
+ params = await websocket.receive_json()
101
+ params = InputParams(**params)
102
+ user_queue_map[uid] = {
103
+ "queue": asyncio.Queue(),
104
+ "params": params,
105
+ }
106
+ await handle_websocket_data(websocket, uid)
107
+ except WebSocketDisconnect as e:
108
+ logging.error(f"Error: {e}")
109
+ traceback.print_exc()
110
+ finally:
111
+ print(f"User disconnected: {uid}")
112
+ queue_value = user_queue_map.pop(uid, None)
113
+ queue = queue_value.get("queue", None)
114
+ if queue:
115
+ while not queue.empty():
116
+ try:
117
+ queue.get_nowait()
118
+ except asyncio.QueueEmpty:
119
+ continue
120
+
121
+
122
+ @app.get("/queue_size")
123
+ async def get_queue_size():
124
+ queue_size = len(user_queue_map)
125
+ return JSONResponse({"queue_size": queue_size})
126
+
127
+
128
+ @app.get("/stream/{user_id}")
129
+ async def stream(user_id: uuid.UUID):
130
+ uid = str(user_id)
131
+ user_queue = user_queue_map[uid]
132
+ queue = user_queue["queue"]
133
+ params = user_queue["params"]
134
+ seed = params.seed
135
+ prompt = params.prompt
136
+ strength = params.strength
137
+ guidance_scale = params.guidance_scale
138
+ if not queue:
139
+ return HTTPException(status_code=404, detail="User not found")
140
+
141
+ async def generate():
142
+ while True:
143
+ input_image = await queue.get()
144
+ if input_image is None:
145
+ continue
146
+
147
+ image = predict(input_image, prompt, guidance_scale, strength, seed)
148
+ if image is None:
149
+ continue
150
+ frame_data = io.BytesIO()
151
+ image.save(frame_data, format="JPEG")
152
+ frame_data = frame_data.getvalue()
153
+ if frame_data is not None and len(frame_data) > 0:
154
+ yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
155
+
156
+ await asyncio.sleep(1.0 / 120.0)
157
+
158
+ return StreamingResponse(
159
+ generate(), media_type="multipart/x-mixed-replace;boundary=frame"
160
+ )
161
+
162
+
163
+ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
164
+ uid = str(user_id)
165
+ user_queue = user_queue_map[uid]
166
+ queue = user_queue["queue"]
167
+ if not queue:
168
+ return HTTPException(status_code=404, detail="User not found")
169
+ last_time = time.time()
170
+ try:
171
+ while True:
172
+ data = await websocket.receive_bytes()
173
+ pil_image = Image.open(io.BytesIO(data))
174
+
175
+ while not queue.empty():
176
+ try:
177
+ queue.get_nowait()
178
+ except asyncio.QueueEmpty:
179
+ continue
180
+ await queue.put(pil_image)
181
+ if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
182
+ await websocket.send_json(
183
+ {
184
+ "status": "timeout",
185
+ "message": "Your session has ended",
186
+ "userId": uid,
187
+ }
188
+ )
189
+ await websocket.close()
190
+ return
191
+
192
+ except Exception as e:
193
+ logging.error(f"Error: {e}")
194
+ traceback.print_exc()
195
+
196
+
197
+ app.mount("/", StaticFiles(directory="public", html=True), name="public")
latent_consistency_img2img.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import PIL.Image
24
+ import torch
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
26
+
27
+ from diffusers import (
28
+ AutoencoderKL,
29
+ ConfigMixin,
30
+ DiffusionPipeline,
31
+ SchedulerMixin,
32
+ UNet2DConditionModel,
33
+ logging,
34
+ )
35
+ from diffusers.configuration_utils import register_to_config
36
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
37
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
38
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
39
+ StableDiffusionSafetyChecker,
40
+ )
41
+ from diffusers.utils import BaseOutput
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
49
+ _optional_components = ["scheduler"]
50
+
51
+ def __init__(
52
+ self,
53
+ vae: AutoencoderKL,
54
+ text_encoder: CLIPTextModel,
55
+ tokenizer: CLIPTokenizer,
56
+ unet: UNet2DConditionModel,
57
+ scheduler: "LCMSchedulerWithTimestamp",
58
+ safety_checker: StableDiffusionSafetyChecker,
59
+ feature_extractor: CLIPImageProcessor,
60
+ requires_safety_checker: bool = True,
61
+ ):
62
+ super().__init__()
63
+
64
+ scheduler = (
65
+ scheduler
66
+ if scheduler is not None
67
+ else LCMSchedulerWithTimestamp(
68
+ beta_start=0.00085,
69
+ beta_end=0.0120,
70
+ beta_schedule="scaled_linear",
71
+ prediction_type="epsilon",
72
+ )
73
+ )
74
+
75
+ self.register_modules(
76
+ vae=vae,
77
+ text_encoder=text_encoder,
78
+ tokenizer=tokenizer,
79
+ unet=unet,
80
+ scheduler=scheduler,
81
+ safety_checker=safety_checker,
82
+ feature_extractor=feature_extractor,
83
+ )
84
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
85
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
86
+
87
+ def _encode_prompt(
88
+ self,
89
+ prompt,
90
+ device,
91
+ num_images_per_prompt,
92
+ prompt_embeds: None,
93
+ ):
94
+ r"""
95
+ Encodes the prompt into text encoder hidden states.
96
+ Args:
97
+ prompt (`str` or `List[str]`, *optional*):
98
+ prompt to be encoded
99
+ device: (`torch.device`):
100
+ torch device
101
+ num_images_per_prompt (`int`):
102
+ number of images that should be generated per prompt
103
+ prompt_embeds (`torch.FloatTensor`, *optional*):
104
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
105
+ provided, text embeddings will be generated from `prompt` input argument.
106
+ """
107
+
108
+ if prompt is not None and isinstance(prompt, str):
109
+ pass
110
+ elif prompt is not None and isinstance(prompt, list):
111
+ len(prompt)
112
+ else:
113
+ prompt_embeds.shape[0]
114
+
115
+ if prompt_embeds is None:
116
+ text_inputs = self.tokenizer(
117
+ prompt,
118
+ padding="max_length",
119
+ max_length=self.tokenizer.model_max_length,
120
+ truncation=True,
121
+ return_tensors="pt",
122
+ )
123
+ text_input_ids = text_inputs.input_ids
124
+ untruncated_ids = self.tokenizer(
125
+ prompt, padding="longest", return_tensors="pt"
126
+ ).input_ids
127
+
128
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
129
+ -1
130
+ ] and not torch.equal(text_input_ids, untruncated_ids):
131
+ removed_text = self.tokenizer.batch_decode(
132
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
133
+ )
134
+ logger.warning(
135
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
136
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
137
+ )
138
+
139
+ if (
140
+ hasattr(self.text_encoder.config, "use_attention_mask")
141
+ and self.text_encoder.config.use_attention_mask
142
+ ):
143
+ attention_mask = text_inputs.attention_mask.to(device)
144
+ else:
145
+ attention_mask = None
146
+
147
+ prompt_embeds = self.text_encoder(
148
+ text_input_ids.to(device),
149
+ attention_mask=attention_mask,
150
+ )
151
+ prompt_embeds = prompt_embeds[0]
152
+
153
+ if self.text_encoder is not None:
154
+ prompt_embeds_dtype = self.text_encoder.dtype
155
+ elif self.unet is not None:
156
+ prompt_embeds_dtype = self.unet.dtype
157
+ else:
158
+ prompt_embeds_dtype = prompt_embeds.dtype
159
+
160
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
161
+
162
+ bs_embed, seq_len, _ = prompt_embeds.shape
163
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
164
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
165
+ prompt_embeds = prompt_embeds.view(
166
+ bs_embed * num_images_per_prompt, seq_len, -1
167
+ )
168
+
169
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
170
+ return prompt_embeds
171
+
172
+ def run_safety_checker(self, image, device, dtype):
173
+ if self.safety_checker is None:
174
+ has_nsfw_concept = None
175
+ else:
176
+ if torch.is_tensor(image):
177
+ feature_extractor_input = self.image_processor.postprocess(
178
+ image, output_type="pil"
179
+ )
180
+ else:
181
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
182
+ safety_checker_input = self.feature_extractor(
183
+ feature_extractor_input, return_tensors="pt"
184
+ ).to(device)
185
+ image, has_nsfw_concept = self.safety_checker(
186
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
187
+ )
188
+ return image, has_nsfw_concept
189
+
190
+ def prepare_latents(
191
+ self,
192
+ image,
193
+ timestep,
194
+ batch_size,
195
+ num_channels_latents,
196
+ height,
197
+ width,
198
+ dtype,
199
+ device,
200
+ latents=None,
201
+ generator=None,
202
+ ):
203
+ shape = (
204
+ batch_size,
205
+ num_channels_latents,
206
+ height // self.vae_scale_factor,
207
+ width // self.vae_scale_factor,
208
+ )
209
+
210
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
211
+ raise ValueError(
212
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
213
+ )
214
+
215
+ image = image.to(device=device, dtype=dtype)
216
+
217
+ # batch_size = batch_size * num_images_per_prompt
218
+ if image.shape[1] == 4:
219
+ init_latents = image
220
+
221
+ else:
222
+ if isinstance(generator, list) and len(generator) != batch_size:
223
+ raise ValueError(
224
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
225
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
226
+ )
227
+
228
+ elif isinstance(generator, list):
229
+ init_latents = [
230
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i])
231
+ for i in range(batch_size)
232
+ ]
233
+ init_latents = torch.cat(init_latents, dim=0)
234
+ else:
235
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
236
+
237
+ init_latents = self.vae.config.scaling_factor * init_latents
238
+
239
+ if (
240
+ batch_size > init_latents.shape[0]
241
+ and batch_size % init_latents.shape[0] == 0
242
+ ):
243
+ # expand init_latents for batch_size
244
+ (
245
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
246
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
247
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
248
+ " your script to pass as many initial images as text prompts to suppress this warning."
249
+ )
250
+ # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
251
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
252
+ init_latents = torch.cat(
253
+ [init_latents] * additional_image_per_prompt, dim=0
254
+ )
255
+ elif (
256
+ batch_size > init_latents.shape[0]
257
+ and batch_size % init_latents.shape[0] != 0
258
+ ):
259
+ raise ValueError(
260
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
261
+ )
262
+ else:
263
+ init_latents = torch.cat([init_latents], dim=0)
264
+
265
+ shape = init_latents.shape
266
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
267
+
268
+ # get latents
269
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
270
+ latents = init_latents
271
+
272
+ if latents is None:
273
+ latents = torch.randn(shape, dtype=dtype).to(device)
274
+ else:
275
+ latents = latents.to(device)
276
+ # scale the initial noise by the standard deviation required by the scheduler
277
+ latents = latents * self.scheduler.init_noise_sigma
278
+ return latents
279
+
280
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
281
+ """
282
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
283
+ Args:
284
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
285
+ embedding_dim: int: dimension of the embeddings to generate
286
+ dtype: data type of the generated embeddings
287
+ Returns:
288
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
289
+ """
290
+ assert len(w.shape) == 1
291
+ w = w * 1000.0
292
+
293
+ half_dim = embedding_dim // 2
294
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
295
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
296
+ emb = w.to(dtype)[:, None] * emb[None, :]
297
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
298
+ if embedding_dim % 2 == 1: # zero pad
299
+ emb = torch.nn.functional.pad(emb, (0, 1))
300
+ assert emb.shape == (w.shape[0], embedding_dim)
301
+ return emb
302
+
303
+ def get_timesteps(self, num_inference_steps, strength, device):
304
+ # get the original timestep using init_timestep
305
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
306
+
307
+ t_start = max(num_inference_steps - init_timestep, 0)
308
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
309
+
310
+ return timesteps, num_inference_steps - t_start
311
+
312
+ @torch.no_grad()
313
+ def __call__(
314
+ self,
315
+ prompt: Union[str, List[str]] = None,
316
+ image: PipelineImageInput = None,
317
+ strength: float = 0.8,
318
+ height: Optional[int] = 768,
319
+ width: Optional[int] = 768,
320
+ guidance_scale: float = 7.5,
321
+ num_images_per_prompt: Optional[int] = 1,
322
+ latents: Optional[torch.FloatTensor] = None,
323
+ generator: Optional[torch.Generator] = None,
324
+ num_inference_steps: int = 4,
325
+ lcm_origin_steps: int = 50,
326
+ prompt_embeds: Optional[torch.FloatTensor] = None,
327
+ output_type: Optional[str] = "pil",
328
+ return_dict: bool = True,
329
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
330
+ ):
331
+ # 0. Default height and width to unet
332
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
333
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
334
+
335
+ # 2. Define call parameters
336
+ if prompt is not None and isinstance(prompt, str):
337
+ batch_size = 1
338
+ elif prompt is not None and isinstance(prompt, list):
339
+ batch_size = len(prompt)
340
+ else:
341
+ batch_size = prompt_embeds.shape[0]
342
+
343
+ device = self._execution_device
344
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
345
+
346
+ # 3. Encode input prompt
347
+ prompt_embeds = self._encode_prompt(
348
+ prompt,
349
+ device,
350
+ num_images_per_prompt,
351
+ prompt_embeds=prompt_embeds,
352
+ )
353
+
354
+ # 3.5 encode image
355
+ image = self.image_processor.preprocess(image)
356
+
357
+ # 4. Prepare timesteps
358
+ self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)
359
+ # timesteps = self.scheduler.timesteps
360
+ # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
361
+ timesteps = self.scheduler.timesteps
362
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
363
+
364
+ print("timesteps: ", timesteps)
365
+
366
+ # 5. Prepare latent variable
367
+ num_channels_latents = self.unet.config.in_channels
368
+ latents = self.prepare_latents(
369
+ image,
370
+ latent_timestep,
371
+ batch_size * num_images_per_prompt,
372
+ num_channels_latents,
373
+ height,
374
+ width,
375
+ prompt_embeds.dtype,
376
+ device,
377
+ latents,
378
+ generator
379
+ )
380
+ bs = batch_size * num_images_per_prompt
381
+
382
+ # 6. Get Guidance Scale Embedding
383
+ w = torch.tensor(guidance_scale).repeat(bs)
384
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(
385
+ device=device, dtype=latents.dtype
386
+ )
387
+
388
+ # 7. LCM MultiStep Sampling Loop:
389
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
390
+ for i, t in enumerate(timesteps):
391
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
392
+ latents = latents.to(prompt_embeds.dtype)
393
+
394
+ # model prediction (v-prediction, eps, x)
395
+ model_pred = self.unet(
396
+ latents,
397
+ ts,
398
+ timestep_cond=w_embedding,
399
+ encoder_hidden_states=prompt_embeds,
400
+ cross_attention_kwargs=cross_attention_kwargs,
401
+ return_dict=False,
402
+ )[0]
403
+
404
+ # compute the previous noisy sample x_t -> x_t-1
405
+ latents, denoised = self.scheduler.step(
406
+ model_pred, i, t, latents, return_dict=False
407
+ )
408
+
409
+ # # call the callback, if provided
410
+ # if i == len(timesteps) - 1:
411
+ progress_bar.update()
412
+
413
+ denoised = denoised.to(prompt_embeds.dtype)
414
+ if not output_type == "latent":
415
+ image = self.vae.decode(
416
+ denoised / self.vae.config.scaling_factor, return_dict=False
417
+ )[0]
418
+ image, has_nsfw_concept = self.run_safety_checker(
419
+ image, device, prompt_embeds.dtype
420
+ )
421
+ else:
422
+ image = denoised
423
+ has_nsfw_concept = None
424
+
425
+ if has_nsfw_concept is None:
426
+ do_denormalize = [True] * image.shape[0]
427
+ else:
428
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
429
+
430
+ image = self.image_processor.postprocess(
431
+ image, output_type=output_type, do_denormalize=do_denormalize
432
+ )
433
+
434
+ if not return_dict:
435
+ return (image, has_nsfw_concept)
436
+
437
+ return StableDiffusionPipelineOutput(
438
+ images=image, nsfw_content_detected=has_nsfw_concept
439
+ )
440
+
441
+
442
+ @dataclass
443
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
444
+ class LCMSchedulerOutput(BaseOutput):
445
+ """
446
+ Output class for the scheduler's `step` function output.
447
+ Args:
448
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
449
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
450
+ denoising loop.
451
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
452
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
453
+ `pred_original_sample` can be used to preview progress or for guidance.
454
+ """
455
+
456
+ prev_sample: torch.FloatTensor
457
+ denoised: Optional[torch.FloatTensor] = None
458
+
459
+
460
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
461
+ def betas_for_alpha_bar(
462
+ num_diffusion_timesteps,
463
+ max_beta=0.999,
464
+ alpha_transform_type="cosine",
465
+ ):
466
+ """
467
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
468
+ (1-beta) over time from t = [0,1].
469
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
470
+ to that part of the diffusion process.
471
+ Args:
472
+ num_diffusion_timesteps (`int`): the number of betas to produce.
473
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
474
+ prevent singularities.
475
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
476
+ Choose from `cosine` or `exp`
477
+ Returns:
478
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
479
+ """
480
+ if alpha_transform_type == "cosine":
481
+
482
+ def alpha_bar_fn(t):
483
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
484
+
485
+ elif alpha_transform_type == "exp":
486
+
487
+ def alpha_bar_fn(t):
488
+ return math.exp(t * -12.0)
489
+
490
+ else:
491
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
492
+
493
+ betas = []
494
+ for i in range(num_diffusion_timesteps):
495
+ t1 = i / num_diffusion_timesteps
496
+ t2 = (i + 1) / num_diffusion_timesteps
497
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
498
+ return torch.tensor(betas, dtype=torch.float32)
499
+
500
+
501
+ def rescale_zero_terminal_snr(betas):
502
+ """
503
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
504
+ Args:
505
+ betas (`torch.FloatTensor`):
506
+ the betas that the scheduler is being initialized with.
507
+ Returns:
508
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
509
+ """
510
+ # Convert betas to alphas_bar_sqrt
511
+ alphas = 1.0 - betas
512
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
513
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
514
+
515
+ # Store old values.
516
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
517
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
518
+
519
+ # Shift so the last timestep is zero.
520
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
521
+
522
+ # Scale so the first timestep is back to the old value.
523
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
524
+
525
+ # Convert alphas_bar_sqrt to betas
526
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
527
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
528
+ alphas = torch.cat([alphas_bar[0:1], alphas])
529
+ betas = 1 - alphas
530
+
531
+ return betas
532
+
533
+
534
+ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
535
+ """
536
+ This class modifies LCMScheduler to add a timestamp argument to set_timesteps
537
+
538
+
539
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
540
+ non-Markovian guidance.
541
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
542
+ methods the library implements for all schedulers such as loading and saving.
543
+ Args:
544
+ num_train_timesteps (`int`, defaults to 1000):
545
+ The number of diffusion steps to train the model.
546
+ beta_start (`float`, defaults to 0.0001):
547
+ The starting `beta` value of inference.
548
+ beta_end (`float`, defaults to 0.02):
549
+ The final `beta` value.
550
+ beta_schedule (`str`, defaults to `"linear"`):
551
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
552
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
553
+ trained_betas (`np.ndarray`, *optional*):
554
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
555
+ clip_sample (`bool`, defaults to `True`):
556
+ Clip the predicted sample for numerical stability.
557
+ clip_sample_range (`float`, defaults to 1.0):
558
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
559
+ set_alpha_to_one (`bool`, defaults to `True`):
560
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
561
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
562
+ otherwise it uses the alpha value at step 0.
563
+ steps_offset (`int`, defaults to 0):
564
+ An offset added to the inference steps. You can use a combination of `offset=1` and
565
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
566
+ Diffusion.
567
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
568
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
569
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
570
+ Video](https://imagen.research.google/video/paper.pdf) paper).
571
+ thresholding (`bool`, defaults to `False`):
572
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
573
+ as Stable Diffusion.
574
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
575
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
576
+ sample_max_value (`float`, defaults to 1.0):
577
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
578
+ timestep_spacing (`str`, defaults to `"leading"`):
579
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
580
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
581
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
582
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
583
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
584
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
585
+ """
586
+
587
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
588
+ order = 1
589
+
590
+ @register_to_config
591
+ def __init__(
592
+ self,
593
+ num_train_timesteps: int = 1000,
594
+ beta_start: float = 0.0001,
595
+ beta_end: float = 0.02,
596
+ beta_schedule: str = "linear",
597
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
598
+ clip_sample: bool = True,
599
+ set_alpha_to_one: bool = True,
600
+ steps_offset: int = 0,
601
+ prediction_type: str = "epsilon",
602
+ thresholding: bool = False,
603
+ dynamic_thresholding_ratio: float = 0.995,
604
+ clip_sample_range: float = 1.0,
605
+ sample_max_value: float = 1.0,
606
+ timestep_spacing: str = "leading",
607
+ rescale_betas_zero_snr: bool = False,
608
+ ):
609
+ if trained_betas is not None:
610
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
611
+ elif beta_schedule == "linear":
612
+ self.betas = torch.linspace(
613
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
614
+ )
615
+ elif beta_schedule == "scaled_linear":
616
+ # this schedule is very specific to the latent diffusion model.
617
+ self.betas = (
618
+ torch.linspace(
619
+ beta_start**0.5,
620
+ beta_end**0.5,
621
+ num_train_timesteps,
622
+ dtype=torch.float32,
623
+ )
624
+ ** 2
625
+ )
626
+ elif beta_schedule == "squaredcos_cap_v2":
627
+ # Glide cosine schedule
628
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
629
+ else:
630
+ raise NotImplementedError(
631
+ f"{beta_schedule} does is not implemented for {self.__class__}"
632
+ )
633
+
634
+ # Rescale for zero SNR
635
+ if rescale_betas_zero_snr:
636
+ self.betas = rescale_zero_terminal_snr(self.betas)
637
+
638
+ self.alphas = 1.0 - self.betas
639
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
640
+
641
+ # At every step in ddim, we are looking into the previous alphas_cumprod
642
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
643
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
644
+ # whether we use the final alpha of the "non-previous" one.
645
+ self.final_alpha_cumprod = (
646
+ torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
647
+ )
648
+
649
+ # standard deviation of the initial noise distribution
650
+ self.init_noise_sigma = 1.0
651
+
652
+ # setable values
653
+ self.num_inference_steps = None
654
+ self.timesteps = torch.from_numpy(
655
+ np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)
656
+ )
657
+
658
+ def scale_model_input(
659
+ self, sample: torch.FloatTensor, timestep: Optional[int] = None
660
+ ) -> torch.FloatTensor:
661
+ """
662
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
663
+ current timestep.
664
+ Args:
665
+ sample (`torch.FloatTensor`):
666
+ The input sample.
667
+ timestep (`int`, *optional*):
668
+ The current timestep in the diffusion chain.
669
+ Returns:
670
+ `torch.FloatTensor`:
671
+ A scaled input sample.
672
+ """
673
+ return sample
674
+
675
+ def _get_variance(self, timestep, prev_timestep):
676
+ alpha_prod_t = self.alphas_cumprod[timestep]
677
+ alpha_prod_t_prev = (
678
+ self.alphas_cumprod[prev_timestep]
679
+ if prev_timestep >= 0
680
+ else self.final_alpha_cumprod
681
+ )
682
+ beta_prod_t = 1 - alpha_prod_t
683
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
684
+
685
+ variance = (beta_prod_t_prev / beta_prod_t) * (
686
+ 1 - alpha_prod_t / alpha_prod_t_prev
687
+ )
688
+
689
+ return variance
690
+
691
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
692
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
693
+ """
694
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
695
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
696
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
697
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
698
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
699
+ https://arxiv.org/abs/2205.11487
700
+ """
701
+ dtype = sample.dtype
702
+ batch_size, channels, height, width = sample.shape
703
+
704
+ if dtype not in (torch.float32, torch.float64):
705
+ sample = (
706
+ sample.float()
707
+ ) # upcast for quantile calculation, and clamp not implemented for cpu half
708
+
709
+ # Flatten sample for doing quantile calculation along each image
710
+ sample = sample.reshape(batch_size, channels * height * width)
711
+
712
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
713
+
714
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
715
+ s = torch.clamp(
716
+ s, min=1, max=self.config.sample_max_value
717
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
718
+
719
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
720
+ sample = (
721
+ torch.clamp(sample, -s, s) / s
722
+ ) # "we threshold xt0 to the range [-s, s] and then divide by s"
723
+
724
+ sample = sample.reshape(batch_size, channels, height, width)
725
+ sample = sample.to(dtype)
726
+
727
+ return sample
728
+
729
+ def set_timesteps(
730
+ self,
731
+ stength,
732
+ num_inference_steps: int,
733
+ lcm_origin_steps: int,
734
+ device: Union[str, torch.device] = None,
735
+ ):
736
+ """
737
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
738
+ Args:
739
+ num_inference_steps (`int`):
740
+ The number of diffusion steps used when generating samples with a pre-trained model.
741
+ """
742
+
743
+ if num_inference_steps > self.config.num_train_timesteps:
744
+ raise ValueError(
745
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
746
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
747
+ f" maximal {self.config.num_train_timesteps} timesteps."
748
+ )
749
+
750
+ self.num_inference_steps = num_inference_steps
751
+
752
+ # LCM Timesteps Setting: # Linear Spacing
753
+ c = self.config.num_train_timesteps // lcm_origin_steps
754
+ lcm_origin_timesteps = (
755
+ np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1
756
+ ) # LCM Training Steps Schedule
757
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
758
+ timesteps = lcm_origin_timesteps[::-skipping_step][
759
+ :num_inference_steps
760
+ ] # LCM Inference Steps Schedule
761
+
762
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
763
+
764
+ def get_scalings_for_boundary_condition_discrete(self, t):
765
+ self.sigma_data = 0.5 # Default: 0.5
766
+
767
+ # By dividing 0.1: This is almost a delta function at t=0.
768
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
769
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
770
+ return c_skip, c_out
771
+
772
+ def step(
773
+ self,
774
+ model_output: torch.FloatTensor,
775
+ timeindex: int,
776
+ timestep: int,
777
+ sample: torch.FloatTensor,
778
+ eta: float = 0.0,
779
+ use_clipped_model_output: bool = False,
780
+ generator=None,
781
+ variance_noise: Optional[torch.FloatTensor] = None,
782
+ return_dict: bool = True,
783
+ ) -> Union[LCMSchedulerOutput, Tuple]:
784
+ """
785
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
786
+ process from the learned model outputs (most often the predicted noise).
787
+ Args:
788
+ model_output (`torch.FloatTensor`):
789
+ The direct output from learned diffusion model.
790
+ timestep (`float`):
791
+ The current discrete timestep in the diffusion chain.
792
+ sample (`torch.FloatTensor`):
793
+ A current instance of a sample created by the diffusion process.
794
+ eta (`float`):
795
+ The weight of noise for added noise in diffusion step.
796
+ use_clipped_model_output (`bool`, defaults to `False`):
797
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
798
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
799
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
800
+ `use_clipped_model_output` has no effect.
801
+ generator (`torch.Generator`, *optional*):
802
+ A random number generator.
803
+ variance_noise (`torch.FloatTensor`):
804
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
805
+ itself. Useful for methods such as [`CycleDiffusion`].
806
+ return_dict (`bool`, *optional*, defaults to `True`):
807
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
808
+ Returns:
809
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
810
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
811
+ tuple is returned where the first element is the sample tensor.
812
+ """
813
+ if self.num_inference_steps is None:
814
+ raise ValueError(
815
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
816
+ )
817
+
818
+ # 1. get previous step value
819
+ prev_timeindex = timeindex + 1
820
+ if prev_timeindex < len(self.timesteps):
821
+ prev_timestep = self.timesteps[prev_timeindex]
822
+ else:
823
+ prev_timestep = timestep
824
+
825
+ # 2. compute alphas, betas
826
+ alpha_prod_t = self.alphas_cumprod[timestep]
827
+ alpha_prod_t_prev = (
828
+ self.alphas_cumprod[prev_timestep]
829
+ if prev_timestep >= 0
830
+ else self.final_alpha_cumprod
831
+ )
832
+
833
+ beta_prod_t = 1 - alpha_prod_t
834
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
835
+
836
+ # 3. Get scalings for boundary conditions
837
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
838
+
839
+ # 4. Different Parameterization:
840
+ parameterization = self.config.prediction_type
841
+
842
+ if parameterization == "epsilon": # noise-prediction
843
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
844
+
845
+ elif parameterization == "sample": # x-prediction
846
+ pred_x0 = model_output
847
+
848
+ elif parameterization == "v_prediction": # v-prediction
849
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
850
+
851
+ # 4. Denoise model output using boundary conditions
852
+ denoised = c_out * pred_x0 + c_skip * sample
853
+
854
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
855
+ # Noise is not used for one-step sampling.
856
+ if len(self.timesteps) > 1:
857
+ noise = torch.randn(model_output.shape).to(model_output.device)
858
+ prev_sample = (
859
+ alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
860
+ )
861
+ else:
862
+ prev_sample = denoised
863
+
864
+ if not return_dict:
865
+ return (prev_sample, denoised)
866
+
867
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
868
+
869
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
870
+ def add_noise(
871
+ self,
872
+ original_samples: torch.FloatTensor,
873
+ noise: torch.FloatTensor,
874
+ timesteps: torch.IntTensor,
875
+ ) -> torch.FloatTensor:
876
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
877
+ alphas_cumprod = self.alphas_cumprod.to(
878
+ device=original_samples.device, dtype=original_samples.dtype
879
+ )
880
+ timesteps = timesteps.to(original_samples.device)
881
+
882
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
883
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
884
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
885
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
886
+
887
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
888
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
889
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
890
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
891
+
892
+ noisy_samples = (
893
+ sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
894
+ )
895
+ return noisy_samples
896
+
897
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
898
+ def get_velocity(
899
+ self,
900
+ sample: torch.FloatTensor,
901
+ noise: torch.FloatTensor,
902
+ timesteps: torch.IntTensor,
903
+ ) -> torch.FloatTensor:
904
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
905
+ alphas_cumprod = self.alphas_cumprod.to(
906
+ device=sample.device, dtype=sample.dtype
907
+ )
908
+ timesteps = timesteps.to(sample.device)
909
+
910
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
911
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
912
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
913
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
914
+
915
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
916
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
917
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
918
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
919
+
920
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
921
+ return velocity
922
+
923
+ def __len__(self):
924
+ return self.config.num_train_timesteps
public/index.html ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html>
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <title>Real-Time Latent Consistency Model</title>
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
+ <script
9
+ src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
10
+ <script src="https://cdn.tailwindcss.com"></script>
11
+ <script type="module">
12
+
13
+ const seedEl = document.querySelector("#seed");
14
+ const promptEl = document.querySelector("#prompt");
15
+ const guidanceEl = document.querySelector("#guidance-scale");
16
+ const strengthEl = document.querySelector("#strength");
17
+ const startBtn = document.querySelector("#start");
18
+ const stopBtn = document.querySelector("#stop");
19
+ const videoEl = document.querySelector("#webcam");
20
+ const imageEl = document.querySelector("#player");
21
+ const queueSizeEl = document.querySelector("#queue_size");
22
+ const errorEl = document.querySelector("#error");
23
+
24
+ function LCMLive(webcamVideo, liveImage) {
25
+ let websocket;
26
+
27
+ async function start(params) {
28
+ return new Promise((resolve, reject) => {
29
+ const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
30
+ }:${window.location.host}/ws`;
31
+
32
+ const socket = new WebSocket(websocketURL);
33
+ socket.onopen = () => {
34
+ console.log("Connected to websocket");
35
+ };
36
+ socket.onclose = () => {
37
+ console.log("Disconnected from websocket");
38
+ stop();
39
+ resolve({ "status": "disconnected" });
40
+ };
41
+ socket.onerror = (err) => {
42
+ console.error(err);
43
+ reject(err);
44
+ };
45
+ socket.onmessage = (event) => {
46
+ const data = JSON.parse(event.data);
47
+ switch (data.status) {
48
+ case "success":
49
+ socket.send(JSON.stringify(params));
50
+ const userId = data.userId;
51
+ liveImage.src = `/stream/${userId}`;
52
+ initVideoStream();
53
+ break;
54
+ case "timeout":
55
+ stop();
56
+ resolve({ "status": "timeout" });
57
+ case "error":
58
+ stop();
59
+ reject(data.message);
60
+
61
+ }
62
+ };
63
+ websocket = socket;
64
+ })
65
+ }
66
+
67
+ async function videoTimeUpdateHandler() {
68
+ const canvas = new OffscreenCanvas(webcamVideo.videoWidth, webcamVideo.videoHeight);
69
+ const ctx = canvas.getContext("2d");
70
+ ctx.drawImage(webcamVideo, 0, 0, canvas.width, canvas.height);
71
+ const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
72
+ websocket.send(blob);
73
+ }
74
+
75
+ function initVideoStream() {
76
+ const constraints = {
77
+ audio: false,
78
+ video: { width: 512, height: 512 },
79
+ };
80
+ navigator.mediaDevices
81
+ .getUserMedia(constraints)
82
+ .then((mediaStream) => {
83
+ webcamVideo.srcObject = mediaStream;
84
+ webcamVideo.onloadedmetadata = () => {
85
+ webcamVideo.play();
86
+ webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
87
+ };
88
+ })
89
+ .catch((err) => {
90
+ console.error(`${err.name}: ${err.message}`);
91
+ });
92
+ }
93
+
94
+ async function stop() {
95
+ websocket.close();
96
+ navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
97
+ mediaStream.getTracks().forEach((track) => track.stop());
98
+ });
99
+ webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
100
+ webcamVideo.srcObject = null;
101
+ }
102
+ return {
103
+ start,
104
+ stop
105
+ }
106
+ }
107
+ function toggleMessage(type) {
108
+ errorEl.hidden = false;
109
+ switch (type) {
110
+ case "error":
111
+ errorEl.innerText = "To many users are using the same GPU, please try again later.";
112
+ errorEl.classList.toggle("bg-red-300", "text-red-900");
113
+ break;
114
+ case "success":
115
+ errorEl.innerText = "Your 2min session has ended, please start training again.";
116
+ errorEl.classList.toggle("bg-green-300", "text-green-900");
117
+ break;
118
+ }
119
+ setTimeout(() => {
120
+ errorEl.hidden = true;
121
+ }, 5000);
122
+ }
123
+
124
+
125
+ const lcmLive = LCMLive(videoEl, imageEl);
126
+ startBtn.addEventListener("click", async () => {
127
+ try {
128
+ const seed = seedEl.value;
129
+ const prompt = promptEl.value;
130
+ const guidance_scale = guidanceEl.value;
131
+ const strength = strengthEl.value;
132
+ startBtn.disabled = true;
133
+ const res = await lcmLive.start({ seed, prompt, guidance_scale, strength });
134
+ startBtn.disabled = false;
135
+ if (res.status === "timeout")
136
+ toggleMessage("success")
137
+ } catch (err) {
138
+ console.log(err);
139
+ toggleMessage("error")
140
+ }
141
+ });
142
+ stopBtn.addEventListener("click", () => {
143
+ lcmLive.stop();
144
+ });
145
+ window.addEventListener("beforeunload", () => {
146
+ lcmLive.stop();
147
+ });
148
+ setInterval(() =>
149
+ fetch("/queue_size")
150
+ .then((res) => res.json())
151
+ .then((data) => {
152
+ queueSizeEl.innerText = data.queue_size;
153
+ })
154
+ .catch((err) => {
155
+ console.log(err);
156
+ })
157
+ , 1000);
158
+ </script>
159
+ </head>
160
+
161
+ <body>
162
+ <div class="fixed right-2 top-2 p-4 font-bold text-sm rounded-lg max-w-xs text-center" id="error">
163
+ </div>
164
+ <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
165
+ <article class="text-center max-w-xl mx-auto">
166
+ <h1 class="text-3xl font-bold mb-4">Real-Time Latent Consistency Model</h1>
167
+
168
+ <p class="text-sm">
169
+ This demo showcases
170
+ <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
171
+ class="text-blue-500 hover:underline">LCM</a>
172
+ using
173
+ <a href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
174
+ target="_blank" class="text-blue-500 hover:underline">Diffusers</a> with a MJPEG
175
+ stream server.
176
+ </p>
177
+ <p class="text-sm">
178
+ To change settings or prompt, stop the current stream and start a new one.
179
+ </p>
180
+ <p class="text-sm">
181
+ There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
182
+ real-time performance. Maximum queue size is 4. <a
183
+ href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
184
+ target="_blank" class="text-blue-500 hover:underline">Duplicate</a> and run it on your own GPU.
185
+ </p>
186
+ </article>
187
+ <div>
188
+ <div class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
189
+ <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
190
+ title="Prompt" oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'"
191
+ placeholder="Add your prompt here...">Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5, cinematic, masterpiece</textarea>
192
+ </div>
193
+
194
+ </div>
195
+ <div class="">
196
+ <details>
197
+ <summary class="font-medium cursor-pointer">Advanced Options</summary>
198
+ <div class="grid grid-cols-3 max-w-md items-center gap-3 py-3">
199
+ <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
200
+ </label>
201
+ <input type="range" id="guidance-scale" name="guidance-scale" min="1" max="30" step="0.001"
202
+ value="8.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
203
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
204
+ 8.0</output>
205
+ <label class="text-sm font-medium" for="strength">Strength</label>
206
+ <input type="range" id="strength" name="strength" min="0" max="1" step="0.01" value="0.50"
207
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
208
+ <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
209
+ 0.5</output>
210
+ <label class="text-sm font-medium" for="seed">Seed</label>
211
+ <input type="number" id="seed" name="seed" value="299792458"
212
+ class="font-light border border-gray-700 text-right rounded-md p-2">
213
+ <button
214
+ onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
215
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
216
+ Rand
217
+ </button>
218
+ </div>
219
+ </details>
220
+ </div>
221
+ <div>
222
+ <button id="start"
223
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
224
+ Start
225
+ </button>
226
+ <button id="stop"
227
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
228
+ Stop
229
+ </button>
230
+ </div>
231
+ <div class="relative rounded-lg border border-slate-300 overflow-hidden">
232
+ <img id="player" class="w-full aspect-square rounded-lg "
233
+ src="">
234
+ <div class="absolute top-0 left-0 w-1/4 aspect-square">
235
+ <video id="webcam" class="w-full aspect-square relative z-10" playsinline autoplay muted loop></video>
236
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 448" width="100"
237
+ class="w-full p-4 absolute top-0 opacity-20 z-0">
238
+ <path fill="currentColor"
239
+ d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z" />
240
+ </svg>
241
+ </div>
242
+ </div>
243
+ </main>
244
+ </body>
245
+
246
+ </html>
public/tailwind.config.js ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ gradio
4
+ torch
5
+ fastapi
6
+ uvicorn
7
+ Pillow
8
+ accelerate