barreloflube commited on
Commit
6854514
·
1 Parent(s): a5df595

Refactor get_prompt_attention function to remove unnecessary code

Browse files
Files changed (1) hide show
  1. tabs/images/handlers.py +12 -8
tabs/images/handlers.py CHANGED
@@ -185,13 +185,13 @@ def get_control_mode(controlnet_config: ControlNetReq):
185
  # return has_nsfw_concepts[1]
186
 
187
 
188
- def get_prompt_attention(pipeline, prompt, negative_prompt):
189
- if isinstance(pipeline, flux_pipes):
190
- prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt, device=device)
191
- return prompt_embeds, None, pooled_prompt_embeds, None
192
- elif isinstance(pipeline, sd_pipes):
193
- prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt, device=device)
194
- return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
195
 
196
 
197
  def cleanup(pipeline, loras = None, embeddings = None):
@@ -211,7 +211,11 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Prog
211
  pipeline = pipeline_args["pipeline"]
212
  try:
213
  progress(0.3, "Getting Prompt Embeddings")
214
- positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
 
 
 
 
215
 
216
  progress(0.5, "Configuring Pipeline")
217
  # Common Args
 
185
  # return has_nsfw_concepts[1]
186
 
187
 
188
+ # def get_prompt_attention(pipeline, prompt, negative_prompt):
189
+ # if isinstance(pipeline, flux_pipes):
190
+ # prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt, device=device)
191
+ # return prompt_embeds, None, pooled_prompt_embeds, None
192
+ # elif isinstance(pipeline, sd_pipes):
193
+ # prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt, device=device)
194
+ # return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
195
 
196
 
197
  def cleanup(pipeline, loras = None, embeddings = None):
 
211
  pipeline = pipeline_args["pipeline"]
212
  try:
213
  progress(0.3, "Getting Prompt Embeddings")
214
+ # Get Prompt Embeddings
215
+ if isinstance(pipeline, flux_pipes):
216
+ positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt, device=device)
217
+ elif isinstance(pipeline, sd_pipes):
218
+ positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt, device=device)
219
 
220
  progress(0.5, "Configuring Pipeline")
221
  # Common Args