boris commited on
Commit
2c583b3
1 Parent(s): a2dcee4

fix: sinkformer

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +3 -3
src/dalle_mini/model/modeling.py CHANGED
@@ -211,7 +211,7 @@ def dot_product_attention_weights(
211
  dtype: Any = jnp.float32,
212
  precision: PrecisionLike = None,
213
  sinkhorn_iters: int = 1,
214
- causal: bool = False,
215
  ):
216
  """
217
  Computes dot-product attention weights given query and key.
@@ -239,7 +239,7 @@ def dot_product_attention_weights(
239
  attn_weights = attn_weights + embed_pos
240
 
241
  # normalize the attention weights
242
- if causal or sinkhorn_iters == 1:
243
  # sinkhorn does not work for causal (leaks info of future tokens into past)
244
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
245
  else:
@@ -461,7 +461,7 @@ class FlaxBartAttention(FlaxBartAttention):
461
  dtype=self.dtype,
462
  precision=None,
463
  sinkhorn_iters=self.config.sinkhorn_iters,
464
- causal=self.causal,
465
  )
466
  if self.config.use_cosine_attention:
467
  # divide by tau
 
211
  dtype: Any = jnp.float32,
212
  precision: PrecisionLike = None,
213
  sinkhorn_iters: int = 1,
214
+ is_encoder: bool = False,
215
  ):
216
  """
217
  Computes dot-product attention weights given query and key.
 
239
  attn_weights = attn_weights + embed_pos
240
 
241
  # normalize the attention weights
242
+ if not is_encoder or sinkhorn_iters == 1:
243
  # sinkhorn does not work for causal (leaks info of future tokens into past)
244
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
245
  else:
 
461
  dtype=self.dtype,
462
  precision=None,
463
  sinkhorn_iters=self.config.sinkhorn_iters,
464
+ is_encoder=self.is_encoder,
465
  )
466
  if self.config.use_cosine_attention:
467
  # divide by tau