File size: 21,675 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for attention classes."""

import dataclasses
from typing import Optional
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from flax import linen as nn
from flax.core import freeze
from flax.linen import partitioning as nn_partitioning
import jax
from jax import random
from jax.nn import initializers
import jax.numpy as jnp
from mt3 import layers
import numpy as np

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()

Array = jnp.ndarray
AxisMetadata = nn_partitioning.AxisMetadata  # pylint: disable=invalid-name


class SelfAttention(layers.MultiHeadDotProductAttention):
  """Self-attention special case of multi-head dot-product attention."""

  @nn.compact
  def __call__(self,
               inputs_q: Array,
               mask: Optional[Array] = None,
               bias: Optional[Array] = None,
               deterministic: bool = False):
    return super().__call__(
        inputs_q, inputs_q, mask, bias, deterministic=deterministic)


@dataclasses.dataclass(frozen=True)
class SelfAttentionArgs:
  num_heads: int = 1
  batch_size: int = 2
  # qkv_features: int = 3
  head_dim: int = 3
  # out_features: int = 4
  q_len: int = 5
  features: int = 6
  dropout_rate: float = 0.1
  deterministic: bool = False
  decode: bool = False
  float32_logits: bool = False

  def __post_init__(self):
    # If we are doing decoding, the query length should be 1, because are doing
    # autoregressive decoding where we feed one position at a time.
    assert not self.decode or self.q_len == 1

  def init_args(self):
    return dict(
        num_heads=self.num_heads,
        head_dim=self.head_dim,
        dropout_rate=self.dropout_rate,
        float32_logits=self.float32_logits)

  def apply_args(self):
    inputs_q = jnp.ones((self.batch_size, self.q_len, self.features))
    mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len))
    bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len))
    return {
        'inputs_q': inputs_q,
        'mask': mask,
        'bias': bias,
        'deterministic': self.deterministic
    }


class AttentionTest(parameterized.TestCase):

  def test_dot_product_attention_shape(self):
    # This test only checks for shape but tries to make sure all code paths are
    # reached.
    dropout_rng = random.PRNGKey(0)
    batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6

    query = jnp.ones((batch_size, q_len, num_heads, qk_depth))
    key = jnp.ones((batch_size, kv_len, num_heads, qk_depth))
    value = jnp.ones((batch_size, kv_len, num_heads, v_depth))
    bias = jnp.ones((batch_size, num_heads, q_len, kv_len))

    args = dict(
        query=query,
        key=key,
        value=value,
        bias=bias,
        dropout_rng=dropout_rng,
        dropout_rate=0.5,
        deterministic=False,
    )

    output = layers.dot_product_attention(**args)
    self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth))

  def test_make_attention_mask_multiply_pairwise_fn(self):
    decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]])
    attention_mask = layers.make_attention_mask(
        decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32)
    expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]])
    expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]])
    self.assertEqual(attention_mask.shape, (2, 1, 3, 3))
    np.testing.assert_array_equal(attention_mask[0, 0], expected0)
    np.testing.assert_array_equal(attention_mask[1, 0], expected1)

  def test_make_attention_mask_equal_pairwise_fn(self):
    segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]])
    attention_mask = layers.make_attention_mask(
        segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32)
    # Padding is not treated in a special way. So they need to be zeroed out
    # separately.
    expected0 = jnp.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
                           [0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 1, 0],
                           [0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]])
    expected1 = jnp.array([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0],
                           [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0],
                           [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1]])
    self.assertEqual(attention_mask.shape, (2, 1, 6, 6))
    np.testing.assert_array_equal(attention_mask[0, 0], expected0)
    np.testing.assert_array_equal(attention_mask[1, 0], expected1)

  def test_make_causal_mask_with_padding(self):
    x = jnp.array([[7, 0, 0], [8, 5, 0]])
    y = layers.make_causal_mask(x)
    self.assertEqual(y.shape, (2, 1, 3, 3))
    # Padding is not treated in a special way. So they need to be zeroed out
    # separately.
    expected_y = jnp.array([[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]],
                           jnp.float32)
    np.testing.assert_allclose(y[0], expected_y)
    np.testing.assert_allclose(y[1], expected_y)

  def test_make_causal_mask_extra_batch_dims(self):
    x = jnp.ones((3, 3, 5))
    y = layers.make_causal_mask(x, extra_batch_dims=2)
    self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5))

  def test_make_causal_mask(self):
    x = jnp.ones((1, 3))
    y = layers.make_causal_mask(x)
    self.assertEqual(y.shape, (1, 1, 3, 3))
    expected_y = jnp.array([[[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]]],
                           jnp.float32)
    np.testing.assert_allclose(y, expected_y)

  def test_combine_masks(self):
    masks = [
        jnp.array([0, 1, 0, 1], jnp.float32), None,
        jnp.array([1, 1, 1, 1], jnp.float32),
        jnp.array([1, 1, 1, 0], jnp.float32)
    ]
    y = layers.combine_masks(*masks)
    np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32))

  def test_combine_biases(self):
    masks = [
        jnp.array([0, 1, 0, 1], jnp.float32), None,
        jnp.array([0, 1, 1, 1], jnp.float32),
        jnp.array([0, 1, 1, 0], jnp.float32)
    ]
    y = layers.combine_biases(*masks)
    np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32))

  def test_make_decoder_mask_lm_unpacked(self):
    decoder_target_tokens = jnp.array([6, 7, 3, 0])
    mask = layers.make_decoder_mask(
        decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32)
    expected_mask = jnp.array([[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0],
                                [0, 0, 0, 0]]])
    np.testing.assert_array_equal(mask, expected_mask)

  def test_make_decoder_mask_lm_packed(self):
    decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]])
    decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]])
    mask = layers.make_decoder_mask(
        decoder_target_tokens=decoder_target_tokens,
        dtype=jnp.float32,
        decoder_segment_ids=decoder_segment_ids)
    expected_mask = jnp.array([[[[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0],
                                 [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0],
                                 [0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]])
    np.testing.assert_array_equal(mask, expected_mask)

  def test_make_decoder_mask_prefix_lm_unpacked(self):
    decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]])
    decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]])
    mask = layers.make_decoder_mask(
        decoder_target_tokens=decoder_target_tokens,
        dtype=jnp.float32,
        decoder_causal_attention=decoder_causal_attention)
    expected_mask = jnp.array(
        [[[[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0],
           [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]],
        dtype=jnp.float32)
    np.testing.assert_array_equal(mask, expected_mask)

  def test_make_decoder_mask_prefix_lm_packed(self):
    decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]])
    decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]])
    decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]])
    mask = layers.make_decoder_mask(
        decoder_target_tokens=decoder_target_tokens,
        dtype=jnp.float32,
        decoder_causal_attention=decoder_causal_attention,
        decoder_segment_ids=decoder_segment_ids)
    expected_mask = jnp.array([[[[1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0],
                                 [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0],
                                 [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0],
                                 [0, 0, 0, 0, 0, 0, 0]]]])
    np.testing.assert_array_equal(mask, expected_mask)

  def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self):
    decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]])
    decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]])
    mask = layers.make_decoder_mask(
        decoder_target_tokens=decoder_target_tokens,
        dtype=jnp.float32,
        decoder_causal_attention=decoder_causal_attention)
    expected_mask0 = jnp.array([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0],
                                [0, 0, 0, 0]])
    expected_mask1 = jnp.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0],
                                [0, 0, 0, 0]])
    self.assertEqual(mask.shape, (2, 1, 4, 4))
    np.testing.assert_array_equal(mask[0, 0], expected_mask0)
    np.testing.assert_array_equal(mask[1, 0], expected_mask1)

  def test_make_decoder_mask_composite_causal_attention(self):
    decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]])
    decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]])
    mask = layers.make_decoder_mask(
        decoder_target_tokens=decoder_target_tokens,
        dtype=jnp.float32,
        decoder_causal_attention=decoder_causal_attention)
    expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0],
                                [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0],
                                [1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0],
                                [0, 0, 0, 0, 0, 0, 0]])

    self.assertEqual(mask.shape, (1, 1, 7, 7))
    np.testing.assert_array_equal(mask[0, 0], expected_mask0)

  def test_make_decoder_mask_composite_causal_attention_packed(self):
    decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]])
    decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]])
    decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]])
    mask = layers.make_decoder_mask(
        decoder_target_tokens=decoder_target_tokens,
        dtype=jnp.float32,
        decoder_causal_attention=decoder_causal_attention,
        decoder_segment_ids=decoder_segment_ids)
    expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0, 0, 0],
                                [1, 1, 0, 0, 1, 1, 0, 0, 0],
                                [1, 1, 1, 0, 0, 0, 0, 0, 0],
                                [1, 1, 1, 1, 0, 0, 0, 0, 0],
                                [1, 1, 1, 1, 1, 1, 0, 0, 0],
                                [1, 1, 1, 1, 1, 1, 0, 0, 0],
                                [0, 0, 0, 0, 0, 0, 1, 1, 0],
                                [0, 0, 0, 0, 0, 0, 1, 1, 0],
                                [0, 0, 0, 0, 0, 0, 1, 1, 1]])

    self.assertEqual(mask.shape, (1, 1, 9, 9))
    np.testing.assert_array_equal(mask[0, 0], expected_mask0)

  @parameterized.parameters({'f': 20}, {'f': 22})
  def test_multihead_dot_product_attention(self, f):
    # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim
    b, q, h, d, k = 2, 3, 4, 5, 6

    base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0)
    args = base_args.init_args()

    np.random.seed(0)
    inputs_q = np.random.randn(b, q, f)
    inputs_kv = np.random.randn(b, k, f)

    # Projection: [b, q, f] -> [b, q, h, d]
    # So the kernels have to be [f, h, d]
    query_kernel = np.random.randn(f, h, d)
    key_kernel = np.random.randn(f, h, d)
    value_kernel = np.random.randn(f, h, d)
    # `out` calculation: [b, q, h, d] -> [b, q, f]
    # So kernel has to be [h, d, f]
    out_kernel = np.random.randn(h, d, f)

    params = {
        'query': {
            'kernel': query_kernel.reshape(f, -1)
        },
        'key': {
            'kernel': key_kernel.reshape(f, -1)
        },
        'value': {
            'kernel': value_kernel.reshape(f, -1)
        },
        'out': {
            'kernel': out_kernel.reshape(-1, f)
        }
    }
    y = layers.MultiHeadDotProductAttention(**args).apply(
        {'params': freeze(params)}, inputs_q, inputs_kv)

    query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel)
    key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel)
    value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel)
    logits = np.einsum('bqhd,bkhd->bhqk', query, key)
    weights = nn.softmax(logits, axis=-1)
    combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value)
    y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel)
    np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5)

  def test_multihead_dot_product_attention_caching(self):
    # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim
    b, h, d, k = 2, 3, 4, 5
    f = h * d

    base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0)
    args = base_args.init_args()

    cache = {
        'cached_key': np.zeros((b, h, d, k)),
        'cached_value': np.zeros((b, h, d, k)),
        'cache_index': np.array(0)
    }
    inputs_q = np.random.randn(b, 1, f)
    inputs_kv = np.random.randn(b, 1, f)

    # Mock dense general such that q, k, v projections are replaced by simple
    # reshaping.
    def mock_dense_general(self, x, **kwargs):  # pylint: disable=unused-argument
      return x.reshape(b, -1, h, d)

    with mock.patch.object(
        layers.DenseGeneral, '__call__', new=mock_dense_general):
      _, mutated = layers.MultiHeadDotProductAttention(**args).apply(
          {'cache': freeze(cache)},
          inputs_q,
          inputs_kv,
          decode=True,
          mutable=['cache'])
      updated_cache = mutated['cache']

    # Perform the same mocked projection to generate the expected cache.
    # (key|value): [b, 1, h, d]
    key = mock_dense_general(None, inputs_kv)
    value = mock_dense_general(None, inputs_kv)

    # cached_(key|value): [b, h, d, k]
    cache['cached_key'][:, :, :, 0] = key[:, 0, :, :]
    cache['cached_value'][:, :, :, 0] = value[:, 0, :, :]
    cache['cache_index'] = np.array(1)
    for name, array in cache.items():
      np.testing.assert_allclose(array, updated_cache[name])

  def test_dot_product_attention(self):
    # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim
    b, q, h, d, k = 2, 3, 4, 5, 6
    np.random.seed(0)
    query = np.random.randn(b, q, h, d)
    key = np.random.randn(b, k, h, d)
    value = np.random.randn(b, k, h, d)
    bias = np.random.randn(b, h, q, k)
    attn_out = layers.dot_product_attention(query, key, value, bias=bias)
    logits = np.einsum('bqhd,bkhd->bhqk', query, key)
    weights = jax.nn.softmax(logits + bias, axis=-1)
    expected = np.einsum('bhqk,bkhd->bqhd', weights, value)
    np.testing.assert_allclose(attn_out, expected, atol=1e-6)


class EmbeddingTest(parameterized.TestCase):

  def test_embedder_raises_exception_for_incorrect_input_type(self):
    """Tests that inputs are integers and that an exception is raised if not."""
    embed = layers.Embed(num_embeddings=10, features=5)
    inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1)
    variables = embed.init(jax.random.PRNGKey(0), inputs)
    bad_inputs = inputs.astype(np.float32)
    with self.assertRaisesRegex(
        ValueError, 'Input type must be an integer or unsigned integer.'):
      _ = embed.apply(variables, bad_inputs)

  @parameterized.named_parameters(
      {
          'testcase_name': 'with_ones',
          'init_fn': jax.nn.initializers.ones,
          'num_embeddings': 10,
          'features': 5,
          'matrix_sum': 5 * 10,
      }, {
          'testcase_name': 'with_zeros',
          'init_fn': jax.nn.initializers.zeros,
          'num_embeddings': 10,
          'features': 5,
          'matrix_sum': 0,
      })
  def test_embedding_initializes_correctly(self, init_fn, num_embeddings,
                                           features, matrix_sum):
    """Tests if the Embed class initializes with the requested initializer."""
    embed = layers.Embed(
        num_embeddings=num_embeddings,
        features=features,
        embedding_init=init_fn)
    inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1)
    variables = embed.init(jax.random.PRNGKey(0), inputs)
    embedding_matrix = variables['params']['embedding']
    self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum)

  def test_embedding_matrix_shape(self):
    """Tests that the embedding matrix has the right shape."""
    num_embeddings = 10
    features = 5
    embed = layers.Embed(num_embeddings=num_embeddings, features=features)
    inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1)
    variables = embed.init(jax.random.PRNGKey(0), inputs)
    embedding_matrix = variables['params']['embedding']
    self.assertEqual((num_embeddings, features), embedding_matrix.shape)

  def test_embedding_attend(self):
    """Tests that attending with ones returns sum of embedding vectors."""
    features = 5
    embed = layers.Embed(num_embeddings=10, features=features)
    inputs = np.array([[1]], dtype=np.int64)
    variables = embed.init(jax.random.PRNGKey(0), inputs)
    query = np.ones(features, dtype=np.float32)
    result = embed.apply(variables, query, method=embed.attend)
    expected = np.sum(variables['params']['embedding'], -1)
    np.testing.assert_array_almost_equal(result, expected)


class DenseTest(parameterized.TestCase):

  def test_dense_general_no_bias(self):
    rng = random.PRNGKey(0)
    x = jnp.ones((1, 3))
    model = layers.DenseGeneral(
        features=4,
        kernel_init=initializers.ones,
    )
    y, _ = model.init_with_output(rng, x)
    self.assertEqual(y.shape, (1, 4))
    np.testing.assert_allclose(y, np.full((1, 4), 3.))

  def test_dense_general_two_features(self):
    rng = random.PRNGKey(0)
    x = jnp.ones((1, 3))
    model = layers.DenseGeneral(
        features=(2, 2),
        kernel_init=initializers.ones,
    )
    y, _ = model.init_with_output(rng, x)
    # We transform the last input dimension to two output dimensions (2, 2).
    np.testing.assert_allclose(y, np.full((1, 2, 2), 3.))

  def test_dense_general_two_axes(self):
    rng = random.PRNGKey(0)
    x = jnp.ones((1, 2, 2))
    model = layers.DenseGeneral(
        features=3,
        axis=(-2, 2),  # Note: this is the same as (1, 2).
        kernel_init=initializers.ones,
    )
    y, _ = model.init_with_output(rng, x)
    # We transform the last two input dimensions (2, 2) to one output dimension.
    np.testing.assert_allclose(y, np.full((1, 3), 4.))

  def test_mlp_same_out_dim(self):
    module = layers.MlpBlock(
        intermediate_dim=4,
        activations=('relu',),
        kernel_init=nn.initializers.xavier_uniform(),
        dtype=jnp.float32,
    )
    inputs = np.array(
        [
            # Batch 1.
            [[1, 1], [1, 1], [1, 2]],
            # Batch 2.
            [[2, 2], [3, 1], [2, 2]],
        ],
        dtype=np.float32)
    params = module.init(random.PRNGKey(0), inputs, deterministic=True)
    self.assertEqual(
        jax.tree_map(lambda a: a.tolist(), params), {
            'params': {
                'wi': {
                    'kernel': [[
                        -0.8675811290740967, 0.08417510986328125,
                        0.022586345672607422, -0.9124102592468262
                    ],
                               [
                                   -0.19464373588562012, 0.49809837341308594,
                                   0.7808468341827393, 0.9267289638519287
                               ]],
                },
                'wo': {
                    'kernel': [[0.01154780387878418, 0.1397249698638916],
                               [0.974980354309082, 0.5903260707855225],
                               [-0.05997943878173828, 0.616570234298706],
                               [0.2934272289276123, 0.8181164264678955]],
                },
            },
            'params_axes': {
                'wi': {
                    'kernel_axes': AxisMetadata(names=('embed', 'mlp')),
                },
                'wo': {
                    'kernel_axes': AxisMetadata(names=('mlp', 'embed')),
                },
            },
        })
    result = module.apply(params, inputs, deterministic=True)
    np.testing.assert_allclose(
        result.tolist(),
        [[[0.5237172245979309, 0.8508185744285583],
          [0.5237172245979309, 0.8508185744285583],
          [1.2344461679458618, 2.3844780921936035]],
         [[1.0474344491958618, 1.7016371488571167],
          [0.6809444427490234, 0.9663378596305847],
          [1.0474344491958618, 1.7016371488571167]]],
        rtol=1e-6,
    )


if __name__ == '__main__':
  absltest.main()