boris commited on
Commit
00d4661
1 Parent(s): b9a1a7d

feat: sinkhorn in lse mode (#155)

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +18 -26
src/dalle_mini/model/modeling.py CHANGED
@@ -187,9 +187,11 @@ def dot_product_attention_weights(
187
  dtype: Any = jnp.float32,
188
  precision: PrecisionLike = None,
189
  sinkhorn_iters: int = 1,
 
190
  ):
191
  """
192
  Computes dot-product attention weights given query and key.
 
193
 
194
  Adapted from flax.linen.attention.dot_product_attention_weights"
195
  """
@@ -207,33 +209,22 @@ def dot_product_attention_weights(
207
  # apply attention bias: masking, dropout, proximity bias, etc.
208
  if bias is not None:
209
  attn_weights = attn_weights + bias
210
- # apply attention mask
211
- if mask is not None:
212
- big_neg = jnp.finfo(dtype).min
213
- attn_weights = jnp.where(mask, attn_weights, big_neg)
214
 
215
  # normalize the attention weights
216
- attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
217
- for i in range(sinkhorn_iters - 1):
218
- # TODO: this is unstable, requires lse space
219
- axis = -2 if i % 2 == 0 else -1
220
- if mask is not None:
221
- attn_weights = jnp.where(
222
- mask > 0,
223
- attn_weights
224
- / (
225
- 1e-5
226
- + jax.lax.stop_gradient(
227
- jnp.sum(attn_weights, axis=axis, where=mask, keepdims=True)
228
- )
229
- ),
230
- 0.0,
231
- )
232
- else:
233
- attn_weights = attn_weights / (
234
- 1e-5
235
- + jax.lax.stop_gradient(jnp.sum(attn_weights, axis=axis, keepdims=True))
236
- )
237
 
238
  # apply attention dropout
239
  if not deterministic and dropout_rate > 0.0:
@@ -392,7 +383,7 @@ class FlaxBartAttention(FlaxBartAttention):
392
  attention_bias = lax.select(
393
  attention_mask > 0,
394
  jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
395
- jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
396
  )
397
  else:
398
  attention_bias = None
@@ -421,6 +412,7 @@ class FlaxBartAttention(FlaxBartAttention):
421
  dtype=self.dtype,
422
  precision=None,
423
  sinkhorn_iters=self.config.sinkhorn_iters,
 
424
  )
425
  if self.config.use_cosine_attention:
426
  # divide by tau
 
187
  dtype: Any = jnp.float32,
188
  precision: PrecisionLike = None,
189
  sinkhorn_iters: int = 1,
190
+ causal: bool = False,
191
  ):
192
  """
193
  Computes dot-product attention weights given query and key.
194
+ mask is included into the bias.
195
 
196
  Adapted from flax.linen.attention.dot_product_attention_weights"
197
  """
 
209
  # apply attention bias: masking, dropout, proximity bias, etc.
210
  if bias is not None:
211
  attn_weights = attn_weights + bias
 
 
 
 
212
 
213
  # normalize the attention weights
214
+ if causal or sinkhorn_iters == 1:
215
+ # sinkhorn does not work for causal (leaks info of future tokens into past)
216
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
217
+ else:
218
+ # adapted from https://github.com/lucidrains/sinkhorn-transformer
219
+ for i in range(sinkhorn_iters):
220
+ # when causal, some attn_weights have been set to -inf through bias
221
+ if i % 2 == 0:
222
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
223
+ else:
224
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
225
+ if mask is not None:
226
+ attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
227
+ attn_weights = jnp.exp(attn_weights).astype(dtype)
 
 
 
 
 
 
 
228
 
229
  # apply attention dropout
230
  if not deterministic and dropout_rate > 0.0:
 
383
  attention_bias = lax.select(
384
  attention_mask > 0,
385
  jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
386
+ jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
387
  )
388
  else:
389
  attention_bias = None
 
412
  dtype=self.dtype,
413
  precision=None,
414
  sinkhorn_iters=self.config.sinkhorn_iters,
415
+ causal=self.causal,
416
  )
417
  if self.config.use_cosine_attention:
418
  # divide by tau