Commit
·
6854514
1
Parent(s):
a5df595
Refactor get_prompt_attention function to remove unnecessary code
Browse files- tabs/images/handlers.py +12 -8
tabs/images/handlers.py
CHANGED
@@ -185,13 +185,13 @@ def get_control_mode(controlnet_config: ControlNetReq):
|
|
185 |
# return has_nsfw_concepts[1]
|
186 |
|
187 |
|
188 |
-
def get_prompt_attention(pipeline, prompt, negative_prompt):
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
|
196 |
|
197 |
def cleanup(pipeline, loras = None, embeddings = None):
|
@@ -211,7 +211,11 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Prog
|
|
211 |
pipeline = pipeline_args["pipeline"]
|
212 |
try:
|
213 |
progress(0.3, "Getting Prompt Embeddings")
|
214 |
-
|
|
|
|
|
|
|
|
|
215 |
|
216 |
progress(0.5, "Configuring Pipeline")
|
217 |
# Common Args
|
|
|
185 |
# return has_nsfw_concepts[1]
|
186 |
|
187 |
|
188 |
+
# def get_prompt_attention(pipeline, prompt, negative_prompt):
|
189 |
+
# if isinstance(pipeline, flux_pipes):
|
190 |
+
# prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt, device=device)
|
191 |
+
# return prompt_embeds, None, pooled_prompt_embeds, None
|
192 |
+
# elif isinstance(pipeline, sd_pipes):
|
193 |
+
# prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt, device=device)
|
194 |
+
# return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
195 |
|
196 |
|
197 |
def cleanup(pipeline, loras = None, embeddings = None):
|
|
|
211 |
pipeline = pipeline_args["pipeline"]
|
212 |
try:
|
213 |
progress(0.3, "Getting Prompt Embeddings")
|
214 |
+
# Get Prompt Embeddings
|
215 |
+
if isinstance(pipeline, flux_pipes):
|
216 |
+
positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt, device=device)
|
217 |
+
elif isinstance(pipeline, sd_pipes):
|
218 |
+
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt, device=device)
|
219 |
|
220 |
progress(0.5, "Configuring Pipeline")
|
221 |
# Common Args
|