radames commited on
Commit
a0a2ed9
1 Parent(s): 8d841f8

Upload 21 files

Browse files
server/config.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import dataclass, field
2
- from typing import List
3
 
4
  import torch
5
  import os
@@ -24,8 +24,9 @@ class Config:
24
  ####################################################################
25
  # Model configuration
26
  ####################################################################
 
27
  # SD1.x variant model
28
- model_id: str = "SimianLuo/LCM_Dreamshaper_v7"
29
  # LCM-LORA model
30
  lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
31
  # TinyVAE model
@@ -34,6 +35,8 @@ class Config:
34
  device: torch.device = torch.device("cuda")
35
  # Data type
36
  dtype: torch.dtype = torch.float16
 
 
37
 
38
  ####################################################################
39
  # Inference configuration
@@ -42,4 +45,4 @@ class Config:
42
  t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
43
  # Number of warmup steps
44
  warmup: int = 10
45
- safety_checker: bool = SAFETY_CHECKER
 
1
  from dataclasses import dataclass, field
2
+ from typing import List, Literal
3
 
4
  import torch
5
  import os
 
24
  ####################################################################
25
  # Model configuration
26
  ####################################################################
27
+ mode: Literal["txt2img", "img2img"] = "txt2img"
28
  # SD1.x variant model
29
+ model_id: str = "KBlueLeaf/kohaku-v2.1"
30
  # LCM-LORA model
31
  lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
32
  # TinyVAE model
 
35
  device: torch.device = torch.device("cuda")
36
  # Data type
37
  dtype: torch.dtype = torch.float16
38
+ # acceleration
39
+ acceleration: Literal["none", "xformers", "sfast", "tensorrt"] = "xformers"
40
 
41
  ####################################################################
42
  # Inference configuration
 
45
  t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
46
  # Number of warmup steps
47
  warmup: int = 10
48
+ use_safety_checker: bool = SAFETY_CHECKER
server/main.py CHANGED
@@ -55,14 +55,16 @@ class Api:
55
  """
56
  self.config = config
57
  self.stream_diffusion = StreamDiffusionWrapper(
 
58
  model_id=config.model_id,
59
  lcm_lora_id=config.lcm_lora_id,
60
  vae_id=config.vae_id,
61
  device=config.device,
62
  dtype=config.dtype,
 
63
  t_index_list=config.t_index_list,
64
  warmup=config.warmup,
65
- safety_checker=config.safety_checker,
66
  )
67
  self.app = FastAPI()
68
  self.app.add_api_route(
@@ -85,8 +87,6 @@ class Api:
85
  self._predict_lock = asyncio.Lock()
86
  self._update_prompt_lock = asyncio.Lock()
87
 
88
- self.last_prompt: str = ""
89
-
90
  async def _predict(self, inp: PredictInputModel) -> PredictResponseModel:
91
  """
92
  Predict an image and return.
@@ -102,7 +102,9 @@ class Api:
102
  The prediction result.
103
  """
104
  async with self._predict_lock:
105
- return PredictResponseModel(base64_image=self._pil_to_base64(self.stream_diffusion(inp.prompt)))
 
 
106
 
107
  def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes:
108
  """
 
55
  """
56
  self.config = config
57
  self.stream_diffusion = StreamDiffusionWrapper(
58
+ mode=config.mode,
59
  model_id=config.model_id,
60
  lcm_lora_id=config.lcm_lora_id,
61
  vae_id=config.vae_id,
62
  device=config.device,
63
  dtype=config.dtype,
64
+ acceleration=config.acceleration,
65
  t_index_list=config.t_index_list,
66
  warmup=config.warmup,
67
+ use_safety_checker=config.use_safety_checker,
68
  )
69
  self.app = FastAPI()
70
  self.app.add_api_route(
 
87
  self._predict_lock = asyncio.Lock()
88
  self._update_prompt_lock = asyncio.Lock()
89
 
 
 
90
  async def _predict(self, inp: PredictInputModel) -> PredictResponseModel:
91
  """
92
  Predict an image and return.
 
102
  The prediction result.
