Spaces:
Build error
Build error
File size: 36,298 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 |
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for t5x.decoding."""
import functools
from typing import Mapping, Tuple
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import api
from jax.experimental import host_callback as hcb
import jax.numpy as jnp
import numpy as np
from t5x import decoding
EOS_ID = 1
NEG_INF = decoding.NEG_INF
class DecodeTest(parameterized.TestCase):
def test_temperature_sample_uneven_prefix(self):
def token_to_logits(ids, cache):
del ids
del cache
# Always sample id 2 for batch element 0 and id 3 for element 1.
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
inputs = np.array([[0, 5, 7, 1, 0, 0], [0, 6, 1, 0, 0, 0]])
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=0,
initial_index=np.array([3, 2]))
expected = np.array([[5, 7, 1, 2, 2, 2], [6, 1, 3, 3, 3, 3]])
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_no_prefix(self):
batch, max_decode_len = 2, 3
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Always sample id 2 for batch element 0 and id 3 for element 1.
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
inputs = np.zeros((batch, max_decode_len), dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0)
expected = [[2, 2, 2], [3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_prefix(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Always sample id 2 for batch element 0 and id 3 for element 1.
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
# batch element 0 has length 3 prefix and element 1 has length 2.
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0)
expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_with_zero_temperature(self):
batch, max_decode_len = 2, 3
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Use very large logits that are close to one another.
logits = np.array(
[[1700.47, 1700.48, 1700.51, 1700.45], [3.2, 4.8, -5.3, 5.6]],
dtype=np.float32)
return logits, {}
inputs = np.zeros((batch, max_decode_len), dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=4,
temperature=0.0)
expected = [[2, 2, 2], [3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_prefix_ending_with_eos(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Always sample id 2 for batch element 0 and id 3 for element 1.
logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
# batch element 0 has length 4 prefix (including the initial dummy token and
# the last eos) and element 1 has length 3.
inputs = np.array([[0, 5, 6, 1, 0], [0, 8, 1, 0, 0]], dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=1)
expected = [[5, 6, 1, 2, 2], [8, 1, 3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_with_state_callback(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# A distribution with roughly all probability mass in sample id 3
logits = np.array([[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32)
return logits, {}
def state_callback_fn(state):
i, sequences, cache, cur_token, ended, rng, log_prob = state
def callback_fn(current_index_and_sequences):
"""Add EOS token after first time token id 3 has been sampled."""
current_index, sequences = current_index_and_sequences
sequences = np.array(sequences)
for i in range(len(current_index)):
if sequences[i, current_index[i]] == 3:
sequences[i, current_index[i] + 1] = EOS_ID
return sequences
sequences = hcb.call(
callback_fn, (i, sequences),
result_shape=api.ShapeDtypeStruct(sequences.shape, sequences.dtype))
return i, sequences, cache, cur_token, ended, rng, log_prob
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=0,
temperature=0.0,
state_callback_fn=state_callback_fn)
expected = [[5, 6, 7, 3, EOS_ID], [8, 9, 3, EOS_ID, 0]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_with_logit_callback(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# uniform distribution over targets from model
logits = np.array([[-1e7, -1e7, -1e7, -1e7], [-1e7, -1e7, -1e7, -1e7]],
dtype=np.float32)
return logits, {}
def logit_callback_fn(logits, state):
del state # unused
# Rewrite logits to always sample id 2 for batch element 0 and
# id 3 for element 1.
logits[0, 2] = 0
logits[1, 3] = 0
return logits
# batch element 0 has length 3 prefix and element 1 has length 2.
inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32)
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=0,
temperature=0.0,
logit_callback_fn=logit_callback_fn)
expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_temperature_sample_prefix_ending_with_eos_early_stop(self):
batch, max_decode_len = 2, 7
rng0 = jax.random.PRNGKey(0)
ret = [np.array([2, 3]) for _ in range(max_decode_len)]
# Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of
# `decoding._temperature_sample_single_trial`.
ret[3] = np.array([2, 1])
# Sequence 0 outputs EOS=1 when i = 4.
ret[4] = np.array([1, 3])
ret = jax.numpy.array(ret)
def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument
"""Ignores logit and returns only based on the rng_input."""
rng = rng0
k = 0
# Mimic the rng split done in `decoding.sample_loop_body_fn`.
for j in range(max_decode_len):
rng1, rng = jax.random.split(rng)
# We want to sift out `j` for which rng1 == rng_input
# rngs are a pair of ints. So sum the bool and divide by 2.
k += j * (rng1 == rng_input).sum() // 2
# `k` at this point is equal to the while loop variable `i` of the caller.
return ret[k]
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# These values are not used in this test because random.categorical is
# directly mocked.
dummy_logits = np.zeros((batch, 4), dtype=np.float32)
return dummy_logits, {}
inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]],
dtype=np.int32)
with mock.patch.object(jax.random, 'categorical', new=mocked_categorical):
sampled_sequences, _ = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, rng0, topk=0)
expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]]
np.testing.assert_array_equal(expected, sampled_sequences)
def test_greedy_decoding_topk_sample_log_probs(self):
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Sample [2, 3] with probability [0.6, 0.4].
logits = np.array([[-1e7, -1e7, -0.510825624, -0.916290732]],
dtype=np.float32)
return logits, {}
inputs = np.array([[0, 2, 2, 2, 0]], dtype=np.int32)
sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=1,
rescale_log_probs=True)
expected_sequence = [[2, 2, 2, 2, 2]]
expected_log_probs = [0.0]
np.testing.assert_array_equal(expected_sequence, sampled_sequences)
np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs)
inputs = np.array([[0, 2, 2, 3, 0]], dtype=np.int32)
sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial(
inputs, {},
token_to_logits,
EOS_ID,
jax.random.PRNGKey(0),
topk=1,
rescale_log_probs=False)
expected_sequence = [[2, 2, 3, 2, 2]]
expected_log_probs = [-1.02165125]
np.testing.assert_array_equal(expected_sequence, sampled_sequences)
np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs)
def test_temperature_sample_log_prob(self):
batch, max_decode_len = 2, 7
rng0 = jax.random.PRNGKey(0)
ret = [np.array([2, 3]) for _ in range(max_decode_len)]
# Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of
# `decoding._temperature_sample_single_trial`.
ret[3] = np.array([2, 1])
# Sequence 0 outputs EOS=1 when i = 4.
ret[4] = np.array([1, 3])
ret = jax.numpy.array(ret)
# TODO(hwchung): refactor this.
def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument
"""Ignores logit and returns only based on the rng_input."""
rng = rng0
k = 0
# Mimic the rng split done in `decoding.sample_loop_body_fn`.
for j in range(max_decode_len):
rng1, rng = jax.random.split(rng)
# We want to sift out `j` for which rng1 == rng_input
# rngs are a pair of ints. So sum the bool and divide by 2.
k += j * (rng1 == rng_input).sum() // 2
# `k` at this point is equal to the while loop variable `i` of the caller.
return ret[k]
logits = np.random.randn(batch, 4)
token_to_logits = lambda ids, cache: (logits, {})
inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]],
dtype=np.int32)
with mock.patch.object(jax.random, 'categorical', new=mocked_categorical):
sampled_sequences, log_prob = decoding._temperature_sample_single_trial(
inputs, {}, token_to_logits, EOS_ID, rng0, topk=0)
log_probs = jax.nn.log_softmax(logits)
expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]]
expected_log_prob = [
log_probs[0, 2] + log_probs[0, 2] + log_probs[0, 1],
log_probs[1, 3] + log_probs[1, 3] + log_probs[1, 1]
]
expected_log_prob = np.array(expected_log_prob)
np.testing.assert_array_equal(expected, sampled_sequences)
np.testing.assert_allclose(expected_log_prob, log_prob, atol=1e-5)
def test_temperature_sample_num_decodes(self):
num_decodes = 3
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32)
with mock.patch.object(decoding,
'_temperature_sample_single_trial') as mocked:
# expanded_decodes: [batch * num_decodes, max_decode_len]
expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3],
[8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]])
# expanded_log_prob: [batch * num_decodes]
expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9])
mocked.return_value = expanded_decodes, expanded_log_prob
decodes, scores = decoding.temperature_sample(
inputs, {}, mock.Mock(), EOS_ID, rng0, num_decodes=num_decodes)
expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0],
[0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]])
# Test that the actual decode function is called with the expanded values.
np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs)
np.testing.assert_array_equal(decodes,
[[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]],
[[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]])
np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]])
def test_temperature_sample_num_decodes_with_initial_index(self):
num_decodes = 3
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32)
initial_index = np.array([1, 2], dtype=np.int32)
with mock.patch.object(decoding,
'_temperature_sample_single_trial') as mocked:
with mock.patch.object(decoding, 'cache_map') as mocked_cache_map:
# expanded_decodes: [batch * num_decodes, max_decode_len]
expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3],
[8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]])
# expanded_log_prob: [batch * num_decodes]
expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9])
mocked.return_value = expanded_decodes, expanded_log_prob
decodes, scores = decoding.temperature_sample(
inputs, {},
mock.Mock(),
EOS_ID,
rng0,
num_decodes=num_decodes,
initial_index=initial_index)
expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0],
[0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]])
expanded_initial_index = np.array([1, 1, 1, 2, 2, 2], dtype=np.int32)
# Test that the actual decode function is called with the expanded
# values.
np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs)
np.testing.assert_array_equal(mocked.call_args[1]['initial_index'],
expanded_initial_index)
# Test that the function was applied to the index in the cache map
self.assertTrue(mocked_cache_map.call_args[1]['apply_to_index'])
np.testing.assert_array_equal(decodes,
[[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]],
[[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]])
np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]])
@parameterized.named_parameters(
dict(
testcase_name='no_initial_index',
initial_index=None,
expected_calls=6,
),
dict(
testcase_name='initial_index',
initial_index=np.array([1, 2], dtype=np.int32),
expected_calls=4,
),
dict(
testcase_name='lower_initial_index',
initial_index=np.array([1, 1], dtype=np.int32),
expected_calls=5, # we decode 4 tokens out of the prompt
),
)
def test_temperature_sample_max_decode_steps_with_initial_index(
self, initial_index, expected_calls):
max_decode_steps = 4
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]],
dtype=np.int32)
token_to_logits = mock.Mock()
token_to_logits.return_value = (np.array(
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {})
# to unroll while loop
with jax.disable_jit():
decodes, scores = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
initial_index=initial_index,
topk=4,
max_decode_steps=max_decode_steps)
self.assertLen(token_to_logits.call_args_list, expected_calls)
expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0],
[2, 2, 3, 3, 3, 3, 0, 0]])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.], [0.]])
def test_temperature_sample_max_decode_steps_endpad(self):
max_decode_steps = 4
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 2, 2, 2, 2, 0],
[0, 2, 2, 2, 0, 0, 0, 0]],
dtype=np.int32)
initial_index = np.array([1, 6, 0])
token_to_logits = mock.Mock()
token_to_logits.return_value = (np.array(
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]],
dtype=np.float32), {})
# to unroll while loop
with jax.disable_jit():
decodes, scores = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
initial_index=initial_index,
topk=4,
max_decode_steps=max_decode_steps)
# `inputs[2]` starts from index 0. So it requires 3 calls to
# `token_to_logits` to exit the prompt (these generated tokens are
# overridden) and 4 more calls to fill the rest. `inputs[0]` only need 4
# calls. In the last 3 calls, it generates but MUST NOT populate the
# sequences because it is already ended.
self.assertLen(token_to_logits.call_args_list, 7)
expected_output = np.array(
[[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 2, 2, 2, 2, 3, 3],
[2, 2, 2, 3, 3, 3, 3, 0]],
dtype=np.int32)
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_allclose(scores, [[0.], [0.], [0.]])
def test_temperature_sample_max_decode_steps_docstring_ex4(self):
max_decode_steps = 2
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 3, 4, 0, 0, 0, 0, 0]],
dtype=np.int32)
initial_index = np.array([1, 2])
token_to_logits = mock.Mock()
token_to_logits.return_value = (np.array(
[[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {})
# to unroll while loop
with jax.disable_jit():
decodes, _ = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
initial_index=initial_index,
topk=4,
max_decode_steps=max_decode_steps)
self.assertLen(token_to_logits.call_args_list, 2)
expected_output = np.array(
[[2, 2, 2, 0, 0, 0, 0, 0], [3, 4, 3, 3, 0, 0, 0, 0]], dtype=np.int32)
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
def test_temperature_sample_max_decode_steps_hard_limit(self):
max_decode_steps = 10
max_decode_steps_hard_limit = 4
rng0 = jax.random.PRNGKey(0)
inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]],
dtype=np.int32)
token_to_logits = mock.Mock()
token_to_logits.return_value = (np.array(
[[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {})
# to unroll while loop
with jax.disable_jit():
decodes, scores = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
topk=4,
max_decode_steps=max_decode_steps,
max_decode_steps_hard_limit=max_decode_steps_hard_limit)
expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0],
[2, 2, 3, 3, 3, 3, 0, 0]])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.], [0.]])
def test_temperature_sample_topp(self):
rng0 = jax.random.PRNGKey(0)
inputs = np.zeros((1, 20), dtype=np.int32)
token_to_logits = mock.Mock()
# logits correspond to (0.3, 0, 0.1, 0.6)
token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]],
dtype=np.float32), {})
decodes, scores = decoding.temperature_sample(
inputs, {}, token_to_logits, EOS_ID, rng0, topp=0.55,
topk=0) # anything under 0.6 will trigger deterministic decoding.
expected_output = np.array([[3] * 20])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.]])
# temperature is applied first, so the distribution becomes
# (0.27, 0, 0.069, 0.65), so if topp is 0.63, it should become greedy.
decodes, scores = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
temperature=0.8,
topp=0.63,
topk=0)
expected_output = np.array([[3] * 20])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.]])
def test_dynamic_topp_max_decode_steps(self):
rng0 = jax.random.PRNGKey(0)
inputs = np.zeros((1, 20), dtype=np.int32)
token_to_logits = mock.Mock()
# logits correspond to (0.3, 0, 0.1, 0.6)
token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]],
dtype=np.float32), {})
def dynamic_decode_fn(inputs, temperature, topp, max_decode_steps):
return decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
temperature=temperature,
topp=topp,
topk=0,
max_decode_steps=max_decode_steps)
dynamic_decode_fn_jit = jax.jit(dynamic_decode_fn)
decodes, scores = dynamic_decode_fn_jit(inputs, 0.8, 0.63, 10)
expected_output = np.array([[3] * 10 + [0] * 10])
expected_output = jnp.expand_dims(expected_output, 1)
np.testing.assert_array_equal(decodes, expected_output)
np.testing.assert_array_equal(scores, [[0.]])
def test_topp_log_probs(self):
rng0 = jax.random.PRNGKey(0)
inputs = np.zeros((1, 1), dtype=np.int32)
token_to_logits = mock.Mock()
# logits correspond to (0.3, 0, 0.1, 0.6)
token_to_logits.return_value = (np.array([[-1.2, NEG_INF, -2.3, -0.51]],
dtype=np.float32), {})
with jax.disable_jit():
# this lets us see logits after topp and topk are applied
with mock.patch.object(jax.random, 'categorical') as mocked:
mocked.return_value = jnp.array([0], dtype=jnp.int32)
decodes, _ = decoding.temperature_sample(
inputs, {},
token_to_logits,
EOS_ID,
rng0,
temperature=1.4,
topp=0.7,
topk=0)
self.assertLen(token_to_logits.call_args_list, 1)
np.testing.assert_array_equal(decodes, jnp.asarray([[[0]]]))
np.testing.assert_array_almost_equal(
mocked.call_args_list[0][0][1],
jnp.asarray([[-0.85714293, NEG_INF, NEG_INF, -0.36428571]]))
def test_add_beam_dim(self):
x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32)
y = decoding.add_beam_dim(x, beam_size=3)
self.assertEqual(y.shape, (2, 3, 4))
np.testing.assert_array_equal([[[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0]],
[[0, 8, 6, 9], [0, 8, 6, 9], [0, 8, 6, 9]]],
y)
def test_flat_batch_beam_expand(self):
x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32)
np.testing.assert_array_equal(
[[0, 5, 1, 0], [0, 5, 1, 0], [0, 8, 6, 9], [0, 8, 6, 9]],
decoding.flat_batch_beam_expand(x, beam_size=2))
def test_top_k_two_stage(self):
def _test_top_k(batch_size, k):
# Pick sufficiently large seq_len.
seq_len = 2047 * k * batch_size
seq = np.arange(seq_len)
np.random.shuffle(seq)
x = jnp.reshape(seq, (batch_size, int(seq_len / batch_size))).astype(
jnp.float32)
np.testing.assert_almost_equal(
decoding.top_k_two_stage(x, k), jax.lax.top_k(x, k), decimal=5)
# Test small batch cases (batch={1,8}, k=16).
_test_top_k(1, 16)
_test_top_k(8, 16)
# Test large batch cases (batch={9,32}, k=11).
_test_top_k(9, 11)
_test_top_k(32, 11)
def test_cache_map(self):
cache = {
'layers_0': {
'cached_key': jnp.ones([3, 6]),
'cached_values': jnp.ones([3, 6]),
'cache_index': jnp.ones([
3,
]),
},
'layers_1': {
'self_attention': {
'cached_key': jnp.ones([2, 7]),
'cached_values': jnp.ones([5, 8]),
'cache_index': jnp.array(1),
},
'encoder_decoder_attention': {
'cached_key': jnp.ones([10, 12, 2]),
'cached_values': jnp.ones([4, 7, 2]),
'cache_index': jnp.ones([4, 5, 6]),
}
},
}
fn = functools.partial(jnp.add, 4)
gold_cache = {
'layers_0': {
'cached_key': fn(jnp.ones([3, 6])),
'cached_values': fn(jnp.ones([3, 6])),
'cache_index': jnp.ones([
3,
]),
},
'layers_1': {
'self_attention': {
'cached_key': fn(jnp.ones([2, 7])),
'cached_values': fn(jnp.ones([5, 8])),
'cache_index': jnp.array(1),
},
'encoder_decoder_attention': {
'cached_key': fn(jnp.ones([10, 12, 2])),
'cached_values': fn(jnp.ones([4, 7, 2])),
'cache_index': jnp.ones([4, 5, 6]),
}
}
}
jax.tree_multimap(np.testing.assert_array_equal,
decoding.cache_map(fn, cache), gold_cache)
def test_cache_map_with_index(self):
cache = {
'layers_0': {
'cached_key': jnp.ones([3, 6]),
'cached_values': jnp.ones([3, 6]),
'cache_index': jnp.ones([
3,
]),
},
'layers_1': {
'relpos_bias': {
'cached_bias': jnp.ones([1, 5, 3]),
},
'self_attention': {
'cached_key': jnp.ones([2, 7]),
'cached_values': jnp.ones([5, 8]),
'cache_index': jnp.array(1),
},
'encoder_decoder_attention': {
'cached_key': jnp.ones([10, 12, 2]),
'cached_values': jnp.ones([4, 7, 2]),
'cache_index': jnp.ones([4, 5, 6]),
}
},
'position_embedder': {
'position_embedder_index': jnp.array(-1),
},
}
fn = functools.partial(jnp.add, 8)
gold_cache = {
'layers_0': {
'cached_key': fn(jnp.ones([3, 6])),
'cached_values': fn(jnp.ones([3, 6])),
'cache_index': fn(jnp.ones([
3,
])),
},
'layers_1': {
'relpos_bias': {
'cached_bias': jnp.ones([1, 5, 3]),
},
'self_attention': {
'cached_key': fn(jnp.ones([2, 7])),
'cached_values': fn(jnp.ones([5, 8])),
'cache_index': fn(jnp.array(1)),
},
'encoder_decoder_attention': {
'cached_key': fn(jnp.ones([10, 12, 2])),
'cached_values': fn(jnp.ones([4, 7, 2])),
'cache_index': fn(jnp.ones([4, 5, 6])),
}
},
'position_embedder': {
'position_embedder_index': jnp.array(-1),
},
}
jax.tree_multimap(np.testing.assert_array_equal,
decoding.cache_map(fn, cache, apply_to_index=True),
gold_cache)
def test_beam_search(self):
# Toy problem, we have 4 states, A, B, START, END, (plus PAD).
# Scores are given by a first-order Markov model.
batch_size = 2
beam_size = 2
# PAD doesn't matter for this test, but part of the contract for beam_search
# is giving the PAD token id 0.
states = ['PAD', 'A', 'B', 'START-', '-END']
num_states = len(states)
decode_length = 7
# Edge potentials (written inside edges for diagonals):
# 1 -1 1 -1
# A ---- A ---- A ---- A ---- A
# 0 \ -1 \ 1 \ -1 \ 1 0
# START X X X X END
# 0 / -1 / 1 / -1 / 1 0
# B ---- B ---- B ---- B ---- B
# 1 -1 1 -1
# put the above edge potentials in a 3-tensor
ab_edge_potentials = np.asarray([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]],
[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]])
# now we have to add on the START, END states
# and PAD at 0
edge_potentials = np.ones([6, 5, 5]) * NEG_INF
edge_potentials[1:5, 1:3, 1:3] = ab_edge_potentials
# START can go to either A or B for free at t0
edge_potentials[0, 3, 1] = 0
edge_potentials[0, 3, 2] = 0
# either A or B can go to END for free at t5
edge_potentials[5, 1, 4] = 0
edge_potentials[5, 2, 4] = 0
# PAD can go to anything for free (doesn't matter for this test)
edge_potentials[:, 0, :] = 0
edge_potentials = jnp.asarray(edge_potentials)
# at time 0, we start with state=START=3
logits0 = jnp.asarray([NEG_INF, NEG_INF, NEG_INF, 0, NEG_INF])
# add dummy flattened batch x beam dim for broadcasting
logits0 = jnp.expand_dims(logits0, axis=0)
edge_potentials = jnp.expand_dims(edge_potentials, axis=0)
def tokens_to_logits(
token_indices: jnp.ndarray, state_cache: Mapping[str, jnp.ndarray]
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
cur_iter = state_cache['cur_iter']
# grab edge potentials for the current timestep
cur_edge_potentials = jnp.take_along_axis(
edge_potentials,
jnp.reshape(
jnp.maximum(0, cur_iter[:, 0].astype(jnp.int32) - 1),
(batch_size * beam_size, 1, 1, 1)),
axis=1)
cur_edge_potentials = jnp.squeeze(cur_edge_potentials, axis=1)
# get "logits" from edge potentials for requested tokens (except at t0)
cur_logits = jnp.matmul(
jnp.reshape(
jax.nn.one_hot(token_indices, num_states, axis=1),
(batch_size * beam_size, 1, num_states)), cur_edge_potentials)
cur_logits = jnp.squeeze(cur_logits, axis=1)
# use our START-only logits for t0, otherwise use the edge potentials
logits_for_tokens = jnp.where(cur_iter == 0, logits0, cur_logits)
# update state in the cache
new_cache = state_cache.copy()
new_cache['cur_iter'] = cur_iter + 1
return logits_for_tokens, new_cache
init_cache = {}
init_cache['cur_iter'] = jnp.zeros((batch_size, 1))
top_scoring, _ = decoding.beam_search(
inputs=np.zeros([batch_size, decode_length]),
cache=init_cache,
tokens_to_logits=tokens_to_logits,
eos_id=4,
num_decodes=beam_size,
alpha=0.0,
max_decode_len=decode_length)
# The two top scoring sequences should be a tie between
# START-AABBA-END
# and
# START-BBAAB-END
# (and greedy beam search will find both these with just two beams)
top_scoring_strings = [
''.join(states[tok]
for tok in top_scoring[0, i, :])
for i in range(beam_size)
]
expected = ['START-AABBA-END', 'START-BBAAB-END']
np.testing.assert_array_equal(expected, top_scoring_strings)
def test_beam_search_force_decode_prefix(self):
beam_size = 2
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1.
logits = np.repeat(
np.expand_dims(
np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4],
[-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]],
dtype=np.float32),
axis=1), [beam_size],
axis=1)
logits = decoding.flatten_beam_dim(logits)
return logits, {}
# batch element 0 has length 1 and element 1 has length 2.
inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32)
rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32)
beam_search_sequences, decoding_scores = decoding.beam_search(
inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size, alpha=0)
# Prefixes are forced depending on inputs.
# Beam search sequences and corresponding scores are in reverse order.
self.assertTrue(np.all(np.diff(decoding_scores) >= 0))
expected = np.array([[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]],
[[4, 5, 2, 3, 3], [4, 5, 3, 3, 3]]])
np.testing.assert_array_equal(expected, beam_search_sequences)
expected_scores = []
batch_logits = np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4],
[-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]],
dtype=np.float32)
for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs):
beam_expected_scores = []
for beam in batch:
log_probs = jax.nn.log_softmax(logits)
# Add them directly since they are static.
beam_scores = []
for token, prompt_token in zip(beam, prompt):
if prompt_token != 0:
beam_scores.append(0)
else:
beam_scores.append(log_probs[token])
beam_expected_scores.append(sum(beam_scores))
expected_scores.append(beam_expected_scores)
np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5)
def test_beam_search_force_decode_no_prefix(self):
beam_size = 2
def token_to_logits(ids, cache): # pylint: disable=unused-argument
# Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1.
logits = np.repeat(
np.expand_dims(
np.array([[-1e7, -1e10, -0.1, -0.9], [-1e7, -1e10, -0.9, -0.1]],
dtype=np.float32),
axis=1), [beam_size],
axis=1)
logits = decoding.flatten_beam_dim(logits)
return logits, {}
# No prefix is passed.
inputs = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=np.int32)
beam_search_sequences, decoding_scores = decoding.beam_search(
inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size)
# Prefixes are forced depending on inputs.
# Beam search sequences and corresponding scores are in reverse order.
self.assertTrue(np.all(np.diff(decoding_scores) >= 0))
expected = np.array([[[3, 2, 2, 2, 2], [2, 2, 2, 2, 2]],
[[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]])
np.testing.assert_array_equal(expected, beam_search_sequences)
if __name__ == '__main__':
absltest.main()
|