Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference
Alex Birch commited on
Commit
ec8ea9d
1 Parent(s): 07e555c

prefer NamedTuple

Browse files
Files changed (1) hide show
  1. attention.py +1 -1
attention.py CHANGED
@@ -121,7 +121,7 @@ def scaled_multihead_dot_product_attention(
121
  out = attn_weight.matmul(v)
122
  out = rearrange(out, 'b h s d -> b s (h d)')
123
  if needs_weights:
124
- return (out, attn_weight)
125
  return AttnFnOutput(out, None)
126
 
127
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
 
121
  out = attn_weight.matmul(v)
122
  out = rearrange(out, 'b h s d -> b s (h d)')
123
  if needs_weights:
124
+ return AttnFnOutput(out, attn_weight)
125
  return AttnFnOutput(out, None)
126
 
127
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):