103
  """
104
  async with self._predict_lock:
105
+ return PredictResponseModel(
106
+ base64_image=self._pil_to_base64(self.stream_diffusion(prompt=inp.prompt))
107
+ )
108
 
109
  def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes:
110
  """
server/requirements.txt CHANGED
@@ -2,7 +2,6 @@ xformers
2
  uvicorn[standard]==0.24.0.post1
3
  fastapi==0.104
4
  accelerate
5
- git+https://github.com/huggingface/diffusers@781775ea56160a6dea3d53fd5005d0d7fca5f10a
6
  # git+https://github.com/cumulo-autumn/StreamDiffusion.git@main#egg=stream-diffusion
7
  --extra-index-url https://download.pytorch.org/whl/cu121
8
  torch
@@ -10,5 +9,4 @@ torchvision
10
  torchaudio
11
  triton
12
  # https://github.com/chengzeyi/stable-fast --index-url https://download.pytorch.org/whl/cu121
13
- # https://github.com/chengzeyi/stable-fast/releases/download/v0.0.14/stable_fast-0.0.14+torch210cu121-cp310-cp310-manylinux2014_x86_64.whl
14
- https://github.com/chengzeyi/stable-fast/releases/download/v0.0.15.post1/stable_fast-0.0.15.post1+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl
 
2
  uvicorn[standard]==0.24.0.post1
3
  fastapi==0.104
4
  accelerate
 
5
  # git+https://github.com/cumulo-autumn/StreamDiffusion.git@main#egg=stream-diffusion
6
  --extra-index-url https://download.pytorch.org/whl/cu121
7
  torch
 
9
  torchaudio
10
  triton
11
  # https://github.com/chengzeyi/stable-fast --index-url https://download.pytorch.org/whl/cu121
12
+ https://github.com/chengzeyi/stable-fast/releases/download/v0.0.14/stable_fast-0.0.14+torch210cu121-cp310-cp310-manylinux2014_x86_64.whl
 
server/wrapper.py CHANGED
@@ -1,156 +1,529 @@
1
- import io
2
  import os
3
- from typing import List
 
4
 
5
- import PIL.Image
6
- import requests
7
  import torch
8
  from diffusers import AutoencoderTiny, StableDiffusionPipeline
 
 
9
 
10
  from streamdiffusion import StreamDiffusion
11
  from streamdiffusion.image_utils import postprocess_image
12
 
13
-
14
- def download_image(url: str):
15
- response = requests.get(url)
16
- image = PIL.Image.open(io.BytesIO(response.content))
17
- return image
18
 
19
 
20
  class StreamDiffusionWrapper:
21
  def __init__(
22
  self,
23
  model_id: str,
24
- lcm_lora_id: str,
25
- vae_id: str,
26
- device: str,
27
- dtype: str,
28
  t_index_list: List[int],
29
- warmup: int,
30
- safety_checker: bool,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  self.device = device
33
  self.dtype = dtype
34
- self.prompt = ""
35
- self.batch_size = len(t_index_list)
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  self.stream = self._load_model(
38
  model_id=model_id,
39
  lcm_lora_id=lcm_lora_id,
40
  vae_id=vae_id,
41
  t_index_list=t_index_list,
 
42
  warmup=warmup,
 
 
 
 
43
  )
44
- self.safety_checker = None
45
- if safety_checker:
46
- from transformers import CLIPFeatureExtractor
47
- from diffusers.pipelines.stable_diffusion.safety_checker import (
48
- StableDiffusionSafetyChecker,
49
  )
50
 
51
- self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
52
- "CompVis/stable-diffusion-safety-checker"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  ).to(self.device)
54
- self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
55
- "openai/clip-vit-base-patch32"
 
56
  )
57
- self.nsfw_fallback_img = PIL.Image.new("RGB", (512, 512), (0, 0, 0))
58
- self.stream.prepare("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def _load_model(
61
  self,
62
  model_id: str,
63
- lcm_lora_id: str,
64
- vae_id: str,
65
  t_index_list: List[int],
66
- warmup: int,
 
 
 
 
 
 
 
67
  ):
68
- if os.path.exists(model_id):
69
- pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(
70
- model_id
71
- ).to(device=self.device, dtype=self.dtype)
72
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
 
 
 
 
 
74
  model_id
