Update pipeline.py
Browse files- pipeline.py +6 -4
pipeline.py
CHANGED
@@ -340,13 +340,15 @@ def get_weighted_text_embeddings(
|
|
340 |
# assign weights to the prompts and normalize in the sense of mean
|
341 |
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
342 |
if (not skip_parsing) and (not skip_weighting):
|
343 |
-
previous_mean = text_embeddings.mean(axis=[-2, -1])
|
344 |
text_embeddings *= prompt_weights.unsqueeze(-1)
|
345 |
-
|
|
|
346 |
if uncond_prompt is not None:
|
347 |
-
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
|
348 |
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
349 |
-
|
|
|
350 |
|
351 |
if uncond_prompt is not None:
|
352 |
return text_embeddings, uncond_embeddings
|
|
|
340 |
# assign weights to the prompts and normalize in the sense of mean
|
341 |
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
342 |
if (not skip_parsing) and (not skip_weighting):
|
343 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
344 |
text_embeddings *= prompt_weights.unsqueeze(-1)
|
345 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
346 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
347 |
if uncond_prompt is not None:
|
348 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
349 |
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
350 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
351 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
352 |
|
353 |
if uncond_prompt is not None:
|
354 |
return text_embeddings, uncond_embeddings
|