barreloflube commited on
Commit
9b41963
·
1 Parent(s): 6854514

Refactor get_prompt_attention function to remove unnecessary code

Browse files

Refactor get_prompt_attention function to include device parameter
Refactor progress tracking in generate_image function
Refactor scheduler loading in get_pipe function
Refactor image options in generate_image function

Files changed (1) hide show
  1. tabs/images/handlers.py +3 -3
tabs/images/handlers.py CHANGED
@@ -213,9 +213,9 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Prog
213
  progress(0.3, "Getting Prompt Embeddings")
214
  # Get Prompt Embeddings
215
  if isinstance(pipeline, flux_pipes):
216
- positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt, device=device)
217
  elif isinstance(pipeline, sd_pipes):
218
- positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt, device=device)
219
 
220
  progress(0.5, "Configuring Pipeline")
221
  # Common Args
@@ -236,7 +236,7 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Prog
236
  args['negative_pooled_prompt_embeds'] = negative_prompt_pooled
237
 
238
  if request.controlnet_config:
239
- args['control_images'] = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode)
240
  args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
241
 
242
  if request.controlnet_config and isinstance(pipeline, flux_pipes):
 
213
  progress(0.3, "Getting Prompt Embeddings")
214
  # Get Prompt Embeddings
215
  if isinstance(pipeline, flux_pipes):
216
+ positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt)
217
  elif isinstance(pipeline, sd_pipes):
218
+ positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt)
219
 
220
  progress(0.5, "Configuring Pipeline")
221
  # Common Args
 
236
  args['negative_pooled_prompt_embeds'] = negative_prompt_pooled
237
 
238
  if request.controlnet_config:
239
+ args['control_image'] = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode)
240
  args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
241
 
242
  if request.controlnet_config and isinstance(pipeline, flux_pipes):