75
  ).to(device=self.device, dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  stream = StreamDiffusion(
78
  pipe=pipe,
79
  t_index_list=t_index_list,
80
  torch_dtype=self.dtype,
81
- is_drawing=True,
82
- )
83
- stream.load_lcm_lora(lcm_lora_id)
84
- stream.fuse_lora()
85
- stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(
86
- device=pipe.device, dtype=pipe.dtype
87
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  try:
90
- from streamdiffusion.acceleration.tensorrt import accelerate_with_tensorrt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- stream = accelerate_with_tensorrt(
93
- stream,
94
- "engines",
95
- max_batch_size=self.batch_size,
96
- engine_build_options={"build_static_batch": False},
97
- )
98
- print("TensorRT acceleration enabled.")
99
- except Exception:
100
- print("TensorRT acceleration has failed. Trying to use Stable Fast.")
101
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  from streamdiffusion.acceleration.sfast import (
103
  accelerate_with_stable_fast,
104
  )
105
 
106
  stream = accelerate_with_stable_fast(stream)
107
  print("StableFast acceleration enabled.")
108
- except Exception:
109
- print("StableFast acceleration has failed. Using normal mode.")
110
- pass
111
 
112
  stream.prepare(
 
113
  "",
114
  num_inference_steps=50,
 
 
 
115
  generator=torch.manual_seed(2),
116
  )
117
 
118
- # warmup
119
- for _ in range(warmup):
120
- start = torch.cuda.Event(enable_timing=True)
121
- end = torch.cuda.Event(enable_timing=True)
122
-
123
- start.record()
124
- stream.txt2img()
125
- end.record()
126
-
127
- torch.cuda.synchronize()
128
-
129
  return stream
130
-
131
- def __call__(self, prompt: str) -> PIL.Image.Image:
132
- if self.prompt != prompt:
133
- self.stream.update_prompt(prompt)
134
- self.prompt = prompt
135
- for i in range(self.batch_size):
136
- x_output = self.stream.txt2img()
137
-
138
- x_output = self.stream.txt2img()
139
- image = postprocess_image(x_output, output_type="pil")[0]
140
-
141
- if self.safety_checker:
142
- safety_checker_input = self.feature_extractor(
143
- image, return_tensors="pt"
144
- ).to(self.device)
145
- _, has_nsfw_concept = self.safety_checker(
146
- images=x_output,
147
- clip_input=safety_checker_input.pixel_values.to(self.dtype),
148
- )
149
- image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
150
-
151
- return image
152
-
153
-
154
- if __name__ == "__main__":
155
- wrapper = StreamDiffusionWrapper(10, 10)
156
- wrapper()
 
1
+ import gc
2
  import os
3
+ import traceback
4
+ from typing import List, Literal, Optional, Union
5
 
6
+ import numpy as np
 
7
  import torch
8
  from diffusers import AutoencoderTiny, StableDiffusionPipeline
9
+ from PIL import Image
10
+ from polygraphy import cuda
11
 
12
  from streamdiffusion import StreamDiffusion
13
  from streamdiffusion.image_utils import postprocess_image
14
 
15
+ torch.set_grad_enabled(False)
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.backends.cudnn.allow_tf32 = True
 
 
18
 
19
 
20
  class StreamDiffusionWrapper:
21
  def __init__(
22
  self,
23
  model_id: str,
 
 
 
 
24
  t_index_list: List[int],
25
+ mode: Literal["img2img", "txt2img"] = "img2img",
26
+ output_type: Literal["pil", "pt", "np", "latent"] = "pil",
27
+ lcm_lora_id: Optional[str] = None,
28
+ vae_id: Optional[str] = None,
29
+ device: Literal["cpu", "cuda"] = "cuda",
30
+ dtype: torch.dtype = torch.float16,
31
+ frame_buffer_size: int = 1,
32
+ width: int = 512,
33
+ height: int = 512,
34
+ warmup: int = 10,
35
+ acceleration: Literal["none", "xformers", "sfast", "tensorrt"] = "xformers",
36
+ is_drawing: bool = True,
37
+ device_ids: Optional[List[int]] = None,
38
+ use_lcm_lora: bool = True,
39
+ use_tiny_vae: bool = True,
40
+ enable_similar_image_filter: bool = False,
41
+ similar_image_filter_threshold: float = 0.98,
42
+ use_denoising_batch: bool = True,
43
+ cfg_type: Literal["none", "full", "self", "initialize"] = "none",
44
+ use_safety_checker: bool = False,
45
  ):
46
+ if mode == "txt2img":
47
+ if cfg_type != "none":
48
+ raise ValueError(
49
+ f"txt2img mode accepts only cfg_type = 'none', but got {cfg_type}"
50
+ )
51
+ if use_denoising_batch and frame_buffer_size > 1:
52
+ raise ValueError(
53
+ "txt2img mode cannot use denoising batch with frame_buffer_size > 1."
54
+ )
55
+
56
+ if mode == "img2img":
57
+ if not use_denoising_batch:
58
+ raise NotImplementedError(
59
+ "img2img mode must use denoising batch for now."
60
+ )
61
+
62
+ self.sd_turbo = "turbo" in model_id
63
  self.device = device
64
  self.dtype = dtype
65
+ self.width = width
66
+ self.height = height
67
+ self.mode = mode
68
+ self.output_type = output_type
69
+ self.frame_buffer_size = frame_buffer_size
70
+ self.batch_size = (
71
+ len(t_index_list) * frame_buffer_size
72
+ if use_denoising_batch
73
+ else frame_buffer_size
74
+ )
75
+
76
+ self.use_denoising_batch = use_denoising_batch
77
+ self.use_safety_checker = use_safety_checker
78
 
79
  self.stream = self._load_model(
80
  model_id=model_id,
81
  lcm_lora_id=lcm_lora_id,
82
  vae_id=vae_id,
83
  t_index_list=t_index_list,
84
+ acceleration=acceleration,
85
  warmup=warmup,
86
+ is_drawing=is_drawing,
87
+ use_lcm_lora=use_lcm_lora,
88
+ use_tiny_vae=use_tiny_vae,
89
+ cfg_type=cfg_type,
90
  )
91
+
92
+ if device_ids is not None:
93
+ self.stream.unet = torch.nn.DataParallel(
94
+ self.stream.unet, device_ids=device_ids
 
95
  )
96
 
97
+ if enable_similar_image_filter:
98
+ self.stream.enable_similar_image_filter(similar_image_filter_threshold)
99
+
100
+ def prepare(
101
+ self,
102
+ prompt: str,
103
+ negative_prompt: str = "",
104
+ num_inference_steps: int = 50,
105
+ guidance_scale: float = 1.2,
106
+ delta: float = 1.0,
107
+ ) -> None:
108
+ """
109
+ Prepares the model for inference.
110
+
111
+ Parameters
112
+ ----------
113
+ prompt : str
114
+ The prompt to generate images from.
115
+ num_inference_steps : int, optional
116
+ The number of inference steps to perform, by default 50.
117
+ """
118
+ self.stream.prepare(
119
+ prompt,
120
+ negative_prompt,
121
+ num_inference_steps=num_inference_steps,
122
+ guidance_scale=guidance_scale,
123
+ delta=delta,
124
+ )
125
+
126
+ def __call__(
127
+ self,
128
+ image: Optional[Union[str, Image.Image, torch.Tensor]] = None,
129
+ prompt: Optional[str] = None,
130
+ ) -> Union[Image.Image, List[Image.Image]]:
131
+ """
132
+ Performs img2img or txt2img based on the mode.
133
+
134
+ Parameters
135
+ ----------
136
+ image : Optional[Union[str, Image.Image, torch.Tensor]]
137
+ The image to generate from.
138
+ prompt : Optional[str]
139
+ The prompt to generate images from.
140
+
141
+ Returns
142
+ -------
143
+ Union[Image.Image, List[Image.Image]]
144
+ The generated image.
145
+ """
146
+ if self.mode == "img2img":
147
+ return self.img2img(image)
148
+ else:
149
+ return self.txt2img(prompt)
150
+
151
+ def txt2img(
152
+ self, prompt: str
153
+ ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
154
+ """
155
+ Performs txt2img.
156
+
157
+ Parameters
158
+ ----------
159
+ prompt : str
160
+ The prompt to generate images from.
161
+
162
+ Returns
163
+ -------
164
+ Union[Image.Image, List[Image.Image]]
165
+ The generated image.
166
+ """
167
+ self.stream.update_prompt(prompt)
168
+
169
+ if self.sd_turbo:
170
+ image_tensor = self.stream.txt2img_sd_turbo(self.batch_size)
171
+ else:
172
+ image_tensor = self.stream.txt2img(self.frame_buffer_size)
173
+ image = self.postprocess_image(image_tensor, output_type=self.output_type)
174
+
175
+ if self.use_safety_checker:
176
+ safety_checker_input = self.feature_extractor(
177
+ image, return_tensors="pt"
178
  ).to(self.device)
179
+ _, has_nsfw_concept = self.safety_checker(
180
+ images=image_tensor.to(self.dtype),
181
+ clip_input=safety_checker_input.pixel_values.to(self.dtype),
182
  )
183
+ image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
184
+
185
+ return image
186
+
187
+ def img2img(
188
+ self, image: Union[str, Image.Image, torch.Tensor]
189
+ ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
190
+ """
191
+ Performs img2img.
192
+
193
+ Parameters
194
+ ----------
195
+ image : Union[str, Image.Image, torch.Tensor]
196
+ The image to generate from.
197
+
198
+ Returns
199
+ -------
200
+ Image.Image
201
+ The generated image.
202
+ """
203
+ if isinstance(image, str) or isinstance(image, Image.Image):
204
+ image = self.preprocess_image(image)
205
+
206
+ image_tensor = self.stream(image)
207
+ return self.postprocess_image(image_tensor, output_type=self.output_type)
208
+
209
+ def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor:
210
+ """
211
+ Preprocesses the image.
212
+
213
+ Parameters
214
+ ----------
215
+ image : Union[str, Image.Image, torch.Tensor]
216
+ The image to preprocess.
217
+
218
+ Returns
219
+ -------
220
+ torch.Tensor
221
+ The preprocessed image.
222
+ """
223
+ if isinstance(image, str):
224
+ image = Image.open(image).convert("RGB").resize((self.width, self.height))
225
+ if isinstance(image, Image.Image):
226
+ image = image.convert("RGB").resize((self.width, self.height))
227
+
228
+ return self.stream.image_processor.preprocess(
229
+ image, self.height, self.width
230
+ ).to(device=self.device, dtype=self.dtype)
231
+
232
+ def postprocess_image(
233
+ self, image_tensor: torch.Tensor, output_type: str = "pil"
234
+ ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
235
+ """
236
+ Postprocesses the image.
237
+
238
+ Parameters
239
+ ----------
240
+ image_tensor : torch.Tensor
241
+ The image tensor to postprocess.
242
+
243
+ Returns
244
+ -------
245
+ Union[Image.Image, List[Image.Image]]
246
+ The postprocessed image.
247
+ """
248
+ if self.frame_buffer_size > 1:
249
+ return postprocess_image(image_tensor.cpu(), output_type=output_type)
250
+ else:
251
+ return postprocess_image(image_tensor.cpu(), output_type=output_type)[0]
252
 
253
  def _load_model(
254
  self,
255
  model_id: str,
 
 
256
  t_index_list: List[int],
257
+ lcm_lora_id: Optional[str] = None,
258
+ vae_id: Optional[str] = None,
259
+ acceleration: Literal["none", "sfast", "tensorrt"] = "tensorrt",
260
+ is_drawing: bool = True,
261
+ warmup: int = 10,
262
+ use_lcm_lora: bool = True,
263
+ use_tiny_vae: bool = True,
264
+ cfg_type: Literal["none", "full", "self", "initialize"] = "self",
265
  ):
266
+ """
267
+ Loads the model.
268
+
269
+ This method does the following:
270
+
271
+ 1. Loads the model from the model_id.
272
+ 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed.
273
+ 3. Loads the VAE model from the vae_id if needed.
274
+ 4. Enables acceleration if needed.
275
+ 5. Prepares the model for inference.
276
+ 6. Warms up the model.
277
+
278
+ Parameters
279
+ ----------
280
+ model_id : str
281
+ The model id to load.
282
+ t_index_list : List[int]
283
+ The t_index_list to use for inference.
284
+ lcm_lora_id : Optional[str], optional
285
+ The lcm_lora_id to load, by default None.
286
+ vae_id : Optional[str], optional
287
+ The vae_id to load, by default None.
288
+ acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional
289
+ The acceleration method to use, by default "tensorrt".
290
+ warmup : int, optional
291
+ The number of warmup steps to perform, by default 10.
292
+ is_drawing : bool, optional
293
+ Whether to draw the image or not, by default True.
294
+ use_lcm_lora : bool, optional
295
+ Whether to use LCM-LoRA or not, by default True.
296
+ use_tiny_vae : bool, optional
297
+ Whether to use TinyVAE or not, by default True.
298
+ cfg_type : Literal["none", "full", "self", "initialize"], optional
299
+ The cfg_type to use, by default "self".
300
+ """
301
+
302
+ try: # Load from local directory
303
  pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
304
+ model_id,
305
+ ).to(device=self.device, dtype=self.dtype)
306
+
307
+ except ValueError: # Load from huggingface
308
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(
309
  model_id
310
  ).to(device=self.device, dtype=self.dtype)
311
+ except Exception: # No model found
312
+ traceback.print_exc()
313
+ print("Model load has failed. Doesn't exist.")
314
+ exit()
315
+
316
+ if self.use_safety_checker:
317
+ from transformers import CLIPFeatureExtractor
318
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
319
+ StableDiffusionSafetyChecker,
320
+ )
321
+
322
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
323
+ "CompVis/stable-diffusion-safety-checker"
324
+ ).to(pipe.device)
325
+ self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
326
+ "openai/clip-vit-base-patch32"
327
+ )
328
+ self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0))
329
 
