tolgacangoz commited on
Commit
2cdf78d
1 Parent(s): dc031b3

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +2 -2
matryoshka.py CHANGED
@@ -4002,7 +4002,7 @@ class MatryoshkaPipeline(
4002
  prompt_attention_mask = torch.cat(
4003
  [
4004
  prompt_attention_mask,
4005
- torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long),
4006
  ],
4007
  dim=1,
4008
  )
@@ -4014,7 +4014,7 @@ class MatryoshkaPipeline(
4014
  negative_prompt_attention_mask = torch.cat(
4015
  [
4016
  negative_prompt_attention_mask,
4017
- torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long),
4018
  ],
4019
  dim=1,
4020
  )
 
4002
  prompt_attention_mask = torch.cat(
4003
  [
4004
  prompt_attention_mask,
4005
+ torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device),
4006
  ],
4007
  dim=1,
4008
  )
 
4014
  negative_prompt_attention_mask = torch.cat(
4015
  [
4016
  negative_prompt_attention_mask,
4017
+ torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long, device=device),
4018
  ],
4019
  dim=1,
4020
  )