radames commited on
Commit
ca39fe5
1 Parent(s): 9b122ea

avoid recalc prompt embeds

Browse files
Files changed (2) hide show
  1. app-img2img.py +16 -6
  2. app-txt2img.py +10 -6
app-img2img.py CHANGED
@@ -26,6 +26,8 @@ TIMEOUT = float(os.environ.get("TIMEOUT", 0))
26
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
27
  WIDTH = 512
28
  HEIGHT = 512
 
 
29
 
30
  # check if MPS is available OSX only M1/M2/M3 chips
31
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -58,9 +60,11 @@ else:
58
  custom_pipeline="latent_consistency_img2img.py",
59
  custom_revision="main",
60
  )
61
- pipe.vae = AutoencoderTiny.from_pretrained(
62
- "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
63
- )
 
 
64
  pipe.set_progress_bar_config(disable=True)
65
  pipe.to(torch_device=torch_device, torch_dtype=torch_dtype).to(device)
66
  pipe.unet.to(memory_format=torch.channels_last)
@@ -89,9 +93,8 @@ class InputParams(BaseModel):
89
  height: int = HEIGHT
90
 
91
 
92
- def predict(input_image: Image.Image, params: InputParams):
93
  generator = torch.manual_seed(params.seed)
94
- prompt_embeds = compel_proc(params.prompt)
95
  # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
96
  num_inference_steps = 3
97
  results = pipe(
@@ -173,18 +176,25 @@ async def stream(user_id: uuid.UUID):
173
  try:
174
  user_queue = user_queue_map[uid]
175
  queue = user_queue["queue"]
176
-
177
  async def generate():
 
 
178
  while True:
179
  data = await queue.get()
180
  input_image = data["image"]
181
  params = data["params"]
182
  if input_image is None:
183
  continue
 
 
 
 
 
184
 
185
  image = predict(
186
  input_image,
187
  params,
 
188
  )
189
  if image is None:
190
  continue
 
26
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
27
  WIDTH = 512
28
  HEIGHT = 512
29
+ # disable tiny autoencoder for better quality speed tradeoff
30
+ USE_TINY_AUTOENCODER=True
31
 
32
  # check if MPS is available OSX only M1/M2/M3 chips
33
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
 
60
  custom_pipeline="latent_consistency_img2img.py",
61
  custom_revision="main",
62
  )
63
+
64
+ if USE_TINY_AUTOENCODER:
65
+ pipe.vae = AutoencoderTiny.from_pretrained(
66
+ "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
67
+ )
68
  pipe.set_progress_bar_config(disable=True)
69
  pipe.to(torch_device=torch_device, torch_dtype=torch_dtype).to(device)
70
  pipe.unet.to(memory_format=torch.channels_last)
 
93
  height: int = HEIGHT
94
 
95
 
96
+ def predict(input_image: Image.Image, params: InputParams, prompt_embeds: torch.Tensor = None):
97
  generator = torch.manual_seed(params.seed)
 
98
  # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
99
  num_inference_steps = 3
100
  results = pipe(
 
176
  try:
177
  user_queue = user_queue_map[uid]
178
  queue = user_queue["queue"]
 
179
  async def generate():
180
+ last_prompt: str = None
181
+ prompt_embeds: torch.Tensor = None
182
  while True:
183
  data = await queue.get()
184
  input_image = data["image"]
185
  params = data["params"]
186
  if input_image is None:
187
  continue
188
+ # avoid recalculate prompt embeds
189
+ if last_prompt != params.prompt:
190
+ print("new prompt")
191
+ prompt_embeds = compel_proc(params.prompt)
192
+ last_prompt = params.prompt
193
 
194
  image = predict(
195
  input_image,
196
  params,
197
+ prompt_embeds,
198
  )
199
  if image is None:
200
  continue
app-txt2img.py CHANGED
@@ -27,6 +27,9 @@ TIMEOUT = float(os.environ.get("TIMEOUT", 0))
27
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
28
  WIDTH = 512
29
  HEIGHT = 512
 
 
 
30
  # check if MPS is available OSX only M1/M2/M3 chips
31
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -57,9 +60,10 @@ else:
57
  custom_pipeline="latent_consistency_txt2img.py",
58
  custom_revision="main",
59
  )
60
- pipe.vae = AutoencoderTiny.from_pretrained(
61
- "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
62
- )
 
63
  pipe.set_progress_bar_config(disable=True)
64
  pipe.to(torch_device=torch_device, torch_dtype=torch_dtype).to(device)
65
  pipe.unet.to(memory_format=torch.channels_last)
@@ -68,9 +72,9 @@ pipe.unet.to(memory_format=torch.channels_last)
68
  if psutil.virtual_memory().total < 64 * 1024**3:
69
  pipe.enable_attention_slicing()
70
 
71
- # if not mps_available:
72
- # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
73
- # pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
74
 
75
  compel_proc = Compel(
76
  tokenizer=pipe.tokenizer,
 
27
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
28
  WIDTH = 512
29
  HEIGHT = 512
30
+ # disable tiny autoencoder for better quality speed tradeoff
31
+ USE_TINY_AUTOENCODER=True
32
+
33
  # check if MPS is available OSX only M1/M2/M3 chips
34
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
60
  custom_pipeline="latent_consistency_txt2img.py",
61
  custom_revision="main",
62
  )
63
+ if USE_TINY_AUTOENCODER:
64
+ pipe.vae = AutoencoderTiny.from_pretrained(
65
+ "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
66
+ )
67
  pipe.set_progress_bar_config(disable=True)
68
  pipe.to(torch_device=torch_device, torch_dtype=torch_dtype).to(device)
69
  pipe.unet.to(memory_format=torch.channels_last)
 
72
  if psutil.virtual_memory().total < 64 * 1024**3:
73
  pipe.enable_attention_slicing()
74
 
75
+ if not mps_available:
76
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
77
+ pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
78
 
79
  compel_proc = Compel(
80
  tokenizer=pipe.tokenizer,