330
  stream = StreamDiffusion(
331
  pipe=pipe,
332
  t_index_list=t_index_list,
333
  torch_dtype=self.dtype,
334
+ width=self.width,
335
+ height=self.height,
336
+ is_drawing=is_drawing,
337
+ frame_buffer_size=self.frame_buffer_size,
338
+ use_denoising_batch=self.use_denoising_batch,
339
+ cfg_type=cfg_type,
340
  )
341
+ if not self.sd_turbo:
342
+ if use_lcm_lora:
343
+ if lcm_lora_id is not None:
344
+ stream.load_lcm_lora(
345
+ pretrained_model_name_or_path_or_dict=lcm_lora_id
346
+ )
347
+ else:
348
+ stream.load_lcm_lora()
349
+ stream.fuse_lora()
350
+
351
+ if use_tiny_vae:
352
+ if vae_id is not None:
353
+ stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(
354
+ device=pipe.device, dtype=pipe.dtype
355
+ )
356
+ else:
357
+ stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(
358
+ device=pipe.device, dtype=pipe.dtype
359
+ )
360
 
361
  try:
362
+ if acceleration == "xformers":
363
+ stream.pipe.enable_xformers_memory_efficient_attention()
364
+ if acceleration == "tensorrt":
365
+ from streamdiffusion.acceleration.tensorrt import (
366
+ TorchVAEEncoder,
367
+ compile_unet,
368
+ compile_vae_decoder,
369
+ compile_vae_encoder,
370
+ )
371
+ from streamdiffusion.acceleration.tensorrt.engine import (
372
+ AutoencoderKLEngine,
373
+ UNet2DConditionModelEngine,
374
+ )
375
+ from streamdiffusion.acceleration.tensorrt.models import (
376
+ VAE,
377
+ UNet,
378
+ VAEEncoder,
379
+ )
380
 
