boris commited on
Commit
eed4896
1 Parent(s): 361a994

fix: sinkformer gradient

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +19 -1
src/dalle_mini/model/modeling.py CHANGED
@@ -215,8 +215,25 @@ def dot_product_attention_weights(
215
  # normalize the attention weights
216
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
217
  for i in range(sinkhorn_iters - 1):
 
218
  axis = -2 if i % 2 == 0 else -1
219
- attn_weights /= 1e-8 + jnp.sum(attn_weights, axis=axis, keepdims=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  # apply attention dropout
222
  if not deterministic and dropout_rate > 0.0:
@@ -396,6 +413,7 @@ class FlaxBartAttention(FlaxBartAttention):
396
  query_states,
397
  key_states,
398
  bias=attention_bias,
 
399
  dropout_rng=dropout_rng,
400
  dropout_rate=self.dropout,
401
  broadcast_dropout=True,
 
215
  # normalize the attention weights
216
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
217
  for i in range(sinkhorn_iters - 1):
218
+ # TODO: this is unstable, requires lse space
219
  axis = -2 if i % 2 == 0 else -1
220
+ if mask is not None:
221
+ attn_weights = jnp.where(
222
+ mask > 0,
223
+ attn_weights
224
+ / (
225
+ 1e-5
226
+ + jax.lax.stop_gradient(
227
+ jnp.sum(attn_weights, axis=axis, where=mask, keepdims=True)
228
+ )
229
+ ),
230
+ 0.0,
231
+ )
232
+ else:
233
+ attn_weights = attn_weights / (
234
+ 1e-5
235
+ + jax.lax.stop_gradient(jnp.sum(attn_weights, axis=axis, keepdims=True))
236
+ )
237
 
238
  # apply attention dropout
239
  if not deterministic and dropout_rate > 0.0:
 
413
  query_states,
414
  key_states,
415
  bias=attention_bias,
416
+ mask=attention_mask,
417
  dropout_rng=dropout_rng,
418
  dropout_rate=self.dropout,
419
  broadcast_dropout=True,