hussain2030 commited on
Commit
87bede3
1 Parent(s): 5fc6a9a

Update modeling_jais.py

Browse files

torch.zeros to torch.empty

Files changed (1) hide show
  1. modeling_jais.py +1 -1
modeling_jais.py CHANGED
@@ -268,7 +268,7 @@ class JAISAttention(nn.Module):
268
  _, _, k_seq_len, _ = key.size()
269
 
270
  # Preallocate attn_weights for `baddbmm`
271
- attn_weights = torch.zeros(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
272
 
273
  # Compute Scale Factor
274
  scale_factor = 1.0
 
268
  _, _, k_seq_len, _ = key.size()
269
 
270
  # Preallocate attn_weights for `baddbmm`
271
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
272
 
273
  # Compute Scale Factor
274
  scale_factor = 1.0