381
+ def create_prefix(
382
+ max_batch_size: int,
383
+ min_batch_size: int,
384
+ ):
385
+ return f"{model_id}--lcm_lora-{use_tiny_vae}--tiny_vae-{use_lcm_lora}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}"
386
+
387
+ engine_dir = os.path.join("engines")
388
+ unet_path = os.path.join(
389
+ engine_dir,
390
+ create_prefix(
391
+ stream.trt_unet_batch_size, stream.trt_unet_batch_size
392
+ ),
393
+ "unet.engine",
394
+ )
395
+ vae_encoder_path = os.path.join(
396
+ engine_dir,
397
+ create_prefix(
398
+ self.batch_size
399
+ if self.mode == "txt2img"
400
+ else stream.frame_bff_size,
401
+ self.batch_size
402
+ if self.mode == "txt2img"
403
+ else stream.frame_bff_size,
404
+ ),
405
+ "vae_encoder.engine",
406
+ )
407
+ vae_decoder_path = os.path.join(
408
+ engine_dir,
409
+ create_prefix(
410
+ self.batch_size
411
+ if self.mode == "txt2img"
412
+ else stream.frame_bff_size,
413
+ self.batch_size
414
+ if self.mode == "txt2img"
415
+ else stream.frame_bff_size,
416
+ ),
417
+ "vae_decoder.engine",
418
+ )
419
+
420
+ if not os.path.exists(unet_path):
421
+ os.makedirs(os.path.dirname(unet_path), exist_ok=True)
422
+ unet_model = UNet(
423
+ fp16=True,
424
+ device=stream.device,
425
+ max_batch_size=stream.trt_unet_batch_size,
426
+ min_batch_size=stream.trt_unet_batch_size,
427
+ embedding_dim=stream.text_encoder.config.hidden_size,
428
+ unet_dim=stream.unet.config.in_channels,
429
+ )
430
+ compile_unet(
431
+ stream.unet,
432
+ unet_model,
433
+ unet_path + ".onnx",
434
+ unet_path + ".opt.onnx",
435
+ unet_path,
436
+ opt_batch_size=stream.trt_unet_batch_size,
437
+ )
438
+
439
+ if not os.path.exists(vae_decoder_path):
440
+ os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True)
441
+ stream.vae.forward = stream.vae.decode
442
+ vae_decoder_model = VAE(
443
+ device=stream.device,
444
+ max_batch_size=self.batch_size
445
+ if self.mode == "txt2img"
446
+ else stream.frame_bff_size,
447
+ min_batch_size=self.batch_size
448
+ if self.mode == "txt2img"
449
+ else stream.frame_bff_size,
450
+ )
451
+ compile_vae_decoder(
452
+ stream.vae,
453
+ vae_decoder_model,
454
+ vae_decoder_path + ".onnx",
455
+ vae_decoder_path + ".opt.onnx",
456
+ vae_decoder_path,
457
+ opt_batch_size=self.batch_size
458
+ if self.mode == "txt2img"
459
+ else stream.frame_bff_size,
460
+ )
461
+ delattr(stream.vae, "forward")
462
+
463
+ if not os.path.exists(vae_encoder_path):
464
+ os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True)
465
+ vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda"))
466
+ vae_encoder_model = VAEEncoder(
467
+ device=stream.device,
468
+ max_batch_size=self.batch_size
469
+ if self.mode == "txt2img"
470
+ else stream.frame_bff_size,
471
+ min_batch_size=self.batch_size
472
+ if self.mode == "txt2img"
473
+ else stream.frame_bff_size,
474
+ )
475
+ compile_vae_encoder(
476
+ vae_encoder,
477
+ vae_encoder_model,
478
+ vae_encoder_path + ".onnx",
479
+ vae_encoder_path + ".opt.onnx",
480
+ vae_encoder_path,
481
+ opt_batch_size=self.batch_size
482
+ if self.mode == "txt2img"
483
+ else stream.frame_bff_size,
484
+ )
485
+
486
+ cuda_steram = cuda.Stream()
487
+
488
+ vae_config = stream.vae.config
489
+ vae_dtype = stream.vae.dtype
490
+
491
+ stream.unet = UNet2DConditionModelEngine(
492
+ unet_path, cuda_steram, use_cuda_graph=False
493
+ )
494
+ stream.vae = AutoencoderKLEngine(
495
+ vae_encoder_path,
496
+ vae_decoder_path,
497
+ cuda_steram,
498
+ stream.pipe.vae_scale_factor,
499
+ use_cuda_graph=False,
500
+ )
501
+ setattr(stream.vae, "config", vae_config)
502
+ setattr(stream.vae, "dtype", vae_dtype)
503
+
504
+ gc.collect()
505
+ torch.cuda.empty_cache()
506
+
507
+ print("TensorRT acceleration enabled.")
508
+ if acceleration == "sfast":
509
  from streamdiffusion.acceleration.sfast import (
510
  accelerate_with_stable_fast,
511
  )
