skytnt commited on
Commit
4223034
1 Parent(s): 070027b

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +6 -4
pipeline.py CHANGED
@@ -340,13 +340,15 @@ def get_weighted_text_embeddings(
340
  # assign weights to the prompts and normalize in the sense of mean
341
  # TODO: should we normalize by chunk or in a whole (current implementation)?
342
  if (not skip_parsing) and (not skip_weighting):
343
- previous_mean = text_embeddings.mean(axis=[-2, -1])
344
  text_embeddings *= prompt_weights.unsqueeze(-1)
345
- text_embeddings *= (previous_mean / text_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
 
346
  if uncond_prompt is not None:
347
- previous_mean = uncond_embeddings.mean(axis=[-2, -1])
348
  uncond_embeddings *= uncond_weights.unsqueeze(-1)
349
- uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
 
350
 
351
  if uncond_prompt is not None:
352
  return text_embeddings, uncond_embeddings
 
340
  # assign weights to the prompts and normalize in the sense of mean
341
  # TODO: should we normalize by chunk or in a whole (current implementation)?
342
  if (not skip_parsing) and (not skip_weighting):
343
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
344
  text_embeddings *= prompt_weights.unsqueeze(-1)
345
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
346
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
347
  if uncond_prompt is not None:
348
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
349
  uncond_embeddings *= uncond_weights.unsqueeze(-1)
350
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
351
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
352
 
353
  if uncond_prompt is not None:
354
  return text_embeddings, uncond_embeddings