512
 
513
  stream = accelerate_with_stable_fast(stream)
514
  print("StableFast acceleration enabled.")
515
+ except Exception:
516
+ traceback.print_exc()
517
+ print("Acceleration has failed. Falling back to normal mode.")
518
 
519
  stream.prepare(
520
+ "",
521
  "",
522
  num_inference_steps=50,
523
+ guidance_scale=1.1
524
+ if stream.cfg_type in ["full", "self", "initialize"]
525
+ else 1.0,
526
  generator=torch.manual_seed(2),
527
  )
528
 
 
 
 
 
 
 
 
 
 
 
 
529
  return stream
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
start.sh CHANGED
@@ -1,2 +1,4 @@
 
 
1
  cd view && npm run build && cd ..
2
- cd server && python3 main.py
 
1
+ #!/bin/bash
2
+ pip install -r requirements.txt
3
  cd view && npm run build && cd ..
4
+ cd server && python3 main.py
view/.DS_Store CHANGED
Binary files a/view/.DS_Store and b/view/.DS_Store differ
 
view/src/App.tsx CHANGED
@@ -38,7 +38,7 @@ function App() {
38
  const fetchImage = useCallback(
39
  async (index: number) => {
40
  try {
41
- const response = await fetch("/api/predict", {
42
  method: "POST",
43
  headers: { "Content-Type": "application/json" },
44
  body: JSON.stringify({ prompt: inputPrompt }),
@@ -63,7 +63,7 @@ function App() {
63
  const newPrompt = event.target.value;
64
  const editDistance = calculateEditDistance(lastPrompt, newPrompt);
65
 
66
- if (editDistance >= 2) {
67
  setInputPrompt(newPrompt);
68
  setLastPrompt(newPrompt);
69
  for (let i = 0; i < 16; i++) {
@@ -98,7 +98,7 @@ function App() {
98
  <Grid
99
  container
100
  spacing={1}
101
- style={{ maxWidth: "50%", maxHeight: "70%" }}
102
  >
103
  {images.map((image, index) => (
104
  <Grid item xs={3} key={index}>
@@ -106,6 +106,8 @@ function App() {
106
  src={image}
107
  alt={`Generated ${index}`}
108
  style={{
 
 
109
  maxWidth: "100%",
110
  maxHeight: "150px",
111
  borderRadius: "10px",
@@ -121,7 +123,8 @@ function App() {
121
  style={{
122
  marginBottom: "20px",
123
  marginTop: "20px",
124
- width: "640px",
 
125
  color: "#ffffff",
126
  borderColor: "#ffffff",
127
  borderRadius: "10px",
 
38
  const fetchImage = useCallback(
39
  async (index: number) => {
40
  try {
41
+ const response = await fetch("api/predict", {
42
  method: "POST",
43
  headers: { "Content-Type": "application/json" },
44
  body: JSON.stringify({ prompt: inputPrompt }),
 
63
  const newPrompt = event.target.value;
64
  const editDistance = calculateEditDistance(lastPrompt, newPrompt);
65
 
66
+ if (editDistance >= 4) {
67
  setInputPrompt(newPrompt);
68
  setLastPrompt(newPrompt);
69
  for (let i = 0; i < 16; i++) {
 
98
  <Grid
99
  container
100
  spacing={1}
101
+ style={{ maxWidth: "60rem", maxHeight: "70%" }}
102
  >
103
  {images.map((image, index) => (
104
  <Grid item xs={3} key={index}>
 
106
  src={image}
107
  alt={`Generated ${index}`}
108
  style={{
109
+ display: "block",
110
+ margin: "0 auto",
111
  maxWidth: "100%",
112
  maxHeight: "150px",
113
  borderRadius: "10px",
 
123
  style={{
124
  marginBottom: "20px",
125
  marginTop: "20px",
126
+ width: "100%",
127
+ maxWidth: "50rem",
128
  color: "#ffffff",
129
  borderColor: "#ffffff",
130
  borderRadius: "10px",