diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a4a3ea2424c09fbe48d455aed1eaa94d9124835 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index 8ca4c5a54f5bb66badb176f682aa5204edc323df..8f3f63f4e6250207c55acd09838bd1f27ee63984 100644 --- a/README.md +++ b/README.md @@ -10,5 +10,3 @@ app_file: app.py pinned: false license: apache-2.0 --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/mt3/__init__.py b/mt3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bfc2b18cf373ea06a7d667dc33e7e8e64b97566 --- /dev/null +++ b/mt3/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base module for MT3.""" + +from mt3 import datasets +from mt3 import event_codec +from mt3 import inference +from mt3 import layers +from mt3 import metrics +from mt3 import metrics_utils +from mt3 import models +from mt3 import network +from mt3 import note_sequences +from mt3 import preprocessors +from mt3 import run_length_encoding +from mt3 import spectrograms +from mt3 import summaries +from mt3 import tasks +from mt3 import vocabularies + +from mt3.version import __version__ diff --git a/mt3/datasets.py b/mt3/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..dea031f97025c6dd60c2abce039287ee3e6e95eb --- /dev/null +++ b/mt3/datasets.py @@ -0,0 +1,325 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset configurations.""" + +import dataclasses +from typing import Mapping, Sequence, Union + +from mt3 import note_sequences +import tensorflow as tf + + + +@dataclasses.dataclass +class InferEvalSplit: + # key in dictionary containing all dataset splits + name: str + # task name suffix (each eval split is a separate task) + suffix: str + # whether or not to include in the mixture of all eval tasks + include_in_mixture: bool = True + + +@dataclasses.dataclass +class DatasetConfig: + """Configuration for a transcription dataset.""" + # dataset name + name: str + # mapping from split name to path + paths: Mapping[str, str] + # mapping from feature name to feature + features: Mapping[str, Union[tf.io.FixedLenFeature, + tf.io.FixedLenSequenceFeature]] + # training split name + train_split: str + # training eval split name + train_eval_split: str + # list of infer eval split specs + infer_eval_splits: Sequence[InferEvalSplit] + # list of track specs to be used for metrics + track_specs: Sequence[note_sequences.TrackSpec] = dataclasses.field( + default_factory=list) + +MAESTROV1_CONFIG = DatasetConfig( + name='maestrov1', + paths={ + 'train': + 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_train.tfrecord-?????-of-00010', + 'train_subset': + 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_train.tfrecord-00002-of-00010', + 'validation': + 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_validation.tfrecord-?????-of-00010', + 'validation_subset': + 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_validation.tfrecord-0000[06]-of-00010', + 'test': + 'gs://magentadata/datasets/maestro/v1.0.0/maestro-v1.0.0_ns_wav_test.tfrecord-?????-of-00010' + }, + features={ + 'audio': tf.io.FixedLenFeature([], dtype=tf.string), + 'sequence': tf.io.FixedLenFeature([], dtype=tf.string), + 'id': tf.io.FixedLenFeature([], dtype=tf.string) + }, + train_split='train', + train_eval_split='validation_subset', + infer_eval_splits=[ + InferEvalSplit(name='train', suffix='eval_train_full', + include_in_mixture=False), + InferEvalSplit(name='train_subset', suffix='eval_train'), + InferEvalSplit(name='validation', suffix='validation_full', + include_in_mixture=False), + InferEvalSplit(name='validation_subset', suffix='validation'), + InferEvalSplit(name='test', suffix='test', include_in_mixture=False) + ]) + + +MAESTROV3_CONFIG = DatasetConfig( + name='maestrov3', + paths={ + 'train': + 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_train.tfrecord-?????-of-00025', + 'train_subset': + 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_train.tfrecord-00004-of-00025', + 'validation': + 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_validation.tfrecord-?????-of-00025', + 'validation_subset': + 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_validation.tfrecord-0002?-of-00025', + 'test': + 'gs://magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0_ns_wav_test.tfrecord-?????-of-00025' + }, + features={ + 'audio': tf.io.FixedLenFeature([], dtype=tf.string), + 'sequence': tf.io.FixedLenFeature([], dtype=tf.string), + 'id': tf.io.FixedLenFeature([], dtype=tf.string) + }, + train_split='train', + train_eval_split='validation_subset', + infer_eval_splits=[ + InferEvalSplit(name='train', suffix='eval_train_full', + include_in_mixture=False), + InferEvalSplit(name='train_subset', suffix='eval_train'), + InferEvalSplit(name='validation', suffix='validation_full', + include_in_mixture=False), + InferEvalSplit(name='validation_subset', suffix='validation'), + InferEvalSplit(name='test', suffix='test', include_in_mixture=False) + ]) + + +GUITARSET_CONFIG = DatasetConfig( + name='guitarset', + paths={ + 'train': + 'gs://mt3/data/datasets/guitarset/train.tfrecord-?????-of-00019', + 'validation': + 'gs://mt3/data/datasets/guitarset/validation.tfrecord-?????-of-00006', + }, + features={ + 'sequence': tf.io.FixedLenFeature([], dtype=tf.string), + 'audio': tf.io.FixedLenFeature([], dtype=tf.string), + 'velocity_range': tf.io.FixedLenFeature([], dtype=tf.string), + 'id': tf.io.FixedLenFeature([], dtype=tf.string), + }, + train_split='train', + train_eval_split='validation', + infer_eval_splits=[ + InferEvalSplit(name='train', suffix='eval_train'), + InferEvalSplit(name='validation', suffix='validation'), + ]) + + +URMP_CONFIG = DatasetConfig( + name='urmp', + paths={ + 'train': 'gs://mt3/data/datasets/urmp/train.tfrecord', + 'validation': 'gs://mt3/data/datasets/urmp/validation.tfrecord', + }, + features={ + 'id': tf.io.FixedLenFeature([], dtype=tf.string), + 'tracks': tf.io.FixedLenSequenceFeature( + [], dtype=tf.int64, allow_missing=True), + 'inst_names': tf.io.FixedLenSequenceFeature( + [], dtype=tf.string, allow_missing=True), + 'audio': tf.io.FixedLenFeature([], dtype=tf.string), + 'sequence': tf.io.FixedLenFeature([], dtype=tf.string), + 'instrument_sequences': tf.io.FixedLenSequenceFeature( + [], dtype=tf.string, allow_missing=True), + }, + train_split='train', + train_eval_split='validation', + infer_eval_splits=[ + InferEvalSplit(name='train', suffix='eval_train'), + InferEvalSplit(name='validation', suffix='validation') + ]) + + +MUSICNET_CONFIG = DatasetConfig( + name='musicnet', + paths={ + 'train': + 'gs://mt3/data/datasets/musicnet/musicnet-train.tfrecord-?????-of-00036', + 'validation': + 'gs://mt3/data/datasets/musicnet/musicnet-validation.tfrecord-?????-of-00005', + 'test': + 'gs://mt3/data/datasets/musicnet/musicnet-test.tfrecord-?????-of-00003' + }, + features={ + 'id': tf.io.FixedLenFeature([], dtype=tf.string), + 'sample_rate': tf.io.FixedLenFeature([], dtype=tf.float32), + 'audio': tf.io.FixedLenSequenceFeature( + [], dtype=tf.float32, allow_missing=True), + 'sequence': tf.io.FixedLenFeature([], dtype=tf.string) + }, + train_split='train', + train_eval_split='validation', + infer_eval_splits=[ + InferEvalSplit(name='train', suffix='eval_train'), + InferEvalSplit(name='validation', suffix='validation'), + InferEvalSplit(name='test', suffix='test', include_in_mixture=False) + ]) + + +MUSICNET_EM_CONFIG = DatasetConfig( + name='musicnet_em', + paths={ + 'train': + 'gs://mt3/data/datasets/musicnet_em/train.tfrecord-?????-of-00103', + 'validation': + 'gs://mt3/data/datasets/musicnet_em/validation.tfrecord-?????-of-00005', + 'test': + 'gs://mt3/data/datasets/musicnet_em/test.tfrecord-?????-of-00006' + }, + features={ + 'id': tf.io.FixedLenFeature([], dtype=tf.string), + 'sample_rate': tf.io.FixedLenFeature([], dtype=tf.float32), + 'audio': tf.io.FixedLenSequenceFeature( + [], dtype=tf.float32, allow_missing=True), + 'sequence': tf.io.FixedLenFeature([], dtype=tf.string) + }, + train_split='train', + train_eval_split='validation', + infer_eval_splits=[ + InferEvalSplit(name='train', suffix='eval_train'), + InferEvalSplit(name='validation', suffix='validation'), + InferEvalSplit(name='test', suffix='test', include_in_mixture=False) + ]) + + +CERBERUS4_CONFIG = DatasetConfig( + name='cerberus4', + paths={ + 'train': + 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_train_bass:drums:guitar:piano.tfrecord-?????-of-00286', + 'train_subset': + 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_train_bass:drums:guitar:piano.tfrecord-00000-of-00286', + 'validation': + 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_validation_bass:drums:guitar:piano.tfrecord-?????-of-00212', + 'validation_subset': + 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_validation_bass:drums:guitar:piano.tfrecord-0000?-of-00212', + 'test': + 'gs://mt3/data/datasets/cerberus4/slakh_multi_cerberus_test_bass:drums:guitar:piano.tfrecord-?????-of-00106' + }, + features={ + 'audio_sample_rate': tf.io.FixedLenFeature([], dtype=tf.int64), + 'inst_names': tf.io.FixedLenSequenceFeature( + [], dtype=tf.string, allow_missing=True), + 'midi_class': tf.io.FixedLenSequenceFeature( + [], dtype=tf.int64, allow_missing=True), + 'mix': tf.io.FixedLenSequenceFeature( + [], dtype=tf.float32, allow_missing=True), + 'note_sequences': tf.io.FixedLenSequenceFeature( + [], dtype=tf.string, allow_missing=True), + 'plugin_name': tf.io.FixedLenSequenceFeature( + [], dtype=tf.int64, allow_missing=True), + 'program_num': tf.io.FixedLenSequenceFeature( + [], dtype=tf.int64, allow_missing=True), + 'slakh_class': tf.io.FixedLenSequenceFeature( + [], dtype=tf.int64, allow_missing=True), + 'src_ids': tf.io.FixedLenSequenceFeature( + [], dtype=tf.string, allow_missing=True), + 'stems': tf.io.FixedLenSequenceFeature( + [], dtype=tf.float32, allow_missing=True), + 'stems_shape': tf.io.FixedLenFeature([2], dtype=tf.int64), + 'target_type': tf.io.FixedLenFeature([], dtype=tf.string), + 'track_id': tf.io.FixedLenFeature([], dtype=tf.string), + }, + train_split='train', + train_eval_split='validation_subset', + infer_eval_splits=[ + InferEvalSplit(name='train', suffix='eval_train_full', + include_in_mixture=False), + InferEvalSplit(name='train_subset', suffix='eval_train'), + InferEvalSplit(name='validation', suffix='validation_full', + include_in_mixture=False), + InferEvalSplit(name='validation_subset', suffix='validation'), + InferEvalSplit(name='test', suffix='test', include_in_mixture=False) + ], + track_specs=[ + note_sequences.TrackSpec('bass', program=32), + note_sequences.TrackSpec('drums', is_drum=True), + note_sequences.TrackSpec('guitar', program=24), + note_sequences.TrackSpec('piano', program=0) + ]) + + +SLAKH_CONFIG = DatasetConfig( + name='slakh', + paths={ + 'train': + 'gs://mt3/data/datasets/slakh/slakh_multi_full_subsets_10_train_all_inst.tfrecord-?????-of-02307', + 'train_subset': + 'gs://mt3/data/datasets/slakh/slakh_multi_full_subsets_10_train_all_inst.tfrecord-00000-of-02307', + 'validation': + 'gs://mt3/data/datasets/slakh/slakh_multi_full_validation_all_inst.tfrecord-?????-of-00168', + 'validation_subset': + 'gs://mt3/data/datasets/slakh/slakh_multi_full_validation_all_inst.tfrecord-0000?-of-00168', + 'test': + 'gs://mt3/data/datasets/slakh/slakh_multi_full_test_all_inst.tfrecord-?????-of-00109' + }, + features={ + 'audio_sample_rate': tf.io.FixedLenFeature([], dtype=tf.int64), + 'inst_names': tf.io.FixedLenSequenceFeature([], dtype=tf.string, + allow_missing=True), + 'midi_class': tf.io.FixedLenSequenceFeature([], dtype=tf.int64, + allow_missing=True), + 'mix': tf.io.FixedLenSequenceFeature([], dtype=tf.float32, + allow_missing=True), + 'note_sequences': tf.io.FixedLenSequenceFeature([], dtype=tf.string, + allow_missing=True), + 'plugin_name': tf.io.FixedLenSequenceFeature([], dtype=tf.int64, + allow_missing=True), + 'program_num': tf.io.FixedLenSequenceFeature([], dtype=tf.int64, + allow_missing=True), + 'slakh_class': tf.io.FixedLenSequenceFeature([], dtype=tf.int64, + allow_missing=True), + 'src_ids': tf.io.FixedLenSequenceFeature([], dtype=tf.string, + allow_missing=True), + 'stems': tf.io.FixedLenSequenceFeature([], dtype=tf.float32, + allow_missing=True), + 'stems_shape': tf.io.FixedLenFeature([2], dtype=tf.int64), + 'target_type': tf.io.FixedLenFeature([], dtype=tf.string), + 'track_id': tf.io.FixedLenFeature([], dtype=tf.string), + }, + train_split='train', + train_eval_split='validation_subset', + infer_eval_splits=[ + InferEvalSplit(name='train', suffix='eval_train_full', + include_in_mixture=False), + InferEvalSplit(name='train_subset', suffix='eval_train'), + InferEvalSplit(name='validation', suffix='validation_full', + include_in_mixture=False), + InferEvalSplit(name='validation_subset', suffix='validation'), + InferEvalSplit(name='test', suffix='test', include_in_mixture=False) + ]) + + diff --git a/mt3/event_codec.py b/mt3/event_codec.py new file mode 100644 index 0000000000000000000000000000000000000000..4486f6b7b298341cf1899d832c07bbf0fc83e5d3 --- /dev/null +++ b/mt3/event_codec.py @@ -0,0 +1,112 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Encode and decode events.""" + +import dataclasses +from typing import List, Tuple + + +@dataclasses.dataclass +class EventRange: + type: str + min_value: int + max_value: int + + +@dataclasses.dataclass +class Event: + type: str + value: int + + +class Codec: + """Encode and decode events. + + Useful for declaring what certain ranges of a vocabulary should be used for. + This is intended to be used from Python before encoding or after decoding with + GenericTokenVocabulary. This class is more lightweight and does not include + things like EOS or UNK token handling. + + To ensure that 'shift' events are always the first block of the vocab and + start at 0, that event type is required and specified separately. + """ + + def __init__(self, max_shift_steps: int, steps_per_second: float, + event_ranges: List[EventRange]): + """Define Codec. + + Args: + max_shift_steps: Maximum number of shift steps that can be encoded. + steps_per_second: Shift steps will be interpreted as having a duration of + 1 / steps_per_second. + event_ranges: Other supported event types and their ranges. + """ + self.steps_per_second = steps_per_second + self._shift_range = EventRange( + type='shift', min_value=0, max_value=max_shift_steps) + self._event_ranges = [self._shift_range] + event_ranges + # Ensure all event types have unique names. + assert len(self._event_ranges) == len( + set([er.type for er in self._event_ranges])) + + @property + def num_classes(self) -> int: + return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) + + # The next couple methods are simplified special case methods just for shift + # events that are intended to be used from within autograph functions. + + def is_shift_event_index(self, index: int) -> bool: + return (self._shift_range.min_value <= index) and ( + index <= self._shift_range.max_value) + + @property + def max_shift_steps(self) -> int: + return self._shift_range.max_value + + def encode_event(self, event: Event) -> int: + """Encode an event to an index.""" + offset = 0 + for er in self._event_ranges: + if event.type == er.type: + if not er.min_value <= event.value <= er.max_value: + raise ValueError( + f'Event value {event.value} is not within valid range ' + f'[{er.min_value}, {er.max_value}] for type {event.type}') + return offset + event.value - er.min_value + offset += er.max_value - er.min_value + 1 + + raise ValueError(f'Unknown event type: {event.type}') + + def event_type_range(self, event_type: str) -> Tuple[int, int]: + """Return [min_id, max_id] for an event type.""" + offset = 0 + for er in self._event_ranges: + if event_type == er.type: + return offset, offset + (er.max_value - er.min_value) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f'Unknown event type: {event_type}') + + def decode_event_index(self, index: int) -> Event: + """Decode an event index to an Event.""" + offset = 0 + for er in self._event_ranges: + if offset <= index <= offset + er.max_value - er.min_value: + return Event( + type=er.type, value=er.min_value + index - offset) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f'Unknown event index: {index}') diff --git a/mt3/event_codec_test.py b/mt3/event_codec_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3d88269b39da933402100f27f651cf3c800ac9da --- /dev/null +++ b/mt3/event_codec_test.py @@ -0,0 +1,55 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for event_codec.""" + +from absl.testing import absltest +from mt3 import event_codec + +Event = event_codec.Event +EventRange = event_codec.EventRange + + +class EventCodecTest(absltest.TestCase): + + def test_encode_decode(self): + ec = event_codec.Codec( + max_shift_steps=100, + steps_per_second=100, + event_ranges=[EventRange('pitch', min_value=0, max_value=127)]) + events = [ + Event(type='pitch', value=60), + Event(type='shift', value=5), + Event(type='pitch', value=62), + ] + encoded = [ec.encode_event(e) for e in events] + self.assertSequenceEqual([161, 5, 163], encoded) + + decoded = [ec.decode_event_index(idx) for idx in encoded] + self.assertSequenceEqual(events, decoded) + + def test_shift_steps(self): + ec = event_codec.Codec( + max_shift_steps=100, + steps_per_second=100, + event_ranges=[EventRange('pitch', min_value=0, max_value=127)]) + + self.assertEqual(100, ec.max_shift_steps) + self.assertFalse(ec.is_shift_event_index(-1)) + self.assertTrue(ec.is_shift_event_index(0)) + self.assertTrue(ec.is_shift_event_index(100)) + self.assertFalse(ec.is_shift_event_index(101)) + +if __name__ == '__main__': + absltest.main() diff --git a/mt3/gin/eval.gin b/mt3/gin/eval.gin new file mode 100644 index 0000000000000000000000000000000000000000..0671bbdf81cf5b306b2273bddf8becd815d034f4 --- /dev/null +++ b/mt3/gin/eval.gin @@ -0,0 +1,72 @@ +# Defaults for eval.py. +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - TASK_PREFIX +# - TASK_FEATURE_LENGTHS +# - CHECKPOINT_PATH +# - EVAL_OUTPUT_DIR +# +# Commonly overridden options: +# +# - DatasetConfig.split +# - DatasetConfig.batch_size +# - DatasetConfig.use_cached +# - RestoreCheckpointConfig.mode +# - PjitPartitioner.num_partitions + +from __gin__ import dynamic_registration + +import __main__ as eval_script +from mt3 import preprocessors +from mt3 import tasks +from mt3 import vocabularies +from t5x import partitioning +from t5x import utils + +# Must be overridden +TASK_PREFIX = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +CHECKPOINT_PATH = %gin.REQUIRED +EVAL_OUTPUT_DIR = %gin.REQUIRED + +# Number of velocity bins: set to 1 (no velocity) or 127 +NUM_VELOCITY_BINS = %gin.REQUIRED +VOCAB_CONFIG = @vocabularies.VocabularyConfig() +vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS + +# Program granularity: set to 'flat', 'midi_class', or 'full' +PROGRAM_GRANULARITY = %gin.REQUIRED +preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY + +TASK_SUFFIX = 'test' +tasks.construct_task_name: + task_prefix = %TASK_PREFIX + vocab_config = %VOCAB_CONFIG + task_suffix = %TASK_SUFFIX + +eval_script.evaluate: + model = %MODEL # imported from separate gin file + dataset_cfg = @utils.DatasetConfig() + partitioner = @partitioning.PjitPartitioner() + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + output_dir = %EVAL_OUTPUT_DIR + +utils.DatasetConfig: + mixture_or_task_name = @tasks.construct_task_name() + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'eval' + batch_size = 32 + shuffle = False + seed = 42 + use_cached = True + pack = False + use_custom_packing_ops = False + +partitioning.PjitPartitioner.num_partitions = 1 + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' diff --git a/mt3/gin/infer.gin b/mt3/gin/infer.gin new file mode 100644 index 0000000000000000000000000000000000000000..9e5b478a54a64518d15bbd9f188fa68aec93b0de --- /dev/null +++ b/mt3/gin/infer.gin @@ -0,0 +1,92 @@ +# Defaults for infer.py. +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - TASK_PREFIX +# - TASK_FEATURE_LENGTHS +# - CHECKPOINT_PATH +# - INFER_OUTPUT_DIR +# +# Commonly overridden options: +# +# - infer.mode +# - infer.checkpoint_period +# - infer.shard_id +# - infer.num_shards +# - DatasetConfig.split +# - DatasetConfig.batch_size +# - DatasetConfig.use_cached +# - RestoreCheckpointConfig.is_tensorflow +# - RestoreCheckpointConfig.mode +# - PjitPartitioner.num_partitions + +from __gin__ import dynamic_registration + +import __main__ as infer_script +from mt3 import inference +from mt3 import preprocessors +from mt3 import tasks +from mt3 import vocabularies +from t5x import partitioning +from t5x import utils + +# Must be overridden +TASK_PREFIX = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +CHECKPOINT_PATH = %gin.REQUIRED +INFER_OUTPUT_DIR = %gin.REQUIRED + +# Number of velocity bins: set to 1 (no velocity) or 127 +NUM_VELOCITY_BINS = %gin.REQUIRED +VOCAB_CONFIG = @vocabularies.VocabularyConfig() +vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS + +# Program granularity: set to 'flat', 'midi_class', or 'full' +PROGRAM_GRANULARITY = %gin.REQUIRED +preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY + +TASK_SUFFIX = 'test' +tasks.construct_task_name: + task_prefix = %TASK_PREFIX + vocab_config = %VOCAB_CONFIG + task_suffix = %TASK_SUFFIX + +ONSETS_ONLY = %gin.REQUIRED +USE_TIES = %gin.REQUIRED +inference.write_inferences_to_file: + vocab_config = %VOCAB_CONFIG + onsets_only = %ONSETS_ONLY + use_ties = %USE_TIES + +infer_script.infer: + mode = 'predict' + model = %MODEL # imported from separate gin file + output_dir = %INFER_OUTPUT_DIR + dataset_cfg = @utils.DatasetConfig() + partitioner = @partitioning.PjitPartitioner() + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + # This is a hack, but pass an extremely large value here to make sure the + # entire dataset fits in a single epoch. Otherwise, segments from a single + # example may end up in different epochs after splitting. + checkpoint_period = 1000000 + shard_id = 0 + num_shards = 1 + write_fn = @inference.write_inferences_to_file + +utils.DatasetConfig: + mixture_or_task_name = @tasks.construct_task_name() + task_feature_lengths = %TASK_FEATURE_LENGTHS + use_cached = True + split = 'eval' + batch_size = 32 + shuffle = False + seed = 0 + pack = False + +partitioning.PjitPartitioner.num_partitions = 1 + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' diff --git a/mt3/gin/ismir2021.gin b/mt3/gin/ismir2021.gin new file mode 100644 index 0000000000000000000000000000000000000000..7e39b75862dd051f3212766730ef9977185cc72f --- /dev/null +++ b/mt3/gin/ismir2021.gin @@ -0,0 +1,9 @@ +# Configuration for ISMIR 2021 piano-only model. + +TASK_PREFIX = 'maestrov3_notes' +TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 1024} +TRAIN_STEPS = 400000 +NUM_VELOCITY_BINS = 127 +PROGRAM_GRANULARITY = 'flat' +ONSETS_ONLY = False +USE_TIES = False diff --git a/mt3/gin/ismir2022/base.gin b/mt3/gin/ismir2022/base.gin new file mode 100644 index 0000000000000000000000000000000000000000..e6d2c25edf73405f53a81b675aaaec0219c5a61d --- /dev/null +++ b/mt3/gin/ismir2022/base.gin @@ -0,0 +1,10 @@ +# T5.1.1 Base model. +include 'model.gin' + +network.T5Config: + emb_dim = 768 + num_heads = 12 + num_encoder_layers = 12 + num_decoder_layers = 12 + head_dim = 64 + mlp_dim = 2048 \ No newline at end of file diff --git a/mt3/gin/ismir2022/finetune.gin b/mt3/gin/ismir2022/finetune.gin new file mode 100644 index 0000000000000000000000000000000000000000..de4d46ed98079d7de34edc2d6111be44b286bd27 --- /dev/null +++ b/mt3/gin/ismir2022/finetune.gin @@ -0,0 +1,25 @@ +from __gin__ import dynamic_registration + +from mt3 import network +from t5x import utils + +include 'train.gin' + +TASK_PREFIX = 'mega_notes_ties' +TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024} +TRAIN_STEPS = 150000 +BATCH_SIZE = 256 +LABEL_SMOOTHING = 0.0 +NUM_VELOCITY_BINS = 1 +PROGRAM_GRANULARITY = 'full' +ONSETS_ONLY = False +USE_TIES = True +MAX_EXAMPLES_PER_MIX = None + +network.T5Config.dropout_rate = 0.1 + +CHECKPOINT_PATH = %gin.REQUIRED +utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig() +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' \ No newline at end of file diff --git a/mt3/gin/ismir2022/pretrain.gin b/mt3/gin/ismir2022/pretrain.gin new file mode 100644 index 0000000000000000000000000000000000000000..47ddc24452dd4dc6c0d95a4d9bade19b48eee77e --- /dev/null +++ b/mt3/gin/ismir2022/pretrain.gin @@ -0,0 +1,13 @@ +include 'train.gin' + +TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024} +TRAIN_STEPS = 500000 +BATCH_SIZE = 1024 +LABEL_SMOOTHING = 0.1 +NUM_VELOCITY_BINS = 1 +PROGRAM_GRANULARITY = 'full' +ONSETS_ONLY = False +USE_TIES = True +MAX_EXAMPLES_PER_MIX = 8 + +network.T5Config.dropout_rate = 0.0 diff --git a/mt3/gin/ismir2022/small.gin b/mt3/gin/ismir2022/small.gin new file mode 100644 index 0000000000000000000000000000000000000000..ef289b4684a5524a89f9a5becd40ca3d6e41d474 --- /dev/null +++ b/mt3/gin/ismir2022/small.gin @@ -0,0 +1,2 @@ +# T5.1.1 Small model. +include 'model.gin' diff --git a/mt3/gin/local_tiny.gin b/mt3/gin/local_tiny.gin new file mode 100644 index 0000000000000000000000000000000000000000..2533dce23681df9b89f2a8354f154b1d1b846a9c --- /dev/null +++ b/mt3/gin/local_tiny.gin @@ -0,0 +1,63 @@ +# A gin file to make the Transformer models tiny for faster local testing. +# +# When testing locally with CPU, there are a few things that we need. +# - tiny model size +# - small enough batch size +# - small sequence length +# - determinstic dataset pipeline +# +# This gin file adds such configs. To use this gin file, add it on top of the +# existing full-scale gin files. The ordering of the gin file matters. So this +# should be added after all the other files are added to override the same +# configurables. + +from __gin__ import dynamic_registration + +from t5x import partitioning +from t5x import trainer +from t5x import utils +from t5x.examples.t5 import network + +import __main__ as train_script + +train_script.train.random_seed = 42 # dropout seed +train/utils.DatasetConfig.seed = 42 # dataset seed + +TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 16} +LABEL_SMOOTHING = 0.0 + +# Network specification overrides +network.Transformer.config = @network.T5Config() +network.T5Config: + dtype = 'float32' + emb_dim = 8 + num_heads = 4 + num_encoder_layers = 2 + num_decoder_layers = 2 + head_dim = 3 + mlp_dim = 16 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.0 + logits_via_embedding = False + +TRAIN_STEPS = 3 + +train/utils.DatasetConfig: + batch_size = 8 + shuffle = False + +train_eval/utils.DatasetConfig.batch_size = 8 + +train_script.train: + eval_period = 3 + eval_steps = 3 + +trainer.Trainer.num_microbatches = 0 +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + +utils.CheckpointConfig: + restore = None + +infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS diff --git a/mt3/gin/model.gin b/mt3/gin/model.gin new file mode 100644 index 0000000000000000000000000000000000000000..e5d18efdadb2f4c61a5ae28650eb86eca4d282c2 --- /dev/null +++ b/mt3/gin/model.gin @@ -0,0 +1,60 @@ +# T5.1.1 Small model. +from __gin__ import dynamic_registration + +from mt3 import models +from mt3 import network +from mt3 import spectrograms +from mt3 import vocabularies +import seqio +from t5x import adafactor + +# ------------------- Loss HParam ---------------------------------------------- +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +LOSS_NORMALIZING_FACTOR = None +models.ContinuousInputsEncoderDecoderModel: + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# Output vocabulary +VOCAB_CONFIG = %gin.REQUIRED +OUTPUT_VOCABULARY = @vocabularies.vocabulary_from_codec() +vocabularies.vocabulary_from_codec.codec = @vocabularies.build_codec() +vocabularies.build_codec.vocab_config = %VOCAB_CONFIG + +# ------------------- Optimizer ------------------------------------------------ +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + logical_factor_rules = @adafactor.standard_logical_factor_rules() + +# ------------------- Model ---------------------------------------------------- +SPECTROGRAM_CONFIG = @spectrograms.SpectrogramConfig() +MODEL = @models.ContinuousInputsEncoderDecoderModel() +models.ContinuousInputsEncoderDecoderModel: + module = @network.Transformer() + input_vocabulary = @seqio.vocabularies.PassThroughVocabulary() + output_vocabulary = %OUTPUT_VOCABULARY + optimizer_def = %OPTIMIZER + input_depth = @spectrograms.input_depth() +seqio.vocabularies.PassThroughVocabulary.size = 0 +spectrograms.input_depth.spectrogram_config = %SPECTROGRAM_CONFIG + +# ------------------- Network specification ------------------------------------ +network.Transformer.config = @network.T5Config() +network.T5Config: + vocab_size = @vocabularies.num_embeddings() + dtype = 'float32' + emb_dim = 512 + num_heads = 6 + num_encoder_layers = 8 + num_decoder_layers = 8 + head_dim = 64 + mlp_dim = 1024 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.1 + logits_via_embedding = False +vocabularies.num_embeddings.vocabulary = %OUTPUT_VOCABULARY diff --git a/mt3/gin/mt3.gin b/mt3/gin/mt3.gin new file mode 100644 index 0000000000000000000000000000000000000000..e7b44f4913a5c4a51310ab2c208e311d8639f1f2 --- /dev/null +++ b/mt3/gin/mt3.gin @@ -0,0 +1,9 @@ +# Configuration for MT3 multi-task multitrack model. + +TASK_PREFIX = 'mega_notes_ties' +TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 1024} +TRAIN_STEPS = 1000000 +NUM_VELOCITY_BINS = 1 +PROGRAM_GRANULARITY = 'full' +ONSETS_ONLY = False +USE_TIES = True diff --git a/mt3/gin/train.gin b/mt3/gin/train.gin new file mode 100644 index 0000000000000000000000000000000000000000..ffcbb6823fae8bf2f3e7fd40752b575a34693b20 --- /dev/null +++ b/mt3/gin/train.gin @@ -0,0 +1,148 @@ +# Defaults for training with train.py. +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - TASK_PREFIX +# - TASK_FEATURE_LENGTHS +# - TRAIN_STEPS +# - MODEL_DIR +# +# Commonly overridden options: +# - BATCH_SIZE +# - PjitPartitioner.num_partitions +# - Trainer.num_microbatches +# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess +# on the fly. + +from __gin__ import dynamic_registration + +import __main__ as train_script +import seqio +from mt3 import mixing +from mt3 import preprocessors +from mt3 import tasks +from mt3 import vocabularies +from t5x import gin_utils +from t5x import partitioning +from t5x import utils +from t5x import trainer + +# Must be overridden +TASK_PREFIX = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +TRAIN_STEPS = %gin.REQUIRED +MODEL_DIR = %gin.REQUIRED + +# Commonly overridden +TRAIN_TASK_SUFFIX = 'train' +EVAL_TASK_SUFFIX = 'eval' +USE_CACHED_TASKS = True +BATCH_SIZE = 256 + +# Sometimes overridden +EVAL_STEPS = 20 + +# Convenience overrides. +EVALUATOR_USE_MEMORY_CACHE = True +EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. +JSON_WRITE_N_RESULTS = 0 # Don't write any inferences. + +# Number of velocity bins: set to 1 (no velocity) or 127 +NUM_VELOCITY_BINS = %gin.REQUIRED +VOCAB_CONFIG = @vocabularies.VocabularyConfig() +vocabularies.VocabularyConfig.num_velocity_bins = %NUM_VELOCITY_BINS + +# Program granularity: set to 'flat', 'midi_class', or 'full' +PROGRAM_GRANULARITY = %gin.REQUIRED +preprocessors.map_midi_programs.granularity_type = %PROGRAM_GRANULARITY + +# Maximum number of examples per mix, or None for no mixing +MAX_EXAMPLES_PER_MIX = None +mixing.mix_transcription_examples.max_examples_per_mix = %MAX_EXAMPLES_PER_MIX + +train/tasks.construct_task_name: + task_prefix = %TASK_PREFIX + vocab_config = %VOCAB_CONFIG + task_suffix = %TRAIN_TASK_SUFFIX + +eval/tasks.construct_task_name: + task_prefix = %TASK_PREFIX + vocab_config = %VOCAB_CONFIG + task_suffix = %EVAL_TASK_SUFFIX + +train_script.train: + model = %MODEL # imported from separate gin file + model_dir = %MODEL_DIR + train_dataset_cfg = @train/utils.DatasetConfig() + train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() + infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() + checkpoint_cfg = @utils.CheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + trainer_cls = @trainer.Trainer + total_steps = %TRAIN_STEPS + eval_steps = %EVAL_STEPS + eval_period = 5000 + random_seed = None # use faster, hardware RNG + summarize_config_fn = @gin_utils.summarize_gin_config + inference_evaluator_cls = @seqio.Evaluator + +seqio.Evaluator: + logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] + num_examples = %EVALUATOR_NUM_EXAMPLES + use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE + +seqio.JSONLogger: + write_n_results = %JSON_WRITE_N_RESULTS + +train/utils.DatasetConfig: + mixture_or_task_name = @train/tasks.construct_task_name() + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'train' + batch_size = %BATCH_SIZE + shuffle = True + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS + pack = False + +train_eval/utils.DatasetConfig: + mixture_or_task_name = @train/tasks.construct_task_name() + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'eval' + batch_size = %BATCH_SIZE + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS + pack = False + +infer_eval/utils.DatasetConfig: + mixture_or_task_name = @eval/tasks.construct_task_name() + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'eval' + batch_size = %BATCH_SIZE + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS + pack = False + +utils.CheckpointConfig: + restore = None + save = @utils.SaveCheckpointConfig() +utils.SaveCheckpointConfig: + period = 5000 + dtype = 'float32' + keep = None # keep all checkpoints + save_dataset = False # don't checkpoint dataset state + +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + +trainer.Trainer: + num_microbatches = None + learning_rate_fn = @utils.create_learning_rate_scheduler() +utils.create_learning_rate_scheduler: + factors = 'constant' + base_learning_rate = 0.001 + warmup_steps = 1000 diff --git a/mt3/inference.py b/mt3/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f63b0353f63be3162144bd70a381d7c36aad8097 --- /dev/null +++ b/mt3/inference.py @@ -0,0 +1,138 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for MT3 inference.""" + +import functools +import json + +from typing import Any, Optional, Sequence + +import gin + +from mt3 import metrics_utils +from mt3 import note_sequences +from mt3 import tasks +from mt3 import vocabularies + +import note_seq +import seqio +import tensorflow as tf + + +def write_inferences_to_file( + path: str, + inferences: Sequence[Any], + task_ds: tf.data.Dataset, + mode: str, + vocabulary: Optional[seqio.Vocabulary] = None, + vocab_config=gin.REQUIRED, + onsets_only=gin.REQUIRED, + use_ties=gin.REQUIRED) -> None: + """Writes model predictions, ground truth transcriptions, and input audio. + + For now this only works for transcription tasks with ties. + + Args: + path: File path to write to. + inferences: Model inferences, output of predict_batch. + task_ds: Original task dataset. + mode: Prediction mode; must be 'predict' as 'score' is not supported. + vocabulary: Task output vocabulary. + vocab_config: Vocabulary config object. + onsets_only: If True, only predict onsets. + use_ties: If True, use "tie" representation. + """ + if mode == 'score': + raise ValueError('`score` mode currently not supported in MT3') + if not vocabulary: + raise ValueError('`vocabulary` parameter required in `predict` mode') + + if onsets_only and use_ties: + raise ValueError('ties not compatible with onset-only transcription') + if onsets_only: + encoding_spec = note_sequences.NoteOnsetEncodingSpec + elif not use_ties: + encoding_spec = note_sequences.NoteEncodingSpec + else: + encoding_spec = note_sequences.NoteEncodingWithTiesSpec + + codec = vocabularies.build_codec(vocab_config) + + targets = [] + predictions = [] + + for inp, output in zip(task_ds.as_numpy_iterator(), inferences): + tokens = tasks.trim_eos(vocabulary.decode_tf(output).numpy()) + + start_time = inp['input_times'][0] + # Round down to nearest symbolic token step. + start_time -= start_time % (1 / codec.steps_per_second) + + targets.append({ + 'unique_id': inp['unique_id'][0], + 'ref_ns': inp['sequence'][0] if inp['sequence'][0] else None, + }) + + predictions.append({ + 'unique_id': inp['unique_id'][0], + 'est_tokens': tokens, + 'start_time': start_time, + # Input audio is not part of the "prediction" but the below call to + # metrics_utils.event_predictions_to_ns handles the concatenation. + 'raw_inputs': inp['raw_inputs'] + }) + + # The first target for each full example contains the NoteSequence; just + # organize by ID. + full_targets = {} + for target in targets: + if target['ref_ns']: + full_targets[target['unique_id']] = { + 'ref_ns': note_seq.NoteSequence.FromString(target['ref_ns']) + } + + full_predictions = metrics_utils.combine_predictions_by_id( + predictions=predictions, + combine_predictions_fn=functools.partial( + metrics_utils.event_predictions_to_ns, + codec=codec, + encoding_spec=encoding_spec)) + + assert sorted(full_targets.keys()) == sorted(full_predictions.keys()) + + full_target_prediction_pairs = [ + (full_targets[id], full_predictions[id]) + for id in sorted(full_targets.keys()) + ] + + def note_to_dict(note): + return { + 'start_time': note.start_time, + 'end_time': note.end_time, + 'pitch': note.pitch, + 'velocity': note.velocity, + 'program': note.program, + 'is_drum': note.is_drum + } + + with tf.io.gfile.GFile(path, 'w') as f: + for target, prediction in full_target_prediction_pairs: + json_dict = { + 'id': target['ref_ns'].id, + 'est_notes': + [note_to_dict(note) for note in prediction['est_ns'].notes] + } + json_str = json.dumps(json_dict, cls=seqio.TensorAndNumpyEncoder) + f.write(json_str + '\n') diff --git a/mt3/layers.py b/mt3/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..25ff3f4de7b8b13254fb4baabb9ef7c0a30f3c81 --- /dev/null +++ b/mt3/layers.py @@ -0,0 +1,830 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dense attention classes and mask/weighting functions.""" + +# pylint: disable=attribute-defined-outside-init,g-bare-generic + +import dataclasses +import functools +import operator +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union + +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +import jax +from jax import lax +from jax import random +import jax.numpy as jnp +import numpy as np + + +# from flax.linen.partitioning import param_with_axes, with_sharding_constraint +param_with_axes = nn_partitioning.param_with_axes +with_sharding_constraint = nn_partitioning.with_sharding_constraint + + +# Type annotations +Array = jnp.ndarray +DType = jnp.dtype +PRNGKey = jnp.ndarray +Shape = Iterable[int] +Activation = Callable[..., Array] +# Parameter initializers. +Initializer = Callable[[PRNGKey, Shape, DType], Array] + +default_embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0) + + +def sinusoidal(min_scale: float = 1.0, + max_scale: float = 10000.0, + dtype: DType = jnp.float32) -> Initializer: + """Creates 1D Sinusoidal Position Embedding Initializer. + + Args: + min_scale: Minimum frequency-scale in sine grating. + max_scale: Maximum frequency-scale in sine grating. + dtype: The DType of the returned values. + + Returns: + The sinusoidal initialization function. + """ + + def init(key: PRNGKey, shape: Shape, dtype: DType = dtype) -> Array: + """Sinusoidal init.""" + del key + if dtype != np.float32: + raise ValueError('The sinusoidal initializer only supports float32.') + if len(list(shape)) != 2: + raise ValueError( + f'Expected a 2D shape (max_len, features), but got {shape}.') + max_len, features = shape + pe = np.zeros((max_len, features), dtype=dtype) + position = np.arange(0, max_len)[:, np.newaxis] + scale_factor = -np.log(max_scale / min_scale) / (features // 2 - 1) + div_term = min_scale * np.exp(np.arange(0, features // 2) * scale_factor) + pe[:, :features // 2] = np.sin(position * div_term) + pe[:, features // 2:2 * (features // 2)] = np.cos(position * div_term) + return jnp.array(pe) + + return init + + +def dot_product_attention(query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0., + deterministic: bool = False, + dtype: DType = jnp.float32, + float32_logits: bool = False): + """Computes dot-product attention given query, key, and value. + + This is the core function for applying attention based on + https://arxiv.org/abs/1706.03762. It calculates the attention weights given + query and key and combines the values using the attention weights. + + Args: + query: queries for calculating attention with shape of `[batch, q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch, kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch, kv_length, + num_heads, v_depth_per_head]`. + bias: bias for the attention weights. This should be broadcastable to the + shape `[batch, num_heads, q_length, kv_length]` This can be used for + incorporating causal masks, padding masks, proximity bias, etc. + dropout_rng: JAX PRNGKey: to be used for dropout + dropout_rate: dropout rate + deterministic: bool, deterministic or not (to apply dropout) + dtype: the dtype of the computation (default: float32) + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + + Returns: + Output of shape `[batch, length, num_heads, v_depth_per_head]`. + """ + assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( + 'q, k, v batch dims must match.') + assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( + 'q, k, v num_heads must match.') + assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' + assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + + # Casting logits and softmax computation for float32 for model stability. + if float32_logits: + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) + + # `attn_weights`: [batch, num_heads, q_length, kv_length] + attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + + # Apply attention bias: masking, dropout, proximity bias, etc. + if bias is not None: + attn_weights = attn_weights + bias.astype(attn_weights.dtype) + + # Normalize the attention weights across `kv_length` dimension. + attn_weights = jax.nn.softmax(attn_weights).astype(dtype) + + # Apply attention dropout. + if not deterministic and dropout_rate > 0.: + keep_prob = 1.0 - dropout_rate + # T5 broadcasts along the "length" dim, but unclear which one that + # corresponds to in positional dimensions here, assuming query dim. + dropout_shape = list(attn_weights.shape) + dropout_shape[-2] = 1 + keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) + keep = jnp.broadcast_to(keep, attn_weights.shape) + multiplier = ( + keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) + attn_weights = attn_weights * multiplier + + # Take the linear combination of `value`. + return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) + + +dynamic_vector_slice_in_dim = jax.vmap( + lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) + + +class MultiHeadDotProductAttention(nn.Module): + """Multi-head dot-product attention. + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + """ + + num_heads: int + head_dim: int + dtype: DType = jnp.float32 + dropout_rate: float = 0. + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal') + float32_logits: bool = False # computes logits in float32 for stability. + + @nn.compact + def __call__(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + decode: bool = False, + deterministic: bool = False) -> Array: + """Applies multi-head dot product attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. + + There are two modes: decoding and non-decoding (e.g., training). The mode is + determined by `decode` argument. For decoding, this method is called twice, + first to initialize the cache and then for an actual decoding process. The + two calls are differentiated by the presence of 'cached_key' in the variable + dict. In the cache initialization stage, the cache variables are initialized + as zeros and will be filled in the subsequent decoding process. + + In the cache initialization call, `inputs_q` has a shape [batch, length, + q_features] and `inputs_kv`: [batch, length, kv_features]. During the + incremental decoding stage, query, key and value all have the shape [batch, + 1, qkv_features] corresponding to a single step. + + Args: + inputs_q: input queries of shape `[batch, q_length, q_features]`. + inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. + mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. + bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. + decode: Whether to prepare and use an autoregressive cache. + deterministic: Disables dropout if set to True. + + Returns: + output of shape `[batch, length, q_features]`. + """ + projection = functools.partial( + DenseGeneral, + axis=-1, + features=(self.num_heads, self.head_dim), + kernel_axes=('embed', 'joined_kv'), + dtype=self.dtype) + + # NOTE: T5 does not explicitly rescale the attention logits by + # 1/sqrt(depth_kq)! This is folded into the initializers of the + # linear transformations, which is equivalent under Adafactor. + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + query_init = lambda *args: self.kernel_init(*args) / depth_scaling + + # Project inputs_q to multi-headed q/k/v + # dimensions are then [batch, length, num_heads, head_dim] + query = projection(kernel_init=query_init, name='query')(inputs_q) + key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) + value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) + + query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) + key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) + value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) + + if decode: + # Detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable('cache', 'cached_key') + # The key and value have dimension [batch, length, num_heads, head_dim], + # but we cache them as [batch, num_heads, head_dim, length] as a TPU + # fusion optimization. This also enables the "scatter via one-hot + # broadcast" trick, which means we do a one-hot broadcast instead of a + # scatter/gather operations, resulting in a 3-4x speedup in practice. + swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) + cached_key = self.variable('cache', 'cached_key', jnp.zeros, + swap_dims(key.shape), key.dtype) + cached_value = self.variable('cache', 'cached_value', jnp.zeros, + swap_dims(value.shape), value.dtype) + cache_index = self.variable('cache', 'cache_index', + lambda: jnp.array(0, dtype=jnp.int32)) + if is_initialized: + batch, num_heads, head_dim, length = (cached_key.value.shape) + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + # Sanity shape check of cached key against input query. + expected_shape = (batch, 1, num_heads, head_dim) + if expected_shape != query.shape: + raise ValueError('Autoregressive cache shape error, ' + 'expected query shape %s instead got %s.' % + (expected_shape, query.shape)) + + # Create a OHE of the current index. NOTE: the index is increased below. + cur_index = cache_index.value + one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) + # In order to update the key, value caches with the current key and + # value, we move the length axis to the back, similar to what we did for + # the cached ones above. + # Note these are currently the key and value of a single position, since + # we feed one position at a time. + one_token_key = jnp.moveaxis(key, -3, -1) + one_token_value = jnp.moveaxis(value, -3, -1) + # Update key, value caches with our new 1d spatial slices. + # We implement an efficient scatter into the cache via one-hot + # broadcast and addition. + key = cached_key.value + one_token_key * one_hot_indices + value = cached_value.value + one_token_value * one_hot_indices + cached_key.value = key + cached_value.value = value + cache_index.value = cache_index.value + 1 + # Move the keys and values back to their original shapes. + key = jnp.moveaxis(key, -1, -3) + value = jnp.moveaxis(value, -1, -3) + + # Causal mask for cached decoder self-attention: our single query + # position should only attend to those key positions that have already + # been generated and cached, not the remaining zero elements. + mask = combine_masks( + mask, + jnp.broadcast_to( + jnp.arange(length) <= cur_index, + # (1, 1, length) represent (head dim, query length, key length) + # query length is 1 because during decoding we deal with one + # index. + # The same mask is applied to all batch elements and heads. + (batch, 1, 1, length))) + + # Grab the correct relative attention bias during decoding. This is + # only required during single step decoding. + if bias is not None: + # The bias is a full attention matrix, but during decoding we only + # have to take a slice of it. + # This is equivalent to bias[..., cur_index:cur_index+1, :]. + bias = dynamic_vector_slice_in_dim( + jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) + + # Convert the boolean attention mask to an attention bias. + if mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + mask > 0, + jnp.full(mask.shape, 0.).astype(self.dtype), + jnp.full(mask.shape, -1e10).astype(self.dtype)) + else: + attention_bias = None + + # Add provided bias term (e.g. relative position embedding). + if bias is not None: + attention_bias = combine_biases(attention_bias, bias) + + dropout_rng = None + if not deterministic and self.dropout_rate > 0.: + dropout_rng = self.make_rng('dropout') + + # Apply attention. + x = dot_product_attention( + query, + key, + value, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout_rate, + deterministic=deterministic, + dtype=self.dtype, + float32_logits=self.float32_logits) + + # Back to the original inputs dimensions. + out = DenseGeneral( + features=inputs_q.shape[-1], # output dim is set to the input dim. + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=('joined_kv', 'embed'), + dtype=self.dtype, + name='out')( + x) + return out + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +#------------------------------------------------------------------------------ +# DenseGeneral for attention layers. +#------------------------------------------------------------------------------ +class DenseGeneral(nn.Module): + """A linear transformation (without bias) with flexible axes. + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + """ + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + dtype: DType = jnp.float32 + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'truncated_normal') + kernel_axes: Tuple[str, ...] = () + + @nn.compact + def __call__(self, inputs: Array) -> Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features + kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), + np.prod(features)) + kernel = param_with_axes( + 'kernel', + self.kernel_init, + kernel_param_shape, + jnp.float32, + axes=self.kernel_axes) + kernel = jnp.asarray(kernel, self.dtype) + kernel = jnp.reshape(kernel, kernel_shape) + + contract_ind = tuple(range(0, len(axis))) + return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) + + +def _convert_to_activation_function( + fn_or_string: Union[str, Callable]) -> Callable: + """Convert a string to an activation function.""" + if fn_or_string == 'linear': + return lambda x: x + elif isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + elif callable(fn_or_string): + return fn_or_string + else: + raise ValueError("don't know how to convert %s to an activation function" % + (fn_or_string,)) + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block. + + Attributes: + intermediate_dim: Shared dimension of hidden layers. + activations: Type of activations for each layer. Each element is either + 'linear', a string function name in flax.linen, or a function. + kernel_init: Kernel function, passed to the dense layers. + deterministic: Whether the dropout layers should be deterministic. + intermediate_dropout_rate: Dropout rate used after the intermediate layers. + dtype: Type for the dense layer. + """ + intermediate_dim: int = 2048 + activations: Sequence[Union[str, Callable]] = ('relu',) + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'truncated_normal') + intermediate_dropout_rate: float = 0.1 + dtype: Any = jnp.float32 + + @nn.compact + def __call__(self, inputs, decode: bool = False, deterministic: bool = False): + """Applies Transformer MlpBlock module.""" + # Iterate over specified MLP input activation functions. + # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. + activations = [] + for idx, act_fn in enumerate(self.activations): + dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' + x = DenseGeneral( + self.intermediate_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=('embed', 'mlp'), + name=dense_name)( + inputs) + x = _convert_to_activation_function(act_fn)(x) + activations.append(x) + + # Take elementwise product of above intermediate activations. + x = functools.reduce(operator.mul, activations) + # Apply dropout and final dense output projection. + x = nn.Dropout( + rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) # Broadcast along length. + x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) + output = DenseGeneral( + inputs.shape[-1], + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=('mlp', 'embed'), + name='wo')( + x) + return output + + +class Embed(nn.Module): + """A parameterized function from integers [0, n) to d-dimensional vectors. + + Attributes: + num_embeddings: number of embeddings. + features: number of feature dimensions for each embedding. + dtype: the dtype of the embedding vectors (default: float32). + embedding_init: embedding initializer. + one_hot: performs the gather with a one-hot contraction rather than a true + gather. This is currently needed for SPMD partitioning. + """ + num_embeddings: int + features: int + cast_input_dtype: Optional[DType] = None + dtype: DType = jnp.float32 + attend_dtype: Optional[DType] = None + embedding_init: Initializer = default_embed_init + one_hot: bool = False + embedding: Array = dataclasses.field(init=False) + + def setup(self): + self.embedding = param_with_axes( + 'embedding', + self.embedding_init, (self.num_embeddings, self.features), + jnp.float32, + axes=('vocab', 'embed')) + + def __call__(self, inputs: Array) -> Array: + """Embeds the inputs along the last dimension. + + Args: + inputs: input data, all dimensions are considered batch dimensions. + + Returns: + Output which is embedded input data. The output shape follows the input, + with an additional `features` dimension appended. + """ + if self.cast_input_dtype: + inputs = inputs.astype(self.cast_input_dtype) + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError('Input type must be an integer or unsigned integer.') + if self.one_hot: + iota = lax.iota(jnp.int32, self.num_embeddings) + one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) + output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) + else: + output = jnp.asarray(self.embedding, self.dtype)[inputs] + output = with_sharding_constraint(output, ('batch', 'length', 'embed')) + return output + + def attend(self, query: Array) -> Array: + """Attend over the embedding using a query array. + + Args: + query: array with last dimension equal the feature depth `features` of the + embedding. + + Returns: + An array with final dim `num_embeddings` corresponding to the batched + inner-product of the array of query vectors against each embedding. + Commonly used for weight-sharing between embeddings and logit transform + in NLP models. + """ + dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype + return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) + + +class FixedEmbed(nn.Module): + """Fixed (not learnable) embeddings specified by the initializer function. + + Attributes: + init_fn: The initializer function that defines the embeddings. + max_length: The maximum supported length. + dtype: The DType to use for the embeddings. + """ + features: int + max_length: int = 2048 + embedding_init: Initializer = sinusoidal() + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # The key is set to None because sinusoid init is deterministic. + shape = (self.max_length, self.features) + self.embedding = self.embedding_init(None, shape, self.dtype) # pylint: disable=too-many-function-args + + @nn.compact + def __call__(self, + inputs, + *, + decode: bool = False): + """Returns the fixed position embeddings specified by the initializer. + + Args: + inputs: [batch_size, seq_len] input position indices. + decode: True if running in single-position autoregressive decode mode. + + Returns: + The fixed position embeddings [batch_size, seq_len, features]. + """ + # We use a cache position index for tracking decoding position. + if decode: + position_embedder_index = self.variable( + 'cache', 'position_embedder_index', + lambda: jnp.array(-1, dtype=jnp.uint32)) + i = position_embedder_index.value + position_embedder_index.value = i + 1 + return jax.lax.dynamic_slice(self.embedding, jnp.array((i, 0)), + np.array((1, self.features))) + + return jnp.take(self.embedding, inputs, axis=0) + + +#------------------------------------------------------------------------------ +# T5 Layernorm - no subtraction of mean or bias. +#------------------------------------------------------------------------------ +class LayerNorm(nn.Module): + """T5 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 + dtype: Any = jnp.float32 + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Applies layer normalization on the input.""" + x = jnp.asarray(x, jnp.float32) + features = x.shape[-1] + mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) + scale = param_with_axes( + 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',)) + + scale = jnp.asarray(scale, self.dtype) + return y * scale + + +#------------------------------------------------------------------------------ +# Mask-making utility functions. +#------------------------------------------------------------------------------ +def make_attention_mask(query_input: Array, + key_input: Array, + pairwise_fn: Callable = jnp.multiply, + extra_batch_dims: int = 0, + dtype: DType = jnp.float32) -> Array: + """Mask-making helper for attention weights. + + In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the + attention weights will be `[batch, heads, len_q, len_kv]` and this + function will produce `[batch, 1, len_q, len_kv]`. + + Args: + query_input: a batched, flat input of query_length size + key_input: a batched, flat input of key_length size + pairwise_fn: broadcasting elementwise comparison function + extra_batch_dims: number of extra batch dims to add singleton axes for, none + by default + dtype: mask return dtype + + Returns: + A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. + """ + # [batch, len_q, len_kv] + mask = pairwise_fn( + # [batch, len_q] -> [batch, len_q, 1] + jnp.expand_dims(query_input, axis=-1), + # [batch, len_q] -> [batch, 1, len_kv] + jnp.expand_dims(key_input, axis=-2)) + + # [batch, 1, len_q, len_kv]. This creates the head dim. + mask = jnp.expand_dims(mask, axis=-3) + mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) + return mask.astype(dtype) + + +def make_causal_mask(x: Array, + extra_batch_dims: int = 0, + dtype: DType = jnp.float32) -> Array: + """Make a causal mask for self-attention. + + In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights + will be `[batch, heads, len, len]` and this function will produce a + causal mask of shape `[batch, 1, len, len]`. + + Note that a causal mask does not depend on the values of x; it only depends on + the shape. If x has padding elements, they will not be treated in a special + manner. + + Args: + x: input array of shape `[batch, len]` + extra_batch_dims: number of batch dims to add singleton axes for, none by + default + dtype: mask return dtype + + Returns: + A `[batch, 1, len, len]` shaped causal mask for 1d attention. + """ + idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) + return make_attention_mask( + idxs, + idxs, + jnp.greater_equal, + extra_batch_dims=extra_batch_dims, + dtype=dtype) + + +def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): + """Combine attention masks. + + Args: + *masks: set of attention mask arguments to combine, some can be None. + dtype: final mask dtype + + Returns: + Combined mask, reduced by logical and, returns None if no masks given. + """ + masks = [m for m in masks if m is not None] + if not masks: + return None + assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( + f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + mask, *other_masks = masks + for other_mask in other_masks: + mask = jnp.logical_and(mask, other_mask) + return mask.astype(dtype) + + +def combine_biases(*masks: Optional[Array]): + """Combine attention biases. + + Args: + *masks: set of attention bias arguments to combine, some can be None. + + Returns: + Combined mask, reduced by summation, returns None if no masks given. + """ + masks = [m for m in masks if m is not None] + if not masks: + return None + assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( + f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + mask, *other_masks = masks + for other_mask in other_masks: + mask = mask + other_mask + return mask + + +def make_decoder_mask(decoder_target_tokens: Array, + dtype: DType, + decoder_causal_attention: Optional[Array] = None, + decoder_segment_ids: Optional[Array] = None) -> Array: + """Compute the self-attention mask for a decoder. + + Decoder mask is formed by combining a causal mask, a padding mask and an + optional packing mask. If decoder_causal_attention is passed, it makes the + masking non-causal for positions that have value of 1. + + A prefix LM is applied to a dataset which has a notion of "inputs" and + "targets", e.g., a machine translation task. The inputs and targets are + concatenated to form a new target. `decoder_target_tokens` is the concatenated + decoder output tokens. + + The "inputs" portion of the concatenated sequence can attend to other "inputs" + tokens even for those at a later time steps. In order to control this + behavior, `decoder_causal_attention` is necessary. This is a binary mask with + a value of 1 indicating that the position belonged to "inputs" portion of the + original dataset. + + Example: + + Suppose we have a dataset with two examples. + + ds = [{"inputs": [6, 7], "targets": [8]}, + {"inputs": [3, 4], "targets": [5]}] + + After the data preprocessing with packing, the two examples are packed into + one example with the following three fields (some fields are skipped for + simplicity). + + decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] + decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] + decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] + + where each array has [batch, length] shape with batch size being 1. Then, + this function computes the following mask. + + mask = [[[[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + + mask[b, 1, :, :] represents the mask for the example `b` in the batch. + Because mask is for a self-attention layer, the mask's shape is a square of + shape [query length, key length]. + + mask[b, 1, i, j] = 1 means that the query token at position i can attend to + the key token at position j. + + Args: + decoder_target_tokens: decoder output tokens. [batch, length] + dtype: dtype of the output mask. + decoder_causal_attention: a binary mask indicating which position should + only attend to earlier positions in the sequence. Others will attend + bidirectionally. [batch, length] + decoder_segment_ids: decoder segmentation info for packed examples. [batch, + length] + + Returns: + the combined decoder mask. + """ + masks = [] + # The same mask is applied to all attention heads. So the head dimension is 1, + # i.e., the mask will be broadcast along the heads dim. + # [batch, 1, length, length] + causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) + + # Positions with value 1 in `decoder_causal_attneition` can attend + # bidirectionally. + if decoder_causal_attention is not None: + # [batch, 1, length, length] + inputs_mask = make_attention_mask( + decoder_causal_attention, + decoder_causal_attention, + jnp.logical_and, + dtype=dtype) + masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) + else: + masks.append(causal_mask) + + # Padding mask. + masks.append( + make_attention_mask( + decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype)) + + # Packing mask + if decoder_segment_ids is not None: + masks.append( + make_attention_mask( + decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype)) + + return combine_masks(*masks, dtype=dtype) diff --git a/mt3/layers_test.py b/mt3/layers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7c3b4a029fa32b66c7b7eed7631f6422dd8268 --- /dev/null +++ b/mt3/layers_test.py @@ -0,0 +1,545 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for attention classes.""" + +import dataclasses +from typing import Optional +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from flax import linen as nn +from flax.core import freeze +from flax.linen import partitioning as nn_partitioning +import jax +from jax import random +from jax.nn import initializers +import jax.numpy as jnp +from mt3 import layers +import numpy as np + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +Array = jnp.ndarray +AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name + + +class SelfAttention(layers.MultiHeadDotProductAttention): + """Self-attention special case of multi-head dot-product attention.""" + + @nn.compact + def __call__(self, + inputs_q: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + deterministic: bool = False): + return super().__call__( + inputs_q, inputs_q, mask, bias, deterministic=deterministic) + + +@dataclasses.dataclass(frozen=True) +class SelfAttentionArgs: + num_heads: int = 1 + batch_size: int = 2 + # qkv_features: int = 3 + head_dim: int = 3 + # out_features: int = 4 + q_len: int = 5 + features: int = 6 + dropout_rate: float = 0.1 + deterministic: bool = False + decode: bool = False + float32_logits: bool = False + + def __post_init__(self): + # If we are doing decoding, the query length should be 1, because are doing + # autoregressive decoding where we feed one position at a time. + assert not self.decode or self.q_len == 1 + + def init_args(self): + return dict( + num_heads=self.num_heads, + head_dim=self.head_dim, + dropout_rate=self.dropout_rate, + float32_logits=self.float32_logits) + + def apply_args(self): + inputs_q = jnp.ones((self.batch_size, self.q_len, self.features)) + mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) + bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) + return { + 'inputs_q': inputs_q, + 'mask': mask, + 'bias': bias, + 'deterministic': self.deterministic + } + + +class AttentionTest(parameterized.TestCase): + + def test_dot_product_attention_shape(self): + # This test only checks for shape but tries to make sure all code paths are + # reached. + dropout_rng = random.PRNGKey(0) + batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6 + + query = jnp.ones((batch_size, q_len, num_heads, qk_depth)) + key = jnp.ones((batch_size, kv_len, num_heads, qk_depth)) + value = jnp.ones((batch_size, kv_len, num_heads, v_depth)) + bias = jnp.ones((batch_size, num_heads, q_len, kv_len)) + + args = dict( + query=query, + key=key, + value=value, + bias=bias, + dropout_rng=dropout_rng, + dropout_rate=0.5, + deterministic=False, + ) + + output = layers.dot_product_attention(**args) + self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth)) + + def test_make_attention_mask_multiply_pairwise_fn(self): + decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]]) + attention_mask = layers.make_attention_mask( + decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32) + expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]) + self.assertEqual(attention_mask.shape, (2, 1, 3, 3)) + np.testing.assert_array_equal(attention_mask[0, 0], expected0) + np.testing.assert_array_equal(attention_mask[1, 0], expected1) + + def test_make_attention_mask_equal_pairwise_fn(self): + segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]]) + attention_mask = layers.make_attention_mask( + segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32) + # Padding is not treated in a special way. So they need to be zeroed out + # separately. + expected0 = jnp.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]]) + expected1 = jnp.array([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1]]) + self.assertEqual(attention_mask.shape, (2, 1, 6, 6)) + np.testing.assert_array_equal(attention_mask[0, 0], expected0) + np.testing.assert_array_equal(attention_mask[1, 0], expected1) + + def test_make_causal_mask_with_padding(self): + x = jnp.array([[7, 0, 0], [8, 5, 0]]) + y = layers.make_causal_mask(x) + self.assertEqual(y.shape, (2, 1, 3, 3)) + # Padding is not treated in a special way. So they need to be zeroed out + # separately. + expected_y = jnp.array([[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]], + jnp.float32) + np.testing.assert_allclose(y[0], expected_y) + np.testing.assert_allclose(y[1], expected_y) + + def test_make_causal_mask_extra_batch_dims(self): + x = jnp.ones((3, 3, 5)) + y = layers.make_causal_mask(x, extra_batch_dims=2) + self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5)) + + def test_make_causal_mask(self): + x = jnp.ones((1, 3)) + y = layers.make_causal_mask(x) + self.assertEqual(y.shape, (1, 1, 3, 3)) + expected_y = jnp.array([[[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]]], + jnp.float32) + np.testing.assert_allclose(y, expected_y) + + def test_combine_masks(self): + masks = [ + jnp.array([0, 1, 0, 1], jnp.float32), None, + jnp.array([1, 1, 1, 1], jnp.float32), + jnp.array([1, 1, 1, 0], jnp.float32) + ] + y = layers.combine_masks(*masks) + np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32)) + + def test_combine_biases(self): + masks = [ + jnp.array([0, 1, 0, 1], jnp.float32), None, + jnp.array([0, 1, 1, 1], jnp.float32), + jnp.array([0, 1, 1, 0], jnp.float32) + ] + y = layers.combine_biases(*masks) + np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32)) + + def test_make_decoder_mask_lm_unpacked(self): + decoder_target_tokens = jnp.array([6, 7, 3, 0]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32) + expected_mask = jnp.array([[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], + [0, 0, 0, 0]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_lm_packed(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_segment_ids=decoder_segment_ids) + expected_mask = jnp.array([[[[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_unpacked(self): + decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask = jnp.array( + [[[[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]], + dtype=jnp.float32) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_packed(self): + decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + expected_mask = jnp.array([[[[1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask0 = jnp.array([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], + [0, 0, 0, 0]]) + expected_mask1 = jnp.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], + [0, 0, 0, 0]]) + self.assertEqual(mask.shape, (2, 1, 4, 4)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + np.testing.assert_array_equal(mask[1, 0], expected_mask1) + + def test_make_decoder_mask_composite_causal_attention(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]) + + self.assertEqual(mask.shape, (1, 1, 7, 7)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + + def test_make_decoder_mask_composite_causal_attention_packed(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1]]) + + self.assertEqual(mask.shape, (1, 1, 9, 9)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + + @parameterized.parameters({'f': 20}, {'f': 22}) + def test_multihead_dot_product_attention(self, f): + # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim + b, q, h, d, k = 2, 3, 4, 5, 6 + + base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + np.random.seed(0) + inputs_q = np.random.randn(b, q, f) + inputs_kv = np.random.randn(b, k, f) + + # Projection: [b, q, f] -> [b, q, h, d] + # So the kernels have to be [f, h, d] + query_kernel = np.random.randn(f, h, d) + key_kernel = np.random.randn(f, h, d) + value_kernel = np.random.randn(f, h, d) + # `out` calculation: [b, q, h, d] -> [b, q, f] + # So kernel has to be [h, d, f] + out_kernel = np.random.randn(h, d, f) + + params = { + 'query': { + 'kernel': query_kernel.reshape(f, -1) + }, + 'key': { + 'kernel': key_kernel.reshape(f, -1) + }, + 'value': { + 'kernel': value_kernel.reshape(f, -1) + }, + 'out': { + 'kernel': out_kernel.reshape(-1, f) + } + } + y = layers.MultiHeadDotProductAttention(**args).apply( + {'params': freeze(params)}, inputs_q, inputs_kv) + + query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel) + key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel) + value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel) + logits = np.einsum('bqhd,bkhd->bhqk', query, key) + weights = nn.softmax(logits, axis=-1) + combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value) + y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel) + np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5) + + def test_multihead_dot_product_attention_caching(self): + # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim + b, h, d, k = 2, 3, 4, 5 + f = h * d + + base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + cache = { + 'cached_key': np.zeros((b, h, d, k)), + 'cached_value': np.zeros((b, h, d, k)), + 'cache_index': np.array(0) + } + inputs_q = np.random.randn(b, 1, f) + inputs_kv = np.random.randn(b, 1, f) + + # Mock dense general such that q, k, v projections are replaced by simple + # reshaping. + def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument + return x.reshape(b, -1, h, d) + + with mock.patch.object( + layers.DenseGeneral, '__call__', new=mock_dense_general): + _, mutated = layers.MultiHeadDotProductAttention(**args).apply( + {'cache': freeze(cache)}, + inputs_q, + inputs_kv, + decode=True, + mutable=['cache']) + updated_cache = mutated['cache'] + + # Perform the same mocked projection to generate the expected cache. + # (key|value): [b, 1, h, d] + key = mock_dense_general(None, inputs_kv) + value = mock_dense_general(None, inputs_kv) + + # cached_(key|value): [b, h, d, k] + cache['cached_key'][:, :, :, 0] = key[:, 0, :, :] + cache['cached_value'][:, :, :, 0] = value[:, 0, :, :] + cache['cache_index'] = np.array(1) + for name, array in cache.items(): + np.testing.assert_allclose(array, updated_cache[name]) + + def test_dot_product_attention(self): + # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim + b, q, h, d, k = 2, 3, 4, 5, 6 + np.random.seed(0) + query = np.random.randn(b, q, h, d) + key = np.random.randn(b, k, h, d) + value = np.random.randn(b, k, h, d) + bias = np.random.randn(b, h, q, k) + attn_out = layers.dot_product_attention(query, key, value, bias=bias) + logits = np.einsum('bqhd,bkhd->bhqk', query, key) + weights = jax.nn.softmax(logits + bias, axis=-1) + expected = np.einsum('bhqk,bkhd->bqhd', weights, value) + np.testing.assert_allclose(attn_out, expected, atol=1e-6) + + +class EmbeddingTest(parameterized.TestCase): + + def test_embedder_raises_exception_for_incorrect_input_type(self): + """Tests that inputs are integers and that an exception is raised if not.""" + embed = layers.Embed(num_embeddings=10, features=5) + inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + bad_inputs = inputs.astype(np.float32) + with self.assertRaisesRegex( + ValueError, 'Input type must be an integer or unsigned integer.'): + _ = embed.apply(variables, bad_inputs) + + @parameterized.named_parameters( + { + 'testcase_name': 'with_ones', + 'init_fn': jax.nn.initializers.ones, + 'num_embeddings': 10, + 'features': 5, + 'matrix_sum': 5 * 10, + }, { + 'testcase_name': 'with_zeros', + 'init_fn': jax.nn.initializers.zeros, + 'num_embeddings': 10, + 'features': 5, + 'matrix_sum': 0, + }) + def test_embedding_initializes_correctly(self, init_fn, num_embeddings, + features, matrix_sum): + """Tests if the Embed class initializes with the requested initializer.""" + embed = layers.Embed( + num_embeddings=num_embeddings, + features=features, + embedding_init=init_fn) + inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + embedding_matrix = variables['params']['embedding'] + self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum) + + def test_embedding_matrix_shape(self): + """Tests that the embedding matrix has the right shape.""" + num_embeddings = 10 + features = 5 + embed = layers.Embed(num_embeddings=num_embeddings, features=features) + inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + embedding_matrix = variables['params']['embedding'] + self.assertEqual((num_embeddings, features), embedding_matrix.shape) + + def test_embedding_attend(self): + """Tests that attending with ones returns sum of embedding vectors.""" + features = 5 + embed = layers.Embed(num_embeddings=10, features=features) + inputs = np.array([[1]], dtype=np.int64) + variables = embed.init(jax.random.PRNGKey(0), inputs) + query = np.ones(features, dtype=np.float32) + result = embed.apply(variables, query, method=embed.attend) + expected = np.sum(variables['params']['embedding'], -1) + np.testing.assert_array_almost_equal(result, expected) + + +class DenseTest(parameterized.TestCase): + + def test_dense_general_no_bias(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 3)) + model = layers.DenseGeneral( + features=4, + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + self.assertEqual(y.shape, (1, 4)) + np.testing.assert_allclose(y, np.full((1, 4), 3.)) + + def test_dense_general_two_features(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 3)) + model = layers.DenseGeneral( + features=(2, 2), + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + # We transform the last input dimension to two output dimensions (2, 2). + np.testing.assert_allclose(y, np.full((1, 2, 2), 3.)) + + def test_dense_general_two_axes(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 2, 2)) + model = layers.DenseGeneral( + features=3, + axis=(-2, 2), # Note: this is the same as (1, 2). + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + # We transform the last two input dimensions (2, 2) to one output dimension. + np.testing.assert_allclose(y, np.full((1, 3), 4.)) + + def test_mlp_same_out_dim(self): + module = layers.MlpBlock( + intermediate_dim=4, + activations=('relu',), + kernel_init=nn.initializers.xavier_uniform(), + dtype=jnp.float32, + ) + inputs = np.array( + [ + # Batch 1. + [[1, 1], [1, 1], [1, 2]], + # Batch 2. + [[2, 2], [3, 1], [2, 2]], + ], + dtype=np.float32) + params = module.init(random.PRNGKey(0), inputs, deterministic=True) + self.assertEqual( + jax.tree_map(lambda a: a.tolist(), params), { + 'params': { + 'wi': { + 'kernel': [[ + -0.8675811290740967, 0.08417510986328125, + 0.022586345672607422, -0.9124102592468262 + ], + [ + -0.19464373588562012, 0.49809837341308594, + 0.7808468341827393, 0.9267289638519287 + ]], + }, + 'wo': { + 'kernel': [[0.01154780387878418, 0.1397249698638916], + [0.974980354309082, 0.5903260707855225], + [-0.05997943878173828, 0.616570234298706], + [0.2934272289276123, 0.8181164264678955]], + }, + }, + 'params_axes': { + 'wi': { + 'kernel_axes': AxisMetadata(names=('embed', 'mlp')), + }, + 'wo': { + 'kernel_axes': AxisMetadata(names=('mlp', 'embed')), + }, + }, + }) + result = module.apply(params, inputs, deterministic=True) + np.testing.assert_allclose( + result.tolist(), + [[[0.5237172245979309, 0.8508185744285583], + [0.5237172245979309, 0.8508185744285583], + [1.2344461679458618, 2.3844780921936035]], + [[1.0474344491958618, 1.7016371488571167], + [0.6809444427490234, 0.9663378596305847], + [1.0474344491958618, 1.7016371488571167]]], + rtol=1e-6, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/mt3/metrics.py b/mt3/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..42f2358f187fbb9a4213310b2e054ab6fc85405d --- /dev/null +++ b/mt3/metrics.py @@ -0,0 +1,392 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transcription metrics.""" + +import collections +import copy +import functools +from typing import Any, Iterable, Mapping, Optional, Sequence + +import mir_eval + +from mt3 import event_codec +from mt3 import metrics_utils +from mt3 import note_sequences +from mt3 import spectrograms +from mt3 import summaries +from mt3 import vocabularies + +import note_seq +import numpy as np +import seqio + + +def _program_aware_note_scores( + ref_ns: note_seq.NoteSequence, + est_ns: note_seq.NoteSequence, + granularity_type: str +) -> Mapping[str, float]: + """Compute precision/recall/F1 for notes taking program into account. + + For non-drum tracks, uses onsets and offsets. For drum tracks, uses onsets + only. Applies MIDI program map of specified granularity type. + + Args: + ref_ns: Reference NoteSequence with ground truth labels. + est_ns: Estimated NoteSequence. + granularity_type: String key in vocabularies.PROGRAM_GRANULARITIES dict. + + Returns: + A dictionary containing precision, recall, and F1 score. + """ + program_map_fn = vocabularies.PROGRAM_GRANULARITIES[ + granularity_type].program_map_fn + + ref_ns = copy.deepcopy(ref_ns) + for note in ref_ns.notes: + if not note.is_drum: + note.program = program_map_fn(note.program) + + est_ns = copy.deepcopy(est_ns) + for note in est_ns.notes: + if not note.is_drum: + note.program = program_map_fn(note.program) + + program_and_is_drum_tuples = ( + set((note.program, note.is_drum) for note in ref_ns.notes) | + set((note.program, note.is_drum) for note in est_ns.notes) + ) + + drum_precision_sum = 0.0 + drum_precision_count = 0 + drum_recall_sum = 0.0 + drum_recall_count = 0 + + nondrum_precision_sum = 0.0 + nondrum_precision_count = 0 + nondrum_recall_sum = 0.0 + nondrum_recall_count = 0 + + for program, is_drum in program_and_is_drum_tuples: + est_track = note_sequences.extract_track(est_ns, program, is_drum) + ref_track = note_sequences.extract_track(ref_ns, program, is_drum) + + est_intervals, est_pitches, unused_est_velocities = ( + note_seq.sequences_lib.sequence_to_valued_intervals(est_track)) + ref_intervals, ref_pitches, unused_ref_velocities = ( + note_seq.sequences_lib.sequence_to_valued_intervals(ref_track)) + + args = { + 'ref_intervals': ref_intervals, 'ref_pitches': ref_pitches, + 'est_intervals': est_intervals, 'est_pitches': est_pitches + } + if is_drum: + args['offset_ratio'] = None + + precision, recall, unused_f_measure, unused_avg_overlap_ratio = ( + mir_eval.transcription.precision_recall_f1_overlap(**args)) + + if is_drum: + drum_precision_sum += precision * len(est_intervals) + drum_precision_count += len(est_intervals) + drum_recall_sum += recall * len(ref_intervals) + drum_recall_count += len(ref_intervals) + else: + nondrum_precision_sum += precision * len(est_intervals) + nondrum_precision_count += len(est_intervals) + nondrum_recall_sum += recall * len(ref_intervals) + nondrum_recall_count += len(ref_intervals) + + precision_sum = drum_precision_sum + nondrum_precision_sum + precision_count = drum_precision_count + nondrum_precision_count + recall_sum = drum_recall_sum + nondrum_recall_sum + recall_count = drum_recall_count + nondrum_recall_count + + precision = (precision_sum / precision_count) if precision_count else 0 + recall = (recall_sum / recall_count) if recall_count else 0 + f_measure = mir_eval.util.f_measure(precision, recall) + + drum_precision = ((drum_precision_sum / drum_precision_count) + if drum_precision_count else 0) + drum_recall = ((drum_recall_sum / drum_recall_count) + if drum_recall_count else 0) + drum_f_measure = mir_eval.util.f_measure(drum_precision, drum_recall) + + nondrum_precision = ((nondrum_precision_sum / nondrum_precision_count) + if nondrum_precision_count else 0) + nondrum_recall = ((nondrum_recall_sum / nondrum_recall_count) + if nondrum_recall_count else 0) + nondrum_f_measure = mir_eval.util.f_measure(nondrum_precision, nondrum_recall) + + return { + f'Onset + offset + program precision ({granularity_type})': precision, + f'Onset + offset + program recall ({granularity_type})': recall, + f'Onset + offset + program F1 ({granularity_type})': f_measure, + f'Drum onset precision ({granularity_type})': drum_precision, + f'Drum onset recall ({granularity_type})': drum_recall, + f'Drum onset F1 ({granularity_type})': drum_f_measure, + f'Nondrum onset + offset + program precision ({granularity_type})': + nondrum_precision, + f'Nondrum onset + offset + program recall ({granularity_type})': + nondrum_recall, + f'Nondrum onset + offset + program F1 ({granularity_type})': + nondrum_f_measure + } + + +def _note_onset_tolerance_sweep( + ref_ns: note_seq.NoteSequence, est_ns: note_seq.NoteSequence, + tolerances: Iterable[float] = (0.01, 0.02, 0.05, 0.1, 0.2, 0.5) +) -> Mapping[str, float]: + """Compute note precision/recall/F1 across a range of tolerances.""" + est_intervals, est_pitches, unused_est_velocities = ( + note_seq.sequences_lib.sequence_to_valued_intervals(est_ns)) + ref_intervals, ref_pitches, unused_ref_velocities = ( + note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns)) + + scores = {} + + for tol in tolerances: + precision, recall, f_measure, _ = ( + mir_eval.transcription.precision_recall_f1_overlap( + ref_intervals=ref_intervals, ref_pitches=ref_pitches, + est_intervals=est_intervals, est_pitches=est_pitches, + onset_tolerance=tol, offset_min_tolerance=tol)) + + scores[f'Onset + offset precision ({tol})'] = precision + scores[f'Onset + offset recall ({tol})'] = recall + scores[f'Onset + offset F1 ({tol})'] = f_measure + + return scores + + +def transcription_metrics( + targets: Sequence[Mapping[str, Any]], + predictions: Sequence[Mapping[str, Any]], + codec: event_codec.Codec, + spectrogram_config: spectrograms.SpectrogramConfig, + onsets_only: bool, + use_ties: bool, + track_specs: Optional[Sequence[note_sequences.TrackSpec]] = None, + num_summary_examples: int = 5, + frame_fps: float = 62.5, + frame_velocity_threshold: int = 30, +) -> Mapping[str, seqio.metrics.MetricValue]: + """Compute mir_eval transcription metrics.""" + if onsets_only and use_ties: + raise ValueError('Ties not compatible with onset-only transcription.') + if onsets_only: + encoding_spec = note_sequences.NoteOnsetEncodingSpec + elif not use_ties: + encoding_spec = note_sequences.NoteEncodingSpec + else: + encoding_spec = note_sequences.NoteEncodingWithTiesSpec + + # The first target for each full example contains the NoteSequence; just + # organize by ID. + full_targets = {} + for target in targets: + if target['ref_ns']: + full_targets[target['unique_id']] = {'ref_ns': target['ref_ns']} + + # Gather all predictions for the same ID and concatenate them in time order, + # to construct full-length predictions. + full_predictions = metrics_utils.combine_predictions_by_id( + predictions=predictions, + combine_predictions_fn=functools.partial( + metrics_utils.event_predictions_to_ns, + codec=codec, + encoding_spec=encoding_spec)) + + assert sorted(full_targets.keys()) == sorted(full_predictions.keys()) + + full_target_prediction_pairs = [ + (full_targets[id], full_predictions[id]) + for id in sorted(full_targets.keys()) + ] + + scores = collections.defaultdict(list) + all_track_pianorolls = collections.defaultdict(list) + for target, prediction in full_target_prediction_pairs: + scores['Invalid events'].append(prediction['est_invalid_events']) + scores['Dropped events'].append(prediction['est_dropped_events']) + + def remove_drums(ns): + ns_drumless = note_seq.NoteSequence() + ns_drumless.CopyFrom(ns) + del ns_drumless.notes[:] + ns_drumless.notes.extend([note for note in ns.notes if not note.is_drum]) + return ns_drumless + + est_ns_drumless = remove_drums(prediction['est_ns']) + ref_ns_drumless = remove_drums(target['ref_ns']) + + # Whether or not there are separate tracks, compute metrics for the full + # NoteSequence minus drums. + est_tracks = [est_ns_drumless] + ref_tracks = [ref_ns_drumless] + use_track_offsets = [not onsets_only] + use_track_velocities = [not onsets_only] + track_instrument_names = [''] + + if track_specs is not None: + # Compute transcription metrics separately for each track. + for spec in track_specs: + est_tracks.append(note_sequences.extract_track( + prediction['est_ns'], spec.program, spec.is_drum)) + ref_tracks.append(note_sequences.extract_track( + target['ref_ns'], spec.program, spec.is_drum)) + use_track_offsets.append(not onsets_only and not spec.is_drum) + use_track_velocities.append(not onsets_only) + track_instrument_names.append(spec.name) + + for est_ns, ref_ns, use_offsets, use_velocities, instrument_name in zip( + est_tracks, ref_tracks, use_track_offsets, use_track_velocities, + track_instrument_names): + track_scores = {} + + est_intervals, est_pitches, est_velocities = ( + note_seq.sequences_lib.sequence_to_valued_intervals(est_ns)) + + ref_intervals, ref_pitches, ref_velocities = ( + note_seq.sequences_lib.sequence_to_valued_intervals(ref_ns)) + + # Precision / recall / F1 using onsets (and pitches) only. + precision, recall, f_measure, avg_overlap_ratio = ( + mir_eval.transcription.precision_recall_f1_overlap( + ref_intervals=ref_intervals, + ref_pitches=ref_pitches, + est_intervals=est_intervals, + est_pitches=est_pitches, + offset_ratio=None)) + del avg_overlap_ratio + track_scores['Onset precision'] = precision + track_scores['Onset recall'] = recall + track_scores['Onset F1'] = f_measure + + if use_offsets: + # Precision / recall / F1 using onsets and offsets. + precision, recall, f_measure, avg_overlap_ratio = ( + mir_eval.transcription.precision_recall_f1_overlap( + ref_intervals=ref_intervals, + ref_pitches=ref_pitches, + est_intervals=est_intervals, + est_pitches=est_pitches)) + del avg_overlap_ratio + track_scores['Onset + offset precision'] = precision + track_scores['Onset + offset recall'] = recall + track_scores['Onset + offset F1'] = f_measure + + if use_velocities: + # Precision / recall / F1 using onsets and velocities (no offsets). + precision, recall, f_measure, avg_overlap_ratio = ( + mir_eval.transcription_velocity.precision_recall_f1_overlap( + ref_intervals=ref_intervals, + ref_pitches=ref_pitches, + ref_velocities=ref_velocities, + est_intervals=est_intervals, + est_pitches=est_pitches, + est_velocities=est_velocities, + offset_ratio=None)) + track_scores['Onset + velocity precision'] = precision + track_scores['Onset + velocity recall'] = recall + track_scores['Onset + velocity F1'] = f_measure + + if use_offsets and use_velocities: + # Precision / recall / F1 using onsets, offsets, and velocities. + precision, recall, f_measure, avg_overlap_ratio = ( + mir_eval.transcription_velocity.precision_recall_f1_overlap( + ref_intervals=ref_intervals, + ref_pitches=ref_pitches, + ref_velocities=ref_velocities, + est_intervals=est_intervals, + est_pitches=est_pitches, + est_velocities=est_velocities)) + track_scores['Onset + offset + velocity precision'] = precision + track_scores['Onset + offset + velocity recall'] = recall + track_scores['Onset + offset + velocity F1'] = f_measure + + # Calculate framewise metrics. + is_drum = all([n.is_drum for n in ref_ns.notes]) + ref_pr = metrics_utils.get_prettymidi_pianoroll( + ref_ns, frame_fps, is_drum=is_drum) + est_pr = metrics_utils.get_prettymidi_pianoroll( + est_ns, frame_fps, is_drum=is_drum) + all_track_pianorolls[instrument_name].append((est_pr, ref_pr)) + frame_precision, frame_recall, frame_f1 = metrics_utils.frame_metrics( + ref_pr, est_pr, velocity_threshold=frame_velocity_threshold) + track_scores['Frame Precision'] = frame_precision + track_scores['Frame Recall'] = frame_recall + track_scores['Frame F1'] = frame_f1 + + for metric_name, metric_value in track_scores.items(): + if instrument_name: + scores[f'{instrument_name}/{metric_name}'].append(metric_value) + else: + scores[metric_name].append(metric_value) + + # Add program-aware note metrics for all program granularities. + # Note that this interacts with the training program granularity; in + # particular granularities *higher* than the training granularity are likely + # to have poor metrics. + for granularity_type in vocabularies.PROGRAM_GRANULARITIES: + for name, score in _program_aware_note_scores( + target['ref_ns'], prediction['est_ns'], + granularity_type=granularity_type).items(): + scores[name].append(score) + + # Add (non-program-aware) note metrics across a range of onset/offset + # tolerances. + for name, score in _note_onset_tolerance_sweep( + ref_ns=ref_ns_drumless, est_ns=est_ns_drumless).items(): + scores[name].append(score) + + mean_scores = {k: np.mean(v) for k, v in scores.items()} + + score_histograms = {'%s (hist)' % k: seqio.metrics.Histogram(np.array(v)) + for k, v in scores.items()} + + # Pick several examples to summarize. + targets_to_summarize, predictions_to_summarize = zip( + *full_target_prediction_pairs[:num_summary_examples]) + + # Compute audio summaries. + audio_summaries = summaries.audio_summaries( + targets=targets_to_summarize, + predictions=predictions_to_summarize, + spectrogram_config=spectrogram_config) + + # Compute transcription summaries. + transcription_summaries = summaries.transcription_summaries( + targets=targets_to_summarize, + predictions=predictions_to_summarize, + spectrogram_config=spectrogram_config, + ns_feature_suffix='ns', + track_specs=track_specs) + + pianorolls_to_summarize = { + k: v[:num_summary_examples] for k, v in all_track_pianorolls.items() + } + + prettymidi_pianoroll_summaries = summaries.prettymidi_pianoroll( + pianorolls_to_summarize, fps=frame_fps) + + return { + **mean_scores, + **score_histograms, + **audio_summaries, + **transcription_summaries, + **prettymidi_pianoroll_summaries, + } diff --git a/mt3/metrics_utils.py b/mt3/metrics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9297f1cac5a3904265238b2486dc3f8cabcca58 --- /dev/null +++ b/mt3/metrics_utils.py @@ -0,0 +1,196 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for transcription metrics.""" + +import collections +import functools + +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, TypeVar + +from mt3 import event_codec +from mt3 import note_sequences +from mt3 import run_length_encoding + +import note_seq +import numpy as np +import pretty_midi +import sklearn + +S = TypeVar('S') +T = TypeVar('T') + +CombineExamplesFunctionType = Callable[[Sequence[Mapping[str, Any]]], + Mapping[str, Any]] + + +def _group_predictions_by_id( + predictions: Sequence[Mapping[str, T]] +) -> Mapping[str, Sequence[T]]: + predictions_by_id = collections.defaultdict(list) + for pred in predictions: + predictions_by_id[pred['unique_id']].append(pred) + return predictions_by_id + + +def combine_predictions_by_id( + predictions: Sequence[Mapping[str, Any]], + combine_predictions_fn: CombineExamplesFunctionType +) -> Mapping[str, Mapping[str, Any]]: + """Concatenate predicted examples, grouping by ID and sorting by time.""" + predictions_by_id = _group_predictions_by_id(predictions) + return { + id: combine_predictions_fn(preds) + for id, preds in predictions_by_id.items() + } + + +def decode_and_combine_predictions( + predictions: Sequence[Mapping[str, Any]], + init_state_fn: Callable[[], S], + begin_segment_fn: Callable[[S], None], + decode_tokens_fn: Callable[[S, Sequence[int], int, Optional[int]], + Tuple[int, int]], + flush_state_fn: Callable[[S], T] +) -> Tuple[T, int, int]: + """Decode and combine a sequence of predictions to a full result. + + For time-based events, this usually means concatenation. + + Args: + predictions: List of predictions, each of which is a dictionary containing + estimated tokens ('est_tokens') and start time ('start_time') fields. + init_state_fn: Function that takes no arguments and returns an initial + decoding state. + begin_segment_fn: Function that updates the decoding state at the beginning + of a segment. + decode_tokens_fn: Function that takes a decoding state, estimated tokens + (for a single segment), start time, and max time, and processes the + tokens, updating the decoding state in place. Also returns the number of + invalid and dropped events for the segment. + flush_state_fn: Function that flushes the final decoding state into the + result. + + Returns: + result: The full combined decoding. + total_invalid_events: Total number of invalid event tokens across all + predictions. + total_dropped_events: Total number of dropped event tokens across all + predictions. + """ + sorted_predictions = sorted(predictions, key=lambda pred: pred['start_time']) + + state = init_state_fn() + total_invalid_events = 0 + total_dropped_events = 0 + + for pred_idx, pred in enumerate(sorted_predictions): + begin_segment_fn(state) + + # Depending on the audio token hop length, each symbolic token could be + # associated with multiple audio frames. Since we split up the audio frames + # into segments for prediction, this could lead to overlap. To prevent + # overlap issues, ensure that the current segment does not make any + # predictions for the time period covered by the subsequent segment. + max_decode_time = None + if pred_idx < len(sorted_predictions) - 1: + max_decode_time = sorted_predictions[pred_idx + 1]['start_time'] + + invalid_events, dropped_events = decode_tokens_fn( + state, pred['est_tokens'], pred['start_time'], max_decode_time) + + total_invalid_events += invalid_events + total_dropped_events += dropped_events + + return flush_state_fn(state), total_invalid_events, total_dropped_events + + +def event_predictions_to_ns( + predictions: Sequence[Mapping[str, Any]], codec: event_codec.Codec, + encoding_spec: note_sequences.NoteEncodingSpecType +) -> Mapping[str, Any]: + """Convert a sequence of predictions to a combined NoteSequence.""" + ns, total_invalid_events, total_dropped_events = decode_and_combine_predictions( + predictions=predictions, + init_state_fn=encoding_spec.init_decoding_state_fn, + begin_segment_fn=encoding_spec.begin_decoding_segment_fn, + decode_tokens_fn=functools.partial( + run_length_encoding.decode_events, + codec=codec, + decode_event_fn=encoding_spec.decode_event_fn), + flush_state_fn=encoding_spec.flush_decoding_state_fn) + + # Also concatenate raw inputs from all predictions. + sorted_predictions = sorted(predictions, key=lambda pred: pred['start_time']) + raw_inputs = np.concatenate( + [pred['raw_inputs'] for pred in sorted_predictions], axis=0) + start_times = [pred['start_time'] for pred in sorted_predictions] + + return { + 'raw_inputs': raw_inputs, + 'start_times': start_times, + 'est_ns': ns, + 'est_invalid_events': total_invalid_events, + 'est_dropped_events': total_dropped_events, + } + + +def get_prettymidi_pianoroll(ns: note_seq.NoteSequence, fps: float, + is_drum: bool): + """Convert NoteSequence to pianoroll through pretty_midi.""" + for note in ns.notes: + if is_drum or note.end_time - note.start_time < 0.05: + # Give all drum notes a fixed length, and all others a min length + note.end_time = note.start_time + 0.05 + + pm = note_seq.note_sequence_to_pretty_midi(ns) + end_time = pm.get_end_time() + cc = [ + # all sound off + pretty_midi.ControlChange(number=120, value=0, time=end_time), + # all notes off + pretty_midi.ControlChange(number=123, value=0, time=end_time) + ] + pm.instruments[0].control_changes = cc + if is_drum: + # If inst.is_drum is set, pretty_midi will return an all zero pianoroll. + for inst in pm.instruments: + inst.is_drum = False + pianoroll = pm.get_piano_roll(fs=fps) + return pianoroll + + +def frame_metrics(ref_pianoroll: np.ndarray, + est_pianoroll: np.ndarray, + velocity_threshold: int) -> Tuple[float, float, float]: + """Frame Precision, Recall, and F1.""" + # Pad to same length + if ref_pianoroll.shape[1] > est_pianoroll.shape[1]: + diff = ref_pianoroll.shape[1] - est_pianoroll.shape[1] + est_pianoroll = np.pad(est_pianoroll, [(0, 0), (0, diff)], mode='constant') + elif est_pianoroll.shape[1] > ref_pianoroll.shape[1]: + diff = est_pianoroll.shape[1] - ref_pianoroll.shape[1] + ref_pianoroll = np.pad(ref_pianoroll, [(0, 0), (0, diff)], mode='constant') + + # For ref, remove any notes that are too quiet (consistent with Cerberus.) + ref_frames_bool = ref_pianoroll > velocity_threshold + # For est, keep all predicted notes. + est_frames_bool = est_pianoroll > 0 + + precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support( + ref_frames_bool.flatten(), + est_frames_bool.flatten(), + labels=[True, False]) + + return precision[0], recall[0], f1[0] diff --git a/mt3/metrics_utils_test.py b/mt3/metrics_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..be4f977dddce00c0a546eb31a657c62980d05a70 --- /dev/null +++ b/mt3/metrics_utils_test.py @@ -0,0 +1,259 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for metrics_utils.""" + +from mt3 import event_codec +from mt3 import metrics_utils +from mt3 import note_sequences + +import note_seq +import numpy as np +import tensorflow as tf + + +class MetricsUtilsTest(tf.test.TestCase): + + def test_event_predictions_to_ns(self): + predictions = [ + { + 'raw_inputs': [0, 0], + 'start_time': 0.0, + 'est_tokens': [20, 160], + }, + { + 'raw_inputs': [1, 1], + 'start_time': 0.4, + # These last 2 events should be dropped. + 'est_tokens': [20, 161, 50, 162], + }, + { + 'raw_inputs': [2, 2], + 'start_time': 0.8, + 'est_tokens': [163, 20, 164] + }, + ] + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=59, + velocity=100, + start_time=0.20, + end_time=0.21) + expected_ns.notes.add( + pitch=60, + velocity=100, + start_time=0.60, + end_time=0.61) + expected_ns.notes.add( + pitch=62, + velocity=100, + start_time=0.80, + end_time=0.81) + expected_ns.notes.add( + pitch=63, + velocity=100, + start_time=1.00, + end_time=1.01) + expected_ns.total_time = 1.01 + + codec = event_codec.Codec( + max_shift_steps=100, + steps_per_second=100, + event_ranges=[ + event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH)]) + res = metrics_utils.event_predictions_to_ns( + predictions, codec=codec, + encoding_spec=note_sequences.NoteOnsetEncodingSpec) + self.assertProtoEquals(expected_ns, res['est_ns']) + self.assertEqual(0, res['est_invalid_events']) + self.assertEqual(2, res['est_dropped_events']) + np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs']) + + def test_event_predictions_to_ns_with_offsets(self): + predictions = [ + { + 'raw_inputs': [0, 0], + 'start_time': 0.0, + 'est_tokens': [20, 356, 160], + }, + { + 'raw_inputs': [1, 1], + 'start_time': 0.4, + 'est_tokens': [20, 292, 161], + }, + { + 'raw_inputs': [2, 2], + 'start_time': 0.8, + 'est_tokens': [20, 229, 160, 161] + }, + ] + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=59, + velocity=127, + start_time=0.20, + end_time=1.00) + expected_ns.notes.add( + pitch=60, + velocity=63, + start_time=0.60, + end_time=1.00) + expected_ns.total_time = 1.00 + + codec = event_codec.Codec( + max_shift_steps=100, + steps_per_second=100, + event_ranges=[ + event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('velocity', 0, 127) + ]) + res = metrics_utils.event_predictions_to_ns( + predictions, codec=codec, encoding_spec=note_sequences.NoteEncodingSpec) + self.assertProtoEquals(expected_ns, res['est_ns']) + self.assertEqual(0, res['est_invalid_events']) + self.assertEqual(0, res['est_dropped_events']) + np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs']) + + def test_event_predictions_to_ns_multitrack(self): + predictions = [ + { + 'raw_inputs': [0, 0], + 'start_time': 0.0, + 'est_tokens': [20, 517, 356, 160], + }, + { + 'raw_inputs': [1, 1], + 'start_time': 0.4, + 'est_tokens': [20, 356, 399], + }, + { + 'raw_inputs': [2, 2], + 'start_time': 0.8, + 'est_tokens': [20, 517, 229, 160] + }, + ] + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=42, + velocity=127, + start_time=0.60, + end_time=0.61, + is_drum=True, + instrument=9) + expected_ns.notes.add( + pitch=59, + velocity=127, + start_time=0.20, + end_time=1.00, + program=32) + expected_ns.total_time = 1.00 + + codec = event_codec.Codec( + max_shift_steps=100, + steps_per_second=100, + event_ranges=[ + event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('velocity', 0, 127), + event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM, + note_seq.MAX_MIDI_PROGRAM) + ]) + res = metrics_utils.event_predictions_to_ns( + predictions, codec=codec, encoding_spec=note_sequences.NoteEncodingSpec) + self.assertProtoEquals(expected_ns, res['est_ns']) + self.assertEqual(0, res['est_invalid_events']) + self.assertEqual(0, res['est_dropped_events']) + np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs']) + + def test_event_predictions_to_ns_multitrack_ties(self): + predictions = [ + { + 'raw_inputs': [0, 0], + 'start_time': 0.0, + 'est_tokens': [613, # no tied notes + 20, 517, 356, 160], + }, + { + 'raw_inputs': [1, 1], + 'start_time': 0.4, + 'est_tokens': [517, 160, 613, # tied note + 20, 356, 399], + }, + { + 'raw_inputs': [2, 2], + 'start_time': 0.8, + 'est_tokens': [613] # no tied notes, causing active note to end + }, + ] + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=42, + velocity=127, + start_time=0.60, + end_time=0.61, + is_drum=True, + instrument=9) + expected_ns.notes.add( + pitch=59, + velocity=127, + start_time=0.20, + end_time=0.80, + program=32) + expected_ns.total_time = 0.80 + + codec = event_codec.Codec( + max_shift_steps=100, + steps_per_second=100, + event_ranges=[ + event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('velocity', 0, 127), + event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM, + note_seq.MAX_MIDI_PROGRAM), + event_codec.EventRange('tie', 0, 0) + ]) + res = metrics_utils.event_predictions_to_ns( + predictions, codec=codec, + encoding_spec=note_sequences.NoteEncodingWithTiesSpec) + self.assertProtoEquals(expected_ns, res['est_ns']) + self.assertEqual(0, res['est_invalid_events']) + self.assertEqual(0, res['est_dropped_events']) + np.testing.assert_array_equal([0, 0, 1, 1, 2, 2], res['raw_inputs']) + + def test_frame_metrics(self): + ref = np.zeros(shape=(128, 5)) + est = np.zeros(shape=(128, 5)) + + # one overlapping note, two false positives, two false negatives + ref[10, 0] = 127 + ref[10, 1] = 127 + ref[10, 2] = 127 + + est[10, 2] = 127 + est[10, 3] = 127 + est[10, 4] = 127 + + prec, rec, _ = metrics_utils.frame_metrics(ref, est, velocity_threshold=1) + np.testing.assert_approx_equal(prec, 1/3) + np.testing.assert_approx_equal(rec, 1/3) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mt3/mixing.py b/mt3/mixing.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2c9b2682105059633bf61f0e39a2e127a356ce --- /dev/null +++ b/mt3/mixing.py @@ -0,0 +1,91 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for mixing (in the audio sense) multiple transcription examples.""" + +from typing import Callable, Optional, Sequence + +import gin + +from mt3 import event_codec +from mt3 import run_length_encoding + +import numpy as np +import seqio +import tensorflow as tf + + +@gin.configurable +def mix_transcription_examples( + ds: tf.data.Dataset, + sequence_length: seqio.preprocessors.SequenceLengthType, + output_features: seqio.preprocessors.OutputFeaturesType, + codec: event_codec.Codec, + inputs_feature_key: str = 'inputs', + targets_feature_keys: Sequence[str] = ('targets',), + max_examples_per_mix: Optional[int] = None, + shuffle_buffer_size: int = seqio.SHUFFLE_BUFFER_SIZE +) -> Callable[..., tf.data.Dataset]: + """Preprocessor that mixes together "batches" of transcription examples. + + Args: + ds: Dataset of individual transcription examples, each of which should + have an 'inputs' field containing 1D audio samples (currently only + audio encoders that use raw samples as an intermediate representation + are supported), and a 'targets' field containing run-length encoded + note events. + sequence_length: Dictionary mapping feature key to length. + output_features: Dictionary mapping feature key to spec. + codec: An event_codec.Codec used to interpret the target events. + inputs_feature_key: Feature key for inputs which will be mixed as audio. + targets_feature_keys: List of feature keys for targets, each of which will + be merged (separately) as run-length encoded note events. + max_examples_per_mix: Maximum number of individual examples to mix together. + shuffle_buffer_size: Size of shuffle buffer to use for shuffle prior to + mixing. + + Returns: + Dataset containing mixed examples. + """ + if max_examples_per_mix is None: + return ds + + # TODO(iansimon): is there a way to use seqio's seed? + ds = tf.data.Dataset.sample_from_datasets([ + ds.shuffle( + buffer_size=shuffle_buffer_size // max_examples_per_mix + ).padded_batch(batch_size=i) for i in range(1, max_examples_per_mix + 1) + ]) + + def mix_inputs(ex): + samples = tf.reduce_sum(ex[inputs_feature_key], axis=0) + norm = tf.linalg.norm(samples, ord=np.inf) + ex[inputs_feature_key] = tf.math.divide_no_nan(samples, norm) + return ex + ds = ds.map(mix_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + max_tokens = sequence_length['targets'] + if output_features['targets'].add_eos: + # Leave room to insert an EOS token. + max_tokens -= 1 + + def mix_targets(ex): + for k in targets_feature_keys: + ex[k] = run_length_encoding.merge_run_length_encoded_targets( + targets=ex[k], + codec=codec) + return ex + ds = ds.map(mix_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + return ds diff --git a/mt3/models.py b/mt3/models.py new file mode 100644 index 0000000000000000000000000000000000000000..7986a18b1425c2a7869985ad986d4204fb8f8e6b --- /dev/null +++ b/mt3/models.py @@ -0,0 +1,152 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Feature converter and model for continuous inputs.""" + +from typing import Mapping +import seqio +from t5x import decoding +from t5x import models +import tensorflow as tf + + +class ContinuousInputsEncDecFeatureConverter(seqio.FeatureConverter): + """Feature converter for an encoder-decoder with continuous inputs.""" + + TASK_FEATURES = { + "inputs": seqio.FeatureConverter.FeatureSpec(dtype=tf.float32, rank=2), + "targets": seqio.FeatureConverter.FeatureSpec(dtype=tf.int32), + } + MODEL_FEATURES = { + "encoder_input_tokens": + seqio.FeatureConverter.FeatureSpec(dtype=tf.float32, rank=2), + "decoder_target_tokens": + seqio.FeatureConverter.FeatureSpec(dtype=tf.int32), + "decoder_input_tokens": + seqio.FeatureConverter.FeatureSpec(dtype=tf.int32), + "decoder_loss_weights": + seqio.FeatureConverter.FeatureSpec(dtype=tf.int32), + } + PACKING_FEATURE_DTYPES = { + "encoder_segment_ids": tf.int32, + "decoder_segment_ids": tf.int32, + "encoder_positions": tf.int32, + "decoder_positions": tf.int32 + } + + def _convert_features( + self, ds: tf.data.Dataset, + task_feature_lengths: Mapping[str, int]) -> tf.data.Dataset: + """Convert the dataset to be fed to the encoder-decoder model. + + The conversion process involves three steps + + 1. Each feature in the `task_feature_lengths` is trimmed/padded and + optionally packed depending on the value of self.pack. + 2. "inputs" fields are mapped to the encoder input and "targets" are mapped + to decoder input (after being shifted) and target. + + All the keys in the `task_feature_lengths` should be present in the input + dataset, which may contain some extra features that are not in the + `task_feature_lengths`. They will not be included in the output dataset. + One common scenario is the "inputs_pretokenized" and "targets_pretokenized" + fields. + + Args: + ds: an input tf.data.Dataset to be converted. + task_feature_lengths: a mapping from feature to its length. + + Returns: + ds: the converted dataset. + """ + + def convert_example( + features: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: + # targets_segment_id is present only for a packed dataset. + decoder_input_tokens = seqio.autoregressive_inputs( + features["targets"], + sequence_id=features.get("targets_segment_ids", None)) + + d = {"encoder_input_tokens": features["inputs"], + "decoder_target_tokens": features["targets"], + "decoder_input_tokens": decoder_input_tokens, + # Loss is computed for all but the padding positions. + "decoder_loss_weights": + seqio.non_padding_position(features["targets"])} + + if self.pack: + d["encoder_segment_ids"] = features["inputs_segment_ids"] + d["decoder_segment_ids"] = features["targets_segment_ids"] + d["encoder_positions"] = features["inputs_positions"] + d["decoder_positions"] = features["targets_positions"] + + return d + + ds = self._pack_or_pad(ds, task_feature_lengths) + return ds.map( + convert_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + def get_model_feature_lengths( + self, task_feature_lengths: Mapping[str, int]) -> Mapping[str, int]: + """Define the length relationship between input and output features.""" + encoder_length = task_feature_lengths["inputs"] + decoder_length = task_feature_lengths["targets"] + + model_feature_lengths = { + "encoder_input_tokens": encoder_length, + "decoder_target_tokens": decoder_length, + "decoder_input_tokens": decoder_length, + "decoder_loss_weights": decoder_length + } + if self.pack: + model_feature_lengths["encoder_segment_ids"] = encoder_length + model_feature_lengths["decoder_segment_ids"] = decoder_length + model_feature_lengths["encoder_positions"] = encoder_length + model_feature_lengths["decoder_positions"] = decoder_length + + return model_feature_lengths + + +class ContinuousInputsEncoderDecoderModel(models.EncoderDecoderModel): + """Encoder-decoder model with continuous inputs.""" + + FEATURE_CONVERTER_CLS = ContinuousInputsEncDecFeatureConverter + + def __init__(self, module, input_vocabulary, output_vocabulary, optimizer_def, + input_depth, decode_fn=decoding.beam_search, label_smoothing=0.0, + z_loss=0.0, loss_normalizing_factor=None): + super().__init__( + module=module, + input_vocabulary=input_vocabulary, + output_vocabulary=output_vocabulary, + optimizer_def=optimizer_def, + decode_fn=decode_fn, + label_smoothing=label_smoothing, + z_loss=z_loss, + loss_normalizing_factor=loss_normalizing_factor) + self._input_depth = input_depth + + def get_initial_variables(self, rng, input_shapes, input_types=None): + """Hacky override to bypass eval/infer inability to handle rank-3 inputs.""" + encoder_shape = input_shapes["encoder_input_tokens"] + if len(encoder_shape) == 2: + input_shapes = { + "encoder_input_tokens": (*encoder_shape, self._input_depth), + **{k: v for k, v in input_shapes.items() + if k != "encoder_input_tokens"} + } + else: + assert encoder_shape[-1] == self._input_depth + return super().get_initial_variables( + rng=rng, input_shapes=input_shapes, input_types=input_types) diff --git a/mt3/network.py b/mt3/network.py new file mode 100644 index 0000000000000000000000000000000000000000..19c0ac03194d72f740714013529c089e317e7b99 --- /dev/null +++ b/mt3/network.py @@ -0,0 +1,409 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5.1.1 Transformer model.""" + +from typing import Any, Sequence + +from flax import linen as nn +from flax import struct +import jax.numpy as jnp +from mt3 import layers + + +@struct.dataclass +class T5Config: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int + # Activation dtypes. + dtype: Any = jnp.float32 + emb_dim: int = 512 + num_heads: int = 8 + num_encoder_layers: int = 6 + num_decoder_layers: int = 6 + head_dim: int = 64 + mlp_dim: int = 2048 + # Activation functions are retrieved from Flax. + mlp_activations: Sequence[str] = ('relu',) + dropout_rate: float = 0.1 + # If `True`, the embedding weights are used in the decoder output layer. + logits_via_embedding: bool = False + + +class EncoderLayer(nn.Module): + """Transformer encoder layer.""" + config: T5Config + + @nn.compact + def __call__(self, inputs, encoder_mask=None, deterministic=False): + cfg = self.config + + # Attention block. + assert inputs.ndim == 3 + x = layers.LayerNorm( + dtype=cfg.dtype, name='pre_attention_layer_norm')( + inputs) + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + name='attention')( + x, x, encoder_mask, deterministic=deterministic) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x + inputs + + # MLP block. + y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) + # [batch, length, emb_dim] -> [batch, length, emb_dim] + y = layers.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + name='mlp', + )(y, deterministic=deterministic) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y + x + + return y + + +class DecoderLayer(nn.Module): + """Transformer decoder layer that attends to the encoder.""" + config: T5Config + + @nn.compact + def __call__(self, + inputs, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + deterministic=False, + decode=False, + max_decode_length=None): + cfg = self.config + + # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] + x = layers.LayerNorm( + dtype=cfg.dtype, name='pre_self_attention_layer_norm')( + inputs) + + # Self-attention block + x = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + name='self_attention')( + x, + x, + decoder_mask, + deterministic=deterministic, + decode=decode) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x + inputs + + # Encoder-Decoder block. + y = layers.LayerNorm( + dtype=cfg.dtype, name='pre_cross_attention_layer_norm')( + x) + y = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + name='encoder_decoder_attention')( + y, encoded, encoder_decoder_mask, deterministic=deterministic) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y + x + + # MLP block. + z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y) + z = layers.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + name='mlp', + )(z, deterministic=deterministic) + z = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + z, deterministic=deterministic) + z = z + y + + return z + + +class Encoder(nn.Module): + """A stack of encoder layers.""" + config: T5Config + + @nn.compact + def __call__(self, + encoder_input_tokens, + encoder_mask=None, + deterministic=False): + cfg = self.config + assert encoder_input_tokens.ndim == 3 # [batch, length, depth] + + seq_length = encoder_input_tokens.shape[-2] + inputs_positions = jnp.arange(seq_length)[None, :] + + # [batch, length, depth] -> [batch, length, emb_dim] + x = layers.DenseGeneral( + cfg.emb_dim, + dtype=cfg.dtype, + kernel_init=nn.linear.default_kernel_init, + kernel_axes=('vocab', 'embed'), + name='continuous_inputs_projection')(encoder_input_tokens) + x = x + layers.FixedEmbed(features=cfg.emb_dim)(inputs_positions) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x.astype(cfg.dtype) + + for lyr in range(cfg.num_encoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = EncoderLayer( + config=cfg, + name=f'layers_{lyr}')(x, encoder_mask, deterministic) + + x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) + return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) + + +class Decoder(nn.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture.""" + config: T5Config + + @nn.compact + def __call__(self, + encoded, + decoder_input_tokens, + decoder_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + deterministic=False, + decode=False, + max_decode_length=None): + cfg = self.config + assert decoder_input_tokens.ndim == 2 # [batch, len] + + seq_length = decoder_input_tokens.shape[-1] + decoder_positions = jnp.arange(seq_length)[None, :] + + # [batch, length] -> [batch, length, emb_dim] + y = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + attend_dtype=jnp.float32, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + one_hot=True, + name='token_embedder')(decoder_input_tokens.astype('int32')) + y = y + layers.FixedEmbed(features=cfg.emb_dim)( + decoder_positions, decode=decode) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + for lyr in range(cfg.num_decoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + y = DecoderLayer( + config=cfg, name=f'layers_{lyr}')( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + deterministic=deterministic, + decode=decode, + max_decode_length=max_decode_length) + + y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + + # [batch, length, emb_dim] -> [batch, length, vocab_size] + if cfg.logits_via_embedding: + # Use the transpose of embedding matrix for logit transform. + logits = self.shared_embedding.attend(y) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + else: + logits = layers.DenseGeneral( + cfg.vocab_size, + dtype=jnp.float32, # Use float32 for stabiliity. + kernel_axes=('embed', 'vocab'), + name='logits_dense')( + y) + return logits + + +class Transformer(nn.Module): + """An encoder-decoder Transformer model.""" + config: T5Config + + def setup(self): + cfg = self.config + + self.encoder = Encoder(config=cfg) + self.decoder = Decoder(config=cfg) + + def encode(self, + encoder_input_tokens, + encoder_segment_ids=None, + enable_dropout=True): + """Applies Transformer encoder-branch on the inputs.""" + cfg = self.config + assert encoder_input_tokens.ndim == 3 # (batch, length, depth) + + # Make padding attention mask; we don't actually mask out any input + # positions, letting the model potentially attend to the zero vector used as + # padding. + encoder_mask = layers.make_attention_mask( + jnp.ones(encoder_input_tokens.shape[:-1]), + jnp.ones(encoder_input_tokens.shape[:-1]), + dtype=cfg.dtype) + # Add segmentation block-diagonal attention mask if using segmented data. + if encoder_segment_ids is not None: + encoder_mask = layers.combine_masks( + encoder_mask, + layers.make_attention_mask( + encoder_segment_ids, + encoder_segment_ids, + jnp.equal, + dtype=cfg.dtype)) + + return self.encoder( + encoder_input_tokens, encoder_mask, deterministic=not enable_dropout) + + def decode( + self, + encoded, + encoder_input_tokens, # only needed for masks + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + decoder_positions=None, + enable_dropout=True, + decode=False, + max_decode_length=None): + """Applies Transformer decoder-branch on encoded-input and target.""" + cfg = self.config + + # Make padding attention masks. + if decode: + # Do not mask decoder attention based on targets padding at + # decoding/inference time. + decoder_mask = None + encoder_decoder_mask = layers.make_attention_mask( + jnp.ones_like(decoder_target_tokens), + jnp.ones(encoder_input_tokens.shape[:-1]), + dtype=cfg.dtype) + else: + decoder_mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=cfg.dtype, + decoder_segment_ids=decoder_segment_ids) + encoder_decoder_mask = layers.make_attention_mask( + decoder_target_tokens > 0, + jnp.ones(encoder_input_tokens.shape[:-1]), + dtype=cfg.dtype) + + # Add segmentation block-diagonal attention masks if using segmented data. + if encoder_segment_ids is not None: + if decode: + raise ValueError( + 'During decoding, packing should not be used but ' + '`encoder_segment_ids` was passed to `Transformer.decode`.') + + encoder_decoder_mask = layers.combine_masks( + encoder_decoder_mask, + layers.make_attention_mask( + decoder_segment_ids, + encoder_segment_ids, + jnp.equal, + dtype=cfg.dtype)) + + logits = self.decoder( + encoded, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + deterministic=not enable_dropout, + decode=decode, + max_decode_length=max_decode_length) + return logits.astype(self.config.dtype) + + def __call__(self, + encoder_input_tokens, + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + encoder_positions=None, + decoder_positions=None, + *, + enable_dropout: bool = True, + decode: bool = False): + """Applies Transformer model on the inputs. + + This method requires both decoder_target_tokens and decoder_input_tokens, + which is a shifted version of the former. For a packed dataset, it usually + has additional processing applied. For example, the first element of each + sequence has id 0 instead of the shifted EOS id from the previous sequence. + + Args: + encoder_input_tokens: input data to the encoder. + decoder_input_tokens: input token to the decoder. + decoder_target_tokens: target token to the decoder. + encoder_segment_ids: encoder segmentation info for packed examples. + decoder_segment_ids: decoder segmentation info for packed examples. + encoder_positions: encoder subsequence positions for packed examples. + decoder_positions: decoder subsequence positions for packed examples. + enable_dropout: Ensables dropout if set to True. + decode: Whether to prepare and use an autoregressive cache. + + Returns: + logits array from full transformer. + """ + encoded = self.encode( + encoder_input_tokens, + encoder_segment_ids=encoder_segment_ids, + enable_dropout=enable_dropout) + + return self.decode( + encoded, + encoder_input_tokens, # only used for masks + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=encoder_segment_ids, + decoder_segment_ids=decoder_segment_ids, + decoder_positions=decoder_positions, + enable_dropout=enable_dropout, + decode=decode) diff --git a/mt3/note_sequences.py b/mt3/note_sequences.py new file mode 100644 index 0000000000000000000000000000000000000000..bac1a137ddcce6fcd35c5d3b639372be05920fb6 --- /dev/null +++ b/mt3/note_sequences.py @@ -0,0 +1,446 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions that operate on NoteSequence protos.""" + +import dataclasses +import itertools + +from typing import MutableMapping, MutableSet, Optional, Sequence, Tuple + +from mt3 import event_codec +from mt3 import run_length_encoding +from mt3 import vocabularies + +import note_seq + +DEFAULT_VELOCITY = 100 +DEFAULT_NOTE_DURATION = 0.01 + +# Quantization can result in zero-length notes; enforce a minimum duration. +MIN_NOTE_DURATION = 0.01 + + +@dataclasses.dataclass +class TrackSpec: + name: str + program: int = 0 + is_drum: bool = False + + +def extract_track(ns, program, is_drum): + track = note_seq.NoteSequence(ticks_per_quarter=220) + track_notes = [note for note in ns.notes + if note.program == program and note.is_drum == is_drum] + track.notes.extend(track_notes) + track.total_time = (max(note.end_time for note in track.notes) + if track.notes else 0.0) + return track + + +def trim_overlapping_notes(ns: note_seq.NoteSequence) -> note_seq.NoteSequence: + """Trim overlapping notes from a NoteSequence, dropping zero-length notes.""" + ns_trimmed = note_seq.NoteSequence() + ns_trimmed.CopyFrom(ns) + channels = set((note.pitch, note.program, note.is_drum) + for note in ns_trimmed.notes) + for pitch, program, is_drum in channels: + notes = [note for note in ns_trimmed.notes if note.pitch == pitch + and note.program == program and note.is_drum == is_drum] + sorted_notes = sorted(notes, key=lambda note: note.start_time) + for i in range(1, len(sorted_notes)): + if sorted_notes[i - 1].end_time > sorted_notes[i].start_time: + sorted_notes[i - 1].end_time = sorted_notes[i].start_time + valid_notes = [note for note in ns_trimmed.notes + if note.start_time < note.end_time] + del ns_trimmed.notes[:] + ns_trimmed.notes.extend(valid_notes) + return ns_trimmed + + +def assign_instruments(ns: note_seq.NoteSequence) -> None: + """Assign instrument numbers to notes; modifies NoteSequence in place.""" + program_instruments = {} + for note in ns.notes: + if note.program not in program_instruments and not note.is_drum: + num_instruments = len(program_instruments) + note.instrument = (num_instruments if num_instruments < 9 + else num_instruments + 1) + program_instruments[note.program] = note.instrument + elif note.is_drum: + note.instrument = 9 + else: + note.instrument = program_instruments[note.program] + + +def validate_note_sequence(ns: note_seq.NoteSequence) -> None: + """Raise ValueError if NoteSequence contains invalid notes.""" + for note in ns.notes: + if note.start_time >= note.end_time: + raise ValueError('note has start time >= end time: %f >= %f' % + (note.start_time, note.end_time)) + if note.velocity == 0: + raise ValueError('note has zero velocity') + + +def note_arrays_to_note_sequence( + onset_times: Sequence[float], + pitches: Sequence[int], + offset_times: Optional[Sequence[float]] = None, + velocities: Optional[Sequence[int]] = None, + programs: Optional[Sequence[int]] = None, + is_drums: Optional[Sequence[bool]] = None +) -> note_seq.NoteSequence: + """Convert note onset / offset / pitch / velocity arrays to NoteSequence.""" + ns = note_seq.NoteSequence(ticks_per_quarter=220) + for onset_time, offset_time, pitch, velocity, program, is_drum in itertools.zip_longest( + onset_times, [] if offset_times is None else offset_times, + pitches, [] if velocities is None else velocities, + [] if programs is None else programs, + [] if is_drums is None else is_drums): + if offset_time is None: + offset_time = onset_time + DEFAULT_NOTE_DURATION + if velocity is None: + velocity = DEFAULT_VELOCITY + if program is None: + program = 0 + if is_drum is None: + is_drum = False + ns.notes.add( + start_time=onset_time, + end_time=offset_time, + pitch=pitch, + velocity=velocity, + program=program, + is_drum=is_drum) + ns.total_time = max(ns.total_time, offset_time) + assign_instruments(ns) + return ns + + +@dataclasses.dataclass +class NoteEventData: + pitch: int + velocity: Optional[int] = None + program: Optional[int] = None + is_drum: Optional[bool] = None + instrument: Optional[int] = None + + +def note_sequence_to_onsets( + ns: note_seq.NoteSequence +) -> Tuple[Sequence[float], Sequence[NoteEventData]]: + """Extract note onsets and pitches from NoteSequence proto.""" + # Sort by pitch to use as a tiebreaker for subsequent stable sort. + notes = sorted(ns.notes, key=lambda note: note.pitch) + return ([note.start_time for note in notes], + [NoteEventData(pitch=note.pitch) for note in notes]) + + +def note_sequence_to_onsets_and_offsets( + ns: note_seq.NoteSequence, +) -> Tuple[Sequence[float], Sequence[NoteEventData]]: + """Extract onset & offset times and pitches from a NoteSequence proto. + + The onset & offset times will not necessarily be in sorted order. + + Args: + ns: NoteSequence from which to extract onsets and offsets. + + Returns: + times: A list of note onset and offset times. + values: A list of NoteEventData objects where velocity is zero for note + offsets. + """ + # Sort by pitch and put offsets before onsets as a tiebreaker for subsequent + # stable sort. + notes = sorted(ns.notes, key=lambda note: note.pitch) + times = ([note.end_time for note in notes] + + [note.start_time for note in notes]) + values = ([NoteEventData(pitch=note.pitch, velocity=0) for note in notes] + + [NoteEventData(pitch=note.pitch, velocity=note.velocity) + for note in notes]) + return times, values + + +def note_sequence_to_onsets_and_offsets_and_programs( + ns: note_seq.NoteSequence, +) -> Tuple[Sequence[float], Sequence[NoteEventData]]: + """Extract onset & offset times and pitches & programs from a NoteSequence. + + The onset & offset times will not necessarily be in sorted order. + + Args: + ns: NoteSequence from which to extract onsets and offsets. + + Returns: + times: A list of note onset and offset times. + values: A list of NoteEventData objects where velocity is zero for note + offsets. + """ + # Sort by program and pitch and put offsets before onsets as a tiebreaker for + # subsequent stable sort. + notes = sorted(ns.notes, + key=lambda note: (note.is_drum, note.program, note.pitch)) + times = ([note.end_time for note in notes if not note.is_drum] + + [note.start_time for note in notes]) + values = ([NoteEventData(pitch=note.pitch, velocity=0, + program=note.program, is_drum=False) + for note in notes if not note.is_drum] + + [NoteEventData(pitch=note.pitch, velocity=note.velocity, + program=note.program, is_drum=note.is_drum) + for note in notes]) + return times, values + + +@dataclasses.dataclass +class NoteEncodingState: + """Encoding state for note transcription, keeping track of active pitches.""" + # velocity bin for active pitches and programs + active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field( + default_factory=dict) + + +def note_event_data_to_events( + state: Optional[NoteEncodingState], + value: NoteEventData, + codec: event_codec.Codec, +) -> Sequence[event_codec.Event]: + """Convert note event data to a sequence of events.""" + if value.velocity is None: + # onsets only, no program or velocity + return [event_codec.Event('pitch', value.pitch)] + else: + num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec) + velocity_bin = vocabularies.velocity_to_bin( + value.velocity, num_velocity_bins) + if value.program is None: + # onsets + offsets + velocities only, no programs + if state is not None: + state.active_pitches[(value.pitch, 0)] = velocity_bin + return [event_codec.Event('velocity', velocity_bin), + event_codec.Event('pitch', value.pitch)] + else: + if value.is_drum: + # drum events use a separate vocabulary + return [event_codec.Event('velocity', velocity_bin), + event_codec.Event('drum', value.pitch)] + else: + # program + velocity + pitch + if state is not None: + state.active_pitches[(value.pitch, value.program)] = velocity_bin + return [event_codec.Event('program', value.program), + event_codec.Event('velocity', velocity_bin), + event_codec.Event('pitch', value.pitch)] + + +def note_encoding_state_to_events( + state: NoteEncodingState +) -> Sequence[event_codec.Event]: + """Output program and pitch events for active notes plus a final tie event.""" + events = [] + for pitch, program in sorted( + state.active_pitches.keys(), key=lambda k: k[::-1]): + if state.active_pitches[(pitch, program)]: + events += [event_codec.Event('program', program), + event_codec.Event('pitch', pitch)] + events.append(event_codec.Event('tie', 0)) + return events + + +@dataclasses.dataclass +class NoteDecodingState: + """Decoding state for note transcription.""" + current_time: float = 0.0 + # velocity to apply to subsequent pitch events (zero for note-off) + current_velocity: int = DEFAULT_VELOCITY + # program to apply to subsequent pitch events + current_program: int = 0 + # onset time and velocity for active pitches and programs + active_pitches: MutableMapping[Tuple[int, int], + Tuple[float, int]] = dataclasses.field( + default_factory=dict) + # pitches (with programs) to continue from previous segment + tied_pitches: MutableSet[Tuple[int, int]] = dataclasses.field( + default_factory=set) + # whether or not we are in the tie section at the beginning of a segment + is_tie_section: bool = False + # partially-decoded NoteSequence + note_sequence: note_seq.NoteSequence = dataclasses.field( + default_factory=lambda: note_seq.NoteSequence(ticks_per_quarter=220)) + + +def decode_note_onset_event( + state: NoteDecodingState, + time: float, + event: event_codec.Event, + codec: event_codec.Codec, +) -> None: + """Process note onset event and update decoding state.""" + if event.type == 'pitch': + state.note_sequence.notes.add( + start_time=time, end_time=time + DEFAULT_NOTE_DURATION, + pitch=event.value, velocity=DEFAULT_VELOCITY) + state.note_sequence.total_time = max(state.note_sequence.total_time, + time + DEFAULT_NOTE_DURATION) + else: + raise ValueError('unexpected event type: %s' % event.type) + + +def _add_note_to_sequence( + ns: note_seq.NoteSequence, + start_time: float, end_time: float, pitch: int, velocity: int, + program: int = 0, is_drum: bool = False +) -> None: + end_time = max(end_time, start_time + MIN_NOTE_DURATION) + ns.notes.add( + start_time=start_time, end_time=end_time, + pitch=pitch, velocity=velocity, program=program, is_drum=is_drum) + ns.total_time = max(ns.total_time, end_time) + + +def decode_note_event( + state: NoteDecodingState, + time: float, + event: event_codec.Event, + codec: event_codec.Codec +) -> None: + """Process note event and update decoding state.""" + if time < state.current_time: + raise ValueError('event time < current time, %f < %f' % ( + time, state.current_time)) + state.current_time = time + if event.type == 'pitch': + pitch = event.value + if state.is_tie_section: + # "tied" pitch + if (pitch, state.current_program) not in state.active_pitches: + raise ValueError('inactive pitch/program in tie section: %d/%d' % + (pitch, state.current_program)) + if (pitch, state.current_program) in state.tied_pitches: + raise ValueError('pitch/program is already tied: %d/%d' % + (pitch, state.current_program)) + state.tied_pitches.add((pitch, state.current_program)) + elif state.current_velocity == 0: + # note offset + if (pitch, state.current_program) not in state.active_pitches: + raise ValueError('note-off for inactive pitch/program: %d/%d' % + (pitch, state.current_program)) + onset_time, onset_velocity = state.active_pitches.pop( + (pitch, state.current_program)) + _add_note_to_sequence( + state.note_sequence, start_time=onset_time, end_time=time, + pitch=pitch, velocity=onset_velocity, program=state.current_program) + else: + # note onset + if (pitch, state.current_program) in state.active_pitches: + # The pitch is already active; this shouldn't really happen but we'll + # try to handle it gracefully by ending the previous note and starting a + # new one. + onset_time, onset_velocity = state.active_pitches.pop( + (pitch, state.current_program)) + _add_note_to_sequence( + state.note_sequence, start_time=onset_time, end_time=time, + pitch=pitch, velocity=onset_velocity, program=state.current_program) + state.active_pitches[(pitch, state.current_program)] = ( + time, state.current_velocity) + elif event.type == 'drum': + # drum onset (drums have no offset) + if state.current_velocity == 0: + raise ValueError('velocity cannot be zero for drum event') + offset_time = time + DEFAULT_NOTE_DURATION + _add_note_to_sequence( + state.note_sequence, start_time=time, end_time=offset_time, + pitch=event.value, velocity=state.current_velocity, is_drum=True) + elif event.type == 'velocity': + # velocity change + num_velocity_bins = vocabularies.num_velocity_bins_from_codec(codec) + velocity = vocabularies.bin_to_velocity(event.value, num_velocity_bins) + state.current_velocity = velocity + elif event.type == 'program': + # program change + state.current_program = event.value + elif event.type == 'tie': + # end of tie section; end active notes that weren't declared tied + if not state.is_tie_section: + raise ValueError('tie section end event when not in tie section') + for (pitch, program) in list(state.active_pitches.keys()): + if (pitch, program) not in state.tied_pitches: + onset_time, onset_velocity = state.active_pitches.pop((pitch, program)) + _add_note_to_sequence( + state.note_sequence, + start_time=onset_time, end_time=state.current_time, + pitch=pitch, velocity=onset_velocity, program=program) + state.is_tie_section = False + else: + raise ValueError('unexpected event type: %s' % event.type) + + +def begin_tied_pitches_section(state: NoteDecodingState) -> None: + """Begin the tied pitches section at the start of a segment.""" + state.tied_pitches = set() + state.is_tie_section = True + + +def flush_note_decoding_state( + state: NoteDecodingState +) -> note_seq.NoteSequence: + """End all active notes and return resulting NoteSequence.""" + for onset_time, _ in state.active_pitches.values(): + state.current_time = max(state.current_time, onset_time + MIN_NOTE_DURATION) + for (pitch, program) in list(state.active_pitches.keys()): + onset_time, onset_velocity = state.active_pitches.pop((pitch, program)) + _add_note_to_sequence( + state.note_sequence, start_time=onset_time, end_time=state.current_time, + pitch=pitch, velocity=onset_velocity, program=program) + assign_instruments(state.note_sequence) + return state.note_sequence + + +class NoteEncodingSpecType(run_length_encoding.EventEncodingSpec): + pass + + +# encoding spec for modeling note onsets only +NoteOnsetEncodingSpec = NoteEncodingSpecType( + init_encoding_state_fn=lambda: None, + encode_event_fn=note_event_data_to_events, + encoding_state_to_events_fn=None, + init_decoding_state_fn=NoteDecodingState, + begin_decoding_segment_fn=lambda state: None, + decode_event_fn=decode_note_onset_event, + flush_decoding_state_fn=lambda state: state.note_sequence) + + +# encoding spec for modeling onsets and offsets +NoteEncodingSpec = NoteEncodingSpecType( + init_encoding_state_fn=lambda: None, + encode_event_fn=note_event_data_to_events, + encoding_state_to_events_fn=None, + init_decoding_state_fn=NoteDecodingState, + begin_decoding_segment_fn=lambda state: None, + decode_event_fn=decode_note_event, + flush_decoding_state_fn=flush_note_decoding_state) + + +# encoding spec for modeling onsets and offsets, with a "tie" section at the +# beginning of each segment listing already-active notes +NoteEncodingWithTiesSpec = NoteEncodingSpecType( + init_encoding_state_fn=NoteEncodingState, + encode_event_fn=note_event_data_to_events, + encoding_state_to_events_fn=note_encoding_state_to_events, + init_decoding_state_fn=NoteDecodingState, + begin_decoding_segment_fn=begin_tied_pitches_section, + decode_event_fn=decode_note_event, + flush_decoding_state_fn=flush_note_decoding_state) diff --git a/mt3/note_sequences_test.py b/mt3/note_sequences_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7cda0c7ee6043591fa85adcda2712b566fd54889 --- /dev/null +++ b/mt3/note_sequences_test.py @@ -0,0 +1,505 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for note_sequences.""" + +from mt3 import event_codec +from mt3 import note_sequences +from mt3 import run_length_encoding + +import note_seq +import numpy as np +import tensorflow as tf + +codec = event_codec.Codec( + max_shift_steps=100, + steps_per_second=100, + event_ranges=[ + event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('velocity', 0, 127), + event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM, + note_seq.MAX_MIDI_PROGRAM), + event_codec.EventRange('tie', 0, 0) + ]) + + +class RunLengthEncodingTest(tf.test.TestCase): + + def test_encode_and_index_note_sequence(self): + ns = note_seq.NoteSequence() + ns.notes.add(start_time=1.0, + end_time=1.1, + pitch=61, + velocity=100) + ns.notes.add(start_time=2.0, + end_time=2.1, + pitch=62, + velocity=100) + ns.notes.add(start_time=3.0, + end_time=3.1, + pitch=63, + velocity=100) + ns.total_time = ns.notes[-1].end_time + + frame_times = np.arange(0, 4, step=.001) + + event_times, event_values = note_sequences.note_sequence_to_onsets(ns) + events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events( + state=None, event_times=event_times, event_values=event_values, + encode_event_fn=note_sequences.note_event_data_to_events, + codec=codec, frame_times=frame_times) + + self.assertEqual(len(frame_times), len(event_start_indices)) + self.assertEqual(len(frame_times), len(event_end_indices)) + self.assertLen(events, 403) + expected_events = ([1] * 100 + + [162] + + [1] * 100 + + [163] + + [1] * 100 + + [164] + + [1] * 100) + np.testing.assert_array_equal(expected_events, events) + + self.assertEqual(event_start_indices[0], 0) + self.assertEqual(event_end_indices[0], 0) + + self.assertEqual(162, events[100]) + self.assertEqual(1.0, frame_times[1000]) + self.assertEqual(event_start_indices[1000], 100) + self.assertEqual(event_end_indices[1000], 100) + + self.assertEqual(163, events[201]) + self.assertEqual(2.0, frame_times[2000]) + self.assertEqual(event_start_indices[2000], 201) + self.assertEqual(event_end_indices[2000], 201) + + self.assertEqual(164, events[302]) + self.assertEqual(3.0, frame_times[3000]) + self.assertEqual(event_start_indices[3000], 302) + self.assertEqual(event_end_indices[3000], 302) + + self.assertEqual(1, events[-1]) + self.assertEqual(3.999, frame_times[-1]) + self.assertEqual(event_start_indices[-1], 402) + self.assertEqual(event_end_indices[-1], len(expected_events)) + + def test_encode_and_index_note_sequence_velocity(self): + ns = note_seq.NoteSequence() + ns.notes.add(start_time=1.0, + end_time=3.0, + pitch=61, + velocity=1) + ns.notes.add(start_time=2.0, + end_time=4.0, + pitch=62, + velocity=127) + ns.total_time = ns.notes[-1].end_time + + frame_times = np.arange(0, 4, step=.001) + + event_times, event_values = ( + note_sequences.note_sequence_to_onsets_and_offsets(ns)) + events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events( + state=None, event_times=event_times, event_values=event_values, + encode_event_fn=note_sequences.note_event_data_to_events, + codec=codec, frame_times=frame_times) + + self.assertEqual(len(frame_times), len(event_start_indices)) + self.assertEqual(len(frame_times), len(event_end_indices)) + self.assertLen(events, 408) + expected_events = ([1] * 100 + + [230, 162] + + [1] * 100 + + [356, 163] + + [1] * 100 + + [229, 162] + + [1] * 100 + + [229, 163]) + np.testing.assert_array_equal(expected_events, events) + + self.assertEqual(event_start_indices[0], 0) + self.assertEqual(event_end_indices[0], 0) + + self.assertEqual(230, events[100]) + self.assertEqual(162, events[101]) + self.assertEqual(1.0, frame_times[1000]) + self.assertEqual(event_start_indices[1000], 100) + self.assertEqual(event_end_indices[1000], 100) + + self.assertEqual(356, events[202]) + self.assertEqual(163, events[203]) + self.assertEqual(2.0, frame_times[2000]) + self.assertEqual(event_start_indices[2000], 202) + self.assertEqual(event_end_indices[2000], 202) + + self.assertEqual(229, events[304]) + self.assertEqual(162, events[305]) + self.assertEqual(3.0, frame_times[3000]) + self.assertEqual(event_start_indices[3000], 304) + self.assertEqual(event_end_indices[3000], 304) + + self.assertEqual(229, events[406]) + self.assertEqual(163, events[407]) + self.assertEqual(3.999, frame_times[-1]) + self.assertEqual(event_start_indices[-1], 405) + self.assertEqual(event_end_indices[-1], len(expected_events)) + + def test_encode_and_index_note_sequence_multitrack(self): + ns = note_seq.NoteSequence() + ns.notes.add(start_time=0.0, + end_time=1.0, + pitch=37, + velocity=127, + is_drum=True) + ns.notes.add(start_time=1.0, + end_time=3.0, + pitch=61, + velocity=127, + program=0) + ns.notes.add(start_time=2.0, + end_time=4.0, + pitch=62, + velocity=127, + program=40) + ns.total_time = ns.notes[-1].end_time + + frame_times = np.arange(0, 4, step=.001) + + event_times, event_values = ( + note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)) + (tokens, event_start_indices, event_end_indices, state_tokens, + state_event_indices) = run_length_encoding.encode_and_index_events( + state=note_sequences.NoteEncodingState(), + event_times=event_times, event_values=event_values, + encode_event_fn=note_sequences.note_event_data_to_events, + codec=codec, frame_times=frame_times, + encoding_state_to_events_fn=( + note_sequences.note_encoding_state_to_events)) + + self.assertEqual(len(frame_times), len(event_start_indices)) + self.assertEqual(len(frame_times), len(event_end_indices)) + self.assertEqual(len(frame_times), len(state_event_indices)) + self.assertLen(tokens, 414) + + expected_events = ( + [event_codec.Event('velocity', 127), event_codec.Event('drum', 37)] + + [event_codec.Event('shift', 1)] * 100 + + [event_codec.Event('program', 0), + event_codec.Event('velocity', 127), event_codec.Event('pitch', 61)] + + [event_codec.Event('shift', 1)] * 100 + + [event_codec.Event('program', 40), + event_codec.Event('velocity', 127), event_codec.Event('pitch', 62)] + + [event_codec.Event('shift', 1)] * 100 + + [event_codec.Event('program', 0), + event_codec.Event('velocity', 0), event_codec.Event('pitch', 61)] + + [event_codec.Event('shift', 1)] * 100 + + [event_codec.Event('program', 40), + event_codec.Event('velocity', 0), event_codec.Event('pitch', 62)]) + expected_tokens = [codec.encode_event(e) for e in expected_events] + np.testing.assert_array_equal(expected_tokens, tokens) + + expected_state_events = [ + event_codec.Event('tie', 0), # state prior to first drum + event_codec.Event('tie', 0), # state prior to first onset + event_codec.Event('program', 0), # state prior to second onset + event_codec.Event('pitch', 61), # | + event_codec.Event('tie', 0), # | + event_codec.Event('program', 0), # state prior to first offset + event_codec.Event('pitch', 61), # | + event_codec.Event('program', 40), # | + event_codec.Event('pitch', 62), # | + event_codec.Event('tie', 0), # | + event_codec.Event('program', 40), # state prior to second offset + event_codec.Event('pitch', 62), # | + event_codec.Event('tie', 0) # | + ] + expected_state_tokens = [codec.encode_event(e) + for e in expected_state_events] + np.testing.assert_array_equal(expected_state_tokens, state_tokens) + + self.assertEqual(event_start_indices[0], 0) + self.assertEqual(event_end_indices[0], 0) + self.assertEqual(state_event_indices[0], 0) + + self.assertEqual(1.0, frame_times[1000]) + self.assertEqual(event_start_indices[1000], 102) + self.assertEqual(event_end_indices[1000], 102) + self.assertEqual(state_event_indices[1000], 1) + + self.assertEqual(2.0, frame_times[2000]) + self.assertEqual(event_start_indices[2000], 205) + self.assertEqual(event_end_indices[2000], 205) + self.assertEqual(state_event_indices[2000], 2) + + self.assertEqual(3.0, frame_times[3000]) + self.assertEqual(event_start_indices[3000], 308) + self.assertEqual(event_end_indices[3000], 308) + self.assertEqual(state_event_indices[3000], 5) + + self.assertEqual(3.999, frame_times[-1]) + self.assertEqual(event_start_indices[-1], 410) + self.assertEqual(event_end_indices[-1], len(expected_events)) + self.assertEqual(state_event_indices[-1], 10) + + def test_encode_and_index_note_sequence_last_token_alignment(self): + ns = note_seq.NoteSequence() + ns.notes.add(start_time=0.0, + end_time=0.1, + pitch=60, + velocity=100) + ns.total_time = ns.notes[-1].end_time + + frame_times = np.arange(0, 1.008, step=.008) + + event_times, event_values = note_sequences.note_sequence_to_onsets(ns) + events, event_start_indices, event_end_indices, _, _ = run_length_encoding.encode_and_index_events( + state=None, + event_times=event_times, + event_values=event_values, + encode_event_fn=note_sequences.note_event_data_to_events, + codec=codec, + frame_times=frame_times) + + self.assertEqual(len(frame_times), len(event_start_indices)) + self.assertEqual(len(frame_times), len(event_end_indices)) + self.assertLen(events, 102) + expected_events = [161] + [1] * 101 + + np.testing.assert_array_equal(expected_events, events) + + self.assertEqual(event_start_indices[0], 0) + self.assertEqual(event_end_indices[0], 0) + self.assertEqual(event_start_indices[125], 101) + self.assertEqual(event_end_indices[125], 102) + + def test_decode_note_sequence_events(self): + events = [25, 161, 50, 162] + + decoding_state = note_sequences.NoteDecodingState() + invalid_ids, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=0, max_time=None, + codec=codec, decode_event_fn=note_sequences.decode_note_onset_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(0, invalid_ids) + self.assertEqual(0, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=60, + velocity=100, + start_time=0.25, + end_time=0.26) + expected_ns.notes.add( + pitch=61, + velocity=100, + start_time=0.50, + end_time=0.51) + expected_ns.total_time = 0.51 + self.assertProtoEquals(expected_ns, ns) + + def test_decode_note_sequence_events_onsets_only(self): + events = [5, 161, 25, 162] + + decoding_state = note_sequences.NoteDecodingState() + invalid_ids, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=0, max_time=None, + codec=codec, decode_event_fn=note_sequences.decode_note_onset_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(0, invalid_ids) + self.assertEqual(0, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=60, + velocity=100, + start_time=0.05, + end_time=0.06) + expected_ns.notes.add( + pitch=61, + velocity=100, + start_time=0.25, + end_time=0.26) + expected_ns.total_time = 0.26 + self.assertProtoEquals(expected_ns, ns) + + def test_decode_note_sequence_events_velocity(self): + events = [5, 356, 161, 25, 229, 161] + + decoding_state = note_sequences.NoteDecodingState() + invalid_ids, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=0, max_time=None, + codec=codec, decode_event_fn=note_sequences.decode_note_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(0, invalid_ids) + self.assertEqual(0, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=60, + velocity=127, + start_time=0.05, + end_time=0.25) + expected_ns.total_time = 0.25 + self.assertProtoEquals(expected_ns, ns) + + def test_decode_note_sequence_events_missing_offset(self): + events = [5, 356, 161, 10, 161, 25, 229, 161] + + decoding_state = note_sequences.NoteDecodingState() + invalid_ids, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=0, max_time=None, + codec=codec, decode_event_fn=note_sequences.decode_note_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(0, invalid_ids) + self.assertEqual(0, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=60, + velocity=127, + start_time=0.05, + end_time=0.10) + expected_ns.notes.add( + pitch=60, + velocity=127, + start_time=0.10, + end_time=0.25) + expected_ns.total_time = 0.25 + self.assertProtoEquals(expected_ns, ns) + + def test_decode_note_sequence_events_multitrack(self): + events = [5, 525, 356, 161, 15, 356, 394, 25, 525, 229, 161] + + decoding_state = note_sequences.NoteDecodingState() + invalid_ids, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=0, max_time=None, + codec=codec, decode_event_fn=note_sequences.decode_note_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(0, invalid_ids) + self.assertEqual(0, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=37, + velocity=127, + start_time=0.15, + end_time=0.16, + instrument=9, + is_drum=True) + expected_ns.notes.add( + pitch=60, + velocity=127, + start_time=0.05, + end_time=0.25, + program=40) + expected_ns.total_time = 0.25 + self.assertProtoEquals(expected_ns, ns) + + def test_decode_note_sequence_events_invalid_tokens(self): + events = [5, -1, 161, -2, 25, 162, 9999] + + decoding_state = note_sequences.NoteDecodingState() + invalid_events, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=0, max_time=None, + codec=codec, decode_event_fn=note_sequences.decode_note_onset_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(3, invalid_events) + self.assertEqual(0, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=60, + velocity=100, + start_time=0.05, + end_time=0.06) + expected_ns.notes.add( + pitch=61, + velocity=100, + start_time=0.25, + end_time=0.26) + expected_ns.total_time = 0.26 + self.assertProtoEquals(expected_ns, ns) + + def test_decode_note_sequence_events_allow_event_at_exactly_max_time(self): + events = [161, 25, 162] + + decoding_state = note_sequences.NoteDecodingState() + invalid_ids, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=1.0, max_time=1.25, + codec=codec, decode_event_fn=note_sequences.decode_note_onset_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(0, invalid_ids) + self.assertEqual(0, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=60, + velocity=100, + start_time=1.00, + end_time=1.01) + expected_ns.notes.add( + pitch=61, + velocity=100, + start_time=1.25, + end_time=1.26) + expected_ns.total_time = 1.26 + self.assertProtoEquals(expected_ns, ns) + + def test_decode_note_sequence_events_dropped_events(self): + events = [5, 161, 30, 162] + + decoding_state = note_sequences.NoteDecodingState() + invalid_ids, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=1.0, max_time=1.25, + codec=codec, decode_event_fn=note_sequences.decode_note_onset_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(0, invalid_ids) + self.assertEqual(2, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=60, + velocity=100, + start_time=1.05, + end_time=1.06) + expected_ns.total_time = 1.06 + self.assertProtoEquals(expected_ns, ns) + + def test_decode_note_sequence_events_invalid_events(self): + events = [25, 230, 50, 161] + + decoding_state = note_sequences.NoteDecodingState() + invalid_ids, dropped_events = run_length_encoding.decode_events( + state=decoding_state, tokens=events, start_time=0, max_time=None, + codec=codec, decode_event_fn=note_sequences.decode_note_onset_event) + ns = note_sequences.flush_note_decoding_state(decoding_state) + + self.assertEqual(1, invalid_ids) + self.assertEqual(0, dropped_events) + expected_ns = note_seq.NoteSequence(ticks_per_quarter=220) + expected_ns.notes.add( + pitch=60, + velocity=100, + start_time=0.50, + end_time=0.51) + expected_ns.total_time = 0.51 + self.assertProtoEquals(expected_ns, ns) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mt3/preprocessors.py b/mt3/preprocessors.py new file mode 100644 index 0000000000000000000000000000000000000000..f0094d407bcec5da69238bfc0370077558407420 --- /dev/null +++ b/mt3/preprocessors.py @@ -0,0 +1,669 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transcription preprocessors.""" + +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple + +from absl import logging +import gin +from immutabledict import immutabledict +import librosa + +from mt3 import event_codec +from mt3 import note_sequences +from mt3 import run_length_encoding +from mt3 import spectrograms +from mt3 import vocabularies + +import note_seq +import numpy as np +import seqio +import tensorflow as tf + + +def add_unique_id(ds: tf.data.Dataset) -> tf.data.Dataset: + """Add unique integer ID to each example in a dataset.""" + def add_id_field(i, ex): + ex['unique_id'] = [i] + return ex + return ds.enumerate().map( + add_id_field, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + +@seqio.map_over_dataset +def pad_notesequence_array(ex): + """Pad the NoteSequence array so that it can later be "split".""" + ex['sequence'] = tf.pad(tf.expand_dims(ex['sequence'], 0), + [[0, len(ex['input_times']) - 1]]) + return ex + + +@seqio.map_over_dataset +def add_dummy_targets(ex): + """Add dummy targets; used in eval when targets are not actually used.""" + ex['targets'] = np.array([], dtype=np.int32) + return ex + + +def _audio_to_frames( + samples: Sequence[float], + spectrogram_config: spectrograms.SpectrogramConfig, +) -> Tuple[Sequence[Sequence[int]], np.ndarray]: + """Convert audio samples to non-overlapping frames and frame times.""" + frame_size = spectrogram_config.hop_width + logging.info('Padding %d samples to multiple of %d', len(samples), frame_size) + samples = np.pad(samples, + [0, frame_size - len(samples) % frame_size], + mode='constant') + + frames = spectrograms.split_audio(samples, spectrogram_config) + + num_frames = len(samples) // frame_size + logging.info('Encoded %d samples to %d frames (%d samples each)', + len(samples), num_frames, frame_size) + + times = np.arange(num_frames) / spectrogram_config.frames_per_second + return frames, times + + +def _include_inputs(ds, input_record, fields_to_omit=('audio',)): + """Include fields from input record (other than audio) in dataset records.""" + def include_inputs_fn(output_record): + for key in set(input_record.keys()) - set(output_record.keys()): + output_record[key] = input_record[key] + for key in fields_to_omit: + del output_record[key] + return output_record + return ds.map(include_inputs_fn, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + +def tokenize_transcription_example( + ds: tf.data.Dataset, spectrogram_config: spectrograms.SpectrogramConfig, + codec: event_codec.Codec, is_training_data: bool, + onsets_only: bool, include_ties: bool, audio_is_samples: bool, + id_feature_key: Optional[str] = None +) -> tf.data.Dataset: + """Tokenize a note transcription example for run-length encoding. + + Outputs include: + inputs: audio sample frames, num_frames-by-frame_size + input_time: timestamp for each frame + targets: symbolic sequence of note-related events + input_event_start_indices: start target index for every input index + input_event_end_indices: end target index for every input index + + Args: + ds: Input dataset. + spectrogram_config: Spectrogram configuration. + codec: Event vocabulary codec. + is_training_data: Unused. + onsets_only: If True, include only onset events (not offset, velocity, or + program). + include_ties: If True, also write state events containing active notes to + support a "tie" section after run-length encoding. + audio_is_samples: If True, audio is floating-point samples instead of + serialized WAV. + id_feature_key: If not None, replace sequence ID with specified key field + from the dataset. + + Returns: + Dataset with the outputs described above. + """ + del is_training_data + + if onsets_only and include_ties: + raise ValueError('Ties not supported when only modeling onsets.') + + def tokenize(sequence, audio, sample_rate, example_id=None): + ns = note_seq.NoteSequence.FromString(sequence) + note_sequences.validate_note_sequence(ns) + + if example_id is not None: + ns.id = example_id + + if audio_is_samples: + samples = audio + if sample_rate != spectrogram_config.sample_rate: + samples = librosa.resample( + samples, sample_rate, spectrogram_config.sample_rate) + else: + samples = note_seq.audio_io.wav_data_to_samples_librosa( + audio, sample_rate=spectrogram_config.sample_rate) + + logging.info('Got samples for %s::%s with length %d', + ns.id, ns.filename, len(samples)) + + frames, frame_times = _audio_to_frames(samples, spectrogram_config) + + if onsets_only: + times, values = note_sequences.note_sequence_to_onsets(ns) + else: + ns = note_seq.apply_sustain_control_changes(ns) + times, values = ( + note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)) + + # The original NoteSequence can have a lot of control changes we don't need; + # delete them. + del ns.control_changes[:] + + (events, event_start_indices, event_end_indices, + state_events, state_event_indices) = ( + run_length_encoding.encode_and_index_events( + state=note_sequences.NoteEncodingState() if include_ties else None, + event_times=times, + event_values=values, + encode_event_fn=note_sequences.note_event_data_to_events, + codec=codec, + frame_times=frame_times, + encoding_state_to_events_fn=( + note_sequences.note_encoding_state_to_events + if include_ties else None))) + + yield { + 'inputs': frames, + 'input_times': frame_times, + 'targets': events, + 'input_event_start_indices': event_start_indices, + 'input_event_end_indices': event_end_indices, + 'state_events': state_events, + 'input_state_event_indices': state_event_indices, + 'sequence': ns.SerializeToString() + } + + def process_record(input_record): + if audio_is_samples and 'sample_rate' not in input_record: + raise ValueError('Must provide sample rate when audio is samples.') + + args = [ + input_record['sequence'], + input_record['audio'], + input_record['sample_rate'] if 'sample_rate' in input_record else 0 + ] + if id_feature_key is not None: + args.append(input_record[id_feature_key]) + + ds = tf.data.Dataset.from_generator( + tokenize, + output_signature={ + 'inputs': + tf.TensorSpec( + shape=(None, spectrogram_config.hop_width), + dtype=tf.float32), + 'input_times': + tf.TensorSpec(shape=(None,), dtype=tf.float32), + 'targets': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_event_start_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_event_end_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'state_events': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_state_event_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'sequence': + tf.TensorSpec(shape=(), dtype=tf.string) + }, + args=args) + + ds = _include_inputs(ds, input_record) + return ds + + tokenized_records = ds.flat_map(process_record) + return tokenized_records + + +def tokenize_guitarset_example( + ds: tf.data.Dataset, spectrogram_config: spectrograms.SpectrogramConfig, + codec: event_codec.Codec, is_training_data: bool, + onsets_only: bool, include_ties: bool +) -> tf.data.Dataset: + """Tokenize a GuitarSet transcription example.""" + def _preprocess_example(ex, name): + assert 'inst_names' not in ex, 'Key `inst_names` is already populated.' + ex['inst_names'] = [name] + ex['instrument_sequences'] = [ex.pop('sequence')] + return ex + + ds = ds.map( + lambda x: _preprocess_example(x, 'Clean Guitar'), + num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = tokenize_example_with_program_lookup( + ds, + spectrogram_config=spectrogram_config, + codec=codec, + is_training_data=is_training_data, + inst_name_to_program_fn=guitarset_instrument_to_program, + onsets_only=onsets_only, + include_ties=include_ties, + id_feature_key='id') + return ds + + +def guitarset_instrument_to_program(instrument: str) -> int: + """GuitarSet is all guitar, return the first MIDI guitar program.""" + if instrument == 'Clean Guitar': + return 24 + else: + raise ValueError('Unknown GuitarSet instrument: %s' % instrument) + + +def tokenize_example_with_program_lookup( + ds: tf.data.Dataset, + spectrogram_config: spectrograms.SpectrogramConfig, + codec: event_codec.Codec, + is_training_data: bool, + onsets_only: bool, + include_ties: bool, + inst_name_to_program_fn: Callable[[str], int], + id_feature_key: Optional[str] = None +) -> tf.data.Dataset: + """Tokenize an example, optionally looking up and assigning program numbers. + + This can be used by any dataset where a mapping function can be used to + map from the inst_names feature to a set of program numbers. + + Args: + ds: Input dataset. + spectrogram_config: Spectrogram configuration. + codec: Event vocabulary codec. + is_training_data: Unused. + onsets_only: If True, include only onset events (not offset & velocity). + include_ties: If True, include tie events. + inst_name_to_program_fn: A function used to map the instrument names + in the `inst_names` feature of each example to a MIDI program number. + id_feature_key: If not None, replace sequence ID with specified key field + from the dataset. + + Returns: + Dataset with the outputs described above. + """ + del is_training_data + + def tokenize(sequences, inst_names, audio, example_id=None): + # Add all the notes from the tracks to a single NoteSequence. + ns = note_seq.NoteSequence(ticks_per_quarter=220) + tracks = [note_seq.NoteSequence.FromString(seq) for seq in sequences] + assert len(tracks) == len(inst_names) + for track, inst_name in zip(tracks, inst_names): + program = inst_name_to_program_fn( + inst_name.decode()) + + # Note that there are no pitch bends in URMP data; the below block will + # raise PitchBendError if one is encountered. + add_track_to_notesequence(ns, track, program=program, is_drum=False, + ignore_pitch_bends=False) + + note_sequences.assign_instruments(ns) + note_sequences.validate_note_sequence(ns) + + if example_id is not None: + ns.id = example_id + + samples = note_seq.audio_io.wav_data_to_samples_librosa( + audio, sample_rate=spectrogram_config.sample_rate) + + logging.info('Got samples for %s::%s with length %d', + ns.id, ns.filename, len(samples)) + + frames, frame_times = _audio_to_frames(samples, spectrogram_config) + + if onsets_only: + times, values = note_sequences.note_sequence_to_onsets(ns) + else: + times, values = ( + note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)) + + # The original NoteSequence can have a lot of control changes we don't need; + # delete them. + del ns.control_changes[:] + + (events, event_start_indices, event_end_indices, + state_events, state_event_indices) = ( + run_length_encoding.encode_and_index_events( + state=note_sequences.NoteEncodingState() if include_ties else None, + event_times=times, + event_values=values, + encode_event_fn=note_sequences.note_event_data_to_events, + codec=codec, + frame_times=frame_times, + encoding_state_to_events_fn=( + note_sequences.note_encoding_state_to_events + if include_ties else None))) + + yield { + 'inputs': frames, + 'input_times': frame_times, + 'targets': events, + 'input_event_start_indices': event_start_indices, + 'input_event_end_indices': event_end_indices, + 'state_events': state_events, + 'input_state_event_indices': state_event_indices, + 'sequence': ns.SerializeToString() + } + + def process_record(input_record): + args = [ + input_record['instrument_sequences'], + input_record['inst_names'], + input_record['audio'], + ] + if id_feature_key is not None: + args.append(input_record[id_feature_key]) + + ds = tf.data.Dataset.from_generator( + tokenize, + output_signature={ + 'inputs': + tf.TensorSpec( + shape=(None, spectrogram_config.hop_width), + dtype=tf.float32), + 'input_times': + tf.TensorSpec(shape=(None,), dtype=tf.float32), + 'targets': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_event_start_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_event_end_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'state_events': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_state_event_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'sequence': + tf.TensorSpec(shape=(), dtype=tf.string) + }, + args=args) + + ds = _include_inputs(ds, input_record) + return ds + + tokenized_records = ds.flat_map(process_record) + return tokenized_records + + +_URMP_INSTRUMENT_PROGRAMS = immutabledict({ + 'vn': 40, # violin + 'va': 41, # viola + 'vc': 42, # cello + 'db': 43, # double bass + 'tpt': 56, # trumpet + 'tbn': 57, # trombone + 'tba': 58, # tuba + 'hn': 60, # French horn + 'sax': 64, # saxophone + 'ob': 68, # oboe + 'bn': 70, # bassoon + 'cl': 71, # clarinet + 'fl': 73 # flute +}) + + +def urmp_instrument_to_program(urmp_instrument: str) -> int: + """Fetch the program number associated with a given URMP instrument code.""" + if urmp_instrument not in _URMP_INSTRUMENT_PROGRAMS: + raise ValueError('unknown URMP instrument: %s' % urmp_instrument) + return _URMP_INSTRUMENT_PROGRAMS[urmp_instrument] + + +_SLAKH_CLASS_PROGRAMS = immutabledict({ + 'Acoustic Piano': 0, + 'Electric Piano': 4, + 'Chromatic Percussion': 8, + 'Organ': 16, + 'Acoustic Guitar': 24, + 'Clean Electric Guitar': 26, + 'Distorted Electric Guitar': 29, + 'Acoustic Bass': 32, + 'Electric Bass': 33, + 'Violin': 40, + 'Viola': 41, + 'Cello': 42, + 'Contrabass': 43, + 'Orchestral Harp': 46, + 'Timpani': 47, + 'String Ensemble': 48, + 'Synth Strings': 50, + 'Choir and Voice': 52, + 'Orchestral Hit': 55, + 'Trumpet': 56, + 'Trombone': 57, + 'Tuba': 58, + 'French Horn': 60, + 'Brass Section': 61, + 'Soprano/Alto Sax': 64, + 'Tenor Sax': 66, + 'Baritone Sax': 67, + 'Oboe': 68, + 'English Horn': 69, + 'Bassoon': 70, + 'Clarinet': 71, + 'Pipe': 73, + 'Synth Lead': 80, + 'Synth Pad': 88 +}) + + +def slakh_class_to_program_and_is_drum(slakh_class: str) -> Tuple[int, bool]: + """Map Slakh class string to program number and boolean indicating drums.""" + if slakh_class == 'Drums': + return 0, True + elif slakh_class not in _SLAKH_CLASS_PROGRAMS: + raise ValueError('unknown Slakh class: %s' % slakh_class) + else: + return _SLAKH_CLASS_PROGRAMS[slakh_class], False + + +class PitchBendError(Exception): + pass + + +def add_track_to_notesequence(ns: note_seq.NoteSequence, + track: note_seq.NoteSequence, + program: int, is_drum: bool, + ignore_pitch_bends: bool): + """Add a track to a NoteSequence.""" + if track.pitch_bends and not ignore_pitch_bends: + raise PitchBendError + track_sus = note_seq.apply_sustain_control_changes(track) + for note in track_sus.notes: + note.program = program + note.is_drum = is_drum + ns.notes.extend([note]) + ns.total_time = max(ns.total_time, note.end_time) + + +def tokenize_slakh_example( + ds: tf.data.Dataset, + spectrogram_config: spectrograms.SpectrogramConfig, + codec: event_codec.Codec, + is_training_data: bool, + onsets_only: bool, + include_ties: bool, + track_specs: Optional[Sequence[note_sequences.TrackSpec]], + ignore_pitch_bends: bool +) -> tf.data.Dataset: + """Tokenize a Slakh multitrack note transcription example.""" + def tokenize(sequences, samples, sample_rate, inst_names, example_id): + if sample_rate != spectrogram_config.sample_rate: + samples = librosa.resample( + samples, sample_rate, spectrogram_config.sample_rate) + + frames, frame_times = _audio_to_frames(samples, spectrogram_config) + + # Add all the notes from the tracks to a single NoteSequence. + ns = note_seq.NoteSequence(ticks_per_quarter=220) + tracks = [note_seq.NoteSequence.FromString(seq) for seq in sequences] + assert len(tracks) == len(inst_names) + if track_specs: + # Specific tracks expected. + assert len(tracks) == len(track_specs) + for track, spec, inst_name in zip(tracks, track_specs, inst_names): + # Make sure the instrument name matches what we expect. + assert inst_name.decode() == spec.name + try: + add_track_to_notesequence(ns, track, + program=spec.program, is_drum=spec.is_drum, + ignore_pitch_bends=ignore_pitch_bends) + except PitchBendError: + # TODO(iansimon): is there a way to count these? + return + else: + for track, inst_name in zip(tracks, inst_names): + # Instrument name should be Slakh class. + program, is_drum = slakh_class_to_program_and_is_drum( + inst_name.decode()) + try: + add_track_to_notesequence(ns, track, program=program, is_drum=is_drum, + ignore_pitch_bends=ignore_pitch_bends) + except PitchBendError: + # TODO(iansimon): is there a way to count these? + return + + note_sequences.assign_instruments(ns) + note_sequences.validate_note_sequence(ns) + if is_training_data: + # Trim overlapping notes in training (as our event vocabulary cannot + # represent them), but preserve original NoteSequence for eval. + ns = note_sequences.trim_overlapping_notes(ns) + + ns.id = example_id + + if onsets_only: + times, values = note_sequences.note_sequence_to_onsets(ns) + else: + times, values = ( + note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)) + + (events, event_start_indices, event_end_indices, + state_events, state_event_indices) = ( + run_length_encoding.encode_and_index_events( + state=note_sequences.NoteEncodingState() if include_ties else None, + event_times=times, + event_values=values, + encode_event_fn=note_sequences.note_event_data_to_events, + codec=codec, + frame_times=frame_times, + encoding_state_to_events_fn=( + note_sequences.note_encoding_state_to_events + if include_ties else None))) + + yield { + 'inputs': frames, + 'input_times': frame_times, + 'targets': events, + 'input_event_start_indices': event_start_indices, + 'input_event_end_indices': event_end_indices, + 'state_events': state_events, + 'input_state_event_indices': state_event_indices, + 'sequence': ns.SerializeToString() + } + + def process_record(input_record): + ds = tf.data.Dataset.from_generator( + tokenize, + output_signature={ + 'inputs': + tf.TensorSpec( + shape=(None, spectrogram_config.hop_width), + dtype=tf.float32), + 'input_times': + tf.TensorSpec(shape=(None,), dtype=tf.float32), + 'targets': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_event_start_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_event_end_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'state_events': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'input_state_event_indices': + tf.TensorSpec(shape=(None,), dtype=tf.int32), + 'sequence': + tf.TensorSpec(shape=(), dtype=tf.string) + }, + args=[ + input_record['note_sequences'], input_record['mix'], + input_record['audio_sample_rate'], input_record['inst_names'], + input_record['track_id'] + ]) + + ds = _include_inputs(ds, input_record, fields_to_omit=['mix', 'stems']) + return ds + + tokenized_records = ds.flat_map(process_record) + return tokenized_records + + + + +@seqio.map_over_dataset +def compute_spectrograms(ex, spectrogram_config): + samples = spectrograms.flatten_frames(ex['inputs']) + ex['inputs'] = spectrograms.compute_spectrogram(samples, spectrogram_config) + ex['raw_inputs'] = samples + return ex + + +def handle_too_long(dataset: tf.data.Dataset, + output_features: seqio.preprocessors.OutputFeaturesType, + sequence_length: seqio.preprocessors.SequenceLengthType, + skip: bool = False) -> tf.data.Dataset: + """Handle sequences that are too long, by either failing or skipping them.""" + def max_length_for_key(key): + max_length = sequence_length[key] + if output_features[key].add_eos: + max_length -= 1 + return max_length + + if skip: + # Drop examples where one of the features is longer than its maximum + # sequence length. + def is_not_too_long(ex): + return not tf.reduce_any( + [k in output_features and len(v) > max_length_for_key(k) + for k, v in ex.items()]) + dataset = dataset.filter(is_not_too_long) + + def assert_not_too_long(key: str, value: tf.Tensor) -> tf.Tensor: + if key in output_features: + max_length = max_length_for_key(key) + tf.debugging.assert_less_equal( + tf.shape(value)[0], max_length, + f'Value for "{key}" field exceeds maximum length') + return value + + # Assert that no examples have features longer than their maximum sequence + # length. + return dataset.map( + lambda ex: {k: assert_not_too_long(k, v) for k, v in ex.items()}, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + +@gin.configurable +def map_midi_programs( + ds: tf.data.Dataset, + codec: event_codec.Codec, + granularity_type: str = 'full', + feature_key: str = 'targets' +) -> Mapping[str, Any]: + """Apply MIDI program map to token sequences.""" + granularity = vocabularies.PROGRAM_GRANULARITIES[granularity_type] + def _map_program_tokens(ex): + ex[feature_key] = granularity.tokens_map_fn(ex[feature_key], codec) + return ex + return ds.map(_map_program_tokens, + num_parallel_calls=tf.data.experimental.AUTOTUNE) diff --git a/mt3/pytest.ini b/mt3/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..5f1cd9f4e5ef281a47dc69180674da15e6b28f56 --- /dev/null +++ b/mt3/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +python_files = *_test.py +log_level = INFO \ No newline at end of file diff --git a/mt3/run_length_encoding.py b/mt3/run_length_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..04b71a76eac17c385c446165178f85a77d193708 --- /dev/null +++ b/mt3/run_length_encoding.py @@ -0,0 +1,423 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for run length encoding.""" + +import dataclasses +from typing import Any, Callable, Mapping, MutableMapping, Tuple, Optional, Sequence, TypeVar + +from absl import logging +from mt3 import event_codec + +import numpy as np +import seqio +import tensorflow as tf + +Event = event_codec.Event + +# These should be type variables, but unfortunately those are incompatible with +# dataclasses. +EventData = Any +EncodingState = Any +DecodingState = Any +DecodeResult = Any + +T = TypeVar('T', bound=EventData) +ES = TypeVar('ES', bound=EncodingState) +DS = TypeVar('DS', bound=DecodingState) + + +@dataclasses.dataclass +class EventEncodingSpec: + """Spec for encoding events.""" + # initialize encoding state + init_encoding_state_fn: Callable[[], EncodingState] + # convert EventData into zero or more events, updating encoding state + encode_event_fn: Callable[[EncodingState, EventData, event_codec.Codec], + Sequence[event_codec.Event]] + # convert encoding state (at beginning of segment) into events + encoding_state_to_events_fn: Optional[Callable[[EncodingState], + Sequence[event_codec.Event]]] + # create empty decoding state + init_decoding_state_fn: Callable[[], DecodingState] + # update decoding state when entering new segment + begin_decoding_segment_fn: Callable[[DecodingState], None] + # consume time and Event and update decoding state + decode_event_fn: Callable[ + [DecodingState, float, event_codec.Event, event_codec.Codec], None] + # flush decoding state into result + flush_decoding_state_fn: Callable[[DecodingState], DecodeResult] + + +def encode_and_index_events( + state: ES, + event_times: Sequence[float], + event_values: Sequence[T], + encode_event_fn: Callable[[ES, T, event_codec.Codec], + Sequence[event_codec.Event]], + codec: event_codec.Codec, + frame_times: Sequence[float], + encoding_state_to_events_fn: Optional[ + Callable[[ES], Sequence[event_codec.Event]]] = None, +) -> Tuple[Sequence[int], Sequence[int], Sequence[int], + Sequence[int], Sequence[int]]: + """Encode a sequence of timed events and index to audio frame times. + + Encodes time shifts as repeated single step shifts for later run length + encoding. + + Optionally, also encodes a sequence of "state events", keeping track of the + current encoding state at each audio frame. This can be used e.g. to prepend + events representing the current state to a targets segment. + + Args: + state: Initial event encoding state. + event_times: Sequence of event times. + event_values: Sequence of event values. + encode_event_fn: Function that transforms event value into a sequence of one + or more event_codec.Event objects. + codec: An event_codec.Codec object that maps Event objects to indices. + frame_times: Time for every audio frame. + encoding_state_to_events_fn: Function that transforms encoding state into a + sequence of one or more event_codec.Event objects. + + Returns: + events: Encoded events and shifts. + event_start_indices: Corresponding start event index for every audio frame. + Note: one event can correspond to multiple audio indices due to sampling + rate differences. This makes splitting sequences tricky because the same + event can appear at the end of one sequence and the beginning of + another. + event_end_indices: Corresponding end event index for every audio frame. Used + to ensure when slicing that one chunk ends where the next begins. Should + always be true that event_end_indices[i] = event_start_indices[i + 1]. + state_events: Encoded "state" events representing the encoding state before + each event. + state_event_indices: Corresponding state event index for every audio frame. + """ + indices = np.argsort(event_times, kind='stable') + event_steps = [round(event_times[i] * codec.steps_per_second) + for i in indices] + event_values = [event_values[i] for i in indices] + + events = [] + state_events = [] + event_start_indices = [] + state_event_indices = [] + + cur_step = 0 + cur_event_idx = 0 + cur_state_event_idx = 0 + + def fill_event_start_indices_to_cur_step(): + while(len(event_start_indices) < len(frame_times) and + frame_times[len(event_start_indices)] < + cur_step / codec.steps_per_second): + event_start_indices.append(cur_event_idx) + state_event_indices.append(cur_state_event_idx) + + for event_step, event_value in zip(event_steps, event_values): + while event_step > cur_step: + events.append(codec.encode_event(Event(type='shift', value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + cur_state_event_idx = len(state_events) + if encoding_state_to_events_fn: + # Dump state to state events *before* processing the next event, because + # we want to capture the state prior to the occurrence of the event. + for e in encoding_state_to_events_fn(state): + state_events.append(codec.encode_event(e)) + for e in encode_event_fn(state, event_value, codec): + events.append(codec.encode_event(e)) + + # After the last event, continue filling out the event_start_indices array. + # The inequality is not strict because if our current step lines up exactly + # with (the start of) an audio frame, we need to add an additional shift event + # to "cover" that frame. + while cur_step / codec.steps_per_second <= frame_times[-1]: + events.append(codec.encode_event(Event(type='shift', value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + + # Now fill in event_end_indices. We need this extra array to make sure that + # when we slice events, each slice ends exactly where the subsequent slice + # begins. + event_end_indices = event_start_indices[1:] + [len(events)] + + events = np.array(events) + state_events = np.array(state_events) + event_start_indices = np.array(event_start_indices) + event_end_indices = np.array(event_end_indices) + state_event_indices = np.array(state_event_indices) + + return (events, event_start_indices, event_end_indices, + state_events, state_event_indices) + + +@seqio.map_over_dataset +def extract_target_sequence_with_indices(features, state_events_end_token=None): + """Extract target sequence corresponding to audio token segment.""" + target_start_idx = features['input_event_start_indices'][0] + target_end_idx = features['input_event_end_indices'][-1] + + features['targets'] = features['targets'][target_start_idx:target_end_idx] + + if state_events_end_token is not None: + # Extract the state events corresponding to the audio start token, and + # prepend them to the targets array. + state_event_start_idx = features['input_state_event_indices'][0] + state_event_end_idx = state_event_start_idx + 1 + while features['state_events'][ + state_event_end_idx - 1] != state_events_end_token: + state_event_end_idx += 1 + features['targets'] = tf.concat([ + features['state_events'][state_event_start_idx:state_event_end_idx], + features['targets'] + ], axis=0) + + return features + + +def remove_redundant_state_changes_fn( + codec: event_codec.Codec, + feature_key: str = 'targets', + state_change_event_types: Sequence[str] = () +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + """Return preprocessing function that removes redundant state change events. + + Args: + codec: The event_codec.Codec used to interpret the events. + feature_key: The feature key for which to remove redundant state changes. + state_change_event_types: A list of event types that represent state + changes; tokens corresponding to these event types will be interpreted + as state changes and redundant ones will be removed. + + Returns: + A preprocessing function that removes redundant state change events. + """ + state_change_event_ranges = [codec.event_type_range(event_type) + for event_type in state_change_event_types] + + def remove_redundant_state_changes( + features: MutableMapping[str, Any], + ) -> Mapping[str, Any]: + """Remove redundant tokens e.g. duplicate velocity changes from sequence.""" + current_state = tf.zeros(len(state_change_event_ranges), dtype=tf.int32) + output = tf.constant([], dtype=tf.int32) + + for event in features[feature_key]: + # Let autograph know that the shape of 'output' will change during the + # loop. + tf.autograph.experimental.set_loop_options( + shape_invariants=[(output, tf.TensorShape([None]))]) + is_redundant = False + for i, (min_index, max_index) in enumerate(state_change_event_ranges): + if (min_index <= event) and (event <= max_index): + if current_state[i] == event: + is_redundant = True + current_state = tf.tensor_scatter_nd_update( + current_state, indices=[[i]], updates=[event]) + if not is_redundant: + output = tf.concat([output, [event]], axis=0) + + features[feature_key] = output + return features + + return seqio.map_over_dataset(remove_redundant_state_changes) + + +def run_length_encode_shifts_fn( + codec: event_codec.Codec, + feature_key: str = 'targets' +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + """Return a function that run-length encodes shifts for a given codec. + + Args: + codec: The Codec to use for shift events. + feature_key: The feature key for which to run-length encode shifts. + + Returns: + A preprocessing function that run-length encodes single-step shifts. + """ + def run_length_encode_shifts( + features: MutableMapping[str, Any] + ) -> Mapping[str, Any]: + """Combine leading/interior shifts, trim trailing shifts. + + Args: + features: Dict of features to process. + + Returns: + A dict of features. + """ + events = features[feature_key] + + shift_steps = 0 + total_shift_steps = 0 + output = tf.constant([], dtype=tf.int32) + + for event in events: + # Let autograph know that the shape of 'output' will change during the + # loop. + tf.autograph.experimental.set_loop_options( + shape_invariants=[(output, tf.TensorShape([None]))]) + if codec.is_shift_event_index(event): + shift_steps += 1 + total_shift_steps += 1 + + else: + # Once we've reached a non-shift event, RLE all previous shift events + # before outputting the non-shift event. + if shift_steps > 0: + shift_steps = total_shift_steps + while shift_steps > 0: + output_steps = tf.minimum(codec.max_shift_steps, shift_steps) + output = tf.concat([output, [output_steps]], axis=0) + shift_steps -= output_steps + output = tf.concat([output, [event]], axis=0) + + features[feature_key] = output + return features + + return seqio.map_over_dataset(run_length_encode_shifts) + + +def merge_run_length_encoded_targets( + targets: np.ndarray, + codec: event_codec.Codec +) -> Sequence[int]: + """Merge multiple tracks of target events into a single stream. + + Args: + targets: A 2D array (# tracks by # events) of integer event values. + codec: The event_codec.Codec used to interpret the events. + + Returns: + A 1D array of merged events. + """ + num_tracks = tf.shape(targets)[0] + targets_length = tf.shape(targets)[1] + + current_step = 0 + current_offsets = tf.zeros(num_tracks, dtype=tf.int32) + + output = tf.constant([], dtype=tf.int32) + done = tf.constant(False) + + while not done: + # Let autograph know that the shape of 'output' will change during the loop. + tf.autograph.experimental.set_loop_options( + shape_invariants=[(output, tf.TensorShape([None]))]) + + # Determine which targets track has the earliest next step. + next_step = codec.max_shift_steps + 1 + next_track = -1 + for i in range(num_tracks): + if (current_offsets[i] == targets_length or + targets[i][current_offsets[i]] == 0): + # Already reached the end of this targets track. + # (Zero is technically a valid shift event but we never actually use it; + # it is always padding.) + continue + if not codec.is_shift_event_index(targets[i][current_offsets[i]]): + # The only way we would be at a non-shift event is if we have not yet + # reached the first shift event, which means we're at step zero. + next_step = 0 + next_track = i + elif targets[i][current_offsets[i]] < next_step: + next_step = targets[i][current_offsets[i]] + next_track = i + + if next_track == -1: + # We've already merged all of the target tracks in their entirety. + done = tf.constant(True) + break + + if next_step == current_step and next_step > 0: + # We don't need to include the shift event itself as it's the same step as + # the previous shift. + start_offset = current_offsets[next_track] + 1 + else: + start_offset = current_offsets[next_track] + + # Merge in events up to but not including the next shift. + end_offset = start_offset + 1 + while end_offset < targets_length and not codec.is_shift_event_index( + targets[next_track][end_offset]): + end_offset += 1 + output = tf.concat( + [output, targets[next_track][start_offset:end_offset]], axis=0) + + current_step = next_step + current_offsets = tf.tensor_scatter_nd_update( + current_offsets, indices=[[next_track]], updates=[end_offset]) + + return output + + +def decode_events( + state: DS, + tokens: np.ndarray, + start_time: int, + max_time: Optional[int], + codec: event_codec.Codec, + decode_event_fn: Callable[[DS, float, event_codec.Event, event_codec.Codec], + None], +) -> Tuple[int, int]: + """Decode a series of tokens, maintaining a decoding state object. + + Args: + state: Decoding state object; will be modified in-place. + tokens: event tokens to convert. + start_time: offset start time if decoding in the middle of a sequence. + max_time: Events at or beyond this time will be dropped. + codec: An event_codec.Codec object that maps indices to Event objects. + decode_event_fn: Function that consumes an Event (and the current time) and + updates the decoding state. + + Returns: + invalid_events: number of events that could not be decoded. + dropped_events: number of events dropped due to max_time restriction. + """ + invalid_events = 0 + dropped_events = 0 + cur_steps = 0 + cur_time = start_time + token_idx = 0 + for token_idx, token in enumerate(tokens): + try: + event = codec.decode_event_index(token) + except ValueError: + invalid_events += 1 + continue + if event.type == 'shift': + cur_steps += event.value + cur_time = start_time + cur_steps / codec.steps_per_second + if max_time and cur_time > max_time: + dropped_events = len(tokens) - token_idx + break + else: + cur_steps = 0 + try: + decode_event_fn(state, cur_time, event, codec) + except ValueError: + invalid_events += 1 + logging.info( + 'Got invalid event when decoding event %s at time %f. ' + 'Invalid event counter now at %d.', + event, cur_time, invalid_events, exc_info=True) + continue + return invalid_events, dropped_events diff --git a/mt3/run_length_encoding_test.py b/mt3/run_length_encoding_test.py new file mode 100644 index 0000000000000000000000000000000000000000..848c900d9ddcf6c3d8f367b79c44552a3e8f8c89 --- /dev/null +++ b/mt3/run_length_encoding_test.py @@ -0,0 +1,107 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for run_length_encoding.""" + +from mt3 import event_codec +from mt3 import run_length_encoding + +import note_seq +import numpy as np +import seqio +import tensorflow as tf + +assert_dataset = seqio.test_utils.assert_dataset +codec = event_codec.Codec( + max_shift_steps=100, + steps_per_second=100, + event_ranges=[ + event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('velocity', 0, 127), + event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM, + note_seq.MAX_MIDI_PROGRAM), + event_codec.EventRange('tie', 0, 0) + ]) +run_length_encode_shifts = run_length_encoding.run_length_encode_shifts_fn( + codec=codec) + + +class RunLengthEncodingTest(tf.test.TestCase): + + def test_remove_redundant_state_changes(self): + og_dataset = tf.data.Dataset.from_tensors({ + 'targets': [3, 525, 356, 161, 2, 525, 356, 161, 355, 394] + }) + + assert_dataset( + run_length_encoding.remove_redundant_state_changes_fn( + codec=codec, + state_change_event_types=['velocity', 'program'])(og_dataset), + { + 'targets': [3, 525, 356, 161, 2, 161, 355, 394], + }) + + def test_run_length_encode_shifts(self): + og_dataset = tf.data.Dataset.from_tensors({ + 'targets': [1, 1, 1, 161, 1, 1, 1, 162, 1, 1, 1] + }) + + assert_dataset( + run_length_encode_shifts(og_dataset), + { + 'targets': [3, 161, 6, 162], + }) + + def test_run_length_encode_shifts_beyond_max_length(self): + og_dataset = tf.data.Dataset.from_tensors({ + 'targets': [1] * 202 + [161, 1, 1, 1] + }) + + assert_dataset( + run_length_encode_shifts(og_dataset), + { + 'targets': [100, 100, 2, 161], + }) + + def test_run_length_encode_shifts_simultaneous(self): + og_dataset = tf.data.Dataset.from_tensors({ + 'targets': [1, 1, 1, 161, 162, 1, 1, 1] + }) + + assert_dataset( + run_length_encode_shifts(og_dataset), + { + 'targets': [3, 161, 162], + }) + + def test_merge_run_length_encoded_targets(self): + # pylint: disable=bad-whitespace + targets = np.array([ + [ 3, 161, 162, 5, 163], + [160, 164, 3, 165, 0] + ]) + # pylint: enable=bad-whitespace + merged_targets = run_length_encoding.merge_run_length_encoded_targets( + targets=targets, codec=codec) + expected_merged_targets = [ + 160, 164, 3, 161, 162, 165, 5, 163 + ] + np.testing.assert_array_equal(expected_merged_targets, merged_targets) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mt3/scripts/dump_task.py b/mt3/scripts/dump_task.py new file mode 100644 index 0000000000000000000000000000000000000000..9491e035cab9ba940f2c8ecd4ef1acf1a9075ed1 --- /dev/null +++ b/mt3/scripts/dump_task.py @@ -0,0 +1,80 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple debugging utility for printing out task contents.""" + +import re + +from absl import app +from absl import flags + +import mt3.tasks # pylint: disable=unused-import + +import seqio +import tensorflow as tf + + +FLAGS = flags.FLAGS + +flags.DEFINE_string("task", None, "A registered Task.") +flags.DEFINE_string("task_cache_dir", None, "Directory to use for task cache.") +flags.DEFINE_integer("max_examples", 10, + "Maximum number of examples (-1 for no limit).") +flags.DEFINE_string("format_string", "targets = {targets}", + "Format for printing examples.") +flags.DEFINE_string("split", "train", + "Which split of the dataset, e.g. train or validation.") +flags.DEFINE_integer("sequence_length_inputs", 256, + "Sequence length for inputs.") +flags.DEFINE_integer("sequence_length_targets", 1024, + "Sequence length for targets.") + + +def main(_): + if FLAGS.task_cache_dir: + seqio.add_global_cache_dirs([FLAGS.task_cache_dir]) + + task = seqio.get_mixture_or_task(FLAGS.task) + + ds = task.get_dataset( + sequence_length={ + "inputs": FLAGS.sequence_length_inputs, + "targets": FLAGS.sequence_length_targets, + }, + split=FLAGS.split, + use_cached=bool(FLAGS.task_cache_dir), + shuffle=False) + + keys = re.findall(r"{([\w+]+)}", FLAGS.format_string) + def _example_to_string(ex): + key_to_string = {} + for k in keys: + if k in ex: + v = ex[k].numpy().tolist() + key_to_string[k] = task.output_features[k].vocabulary.decode(v) + else: + key_to_string[k] = "" + return FLAGS.format_string.format(**key_to_string) + + for ex in ds.take(FLAGS.max_examples): + for k, v in ex.items(): + print(f"{k}: {tf.shape(v)}") + print(_example_to_string(ex)) + print() + + +if __name__ == "__main__": + flags.mark_flags_as_required(["task"]) + + app.run(main) diff --git a/mt3/scripts/extract_monophonic_examples.py b/mt3/scripts/extract_monophonic_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..38a3b136e87e78edc46617a25e5e5072d994fbd3 --- /dev/null +++ b/mt3/scripts/extract_monophonic_examples.py @@ -0,0 +1,251 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Detect monophonic tracks and extract notes.""" + +import collections +import os + +from absl import app +from absl import flags +from absl import logging + +import ddsp +import librosa +import note_seq +import numpy as np +import scipy +import tensorflow as tf + + +_INPUT_DIR = flags.DEFINE_string( + 'input_dir', None, + 'Input directory containing WAV files.') +_OUTPUT_TFRECORD_PATH = flags.DEFINE_string( + 'output_tfrecord_path', None, + 'Path to the output TFRecord containing tf.train.Example protos with ' + 'monophonic tracks and inferred NoteSequence protos.') + + +CREPE_SAMPLE_RATE = 16000 +CREPE_FRAME_RATE = 100 + +MONOPHONIC_CONFIDENCE_THRESHOLD = 0.95 # confidence must be greater than this +MONOPHONIC_CONFIDENCE_FRAC = 0.2 # for this fraction of frames + +# split input audio into clips +CLIP_LENGTH_SECONDS = 5 + + +def is_monophonic_heuristic(f0_confidence): + """Heuristic to check for monophonicity using f0 confidence.""" + return (np.sum(f0_confidence >= MONOPHONIC_CONFIDENCE_THRESHOLD) / + len(f0_confidence) >= MONOPHONIC_CONFIDENCE_FRAC) + + +# HMM parameters for modeling notes and F0 tracks. +F0_MIDI_SIGMA = 0.2 +OCTAVE_ERROR_PROB = 0.05 +NOTES_PER_SECOND = 2 +NOTE_CHANGE_PROB = NOTES_PER_SECOND / CREPE_FRAME_RATE +F0_CONFIDENCE_EXP = 7.5 + + +def f0_hmm_matrices(f0_hz, f0_confidence): + """Observation and transition matrices for hidden Markov model of F0.""" + f0_midi = librosa.hz_to_midi(f0_hz) + f0_midi_diff = f0_midi[:, np.newaxis] - np.arange(128)[np.newaxis, :] + + # Compute the probability of each pitch at each frame, taking octave errors + # into account. + f0_midi_prob_octave_correct = scipy.stats.norm.pdf( + f0_midi_diff, scale=F0_MIDI_SIGMA) + f0_midi_prob_octave_low = scipy.stats.norm.pdf( + f0_midi_diff + 12, scale=F0_MIDI_SIGMA) + f0_midi_prob_octave_high = scipy.stats.norm.pdf( + f0_midi_diff - 12, scale=F0_MIDI_SIGMA) + + # distribution of pitch values given note + f0_midi_loglik = ((1 - OCTAVE_ERROR_PROB) * f0_midi_prob_octave_correct + + 0.5 * OCTAVE_ERROR_PROB * f0_midi_prob_octave_low + + 0.5 * OCTAVE_ERROR_PROB * f0_midi_prob_octave_high) + # (uniform) distribution of pitch values given rest + f0_midi_rest_loglik = -np.log(128) + + # Here we interpret confidence, after adjusting by exponent, as P(not rest). + f0_confidence_prob = np.power(f0_confidence, F0_CONFIDENCE_EXP)[:, np.newaxis] + + obs_loglik = np.concatenate([ + # probability of note (normalized by number of possible notes) + f0_midi_loglik + np.log(f0_confidence_prob) - np.log(128), + # probability of rest + f0_midi_rest_loglik + np.log(1.0 - f0_confidence_prob) + ], axis=1) + + # Normalize to adjust P(confidence | note) by uniform P(note). + # TODO(iansimon): Not sure how correct this is but it doesn't affect the path. + obs_loglik += np.log(129) + + trans_prob = ((NOTE_CHANGE_PROB / 128) * np.ones(129) + + (1 - NOTE_CHANGE_PROB - NOTE_CHANGE_PROB / 128) * np.eye(129)) + trans_loglik = np.log(trans_prob) + + return obs_loglik, trans_loglik + + +def hmm_forward(obs_loglik, trans_loglik): + """Forward algorithm for a hidden Markov model.""" + n, k = obs_loglik.shape + trans = np.exp(trans_loglik) + + loglik = 0.0 + + l = obs_loglik[0] - np.log(k) + c = scipy.special.logsumexp(l) + loglik += c + + for i in range(1, n): + p = np.exp(l - c) + l = np.log(np.dot(p, trans)) + obs_loglik[i] + c = scipy.special.logsumexp(l) + loglik += c + + return loglik + + +def hmm_viterbi(obs_loglik, trans_loglik): + """Viterbi algorithm for a hidden Markov model.""" + n, k = obs_loglik.shape + + loglik_matrix = np.zeros_like(obs_loglik) + path_matrix = np.zeros_like(obs_loglik, dtype=np.int32) + + loglik_matrix[0, :] = obs_loglik[0, :] - np.log(k) + + for i in range(1, n): + mat = np.tile(loglik_matrix[i - 1][:, np.newaxis], [1, 129]) + trans_loglik + path_matrix[i, :] = mat.argmax(axis=0) + loglik_matrix[i, :] = mat[path_matrix[i, :], range(129)] + obs_loglik[i] + + path = [np.argmax(loglik_matrix[-1])] + for i in range(n, 1, -1): + path.append(path_matrix[i - 1, path[-1]]) + + return [(pitch if pitch < 128 else None) for pitch in path[::-1]] + + +def pitches_to_notesequence(pitches): + """Convert sequence of pitches output by Viterbi to NoteSequence proto.""" + ns = note_seq.NoteSequence(ticks_per_quarter=220) + current_pitch = None + start_time = None + for frame, pitch in enumerate(pitches): + time = frame / CREPE_FRAME_RATE + if pitch != current_pitch: + if current_pitch is not None: + ns.notes.add( + pitch=current_pitch, velocity=100, + start_time=start_time, end_time=time) + current_pitch = pitch + start_time = time + if current_pitch is not None: + ns.notes.add( + pitch=current_pitch, velocity=100, + start_time=start_time, end_time=len(pitches) / CREPE_FRAME_RATE) + if ns.notes: + ns.total_time = ns.notes[-1].end_time + return ns + + +# Per-frame log likelihood threshold below which an F0 track will be discarded. +# Note that this is dependent on the HMM parameters specified above, so if those +# change then this threshold should also change. +PER_FRAME_LOGLIK_THRESHOLD = 0.3 + + +def extract_note_sequence(crepe, samples, counters): + """Use CREPE to attempt to extract a monophonic NoteSequence from audio.""" + f0_hz, f0_confidence = crepe.predict_f0_and_confidence( + samples[np.newaxis, :], viterbi=False) + + f0_hz = f0_hz[0].numpy() + f0_confidence = f0_confidence[0].numpy() + + if not is_monophonic_heuristic(f0_confidence): + counters['not_monophonic'] += 1 + return None + + obs_loglik, trans_loglik = f0_hmm_matrices(f0_hz, f0_confidence) + + loglik = hmm_forward(obs_loglik, trans_loglik) + if loglik / len(obs_loglik) < PER_FRAME_LOGLIK_THRESHOLD: + counters['low_likelihood'] += 1 + return None + + pitches = hmm_viterbi(obs_loglik, trans_loglik) + ns = pitches_to_notesequence(pitches) + + counters['extracted_monophonic_sequence'] += 1 + return ns + + +def process_wav_file(wav_filename, crepe, counters): + """Extract monophonic transcription examples from a WAV file.""" + wav_data = tf.io.gfile.GFile(wav_filename, 'rb').read() + samples = note_seq.audio_io.wav_data_to_samples_librosa( + wav_data, sample_rate=CREPE_SAMPLE_RATE) + clip_length_samples = int(CREPE_SAMPLE_RATE * CLIP_LENGTH_SECONDS) + for start_sample in range(0, len(samples), clip_length_samples): + clip_samples = samples[start_sample:start_sample + clip_length_samples] + if len(clip_samples) < clip_length_samples: + clip_samples = np.pad( + clip_samples, [(0, clip_length_samples - len(clip_samples))]) + ns = extract_note_sequence(crepe, clip_samples, counters) + if ns: + feature = { + 'audio': tf.train.Feature( + float_list=tf.train.FloatList(value=clip_samples.tolist())), + 'filename': tf.train.Feature( + bytes_list=tf.train.BytesList(value=[wav_filename.encode()])), + 'offset': tf.train.Feature( + int64_list=tf.train.Int64List(value=[start_sample])), + 'sampling_rate': tf.train.Feature( + float_list=tf.train.FloatList(value=[CREPE_SAMPLE_RATE])), + 'sequence': tf.train.Feature( + bytes_list=tf.train.BytesList(value=[ns.SerializeToString()])) + } + yield tf.train.Example(features=tf.train.Features(feature=feature)) + + +def main(unused_argv): + flags.mark_flags_as_required(['input_dir', 'output_tfrecord_path']) + crepe = ddsp.spectral_ops.PretrainedCREPE('full') + counters = collections.defaultdict(int) + with tf.io.TFRecordWriter(_OUTPUT_TFRECORD_PATH.value) as writer: + for filename in tf.io.gfile.listdir(_INPUT_DIR.value): + if not filename.endswith('.wav'): + logging.info('skipping %s...', filename) + counters['non_wav_files_skipped'] += 1 + continue + logging.info('processing %s...', filename) + for ex in process_wav_file( + os.path.join(_INPUT_DIR.value, filename), crepe, counters): + writer.write(ex.SerializeToString()) + counters['wav_files_processed'] += 1 + for k, v in counters.items(): + logging.info('COUNTER: %s = %d', k, v) + + +if __name__ == '__main__': + app.run(main) diff --git a/mt3/spectrograms.py b/mt3/spectrograms.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a4eebd200d40808e99db9a707262567dd4e298 --- /dev/null +++ b/mt3/spectrograms.py @@ -0,0 +1,82 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audio spectrogram functions.""" + +import dataclasses + +from ddsp import spectral_ops +import tensorflow as tf + +# defaults for spectrogram config +DEFAULT_SAMPLE_RATE = 16000 +DEFAULT_HOP_WIDTH = 128 +DEFAULT_NUM_MEL_BINS = 512 + +# fixed constants; add these to SpectrogramConfig before changing +FFT_SIZE = 2048 +MEL_LO_HZ = 20.0 + + +@dataclasses.dataclass +class SpectrogramConfig: + """Spectrogram configuration parameters.""" + sample_rate: int = DEFAULT_SAMPLE_RATE + hop_width: int = DEFAULT_HOP_WIDTH + num_mel_bins: int = DEFAULT_NUM_MEL_BINS + + @property + def abbrev_str(self): + s = '' + if self.sample_rate != DEFAULT_SAMPLE_RATE: + s += 'sr%d' % self.sample_rate + if self.hop_width != DEFAULT_HOP_WIDTH: + s += 'hw%d' % self.hop_width + if self.num_mel_bins != DEFAULT_NUM_MEL_BINS: + s += 'mb%d' % self.num_mel_bins + return s + + @property + def frames_per_second(self): + return self.sample_rate / self.hop_width + + +def split_audio(samples, spectrogram_config): + """Split audio into frames.""" + return tf.signal.frame( + samples, + frame_length=spectrogram_config.hop_width, + frame_step=spectrogram_config.hop_width, + pad_end=True) + + +def compute_spectrogram(samples, spectrogram_config): + """Compute a mel spectrogram.""" + overlap = 1 - (spectrogram_config.hop_width / FFT_SIZE) + return spectral_ops.compute_logmel( + samples, + bins=spectrogram_config.num_mel_bins, + lo_hz=MEL_LO_HZ, + overlap=overlap, + fft_size=FFT_SIZE, + sample_rate=spectrogram_config.sample_rate) + + +def flatten_frames(frames): + """Convert frames back into a flat array of samples.""" + return tf.reshape(frames, [-1]) + + +def input_depth(spectrogram_config): + return spectrogram_config.num_mel_bins diff --git a/mt3/summaries.py b/mt3/summaries.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c0ced11a1ad41d3bbadc72ebc7ff466b0e0d71 --- /dev/null +++ b/mt3/summaries.py @@ -0,0 +1,471 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorBoard summaries and utilities.""" + +from typing import Any, Mapping, Optional, Sequence, Tuple + +import librosa + +from mt3 import note_sequences +from mt3 import spectrograms + +import note_seq +from note_seq import midi_synth +from note_seq import sequences_lib +from note_seq.protobuf import music_pb2 + +import numpy as np +import seqio + + +_DEFAULT_AUDIO_SECONDS = 30.0 +_DEFAULT_PIANOROLL_FRAMES_PER_SECOND = 15 + +# TODO(iansimon): pick a SoundFont; for some reason the default is all organ + + +def _extract_example_audio( + examples: Sequence[Mapping[str, Any]], + sample_rate: float, + num_seconds: float, + audio_key: str = 'raw_inputs' +) -> np.ndarray: + """Extract audio from examples. + + Args: + examples: List of examples containing raw audio. + sample_rate: Number of samples per second. + num_seconds: Number of seconds of audio to include. + audio_key: Dictionary key for the raw audio. + + Returns: + An n-by-num_samples numpy array of samples. + """ + n = len(examples) + num_samples = round(num_seconds * sample_rate) + all_samples = np.zeros([n, num_samples]) + for i, ex in enumerate(examples): + samples = ex[audio_key][:num_samples] + all_samples[i, :len(samples)] = samples + return all_samples + + +def _example_to_note_sequence( + example: Mapping[str, Sequence[float]], + ns_feature_name: str, + note_onset_feature_name: str, + note_offset_feature_name: str, + note_frequency_feature_name: str, + note_confidence_feature_name: str, + num_seconds: float +) -> music_pb2.NoteSequence: + """Extract NoteSequence from example.""" + if ns_feature_name: + ns = example[ns_feature_name] + + else: + onset_times = np.array(example[note_onset_feature_name]) + pitches = librosa.hz_to_midi( + example[note_frequency_feature_name]).round().astype(int) + assert len(onset_times) == len(pitches) + + if note_offset_feature_name or note_confidence_feature_name: + offset_times = ( + example[note_offset_feature_name] + if note_offset_feature_name + else onset_times + note_sequences.DEFAULT_NOTE_DURATION + ) + assert len(onset_times) == len(offset_times) + + confidences = (np.array(example[note_confidence_feature_name]) + if note_confidence_feature_name else None) + velocities = np.ceil( + note_seq.MAX_MIDI_VELOCITY * confidences if confidences is not None + else note_sequences.DEFAULT_VELOCITY * np.ones_like(onset_times) + ).astype(int) + assert len(onset_times) == len(velocities) + + ns = note_sequences.note_arrays_to_note_sequence( + onset_times=onset_times, offset_times=offset_times, + pitches=pitches, velocities=velocities) + + else: + ns = note_sequences.note_arrays_to_note_sequence( + onset_times=onset_times, pitches=pitches) + + return sequences_lib.trim_note_sequence(ns, 0, num_seconds) + + +def _synthesize_example_notes( + examples: Sequence[Mapping[str, Sequence[float]]], + ns_feature_name: str, + note_onset_feature_name: str, + note_offset_feature_name: str, + note_frequency_feature_name: str, + note_confidence_feature_name: str, + sample_rate: float, + num_seconds: float, +) -> np.ndarray: + """Synthesize example notes to audio. + + Args: + examples: List of example dictionaries, containing either serialized + NoteSequence protos or note onset times and pitches. + ns_feature_name: Name of serialized NoteSequence feature. + note_onset_feature_name: Name of note onset times feature. + note_offset_feature_name: Name of note offset times feature. + note_frequency_feature_name: Name of note frequencies feature. + note_confidence_feature_name: Name of note confidences (velocities) feature. + sample_rate: Sample rate at which to synthesize. + num_seconds: Number of seconds to synthesize for each example. + + Returns: + An n-by-num_samples numpy array of samples. + """ + if (ns_feature_name is not None) == (note_onset_feature_name is not None): + raise ValueError( + 'must specify exactly one of NoteSequence feature and onset feature') + + n = len(examples) + num_samples = round(num_seconds * sample_rate) + + all_samples = np.zeros([n, num_samples]) + + for i, ex in enumerate(examples): + ns = _example_to_note_sequence( + ex, + ns_feature_name=ns_feature_name, + note_onset_feature_name=note_onset_feature_name, + note_offset_feature_name=note_offset_feature_name, + note_frequency_feature_name=note_frequency_feature_name, + note_confidence_feature_name=note_confidence_feature_name, + num_seconds=num_seconds) + fluidsynth = midi_synth.fluidsynth + samples = fluidsynth(ns, sample_rate=sample_rate) + if len(samples) > num_samples: + samples = samples[:num_samples] + all_samples[i, :len(samples)] = samples + + return all_samples + + +def _examples_to_pianorolls( + targets: Sequence[Mapping[str, Sequence[float]]], + predictions: Sequence[Mapping[str, Sequence[float]]], + ns_feature_suffix: str, + note_onset_feature_suffix: str, + note_offset_feature_suffix: str, + note_frequency_feature_suffix: str, + note_confidence_feature_suffix: str, + track_specs: Optional[Sequence[note_sequences.TrackSpec]], + num_seconds: float, + frames_per_second: float +) -> Tuple[np.ndarray, np.ndarray]: + """Generate pianoroll images from example notes. + + Args: + targets: List of target dictionaries, containing either serialized + NoteSequence protos or note onset times and pitches. + predictions: List of prediction dictionaries, containing either serialized + NoteSequence protos or note onset times and pitches. + ns_feature_suffix: Suffix of serialized NoteSequence feature. + note_onset_feature_suffix: Suffix of note onset times feature. + note_offset_feature_suffix: Suffix of note offset times feature. + note_frequency_feature_suffix: Suffix of note frequencies feature. + note_confidence_feature_suffix: Suffix of note confidences (velocities) + feature. + track_specs: Optional list of TrackSpec objects to indicate a set of tracks + into which each NoteSequence should be split. Tracks will be stacked + vertically in the pianorolls + num_seconds: Number of seconds to show for each example. + frames_per_second: Number of pianoroll frames per second. + + Returns: + onset_pianorolls: An n-by-num_pitches-by-num_frames-by-4 numpy array of + pianoroll images showing only onsets. + full_pianorolls: An n-by-num_pitches-by-num_frames-by-4 numpy array of + pianoroll images. + """ + if (ns_feature_suffix is not None) == (note_onset_feature_suffix is not None): + raise ValueError( + 'must specify exactly one of NoteSequence feature and onset feature') + + def ex_to_ns(example, prefix): + return _example_to_note_sequence( + example=example, + ns_feature_name=(prefix + ns_feature_suffix + if ns_feature_suffix else None), + note_onset_feature_name=(prefix + note_onset_feature_suffix + if note_onset_feature_suffix else None), + note_offset_feature_name=(prefix + note_offset_feature_suffix + if note_offset_feature_suffix else None), + note_frequency_feature_name=( + prefix + note_frequency_feature_suffix + if note_frequency_feature_suffix else None), + note_confidence_feature_name=( + prefix + note_confidence_feature_suffix + if note_confidence_feature_suffix else None), + num_seconds=num_seconds) + + n = len(targets) + num_pitches = note_seq.MAX_MIDI_PITCH - note_seq.MIN_MIDI_PITCH + 1 + num_frames = round(num_seconds * frames_per_second) + num_tracks = len(track_specs) if track_specs else 1 + pianoroll_height = num_tracks * num_pitches + (num_tracks - 1) + + onset_images = np.zeros([n, pianoroll_height, num_frames, 3]) + full_images = np.zeros([n, pianoroll_height, num_frames, 3]) + + for i, (target, pred) in enumerate(zip(targets, predictions)): + target_ns, pred_ns = [ + ex_to_ns(ex, prefix) + for (ex, prefix) in [(target, 'ref_'), (pred, 'est_')] + ] + + # Show lines at frame boundaries. To ensure that these lines are drawn with + # the same downsampling and frame selection logic as the real NoteSequences, + # use this hack to draw the lines with a NoteSequence that contains notes + # across all pitches at all frame start times. + start_times_ns = note_seq.NoteSequence() + start_times_ns.CopyFrom(target_ns) + del start_times_ns.notes[:] + for start_time in pred['start_times']: + if start_time < target_ns.total_time: + for pitch in range( + note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH + 1): + start_times_ns.notes.add( + pitch=pitch, + velocity=100, + start_time=start_time, + end_time=start_time + (1 / frames_per_second)) + + start_time_roll = sequences_lib.sequence_to_pianoroll( + start_times_ns, + frames_per_second=frames_per_second, + min_pitch=note_seq.MIN_MIDI_PITCH, + max_pitch=note_seq.MAX_MIDI_PITCH, + onset_mode='length_ms') + num_start_time_frames = min(len(start_time_roll.onsets), num_frames) + + if track_specs is not None: + target_tracks = [note_sequences.extract_track(target_ns, + spec.program, spec.is_drum) + for spec in track_specs] + pred_tracks = [note_sequences.extract_track(pred_ns, + spec.program, spec.is_drum) + for spec in track_specs] + else: + target_tracks = [target_ns] + pred_tracks = [pred_ns] + + for j, (target_track, pred_track) in enumerate(zip(target_tracks[::-1], + pred_tracks[::-1])): + target_roll = sequences_lib.sequence_to_pianoroll( + target_track, + frames_per_second=frames_per_second, + min_pitch=note_seq.MIN_MIDI_PITCH, + max_pitch=note_seq.MAX_MIDI_PITCH, + onset_mode='length_ms') + pred_roll = sequences_lib.sequence_to_pianoroll( + pred_track, + frames_per_second=frames_per_second, + min_pitch=note_seq.MIN_MIDI_PITCH, + max_pitch=note_seq.MAX_MIDI_PITCH, + onset_mode='length_ms') + + num_target_frames = min(len(target_roll.onsets), num_frames) + num_pred_frames = min(len(pred_roll.onsets), num_frames) + + start_offset = j * (num_pitches + 1) + end_offset = (j + 1) * (num_pitches + 1) - 1 + + # Onsets + onset_images[ + i, start_offset:end_offset, :num_start_time_frames, 0 + ] = start_time_roll.onsets[:num_start_time_frames, :].T + onset_images[ + i, start_offset:end_offset, :num_target_frames, 1 + ] = target_roll.onsets[:num_target_frames, :].T + onset_images[ + i, start_offset:end_offset, :num_pred_frames, 2 + ] = pred_roll.onsets[:num_pred_frames, :].T + + # Full notes + full_images[ + i, start_offset:end_offset, :num_start_time_frames, 0 + ] = start_time_roll.onsets[:num_start_time_frames, :].T + full_images[ + i, start_offset:end_offset, :num_target_frames, 1 + ] = target_roll.active[:num_target_frames, :].T + full_images[ + i, start_offset:end_offset, :num_pred_frames, 2 + ] = pred_roll.active[:num_pred_frames, :].T + + # Add separator between tracks. + if j < num_tracks - 1: + onset_images[i, end_offset, :, 0] = 1 + full_images[i, end_offset, :, 0] = 1 + + return onset_images[:, ::-1, :, :], full_images[:, ::-1, :, :] + + +def prettymidi_pianoroll( + track_pianorolls: Mapping[str, Sequence[Tuple[np.ndarray, np.ndarray]]], + fps: float, + num_seconds=_DEFAULT_AUDIO_SECONDS +) -> Mapping[str, seqio.metrics.MetricValue]: + """Create summary from given pianorolls.""" + max_len = int(num_seconds * fps) + summaries = {} + for inst_name, all_prs in track_pianorolls.items(): + + est_prs, ref_prs = zip(*all_prs) + + bs = len(ref_prs) + pianoroll_image_batch = np.zeros(shape=(bs, 128, max_len, 3)) + for i in range(bs): + ref_pr = ref_prs[i][:, :max_len] + est_pr = est_prs[i][:, :max_len] + + pianoroll_image_batch[i, :, :est_pr.shape[1], 2] = est_pr + pianoroll_image_batch[i, :, :ref_pr.shape[1], 1] = ref_pr + if not inst_name: + inst_name = 'all instruments' + + summaries[f'{inst_name} pretty_midi pianoroll'] = seqio.metrics.Image( + image=pianoroll_image_batch, max_outputs=bs) + + return summaries + + +def audio_summaries( + targets: Sequence[Mapping[str, Sequence[float]]], + predictions: Sequence[Mapping[str, Sequence[float]]], + spectrogram_config: spectrograms.SpectrogramConfig, + num_seconds: float = _DEFAULT_AUDIO_SECONDS +) -> Mapping[str, seqio.metrics.MetricValue]: + """Compute audio summaries for a list of examples. + + Args: + targets: List of targets, unused as we pass the input audio tokens via + predictions. + predictions: List of predictions, including input audio tokens. + spectrogram_config: Spectrogram configuration. + num_seconds: Number of seconds of audio to include in the summaries. + Longer audio will be cropped (from the beginning), shorter audio will be + padded with silence (at the end). + + Returns: + A dictionary mapping "audio" to the audio summaries. + """ + del targets + samples = _extract_example_audio( + examples=predictions, + sample_rate=spectrogram_config.sample_rate, + num_seconds=num_seconds) + return { + 'audio': seqio.metrics.Audio( + audiodata=samples[:, :, np.newaxis], + sample_rate=spectrogram_config.sample_rate, + max_outputs=samples.shape[0]) + } + + +def transcription_summaries( + targets: Sequence[Mapping[str, Sequence[float]]], + predictions: Sequence[Mapping[str, Sequence[float]]], + spectrogram_config: spectrograms.SpectrogramConfig, + ns_feature_suffix: Optional[str] = None, + note_onset_feature_suffix: Optional[str] = None, + note_offset_feature_suffix: Optional[str] = None, + note_frequency_feature_suffix: Optional[str] = None, + note_confidence_feature_suffix: Optional[str] = None, + track_specs: Optional[Sequence[note_sequences.TrackSpec]] = None, + num_seconds: float = _DEFAULT_AUDIO_SECONDS, + pianoroll_frames_per_second: float = _DEFAULT_PIANOROLL_FRAMES_PER_SECOND, +) -> Mapping[str, seqio.metrics.MetricValue]: + """Compute note transcription summaries for multiple examples. + + Args: + targets: List of targets containing ground truth. + predictions: List of predictions, including raw input audio. + spectrogram_config: The spectrogram configuration. + ns_feature_suffix: Suffix of serialized NoteSequence feature. + note_onset_feature_suffix: Suffix of note onset times feature. + note_offset_feature_suffix: Suffix of note offset times feature. + note_frequency_feature_suffix: Suffix of note frequencies feature. + note_confidence_feature_suffix: Suffix of note confidences (velocities) + feature. + track_specs: Optional list of TrackSpec objects to indicate a set of tracks + into which each NoteSequence should be split. + num_seconds: Number of seconds of audio to include in the summaries. + Longer audio will be cropped (from the beginning), shorter audio will be + padded with silence (at the end). + pianoroll_frames_per_second: Temporal resolution of pianoroll images. + + Returns: + A dictionary of input, ground truth, and transcription summaries. + """ + audio_samples = _extract_example_audio( + examples=predictions, + sample_rate=spectrogram_config.sample_rate, + num_seconds=num_seconds) + + def synthesize(examples, prefix): + return _synthesize_example_notes( + examples=examples, + ns_feature_name=(prefix + ns_feature_suffix + if ns_feature_suffix else None), + note_onset_feature_name=(prefix + note_onset_feature_suffix + if note_onset_feature_suffix else None), + note_offset_feature_name=(prefix + note_offset_feature_suffix + if note_offset_feature_suffix else None), + note_frequency_feature_name=( + prefix + note_frequency_feature_suffix + if note_frequency_feature_suffix else None), + note_confidence_feature_name=( + prefix + note_confidence_feature_suffix + if note_confidence_feature_suffix else None), + sample_rate=spectrogram_config.sample_rate, + num_seconds=num_seconds) + + synthesized_predictions = synthesize(predictions, 'est_') + + onset_pianoroll_images, full_pianoroll_images = _examples_to_pianorolls( + targets=targets, + predictions=predictions, + ns_feature_suffix=ns_feature_suffix, + note_onset_feature_suffix=note_onset_feature_suffix, + note_offset_feature_suffix=note_offset_feature_suffix, + note_frequency_feature_suffix=note_frequency_feature_suffix, + note_confidence_feature_suffix=note_confidence_feature_suffix, + track_specs=track_specs, + num_seconds=num_seconds, + frames_per_second=pianoroll_frames_per_second) + + return { + 'input_with_transcription': seqio.metrics.Audio( + audiodata=np.stack([audio_samples, synthesized_predictions], axis=2), + sample_rate=spectrogram_config.sample_rate, + max_outputs=audio_samples.shape[0]), + + 'pianoroll': seqio.metrics.Image( + image=full_pianoroll_images, + max_outputs=full_pianoroll_images.shape[0]), + + 'onset_pianoroll': seqio.metrics.Image( + image=onset_pianoroll_images, + max_outputs=onset_pianoroll_images.shape[0]), + } diff --git a/mt3/tasks.py b/mt3/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..287fa5116a7adbd3283f605cdd5fda668896c9fd --- /dev/null +++ b/mt3/tasks.py @@ -0,0 +1,402 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transcription task definitions.""" + +import functools +from typing import Optional, Sequence + +from mt3 import datasets +from mt3 import event_codec +from mt3 import metrics +from mt3 import mixing +from mt3 import preprocessors +from mt3 import run_length_encoding +from mt3 import spectrograms +from mt3 import vocabularies + +import note_seq +import numpy as np +import seqio +import t5 +import tensorflow as tf + +# Split audio frame sequences into this length before the cache placeholder. +MAX_NUM_CACHED_FRAMES = 2000 + +seqio.add_global_cache_dirs(['gs://mt3/data/cache_tasks/']) + + +def construct_task_name( + task_prefix: str, + spectrogram_config=spectrograms.SpectrogramConfig(), + vocab_config=vocabularies.VocabularyConfig(), + task_suffix: Optional[str] = None +) -> str: + """Construct task name from prefix, config, and optional suffix.""" + fields = [task_prefix] + if spectrogram_config.abbrev_str: + fields.append(spectrogram_config.abbrev_str) + if vocab_config.abbrev_str: + fields.append(vocab_config.abbrev_str) + if task_suffix: + fields.append(task_suffix) + return '_'.join(fields) + + +def trim_eos(tokens: Sequence[int]) -> np.ndarray: + """If EOS is present, remove it and everything after.""" + tokens = np.array(tokens, np.int32) + if vocabularies.DECODED_EOS_ID in tokens: + tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)] + return tokens + + +def postprocess(tokens, example, is_target, codec): + """Transcription postprocessing function.""" + tokens = trim_eos(tokens) + + if is_target: + return { + 'unique_id': example['unique_id'][0], + 'ref_ns': (note_seq.NoteSequence.FromString(example['sequence'][0]) + if example['sequence'][0] else None), + 'ref_tokens': tokens, + } + + start_time = example['input_times'][0] + # Round down to nearest symbolic token step. + start_time -= start_time % (1 / codec.steps_per_second) + + return { + 'unique_id': example['unique_id'][0], + 'raw_inputs': example['raw_inputs'], + 'est_tokens': tokens, + 'start_time': start_time + } + + +def add_transcription_task_to_registry( + dataset_config: datasets.DatasetConfig, + spectrogram_config: spectrograms.SpectrogramConfig, + vocab_config: vocabularies.VocabularyConfig, + tokenize_fn, # TODO(iansimon): add type signature + onsets_only: bool, + include_ties: bool, + skip_too_long: bool = False +) -> None: + """Add note transcription task to seqio.TaskRegistry.""" + codec = vocabularies.build_codec(vocab_config) + vocabulary = vocabularies.vocabulary_from_codec(codec) + + output_features = { + 'targets': seqio.Feature(vocabulary=vocabulary), + 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2) + } + + task_name = 'onsets' if onsets_only else 'notes' + if include_ties: + task_name += '_ties' + task_prefix = f'{dataset_config.name}_{task_name}' + + train_task_name = construct_task_name( + task_prefix=task_prefix, + spectrogram_config=spectrogram_config, + vocab_config=vocab_config, + task_suffix='train') + + mixture_task_names = [] + + tie_token = codec.encode_event(event_codec.Event('tie', 0)) + track_specs = (dataset_config.track_specs + if dataset_config.track_specs else None) + + # Add transcription training task. + seqio.TaskRegistry.add( + train_task_name, + source=seqio.TFExampleDataSource( + split_to_filepattern={ + 'train': dataset_config.paths[dataset_config.train_split], + 'eval': dataset_config.paths[dataset_config.train_eval_split] + }, + feature_description=dataset_config.features), + output_features=output_features, + preprocessors=[ + functools.partial( + tokenize_fn, + spectrogram_config=spectrogram_config, codec=codec, + is_training_data=True, onsets_only=onsets_only, + include_ties=include_ties), + functools.partial( + t5.data.preprocessors.split_tokens, + max_tokens_per_segment=MAX_NUM_CACHED_FRAMES, + feature_key='inputs', + additional_feature_keys=[ + 'input_event_start_indices', 'input_event_end_indices', + 'input_state_event_indices' + ], + passthrough_feature_keys=['targets', 'state_events']), + seqio.CacheDatasetPlaceholder(), + functools.partial( + t5.data.preprocessors.select_random_chunk, + feature_key='inputs', + additional_feature_keys=[ + 'input_event_start_indices', 'input_event_end_indices', + 'input_state_event_indices' + ], + passthrough_feature_keys=['targets', 'state_events'], + uniform_random_start=True), + functools.partial( + run_length_encoding.extract_target_sequence_with_indices, + state_events_end_token=tie_token if include_ties else None), + functools.partial(preprocessors.map_midi_programs, codec=codec), + run_length_encoding.run_length_encode_shifts_fn( + codec, + feature_key='targets'), + functools.partial( + mixing.mix_transcription_examples, + codec=codec, + targets_feature_keys=['targets']), + run_length_encoding.remove_redundant_state_changes_fn( + feature_key='targets', codec=codec, + state_change_event_types=['velocity', 'program']), + functools.partial( + preprocessors.compute_spectrograms, + spectrogram_config=spectrogram_config), + functools.partial(preprocessors.handle_too_long, skip=skip_too_long), + functools.partial( + seqio.preprocessors.tokenize_and_append_eos, + copy_pretokenized=False) + ], + postprocess_fn=None, + metric_fns=[], + ) + + # Add transcription eval tasks. + for split in dataset_config.infer_eval_splits: + eval_task_name = construct_task_name( + task_prefix=task_prefix, + spectrogram_config=spectrogram_config, + vocab_config=vocab_config, + task_suffix=split.suffix) + + if split.include_in_mixture: + mixture_task_names.append(eval_task_name) + + seqio.TaskRegistry.add( + eval_task_name, + source=seqio.TFExampleDataSource( + split_to_filepattern={'eval': dataset_config.paths[split.name]}, + feature_description=dataset_config.features), + output_features=output_features, + preprocessors=[ + functools.partial( + tokenize_fn, + spectrogram_config=spectrogram_config, codec=codec, + is_training_data='train' in split.name, onsets_only=onsets_only, + include_ties=include_ties), + seqio.CacheDatasetPlaceholder(), + preprocessors.add_unique_id, + preprocessors.pad_notesequence_array, + functools.partial( + t5.data.preprocessors.split_tokens_to_inputs_length, + feature_key='inputs', + additional_feature_keys=['input_times', 'sequence'], + passthrough_feature_keys=['unique_id']), + # Add dummy targets as they are dropped during the above split to + # avoid memory blowups, but expected to be present by seqio; the + # evaluation metrics currently only use the target NoteSequence. + preprocessors.add_dummy_targets, + functools.partial( + preprocessors.compute_spectrograms, + spectrogram_config=spectrogram_config), + functools.partial(preprocessors.handle_too_long, skip=False), + functools.partial( + seqio.preprocessors.tokenize_and_append_eos, + copy_pretokenized=False) + ], + postprocess_fn=functools.partial(postprocess, codec=codec), + metric_fns=[ + functools.partial( + metrics.transcription_metrics, + codec=codec, + spectrogram_config=spectrogram_config, + onsets_only=onsets_only, + use_ties=include_ties, + track_specs=track_specs) + ], + ) + + seqio.MixtureRegistry.add( + construct_task_name( + task_prefix=task_prefix, spectrogram_config=spectrogram_config, + vocab_config=vocab_config, task_suffix='eval'), + mixture_task_names, + default_rate=1) + + +# Just use default spectrogram config. +SPECTROGRAM_CONFIG = spectrograms.SpectrogramConfig() + +# Create two vocabulary configs, one default and one with only on-off velocity. +VOCAB_CONFIG_FULL = vocabularies.VocabularyConfig() +VOCAB_CONFIG_NOVELOCITY = vocabularies.VocabularyConfig(num_velocity_bins=1) + +# Transcribe MAESTRO v1. +add_transcription_task_to_registry( + dataset_config=datasets.MAESTROV1_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_FULL, + tokenize_fn=functools.partial( + preprocessors.tokenize_transcription_example, + audio_is_samples=False, + id_feature_key='id'), + onsets_only=False, + include_ties=False) + +# Transcribe MAESTRO v3. +add_transcription_task_to_registry( + dataset_config=datasets.MAESTROV3_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_FULL, + tokenize_fn=functools.partial( + preprocessors.tokenize_transcription_example, + audio_is_samples=False, + id_feature_key='id'), + onsets_only=False, + include_ties=False) + +# Transcribe MAESTRO v3 without velocities, with ties. +add_transcription_task_to_registry( + dataset_config=datasets.MAESTROV3_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + tokenize_fn=functools.partial( + preprocessors.tokenize_transcription_example, + audio_is_samples=False, + id_feature_key='id'), + onsets_only=False, + include_ties=True) + +# Transcribe GuitarSet, with ties. +add_transcription_task_to_registry( + dataset_config=datasets.GUITARSET_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + tokenize_fn=preprocessors.tokenize_guitarset_example, + onsets_only=False, + include_ties=True) + +# Transcribe URMP mixes, with ties. +add_transcription_task_to_registry( + dataset_config=datasets.URMP_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + tokenize_fn=functools.partial( + preprocessors.tokenize_example_with_program_lookup, + inst_name_to_program_fn=preprocessors.urmp_instrument_to_program, + id_feature_key='id'), + onsets_only=False, + include_ties=True) + +# Transcribe MusicNet, with ties. +add_transcription_task_to_registry( + dataset_config=datasets.MUSICNET_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + tokenize_fn=functools.partial( + preprocessors.tokenize_transcription_example, + audio_is_samples=True, + id_feature_key='id'), + onsets_only=False, + include_ties=True) + +# Transcribe MusicNetEM, with ties. +add_transcription_task_to_registry( + dataset_config=datasets.MUSICNET_EM_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + tokenize_fn=functools.partial( + preprocessors.tokenize_transcription_example, + audio_is_samples=True, + id_feature_key='id'), + onsets_only=False, + include_ties=True) + +# Transcribe Cerberus4 (piano-guitar-bass-drums quartets), with ties. +add_transcription_task_to_registry( + dataset_config=datasets.CERBERUS4_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + tokenize_fn=functools.partial( + preprocessors.tokenize_slakh_example, + track_specs=datasets.CERBERUS4_CONFIG.track_specs, + ignore_pitch_bends=True), + onsets_only=False, + include_ties=True) + +# Transcribe 10 random sub-mixes of each song from Slakh, with ties. +add_transcription_task_to_registry( + dataset_config=datasets.SLAKH_CONFIG, + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + tokenize_fn=functools.partial( + preprocessors.tokenize_slakh_example, + track_specs=None, + ignore_pitch_bends=True), + onsets_only=False, + include_ties=True) + + +# Construct task names to include in transcription mixture. +MIXTURE_DATASET_NAMES = [ + 'maestrov3', 'guitarset', 'urmp', 'musicnet_em', 'cerberus4', 'slakh' +] +MIXTURE_TRAIN_TASK_NAMES = [] +MIXTURE_EVAL_TASK_NAMES = [] +MIXTURE_TEST_TASK_NAMES = [] +for dataset_name in MIXTURE_DATASET_NAMES: + MIXTURE_TRAIN_TASK_NAMES.append( + construct_task_name(task_prefix=f'{dataset_name}_notes_ties', + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + task_suffix='train')) + MIXTURE_EVAL_TASK_NAMES.append( + construct_task_name(task_prefix=f'{dataset_name}_notes_ties', + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + task_suffix='validation')) +MIXING_TEMPERATURE = 10 / 3 + +# Add the mixture of all transcription tasks, with ties. +seqio.MixtureRegistry.add( + construct_task_name( + task_prefix='mega_notes_ties', + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + task_suffix='train'), + MIXTURE_TRAIN_TASK_NAMES, + default_rate=functools.partial( + seqio.mixing_rate_num_examples, + temperature=MIXING_TEMPERATURE)) +seqio.MixtureRegistry.add( + construct_task_name( + task_prefix='mega_notes_ties', + spectrogram_config=SPECTROGRAM_CONFIG, + vocab_config=VOCAB_CONFIG_NOVELOCITY, + task_suffix='eval'), + MIXTURE_EVAL_TASK_NAMES, + default_rate=functools.partial( + seqio.mixing_rate_num_examples, + temperature=MIXING_TEMPERATURE)) diff --git a/mt3/version.py b/mt3/version.py new file mode 100644 index 0000000000000000000000000000000000000000..bfaf4a4165e6a3ba1bc2c1e03c3d016694da67e5 --- /dev/null +++ b/mt3/version.py @@ -0,0 +1,16 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MT3 version.""" +__version__ = '0.0.1' diff --git a/mt3/vocabularies.py b/mt3/vocabularies.py new file mode 100644 index 0000000000000000000000000000000000000000..e786d8673628f56678bdab573902d3052aa90f0f --- /dev/null +++ b/mt3/vocabularies.py @@ -0,0 +1,282 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model vocabulary.""" + +import dataclasses +import math + +from typing import Callable, Optional, Sequence +from mt3 import event_codec + +import note_seq +import seqio +import t5.data +import tensorflow as tf + + +DECODED_EOS_ID = -1 +DECODED_INVALID_ID = -2 + +# defaults for vocabulary config +DEFAULT_STEPS_PER_SECOND = 100 +DEFAULT_MAX_SHIFT_SECONDS = 10 +DEFAULT_NUM_VELOCITY_BINS = 127 + + +@dataclasses.dataclass +class VocabularyConfig: + """Vocabulary configuration parameters.""" + steps_per_second: int = DEFAULT_STEPS_PER_SECOND + max_shift_seconds: int = DEFAULT_MAX_SHIFT_SECONDS + num_velocity_bins: int = DEFAULT_NUM_VELOCITY_BINS + + @property + def abbrev_str(self): + s = '' + if self.steps_per_second != DEFAULT_STEPS_PER_SECOND: + s += 'ss%d' % self.steps_per_second + if self.max_shift_seconds != DEFAULT_MAX_SHIFT_SECONDS: + s += 'ms%d' % self.max_shift_seconds + if self.num_velocity_bins != DEFAULT_NUM_VELOCITY_BINS: + s += 'vb%d' % self.num_velocity_bins + return s + + +def num_velocity_bins_from_codec(codec: event_codec.Codec): + """Get number of velocity bins from event codec.""" + lo, hi = codec.event_type_range('velocity') + return hi - lo + + +def velocity_to_bin(velocity, num_velocity_bins): + if velocity == 0: + return 0 + else: + return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) + + +def bin_to_velocity(velocity_bin, num_velocity_bins): + if velocity_bin == 0: + return 0 + else: + return int(note_seq.MAX_MIDI_VELOCITY * velocity_bin / num_velocity_bins) + + +def drop_programs(tokens, codec: event_codec.Codec): + """Drops program change events from a token sequence.""" + min_program_id, max_program_id = codec.event_type_range('program') + return tokens[(tokens < min_program_id) | (tokens > max_program_id)] + + +def programs_to_midi_classes(tokens, codec): + """Modifies program events to be the first program in the MIDI class.""" + min_program_id, max_program_id = codec.event_type_range('program') + is_program = (tokens >= min_program_id) & (tokens <= max_program_id) + return tf.where( + is_program, + min_program_id + 8 * ((tokens - min_program_id) // 8), + tokens) + + +@dataclasses.dataclass +class ProgramGranularity: + # both tokens_map_fn and program_map_fn should be idempotent + tokens_map_fn: Callable[[Sequence[int], event_codec.Codec], Sequence[int]] + program_map_fn: Callable[[int], int] + + +PROGRAM_GRANULARITIES = { + # "flat" granularity; drop program change tokens and set NoteSequence + # programs to zero + 'flat': ProgramGranularity( + tokens_map_fn=drop_programs, + program_map_fn=lambda program: 0), + + # map each program to the first program in its MIDI class + 'midi_class': ProgramGranularity( + tokens_map_fn=programs_to_midi_classes, + program_map_fn=lambda program: 8 * (program // 8)), + + # leave programs as is + 'full': ProgramGranularity( + tokens_map_fn=lambda tokens, codec: tokens, + program_map_fn=lambda program: program) +} + + +def build_codec(vocab_config: VocabularyConfig): + """Build event codec.""" + event_ranges = [ + event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + # velocity bin 0 is used for note-off + event_codec.EventRange('velocity', 0, vocab_config.num_velocity_bins), + # used to indicate that a pitch is present at the beginning of a segment + # (only has an "off" event as when using ties all pitch events until the + # "tie" event belong to the tie section) + event_codec.EventRange('tie', 0, 0), + event_codec.EventRange('program', note_seq.MIN_MIDI_PROGRAM, + note_seq.MAX_MIDI_PROGRAM), + event_codec.EventRange('drum', note_seq.MIN_MIDI_PITCH, + note_seq.MAX_MIDI_PITCH), + ] + + return event_codec.Codec( + max_shift_steps=(vocab_config.steps_per_second * + vocab_config.max_shift_seconds), + steps_per_second=vocab_config.steps_per_second, + event_ranges=event_ranges) + + +def vocabulary_from_codec(codec: event_codec.Codec) -> seqio.Vocabulary: + return GenericTokenVocabulary( + codec.num_classes, extra_ids=t5.data.DEFAULT_EXTRA_IDS) + + +class GenericTokenVocabulary(seqio.Vocabulary): + """Vocabulary with pass-through encoding of tokens.""" + + def __init__(self, regular_ids: int, extra_ids: int = 0): + # The special tokens: 0=PAD, 1=EOS, and 2=UNK + self._num_special_tokens = 3 + self._num_regular_tokens = regular_ids + super().__init__(extra_ids=extra_ids) + + @property + def eos_id(self) -> Optional[int]: + return 1 + + @property + def unk_id(self) -> Optional[int]: + return 2 + + @property + def _base_vocab_size(self) -> int: + """Number of ids. + + Returns: + an integer, the vocabulary size + """ + return self._num_special_tokens + self._num_regular_tokens + + def _encode(self, token_ids: Sequence[int]) -> Sequence[int]: + """Encode a list of tokens ids as a list of integers. + + To keep the first few ids for special tokens, increase ids by the number + of special tokens. + + Args: + token_ids: array of token ids. + + Returns: + a list of integers (not terminated by EOS) + """ + encoded = [] + for token_id in token_ids: + if not 0 <= token_id < self._num_regular_tokens: + raise ValueError( + f'token_id {token_id} does not fall within valid range of ' + f'[0, {self._num_regular_tokens})') + encoded.append(token_id + self._num_special_tokens) + + return encoded + + def _decode(self, ids: Sequence[int]) -> Sequence[int]: + """Decode a list of integers to a list of token ids. + + The special tokens of PAD and UNK as well as extra_ids will be + replaced with DECODED_INVALID_ID in the output. If EOS is present, it will + be the final token in the decoded output and will be represented by + DECODED_EOS_ID. + + Args: + ids: a list of integers + + Returns: + a list of token ids. + """ + # convert all the extra ids to INVALID_ID + def _decode_id(encoded_id): + if encoded_id == self.eos_id: + return DECODED_EOS_ID + elif encoded_id < self._num_special_tokens: + return DECODED_INVALID_ID + elif encoded_id >= self._base_vocab_size: + return DECODED_INVALID_ID + else: + return encoded_id - self._num_special_tokens + ids = [_decode_id(int(i)) for i in ids] + return ids + + def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor: + """Encode a list of tokens to a tf.Tensor. + + Args: + token_ids: array of audio token ids. + + Returns: + a 1d tf.Tensor with dtype tf.int32 + """ + with tf.control_dependencies( + [tf.debugging.assert_less( + token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)), + tf.debugging.assert_greater_equal( + token_ids, tf.cast(0, token_ids.dtype)) + ]): + tf_ids = token_ids + self._num_special_tokens + return tf_ids + + def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: + """Decode in TensorFlow. + + The special tokens of PAD and UNK as well as extra_ids will be + replaced with DECODED_INVALID_ID in the output. If EOS is present, it and + all following tokens in the decoded output and will be represented by + DECODED_EOS_ID. + + Args: + ids: a 1d tf.Tensor with dtype tf.int32 + + Returns: + a 1d tf.Tensor with dtype tf.int32 + """ + # Create a mask that is true from the first EOS position onward. + # First, create an array that is True whenever there is an EOS, then cumsum + # that array so that every position after and including the first True is + # >1, then cast back to bool for the final mask. + eos_and_after = tf.cumsum( + tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1) + eos_and_after = tf.cast(eos_and_after, tf.bool) + + return tf.where( + eos_and_after, + DECODED_EOS_ID, + tf.where( + tf.logical_and( + tf.greater_equal(ids, self._num_special_tokens), + tf.less(ids, self._base_vocab_size)), + ids - self._num_special_tokens, + DECODED_INVALID_ID)) + + def __eq__(self, other): + their_extra_ids = other.extra_ids + their_num_regular_tokens = other._num_regular_tokens + return (self.extra_ids == their_extra_ids and + self._num_regular_tokens == their_num_regular_tokens) + + +def num_embeddings(vocabulary: GenericTokenVocabulary) -> int: + """Vocabulary size as a multiple of 128 for TPU efficiency.""" + return 128 * math.ceil(vocabulary.vocab_size / 128) diff --git a/mt3/vocabularies_test.py b/mt3/vocabularies_test.py new file mode 100644 index 0000000000000000000000000000000000000000..257b593c428aeb57b7bd1bf57bc1dcea481e75c0 --- /dev/null +++ b/mt3/vocabularies_test.py @@ -0,0 +1,114 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for vocabularies.""" + +from absl.testing import absltest +from mt3 import vocabularies + +import numpy as np +import tensorflow.compat.v2 as tf + +tf.compat.v1.enable_eager_execution() + + +class VocabulariesTest(absltest.TestCase): + + def test_velocity_quantization(self): + self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=1)) + self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=127)) + self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=1)) + self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=127)) + + self.assertEqual( + 1, + vocabularies.velocity_to_bin( + vocabularies.bin_to_velocity(1, num_velocity_bins=1), + num_velocity_bins=1)) + + for velocity_bin in range(1, 128): + self.assertEqual( + velocity_bin, + vocabularies.velocity_to_bin( + vocabularies.bin_to_velocity(velocity_bin, num_velocity_bins=127), + num_velocity_bins=127)) + + def test_encode_decode(self): + vocab = vocabularies.GenericTokenVocabulary(32) + input_tokens = [1, 2, 3] + expected_encoded = [4, 5, 6] + + # Encode + self.assertSequenceEqual(vocab.encode(input_tokens), expected_encoded) + np.testing.assert_array_equal( + vocab.encode_tf(tf.convert_to_tensor(input_tokens)).numpy(), + expected_encoded) + + # Decode + self.assertSequenceEqual(vocab.decode(expected_encoded), input_tokens) + np.testing.assert_array_equal( + vocab.decode_tf(tf.convert_to_tensor(expected_encoded)).numpy(), + input_tokens) + + def test_decode_invalid_ids(self): + vocab = vocabularies.GenericTokenVocabulary(32, extra_ids=4) + encoded = [0, 2, 3, 4, 34, 35] + expected_decoded = [-2, -2, 0, 1, 31, -2] + self.assertSequenceEqual(vocab.decode(encoded), expected_decoded) + np.testing.assert_array_equal( + vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(), + expected_decoded) + + def test_decode_eos(self): + vocab = vocabularies.GenericTokenVocabulary(32) + encoded = [0, 2, 3, 4, 1, 0, 1, 0] + # Python decode function truncates everything after first EOS. + expected_decoded = [-2, -2, 0, 1, -1] + self.assertSequenceEqual(vocab.decode(encoded), expected_decoded) + # TF decode function preserves array length. + expected_decoded_tf = [-2, -2, 0, 1, -1, -1, -1, -1] + np.testing.assert_array_equal( + vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(), + expected_decoded_tf) + + def test_encode_invalid_id(self): + vocab = vocabularies.GenericTokenVocabulary(32) + inputs = [0, 15, 31] + # No exception expected. + vocab.encode(inputs) + vocab.encode_tf(tf.convert_to_tensor(inputs)) + + inputs_too_low = [-1, 15, 31] + with self.assertRaises(ValueError): + vocab.encode(inputs_too_low) + with self.assertRaises(tf.errors.InvalidArgumentError): + vocab.encode_tf(tf.convert_to_tensor(inputs_too_low)) + + inputs_too_high = [0, 15, 32] + with self.assertRaises(ValueError): + vocab.encode(inputs_too_high) + with self.assertRaises(tf.errors.InvalidArgumentError): + vocab.encode_tf(tf.convert_to_tensor(inputs_too_high)) + + def test_encode_dtypes(self): + vocab = vocabularies.GenericTokenVocabulary(32) + inputs = [0, 15, 31] + encoded32 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int32)) + self.assertEqual(tf.int32, encoded32.dtype) + encoded64 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int64)) + self.assertEqual(tf.int64, encoded64.dtype) + + +if __name__ == '__main__': + absltest.main() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..5f1cd9f4e5ef281a47dc69180674da15e6b28f56 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +python_files = *_test.py +log_level = INFO \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..9af7e6f11bb01f7306f796faf7bfbe3e2955cd94 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[aliases] +test=pytest \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..eee54d2b815d0a07f7d257c722abaa1a55632d40 --- /dev/null +++ b/setup.py @@ -0,0 +1,67 @@ +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Install mt3.""" + +import os +import sys +import setuptools + +# To enable importing version.py directly, we add its path to sys.path. +version_path = os.path.join(os.path.dirname(__file__), 'mt3') +sys.path.append(version_path) +from version import __version__ # pylint: disable=g-import-not-at-top + +setuptools.setup( + name='mt3', + version=__version__, + description='Multi-Task Multitrack Music Transcription', + author='Google Inc.', + author_email='no-reply@google.com', + url='http://github.com/magenta/mt3', + license='Apache 2.0', + packages=setuptools.find_packages(), + package_data={ + '': ['*.gin'], + }, + scripts=[], + install_requires=[ + 'absl-py == 1.1.0', + 'ddsp == 3.4.4', + 'flax == 0.5.2', + 'gin-config == 0.5.0', + 'immutabledict == 2.2.1', + 'librosa == 0.9.2', + 'mir_eval == 0.7', + 'note_seq == 0.0.3', + 'numpy == 1.21.6', + 'pretty_midi == 0.2.9', + 'scikit-learn == 1.0.2', + 'scipy == 1.7.3', + 'seqio == 0.0.8', + 't5 == 0.9.3', + 'tensorflow', + 'tensorflow-datasets == 4.6.0', + ], + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], + tests_require=['pytest'], + setup_requires=['pytest-runner'], + keywords='music transcription machinelearning audio', +) diff --git a/t5x/__init__.py b/t5x/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31cc54f5e709ffd6aa72364b24940dc7f12783f3 --- /dev/null +++ b/t5x/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Import API modules.""" + +import t5x.adafactor +import t5x.checkpoints +import t5x.decoding +import t5x.gin_utils +import t5x.losses +import t5x.models +import t5x.partitioning +import t5x.state_utils +import t5x.train_state +import t5x.trainer +import t5x.utils + +# Version number. +from t5x.version import __version__ + +# TODO(adarob): Move clients to t5x.checkpointing and rename +# checkpoints.py to checkpointing.py +checkpointing = t5x.checkpoints diff --git a/t5x/adafactor.py b/t5x/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..67700bf0e467545f9cdaf8847965a8b9e8ab6157 --- /dev/null +++ b/t5x/adafactor.py @@ -0,0 +1,608 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adafactor Optimizer. + +Specialized Adafactor implementation for T5X with: + - custom factorization specification rules. + - support for stacked parameters from scanned layers and parameter fusions. + +Why do we need custom factorization? In the Adafactor paper, scalar, vector and +matrix parameters are considered. This is sufficiently general because higher +dimensional parameters can be reshaped. In practice, there are situations where +higher dimensional parameters are desirable. For example, consider the +multi-headed attention. It has projection kernels. This is naturally +represented as 3-dimensional array [d_model, num_head, head_dim]. Keeping the +3-dimensional structure can be beneficial for performance optimization, e.g., by +giving compilers additional degree of freedom to do layout optimization. + +The default heuristic behavior for the second-moment estimator can lead to an +unexpected result because it assumes that the parameters are matrices (vectors +and scalars are not factored). The dimensions are sorted and the smaller +dimension is assigned to the row dim and the larger dim to the col dim (unless +the two largest dims have an equal size and then the original ordering of the +dimensions is used). Then `v_row` (i.e., the optimizer state for the row) is +obtained by removing the col dim. In other words, `rank(v_row) = rank(v) - 1`. +If the parameter is higher dimensional, v_row and v_col are higher dimensional. +Therefore, the outer product of v_row and v_col do not necessarily corresponds +to the row rank approximation that minimizes the generalized Kullback-Leibler +divergence (the original Adafactor formulation). + +This Adafactor implementation generalized the default behavior such that we +obtain the correct second moment estimator even for higher dimensional +parameters. + +""" +import enum +import re +from typing import Any, Mapping, Optional, Sequence, Tuple, Union + +from absl import logging +from flax import struct +from flax.core import freeze +from flax.core import FrozenDict +from flax.core import unfreeze +from flax.serialization import from_state_dict +from flax.serialization import to_state_dict +from flax.traverse_util import flatten_dict +from flax.traverse_util import unflatten_dict +import jax +import jax.numpy as jnp +import numpy as np +from t5x import utils +from t5x.optimizers import OptimizerDef +from t5x.optimizers import OptimizerState + +Dtype = Any + + +class FactorDim(enum.Enum): + # Don't factorize this dimension. + NONE = None + # A batch-like dimension that we should not average over. + BATCH = 1 + ROW = 2 + COLUMN = 3 + + +# Sentinel value signifying the legacy heuristic factorization rule. +class HeuristicRule(enum.Enum): + token = 1 + + +HEURISTIC_RULE = HeuristicRule.token +FactorRule = Union[HeuristicRule, Tuple[FactorDim]] + + +def _restore(target, flat): + state_dict = unflatten_dict({tuple(k.split('/')): v for k, v in flat.items()}) + if isinstance(target, FrozenDict): + return freeze(state_dict) + else: + return state_dict + + +def _insert(tpl, idx, x): + tmp = list(tpl) + tmp.insert(idx, x) + return tuple(tmp) + + +def standard_logical_factor_rules(): + return freeze({ + 'vocab': FactorDim.COLUMN, + 'embed': FactorDim.ROW, + 'mlp': FactorDim.COLUMN, + 'heads': FactorDim.COLUMN, + 'kv': FactorDim.COLUMN, + 'joined_kv': FactorDim.COLUMN, + 'relpos_buckets': FactorDim.NONE, + 'layers': FactorDim.BATCH, # used in scanned layers + 'stack': FactorDim.BATCH, # used in stacked params + # 'batch', 'length' should not occur in parameters + 'q_wi_fused': FactorDim.COLUMN, + 'o_wo_fused': FactorDim.COLUMN, + 'multiquery_heads': FactorDim.COLUMN, + 'kv_fused': FactorDim.COLUMN, + 'layer_norm_scale': FactorDim.NONE, + 'mlp_activations': FactorDim.COLUMN, + }) + + +def factor_name_to_factordim(name): + if not isinstance(name, str): + return name + name = name.lower() + return { + 'row': FactorDim.ROW, + 'col': FactorDim.COLUMN, + 'column': FactorDim.COLUMN, + 'batch': FactorDim.BATCH, + 'none': FactorDim.NONE, + 'unfactorized': FactorDim.NONE + }[name] + + +class HParamMap: + """Maps parameter path names to hparams. + + Names of parameters nested in a PyTree (e.g., an Optimizer) are formed by + joining the names along the path to the parameter leaf with '/'. + """ + + def __init__(self, rules): + self._rules = [(re.compile(r), p) for r, p in rules] + + def __getitem__(self, key: str) -> Any: + for r, p in self._rules: + if r.search(key): + return p + raise KeyError(f'No factor rule found for parameter: {key}') + + def __call__(self, params): + """Returns a copy of the params with mapped hparams in leaves.""" + flat_state_dict = flatten_dict(to_state_dict(params)) + flat_rules_dict = {k: self['/'.join(k)] for k in flat_state_dict.keys()} + return from_state_dict(params, unflatten_dict(flat_rules_dict)) + + +@struct.dataclass +class _AdafactorHyperParams: + """Hparams for Adafactor optimizer.""" + learning_rate: Optional[float] + factored: bool + multiply_by_parameter_scale: Union[bool, HParamMap] + beta1: Optional[float] + decay_rate: float + step_offset: int + clipping_threshold: Optional[float] + weight_decay_rate: Optional[float] + min_dim_size_to_factor: int + epsilon1: float + epsilon2: float + factor_map: Optional[HParamMap] = None + logical_factor_rules: Any = None + weight_decay_rate_lr_exponent: Optional[float] = None + global_norm_clip_threshold: Optional[float] = None + max_parameter_scale: Optional[float] = None + skip_nan_updates: Optional[bool] = False + + +@struct.dataclass +class _AdafactorParamState: + v_row: np.ndarray # used in normal factored version + v_col: np.ndarray + v: np.ndarray # only used without factoring + m: np.ndarray # only used with momentum + + +class Adafactor(OptimizerDef): + """Adafactor optimizer. + + Adafactor is described in https://arxiv.org/abs/1804.04235. + """ + + def __init__(self, + learning_rate: Optional[float] = None, + factored: bool = True, + multiply_by_parameter_scale: Union[bool, HParamMap] = True, + beta1: Optional[float] = None, + decay_rate: float = 0.8, + step_offset: int = 0, + clipping_threshold: Optional[float] = 1.0, + weight_decay_rate: Optional[float] = None, + min_dim_size_to_factor: int = 128, + epsilon1: float = 1e-30, + epsilon2: float = 1e-3, + dtype_momentum: Dtype = jnp.float32, + factor_map: Optional[HParamMap] = None, + logical_factor_rules: Optional[Mapping[str, FactorDim]] = None, + weight_decay_rate_lr_exponent: Optional[float] = None, + global_norm_clip_threshold: Optional[float] = None, + max_parameter_scale: Optional[float] = None, + skip_nan_updates: Optional[bool] = False): + """Constructor for the Adafactor optimizer. + + + Args: + learning_rate: float: learning rate. NB: the natural scale for adafactor + LR is markedly different from Adam, one doesn't use the 1/sqrt(hidden) + correction for this optimizer with attention-based models. + factored: boolean: whether to use factored second-moment estimator for 2d + variables. + multiply_by_parameter_scale: boolean: if True, then scale provided + learning_rate by parameter norm. if False, provided learning_rate is + absolute step size. + beta1: an optional float value between 0 and 1, enables momentum and uses + extra memory if non-None! None by default. + decay_rate: float: controls second-moment exponential decay schedule. + step_offset: for finetuning, one may optionally set this to the starting + step-number of the finetuning phase to reset the second moment + accumulators after pretraining. Does not affect the momentum even if it + was used during pretraining. + clipping_threshold: an optional float >= 1, if None no update clipping. + weight_decay_rate: optional rate at which to decay weights. + min_dim_size_to_factor: only factor accumulator if two array dimensions + are at least this size. + epsilon1: Regularization constant for squared gradient. + epsilon2: Regularization constant for parameter scale. + dtype_momentum: dtype of momentum buffers. + factor_map: hparam-map from key path to manual factorization rules. + logical_factor_rules: factorization rules provided as a set of mappings + from logical axis name to ROW, COLUMN, BATCH, or NONE. Supercedes + factor_map if `set_param_axes` is called. + weight_decay_rate_lr_exponent: If present, weight decay rate is computed + as (learning_rate ** weight_decay_rate_lr_exponent). If + weight_decay_rate is also present, then multiply by it. + global_norm_clip_threshold: If set, will clip gradients by global norm + before Adafactor stats are applied. + max_parameter_scale: If set, clips the parameter scale to a maximum value, + which helps prevent parameters from growing without bound. + skip_nan_updates: If set, any parameter that would have been updated by a + NaN value after a applying gradients will be kept with the earlier + value it had. + """ + if not factored and factor_map is not None: + raise ValueError('Adafactor factored is False but factorization rules ' + 'have been provided.') + if not isinstance(multiply_by_parameter_scale, (bool, HParamMap)): + raise TypeError( + '`multiply_by_parameter_scale` must be either bool or `HParamMap` ' + f'type. Got {type(multiply_by_parameter_scale)}') + + if not isinstance(factor_map, (type(None), HParamMap)): + raise TypeError( + '`factor_map` must be either None or `HParamMap` type. Got ' + f'{type(factor_map)}') + + hyper_params = _AdafactorHyperParams( + learning_rate, factored, multiply_by_parameter_scale, beta1, decay_rate, + step_offset, clipping_threshold, weight_decay_rate, + min_dim_size_to_factor, epsilon1, epsilon2, factor_map, + logical_factor_rules, weight_decay_rate_lr_exponent, + global_norm_clip_threshold, max_parameter_scale, skip_nan_updates) + self.dtype_momentum = jax.dtypes.canonicalize_dtype(dtype_momentum) + super().__init__(hyper_params) + + @staticmethod + def _decay_rate_pow(i: int, exponent: float = 0.8) -> float: + """Default Adafactor second-moment decay schedule.""" + t = jnp.array(i, jnp.float32) + 1.0 + return 1.0 - t**(-exponent) + + @staticmethod + def _parse_rule( + rule: Optional[FactorRule], + shape: Sequence[int], + path: str, + fallback_to_heuristics=True + ) -> Tuple[Tuple[int, ...], Optional[Union[HeuristicRule, Tuple[Tuple[ + int, ...], Tuple[int, ...]]]]]: + """Parses specification and return factored dims and dims for averaging. + + Adafactor needs to know the two largest dimensions to factorize along. + Traditionally it used a heuristic, but we want finer control over these + factorization dimensions. Additionally, there are situations where + parameters are batched together for e.g. scanned layers and QKV fusion, + and we want to ensure that the scale updates and clipping thresholds are + calculated _within_ each array and not across the entire batched array. + + Args: + rule: the rule is either None (default to heuristic behavior) or a tuple + of the same rank as the `param` array containing a FactorDim.ROW or + FactorDim.COLUMN to mark dimensions to factorize in two row and column + sets, and optionally dimensions marked FactorDim.BATCH to denote batched + dimensions that should not be averaged over. e.g. (BATCH, ROW, COLUMN, + COLUMN) + shape: shape of the variable + path: '/' joined parameter path. + fallback_to_heuristics: whether to fallback to heuristic factorization + rule. For most cases this should be set to `True`. + + Returns: + tuple of: tuple of dimensions to average over, 2-tuple of dimensions to + factorize over. + """ + param_ndim = len(shape) + + if rule is None: + # No factorization. + return tuple(np.arange(param_ndim)), None + + if rule is HEURISTIC_RULE: + if param_ndim > 2: + raise ValueError( + f'A parameter with rank strictly higher than 2 must have an ' + f'explicit factorization rule: {path}, {shape}') + # Even if no explicit rule is provided for the param, we still want to + # average over all the dimensions for computing the RMS scale. + return tuple(np.arange(param_ndim)), HEURISTIC_RULE + + if len(rule) != param_ndim: + raise ValueError(f'Factorization rule {rule} has incorrect rank ' + f'for param of rank {param_ndim}: {path}, {shape}') + + row_dims = tuple(idx for idx, d in enumerate(rule) if d == FactorDim.ROW) + col_dims = tuple(idx for idx, d in enumerate(rule) if d == FactorDim.COLUMN) + batched_dims = tuple( + idx for idx, d in enumerate(rule) if d == FactorDim.BATCH) + averaging_dims = tuple(np.delete(np.arange(param_ndim), batched_dims)) + factor_dims = (row_dims, col_dims) + if factor_dims == ((), ()): + factor_dims = None + + if fallback_to_heuristics and param_ndim <= 2 and not batched_dims: + logging.warning( + 'Since rank of parameter %s %d is less than or equal to 2, the ' + 'factorization method falls back to heuristics and the provided ' + 'factor rule %s is ignored.', path, param_ndim, rule) + return tuple(np.arange(param_ndim)), HEURISTIC_RULE + + return averaging_dims, factor_dims + + def _factored_dims( + self, shape: Sequence[int]) -> Optional[Tuple[Tuple[int], Tuple[int]]]: + """Whether to use a factored second moment estimator. + + If there are not two dimensions of size >= min_dim_size_to_factor, then we + do not factor. If we do factor the accumulator, then this function returns a + tuple of the two largest axes to reduce over. + + Args: + shape: a Shape + + Returns: + None or a tuple of ints + """ + if not self.hyper_params.factored or len(shape) < 2: + return None + sorted_dims = np.argsort(shape) + if shape[sorted_dims[-2]] < self.hyper_params.min_dim_size_to_factor: + return None + return (int(sorted_dims[-2]),), (int(sorted_dims[-1]),) + + def init_param_state(self, param, path): + shape = param.shape + state = {k: jnp.zeros((1,)) for k in ['v_row', 'v_col', 'v', 'm']} + if self.hyper_params.factored: + factor_rule = ( + self.hyper_params.factor_map[path] + if self.hyper_params.factor_map else HEURISTIC_RULE) + else: + factor_rule = None + _, factored_dims = self._parse_rule(factor_rule, param.shape, path) + if factored_dims is HEURISTIC_RULE: + factored_dims = self._factored_dims(shape) + if factored_dims is not None: + d1, d0 = factored_dims + vr_shape = np.delete(shape, d0) + vc_shape = np.delete(shape, d1) + state['v_row'] = jnp.zeros(vr_shape, dtype=jnp.float32) + state['v_col'] = jnp.zeros(vc_shape, dtype=jnp.float32) + else: + state['v'] = jnp.zeros(param.shape, dtype=jnp.float32) + if self.hyper_params.beta1 is not None: + state['m'] = jnp.zeros(param.shape, dtype=self.dtype_momentum) + return _AdafactorParamState(**state) + + def init_state(self, params): + params_flat = utils.flatten_dict_string_keys(params) + param_states_flat = [ + self.init_param_state(param, path) + for path, param in params_flat.items() + ] + param_states_flat = { + k: v for k, v in zip(params_flat.keys(), param_states_flat) + } + param_states = _restore(params, param_states_flat) + state = OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states) + return state + + def apply_param_gradient(self, step, hyper_params, param, state, grad, path): + assert hyper_params.learning_rate is not None, 'no learning rate provided.' + learning_rate = hyper_params.learning_rate + beta1 = hyper_params.beta1 + decay_rate = hyper_params.decay_rate + step_offset = hyper_params.step_offset + multiply_by_parameter_scale = hyper_params.multiply_by_parameter_scale + max_parameter_scale = hyper_params.max_parameter_scale + clipping_threshold = hyper_params.clipping_threshold + weight_decay_rate = hyper_params.weight_decay_rate + epsilon1 = hyper_params.epsilon1 + epsilon2 = hyper_params.epsilon2 + if hyper_params.weight_decay_rate_lr_exponent: + weight_decay_rate = ( + (weight_decay_rate or 1.0) * + learning_rate**hyper_params.weight_decay_rate_lr_exponent) + + if self.hyper_params.factored: + factor_rule = ( + self.hyper_params.factor_map[path] + if self.hyper_params.factor_map else HEURISTIC_RULE) + else: + factor_rule = None + averaging_dims, factored_dims = self._parse_rule(factor_rule, param.shape, + path) + + grad = grad.astype(jnp.float32) + + updates = {k: jnp.zeros((1,)) for k in ['v_row', 'v_col', 'v', 'm']} + decay_rate = self._decay_rate_pow(step - step_offset, exponent=decay_rate) + update_scale = learning_rate + + if isinstance(multiply_by_parameter_scale, HParamMap): + multiply_by_parameter_scale = multiply_by_parameter_scale[path] + if multiply_by_parameter_scale: + param_scale = jnp.sqrt( + jnp.mean(param * param, axis=averaging_dims, keepdims=True)) + # Clip param_scale to a minimum value of epsilon2. + param_scale = jnp.maximum(param_scale, epsilon2) + # Clip param_scale to a maximum value, if specified. + if max_parameter_scale is not None: + param_scale = jnp.minimum(param_scale, max_parameter_scale) + update_scale *= param_scale + mixing_rate = 1.0 - decay_rate + + grad_sqr = grad * grad + epsilon1 + if factored_dims is HEURISTIC_RULE: + factored_dims = self._factored_dims(param.shape) + if factored_dims is not None: + d1, d0 = factored_dims + new_v_row = ( + decay_rate * state.v_row + mixing_rate * jnp.mean(grad_sqr, axis=d0)) + new_v_col = ( + decay_rate * state.v_col + mixing_rate * jnp.mean(grad_sqr, axis=d1)) + updates['v_row'] = new_v_row + updates['v_col'] = new_v_col + reduced_d1 = tuple(d - len([e for e in d0 if e < d]) for d in d1) + + row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) + row_factor = (new_v_row / row_col_mean)**-0.5 + col_factor = (new_v_col)**-0.5 + y = ( + grad * jnp.expand_dims(row_factor, axis=d0) * + jnp.expand_dims(col_factor, axis=d1)) + else: + new_v = decay_rate * state.v + mixing_rate * grad_sqr + updates['v'] = new_v + y = grad * (new_v)**-0.5 + + if clipping_threshold is not None: + clipping_denom = ( + jnp.maximum( + 1.0, + jnp.sqrt(jnp.mean(y * y, axis=averaging_dims, keepdims=True)) / + clipping_threshold)) + y /= clipping_denom + + subtrahend = update_scale * y + if beta1 is not None: + new_m = beta1 * state.m + (1.0 - beta1) * subtrahend + subtrahend = new_m + updates['m'] = new_m.astype(self.dtype_momentum) + + if weight_decay_rate is not None: + new_param = (1.0 - weight_decay_rate) * param - subtrahend + else: + new_param = param - subtrahend + + if hyper_params.skip_nan_updates: + updates['v_row'] = jnp.where( + jnp.isnan(updates['v_row']), state.v_row, updates['v_row']) + updates['v_col'] = jnp.where( + jnp.isnan(updates['v_col']), state.v_col, updates['v_col']) + updates['v'] = jnp.where(jnp.isnan(updates['v']), state.v, updates['v']) + updates['m'] = jnp.where(jnp.isnan(updates['m']), state.m, updates['m']) + new_param = jnp.where(jnp.isnan(new_param), param, new_param) + new_state = _AdafactorParamState(**updates) + + return new_param.astype(param.dtype), new_state + + def apply_gradient(self, hyper_params, params, state, grads): + """Applies a gradient for a set of parameters. + + Args: + hyper_params: a named tuple of hyper parameters. + params: the parameters that should be updated. + state: a named tuple containing the state of the optimizer + grads: the gradient tensors for the parameters. + + Returns: + A tuple containing the new parameters and the new optimizer state. + """ + step = state.step + # We assume that params, param_states, and grads are all dict-like here. + params_flat_dict = utils.flatten_dict_string_keys(params) + params_paths = params_flat_dict.keys() + params_flat = params_flat_dict.values() + # extra paranoia to guarantee identical value ordering + states_flat = utils.flatten_dict_string_keys(state.param_states) + states_flat = [states_flat[k] for k in params_paths] + grads_flat = utils.flatten_dict_string_keys(grads) + grads_flat = [grads_flat[k] for k in params_paths] + + if hyper_params.global_norm_clip_threshold: + # Paper: http://proceedings.mlr.press/v28/pascanu13.pdf + # TF: https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm + squared_l2_norms = [jnp.sum(jnp.square(g)) for g in grads_flat] + global_norm = jnp.sqrt(jnp.sum(jnp.array(squared_l2_norms))) + scale = hyper_params.global_norm_clip_threshold * jnp.minimum( + 1.0 / hyper_params.global_norm_clip_threshold, 1.0 / global_norm) + grads_flat = [g * scale for g in grads_flat] + + out = [ + self.apply_param_gradient(step, hyper_params, param, state, grad, path) + for param, state, grad, path in zip(params_flat, states_flat, + grads_flat, params_paths) + ] + + new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ()) + new_params_flat = {k: v for k, v in zip(params_paths, new_params_flat)} + new_states_flat = {k: v for k, v in zip(params_paths, new_states_flat)} + new_params = _restore(params, new_params_flat) + new_param_states = _restore(params, new_states_flat) + new_state = OptimizerState(step + 1, new_param_states) + + return new_params, new_state + + def set_param_axes(self, param_logical_axes): + """Sets Adafactor factorization map from logical axis names tree.""" + logical_factor_rules = self.hyper_params.logical_factor_rules + if logical_factor_rules is None: + return + + # pylint:disable=invalid-name + NONE = FactorDim.NONE + COLUMN = FactorDim.COLUMN + ROW = FactorDim.ROW + + # pylint:enable=invalid-name + + def apply_rules(axes): + # Partially factorized params are marked as unfactorized, preserving + # only BATCH axis annotations. We also check for incompletely factorized + # params that have ROW, COLUMN but also accidental NONE dimensions and + # raise an error in that case. + axis_rules = tuple(logical_factor_rules[x] for x in axes) + axis_rules = tuple(factor_name_to_factordim(x) for x in axis_rules) + if ROW in axis_rules and COLUMN in axis_rules and NONE in axis_rules: + raise ValueError(f'Incomplete adafactor spec {axis_rules} for {axes}!') + if ROW not in axis_rules or COLUMN not in axis_rules: + axis_rules = tuple( + NONE if x in (ROW, COLUMN) else x for x in axis_rules) + return axis_rules + + factor_map = jax.tree_map(apply_rules, param_logical_axes) + factor_map = utils.flatten_dict_string_keys(factor_map) + + self.hyper_params = self.hyper_params.replace(factor_map=factor_map) + + def derive_logical_axes(self, optimizer_state, param_logical_axes): + """Derives optimizer logical partitioning from model logical partitions.""" + optimizer_logical_axes = jax.tree_map(lambda x: None, + optimizer_state.state_dict()) + optimizer_logical_axes['target'] = param_logical_axes + + def factor_rule(logical_axes, adafactor_leaf): + return dict( + v_row=None, + v_col=None, + v=logical_axes if adafactor_leaf['v'].shape != (1,) else None, + m=logical_axes if self.hyper_params.beta1 else None) + + optimizer_logical_axes['state']['param_states'] = jax.tree_map( + factor_rule, unfreeze(param_logical_axes), + optimizer_state.state_dict()['state']['param_states']) + + return optimizer_state.restore_state(unfreeze(optimizer_logical_axes)) diff --git a/t5x/adafactor_test.py b/t5x/adafactor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b3941ddef05ba7391873d9c8ab28b86069e2a87a --- /dev/null +++ b/t5x/adafactor_test.py @@ -0,0 +1,527 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.adafactor.""" + +import functools +import operator +from typing import Sequence + +from absl.testing import absltest +from absl.testing import parameterized + +import flax +from flax import optim # used for equivalence testing only +from flax import traverse_util +import jax +from jax import numpy as jnp +from jax import random +import numpy as np + +from t5x import adafactor +from t5x import optimizers + +OptimizerState = optimizers.OptimizerState + +_AdafactorHyperParams = adafactor._AdafactorHyperParams +_AdafactorParamState = adafactor._AdafactorParamState + +_BATCH = adafactor.FactorDim.BATCH +_ROW = adafactor.FactorDim.ROW +_COL = adafactor.FactorDim.COLUMN + +# Testing helpers + + +def _assert_numpy_allclose(a, b, atol=None, rtol=None): + a, b = jnp.array(a), jnp.array(b) + a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a + b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b + kw = {} + if atol: + kw['atol'] = atol + if rtol: + kw['rtol'] = rtol + np.testing.assert_allclose(a, b, **kw) + + +def check_eq(xs, ys, atol=None, rtol=None): + xs_leaves, xs_tree = jax.tree_flatten(xs) + ys_leaves, ys_tree = jax.tree_flatten(ys) + assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}" + assert jax.tree_util.tree_all( + jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape, + xs_leaves, ys_leaves)), "Leaves' shapes don't match." + assert jax.tree_multimap( + functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol), + xs_leaves, ys_leaves) + + +def flattened_state_dict(x): + s = flax.serialization.to_state_dict(x) + return flax.traverse_util.flatten_dict(s, sep='/') + + +def tree_shape(x): + return jax.tree_map(jnp.shape, x) + + +def tree_equals(x, y): + return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y)) + + +def _get_multi_adafactor( + learning_rate: float, step_offset: int, + adafactor_exclude_from_parameter_scale: Sequence[str] +) -> optim.MultiOptimizer: + """Get adafactor with support for excluding some parameters from scaling.""" + + def _should_not_scale(path): + return any([s in path for s in adafactor_exclude_from_parameter_scale]) + + scaled_vars = traverse_util.ModelParamTraversal( + lambda path, _: not _should_not_scale(path)) + unscaled_vars = traverse_util.ModelParamTraversal( + lambda path, _: _should_not_scale(path)) + scaled_opt = optim.Adafactor( + learning_rate, decay_rate=0.8, step_offset=step_offset) + unscaled_opt = optim.Adafactor( + learning_rate, + decay_rate=0.8, + step_offset=step_offset, + multiply_by_parameter_scale=False) + return optim.MultiOptimizer((scaled_vars, scaled_opt), + (unscaled_vars, unscaled_opt)) + + +# Inline test data + +MODEL_SHAPE = { + 'decoder': { + 'decoder_norm': {'scale': [128]}, + 'layers_0': { + 'encoder_decoder_attention': { + 'key': {'kernel': [128, 256]}, + 'out': {'kernel': [256, 128]}, + 'query': {'kernel': [128, 256]}, + 'value': {'kernel': [128, 256]}}, + 'mlp': { + 'wi': {'kernel': [128, 512]}, + 'wo': {'kernel': [512, 128]}}, + 'pre_cross_attention_layer_norm': {'scale': [128]}, + 'pre_mlp_layer_norm': {'scale': [128]}, + 'pre_self_attention_layer_norm': {'scale': [128]}, + 'self_attention': { + 'key': {'kernel': [128, 256]}, + 'out': {'kernel': [256, 128]}, + 'query': {'kernel': [128, 256]}, + 'value': {'kernel': [128, 256]}}}, + 'layers_1': { + 'encoder_decoder_attention': { + 'key': {'kernel': [128, 128]}, + 'out': {'kernel': [128, 128]}, + 'query': {'kernel': [128, 128]}, + 'value': {'kernel': [128, 128]}}, + 'mlp': { + 'wi': {'kernel': [128, 512]}, + 'wo': {'kernel': [512, 128]}}, + 'pre_cross_attention_layer_norm': {'scale': [128]}, + 'pre_mlp_layer_norm': {'scale': [128]}, + 'pre_self_attention_layer_norm': {'scale': [128]}, + 'self_attention': { + 'key': {'kernel': [128, 256]}, + 'out': {'kernel': [256, 128]}, + 'query': {'kernel': [128, 256]}, + 'value': {'kernel': [128, 256]}}}, + 'relpos_bias': {'rel_embedding': [2, 32]}}, + 'encoder': { + 'encoder_norm': {'scale': [128]}, + 'layers_0': { + 'attention': { + 'key': {'kernel': [128, 256]}, + 'out': {'kernel': [256, 128]}, + 'query': {'kernel': [128, 256]}, + 'value': {'kernel': [128, 256]}}, + 'mlp': { + 'wi': {'kernel': [128, 512]}, + 'wo': {'kernel': [512, 128]}}, + 'pre_attention_layer_norm': {'scale': [128]}, + 'pre_mlp_layer_norm': {'scale': [128]}}, + 'layers_1': { + 'attention': { + 'key': {'kernel': [128, 256]}, + 'out': {'kernel': [256, 128]}, + 'query': {'kernel': [128, 256]}, + 'value': {'kernel': [128, 256]}}, + 'mlp': { + 'wi': {'kernel': [128, 512]}, + 'wo': {'kernel': [512, 128]}}, + 'pre_attention_layer_norm': {'scale': [128]}, + 'pre_mlp_layer_norm': {'scale': [128]}}, + 'relpos_bias': {'rel_embedding': [2, 32]}}, + 'token_embedder': {'embedding': [32128, 128]}} # pyformat: disable + + +class AdafactorTest(parameterized.TestCase): + + # Classic Adafactor Behavior Tests + + def test_2D_simple(self): + x = {'a': jnp.ones((24, 16))} + opt_def = adafactor.Adafactor(min_dim_size_to_factor=8) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)} + self.assertTrue(tree_equals(shapes, ref)) + + def test_2D_simple_nofactor(self): + x = {'a': jnp.ones((24, 16))} + opt_def = adafactor.Adafactor(min_dim_size_to_factor=32) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + ref = {'a/m': (1,), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)} + self.assertTrue(tree_equals(shapes, ref)) + + def test_2D_simple_nofactor_momentum(self): + x = {'a': jnp.ones((24, 16))} + opt_def = adafactor.Adafactor(min_dim_size_to_factor=32, beta1=0.1) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + ref = {'a/m': (24, 16), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)} + self.assertTrue(tree_equals(shapes, ref)) + + def test_3D_simple(self): + x = {'a': jnp.ones((24, 4, 16))} + factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),)) + opt_def = adafactor.Adafactor( + min_dim_size_to_factor=8, factor_map=factor_map) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)} + self.assertTrue(tree_equals(shapes, ref)) + + def test_init_state(self): + params = {'x': np.zeros((3, 2))} + optimizer_def = adafactor.Adafactor( + learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=0) + state = optimizer_def.init_state(params) + + expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None, 0.8, 0, + 1.0, None, 0, 1e-30, 1e-3) + self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) + expected_state = OptimizerState( + 0, { + 'x': + _AdafactorParamState( + np.zeros((2,)), np.zeros((3,)), np.zeros( + (1,)), np.zeros((1,))) + }) + check_eq(state, expected_state) + + # unfactorized + optimizer_def = adafactor.Adafactor( + learning_rate=0.1, decay_rate=0.8, beta1=0.0, min_dim_size_to_factor=32) + state = optimizer_def.init_state(params) + + expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0, 0.8, 0, + 1.0, None, 32, 1e-30, 1e-3) + self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) + expected_state = OptimizerState( + 0, { + 'x': + _AdafactorParamState( + np.zeros((1,)), np.zeros((1,)), np.zeros( + (3, 2)), np.zeros((3, 2))) + }) + check_eq(state, expected_state) + + def test_apply_gradient(self): + optimizer_def = adafactor.Adafactor( + learning_rate=0.1, decay_rate=0.8, min_dim_size_to_factor=0) + params = {'x': np.ones((3, 2), np.float32)} + state = OptimizerState( + 1, { + 'x': + _AdafactorParamState( + np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]), + np.zeros((1,)), np.zeros((1,))) + }) + grads = {'x': np.ones((3, 2), np.float32)} + new_params, new_state = optimizer_def.apply_gradient( + optimizer_def.hyper_params, params, state, grads) + expected_new_state = OptimizerState( + 2, { + 'x': + _AdafactorParamState( + np.array([0.9574349, 0.9574349]), + np.array([0.6169143, 0.6169143, 0.6169143]), np.zeros( + (1,)), np.zeros((1,))) + }) + expected_new_params = {'x': 0.9 * np.ones((3, 2))} + check_eq(new_params, expected_new_params) + check_eq(new_state, expected_new_state, rtol=1e-6) + + # unfactored w momentum + optimizer_def = adafactor.Adafactor( + learning_rate=0.1, beta1=0.0, decay_rate=0.8, min_dim_size_to_factor=32) + params = {'x': np.ones((3, 2), np.float32)} + state = OptimizerState( + 1, { + 'x': + _AdafactorParamState( + np.zeros(1,), np.zeros(1,), 0.5 * np.ones( + (3, 2)), np.zeros((3, 2))) + }) + grads = {'x': np.ones((3, 2), np.float32)} + new_params, new_state = optimizer_def.apply_gradient( + optimizer_def.hyper_params, params, state, grads) + expected_new_params = {'x': 0.9 * np.ones((3, 2))} + check_eq(new_params, expected_new_params) + expected_new_state = OptimizerState( + 2, { + 'x': + _AdafactorParamState( + np.array([0.0]), np.array([0.0]), 0.787174 * np.ones( + (3, 2)), 0.1 * np.ones((3, 2))) + }) + check_eq(new_state, expected_new_state, rtol=1e-6) + + def test_apply_gradient_with_global_norm_clipping(self): + optimizer_def = adafactor.Adafactor( + learning_rate=0.1, + decay_rate=0.8, + min_dim_size_to_factor=0, + global_norm_clip_threshold=1.0) + params = {'x': np.ones((3, 2), np.float32)} + state = OptimizerState( + 1, { + 'x': + _AdafactorParamState( + np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]), + np.zeros((1,)), np.zeros((1,))) + }) + grads = {'x': np.ones((3, 2), np.float32)} + new_params, new_state = optimizer_def.apply_gradient( + optimizer_def.hyper_params, params, state, grads) + expected_new_state = OptimizerState( + 2, { + 'x': + _AdafactorParamState( + np.array([0.478811, 0.478811]), + np.array([0.13829, 0.13829, 0.13829]), np.zeros( + (1,)), np.zeros((1,))) + }) + expected_new_params = {'x': 0.9 * np.ones((3, 2))} + check_eq(new_params, expected_new_params) + check_eq(new_state, expected_new_state, rtol=1e-6) + + def test_factorizes(self): + params = {'x': np.zeros((64, 64))} + optimizer_def = adafactor.Adafactor( + learning_rate=0.1, + decay_rate=0.8, + beta1=None, + min_dim_size_to_factor=32) + state = optimizer_def.init_state(params) + self.assertEqual(state.param_states['x'].v.shape, (1,)) + self.assertEqual(state.param_states['x'].m.shape, (1,)) + self.assertEqual(state.param_states['x'].v_row.shape, (64,)) + self.assertEqual(state.param_states['x'].v_col.shape, (64,)) + + params = {'x': np.zeros((31, 64))} + optimizer_def = adafactor.Adafactor( + learning_rate=0.1, + decay_rate=0.8, + beta1=None, + min_dim_size_to_factor=32) + state = optimizer_def.init_state(params) + self.assertEqual(state.param_states['x'].v.shape, (31, 64)) + self.assertEqual(state.param_states['x'].m.shape, (1,)) + self.assertEqual(state.param_states['x'].v_row.shape, (1,)) + self.assertEqual(state.param_states['x'].v_col.shape, (1,)) + + # Manually specified factorization rules tests. + + @parameterized.parameters( + {'rule': (_ROW, _COL)}, + {'rule': (_COL, _ROW)}, + ) + def test_2D_ignore_specified_factor_rule(self, rule): + x = {'a': jnp.ones((24, 16))} + factor_map = adafactor.HParamMap((('a', rule),)) + opt_def = adafactor.Adafactor( + min_dim_size_to_factor=8, factor_map=factor_map) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + # Since param is 2D, the explicit factor rule should be ignored and falls + # back to heuristics where v_row corresponds to the smaller dim. + ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)} + self.assertTrue(tree_equals(shapes, ref)) + + def test_3D_simple_manual_rules(self): + x = {'a': jnp.ones((24, 4, 16))} + + factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),)) + opt_def = adafactor.Adafactor( + min_dim_size_to_factor=8, factor_map=factor_map) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)} + self.assertTrue(tree_equals(shapes, ref)) + + factor_map = adafactor.HParamMap((('a', (_ROW, _BATCH, _COL)),)) + opt_def = adafactor.Adafactor( + min_dim_size_to_factor=8, factor_map=factor_map) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (4, 16), 'a/v_row': (24, 4)} + self.assertTrue(tree_equals(shapes, ref)) + + factor_map = adafactor.HParamMap((('a', (_COL, _ROW, _ROW)),)) + opt_def = adafactor.Adafactor( + min_dim_size_to_factor=8, factor_map=factor_map) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (4, 16)} + self.assertTrue(tree_equals(shapes, ref)) + + factor_map = adafactor.HParamMap((('a', (_COL, _COL, _ROW)),)) + opt_def = adafactor.Adafactor( + min_dim_size_to_factor=8, factor_map=factor_map) + optimizer = opt_def.create(x) + shapes = tree_shape(flattened_state_dict(optimizer.state.param_states)) + ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (16,)} + self.assertTrue(tree_equals(shapes, ref)) + + def test_standard_factor_rules(self): + # one-off test to double-check that we're following the previous + # heuristic convention for rows/columns. + def test_standard_factor_rules(): + token_embedding = (_COL, _ROW) + attn_qkv = (_ROW, _COL) + attn_out = (_COL, _ROW) + mlp_in = (_ROW, _COL) + mlp_out = (_COL, _ROW) + return ((r'_layer_norm/(bias|scale)', + None), (r'(encoder|decoder)_norm/(bias|scale)', None), + (r'(encoder_decoder_|self_|\b)attention/(query|key|value)/kernel', + attn_qkv), (r'(encoder_decoder_|self_|\b)attention/out/kernel', + attn_out), (r'mlp/DenseGeneral_\d+/bias', None), + (r'mlp/wi(_\d+)?/kernel', mlp_in), (r'mlp/wo/kernel', mlp_out), + (r'\brelpos_bias', None), (r'token_embedder', token_embedding), + (r'.*', adafactor.HEURISTIC_RULE)) + + # create fake model parameters + k = jax.random.PRNGKey(0) + params = jax.tree_map( + lambda shape: jax.random.uniform(k, shape), + MODEL_SHAPE, + is_leaf=lambda x: isinstance(x, list)) + # make traditional adafactor state with heuristic + factor_map1 = adafactor.HParamMap(((r'.*', adafactor.HEURISTIC_RULE),)) + optimizer_def1 = adafactor.Adafactor( + 0.1, + decay_rate=0.8, + step_offset=0, + multiply_by_parameter_scale=True, + factor_map=factor_map1) + optimizer1 = optimizer_def1.create(params) + # make traditional adafactor state with explicit rules + factor_map2 = adafactor.HParamMap(test_standard_factor_rules()) + optimizer_def2 = adafactor.Adafactor( + 0.1, + decay_rate=0.8, + step_offset=0, + multiply_by_parameter_scale=True, + factor_map=factor_map2) + optimizer2 = optimizer_def2.create(params) + # are they the same? + check_eq(optimizer1.state.param_states, optimizer2.state.param_states) + + @parameterized.parameters( + {'shape': (64, 64)}, + {'shape': (64, 132)}, + {'shape': (132, 64)}, + {'shape': (132, 132)}, + {'shape': (132, 140)}, + {'shape': (140, 132)}, + ) + def test_no_factor_map_equivalence(self, shape): + k = random.PRNGKey(0) + k1, k2 = random.split(k) + p = {'a': random.uniform(k1, shape)} + g = {'a': random.uniform(k2, shape)} + + orig_opt = optim.Adafactor(0.1).create(p) + new_opt = adafactor.Adafactor(0.1, factor_map=None).create(p) + check_eq(orig_opt.state_dict(), new_opt.state_dict()) + + orig_opt1 = orig_opt.apply_gradient(g) + new_opt1 = new_opt.apply_gradient(g) + check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) + + @parameterized.parameters({ + 'shape': (128, 128), + 'rule': (_ROW, _COL) + }, { + 'shape': (132, 128), + 'rule': (_COL, _ROW) + }, { + 'shape': (128, 132), + 'rule': (_ROW, _COL) + }) + def test_simple_equivalence(self, shape, rule): + k = random.PRNGKey(0) + k1, k2 = random.split(k) + k3, k4 = random.split(k1) + k5, k6 = random.split(k2) + + p = {'a': random.uniform(k3, shape), 'b': random.uniform(k4, shape)} + g = {'a': random.uniform(k5, shape), 'b': random.uniform(k6, shape)} + + orig_opt = optim.Adafactor(0.1).create(p) + factor_map = adafactor.HParamMap( + rules=((('a'), rule), ('.*', adafactor.HEURISTIC_RULE))) + new_opt = adafactor.Adafactor(0.1, factor_map=factor_map).create(p) + check_eq(orig_opt.state_dict(), new_opt.state_dict()) + + orig_opt1 = orig_opt.apply_gradient(g) + new_opt1 = new_opt.apply_gradient(g) + check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) + + @parameterized.parameters({'shape': (64, 64)}, {'shape': (132, 132)}) + def test_multiply_by_parameter_scale_equivalence(self, shape): + # Use large parameter values to magnify the parameter scaling effect. + p = {'a': np.random.randn(*shape) * 100, 'b': np.random.randn(*shape) * 100} + g = {'a': np.random.randn(*shape), 'b': np.random.randn(*shape)} + orig_opt = _get_multi_adafactor( + 3.0, 0, adafactor_exclude_from_parameter_scale=('a',)).create(p) + scaling_map = adafactor.HParamMap([('a', False), ('.*', True)]) + new_opt = adafactor.Adafactor( + 3.0, multiply_by_parameter_scale=scaling_map).create(p) + check_eq(orig_opt.state_dict(), new_opt.state_dict()) + + orig_opt1 = orig_opt.apply_gradient(g) + new_opt1 = new_opt.apply_gradient(g) + check_eq(orig_opt1.state_dict(), new_opt1.state_dict()) + + def test_3d_without_factor_map(self): + x = {'a': jnp.ones((24, 4, 16))} + opt_def = adafactor.Adafactor(factor_map=None) + with self.assertRaises(ValueError): + _ = opt_def.create(x) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/checkpoint_importer.py b/t5x/checkpoint_importer.py new file mode 100644 index 0000000000000000000000000000000000000000..81193a7328e89d4be621ccbe2d6ceee539d62681 --- /dev/null +++ b/t5x/checkpoint_importer.py @@ -0,0 +1,485 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5 Checkpoint Importer.""" + +import asyncio +from concurrent.futures import thread +import re +from typing import Any, Callable, Mapping, MutableMapping, Optional, Union + +from flax import traverse_util +import jax +from jax import numpy as jnp +import numpy as np +import orbax.checkpoint +import tensorflow as tf +import tensorstore as ts + +# TODO(b/233659813): Cleanup clients depending on t5x.checkpoint_importer for +# LazyArray. Reconcile divergence in subclass implementation when possible. +LazyArray = orbax.checkpoint.lazy_array.LazyArray + + +# TODO(brianlester): The choice between using a `LazyTreadPoolArray` or a +# `LazyAwaitableArray` is dependent on if the user provided `get_fn` is blocking +# or async respectively, if we can detect which it is, we can automatically +# proxy to the correct subclass. We cannot detect of `get_fn` is a lambda that +# wraps an async call so this isn't possible yet. Add this dispatch once we are +# able to detect that, python3.8+ can detect async for partial'ed functions but +# not lambdas. +class LazyThreadPoolArray(LazyArray): + """Lazily and asynchronously loads an array when the `get_fn` blocks.""" + + # Uses a global threadpool to enable asynchronous loading. + executor = thread.ThreadPoolExecutor() + + def get_async(self) -> asyncio.Future: + return asyncio.wrap_future(self.executor.submit(self.get)) + + def get(self) -> np.ndarray: + arr = self._get_fn() + if arr.dtype != self.dtype: + arr = arr.astype(self.dtype) + return arr + + +class LazyAwaitableArray(LazyArray): + """Lazily and asynchronously loads an array when the `get_fn` is async. + + Note: + The synchronous load method `.get` requires the asyncio event loop and + calling `.run_until_complete`. This is not supported when the event loop is + already running (for example, from inside another async function). + + Note: + Currently, this class has a few helper methods for creating a + LazyAwaitableArray when the input could be either an array, or a TensorStore + spec. Most people use async code when dealing with TensorStore so the + classmethods have been placed here. When someone eventually uses a blocking + function to read from TensorStore they can be moved to the LazyArray base + class. + """ + + def get_async(self) -> asyncio.Future: + + async def _get_and_cast(): + # Pytype has a false positive here, where it treats our _get_fn (_read_ts + # in this case) as having a return type of `np.ndarray` instead of + # wrapping it in an Awaitable. Related to this bug + # https://github.com/google/pytype/issues/527 + arr = await self._get_fn() # pytype: disable=bad-return-type + if arr.dtype != self.dtype: + arr = arr.astype(self.dtype) + return arr + + return asyncio.ensure_future(_get_and_cast()) + + def get(self) -> np.ndarray: + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.get_async()) + + @classmethod + def from_tensor_store_spec( + cls, + ts_spec: ts.Spec, + get_fn: Callable[[], np.ndarray], + dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray': + """Create a LazyAwaitableArray based on a tensorstore.Spec.""" + ts_spec = ts_spec.to_json() + shape = ts_spec['metadata']['shape'] + if dtype is None: + dtype = jnp.dtype(ts_spec['dtype']) + else: + dtype = jnp.dtype(dtype) + # v2 T5X checkpoints use uint16 as the TensorStore datatype and then store + # the bfloat16 bytes as in in the 16 bytes uint16 has (no actual cast). When + # When reading the dtype from the TensorStore, if we keep the dtype of these + # v2 checkpoints as np.uint16 then the _get_fn (which has a possible cast to + # support the `restore_dtype` parameter for the checkpointer) will actually + # cast the bfloat16 values to uint16, generally resulting in an array of all + # zeros. This check avoid the actual cast to uint16 by replacing the dtype. + if dtype == np.uint16: + dtype = jnp.bfloat16 + return cls(shape, dtype, get_fn) + + @classmethod + def from_array(cls, + array: np.ndarray, + get_fn: Callable[[], np.ndarray], + dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray': + """Create a LazyAwaitableArray based on an array or python number.""" + if dtype is None: + dtype = array.dtype + else: + dtype = jnp.dtype(dtype) + return cls(array.shape, dtype, get_fn) + + @classmethod + def from_tensor_store_spec_or_array( + cls, + maybe_ts_spec: Union[ts.Spec, np.ndarray], + get_fn: Callable[[], np.ndarray], + dtype: Optional[jnp.dtype] = None) -> 'LazyAwaitableArray': + """Create a LazyAwaitableArray based on an array or a tensorstore.Spec.""" + if isinstance(maybe_ts_spec, ts.Spec): + return cls.from_tensor_store_spec(maybe_ts_spec, get_fn, dtype=dtype) + return cls.from_array(maybe_ts_spec, get_fn, dtype=dtype) + + +class CheckpointTranslator: + """Utility class for defining mapping rules from one flatdict to another. + + We assume a checkpoint is loaded as a dictionary with flattened keys of the + form: 'name0/name1/name2/.../nameN' + + A rule is added with the 'add' decorator, which takes a regex matching rule + and wraps a conversion function, feeding it (opts, key, val, **regex_groups) + where opts is a dict containing apply-time keyword options for use by the + conversion functions. + """ + + def __init__(self): + self.rules = [] + + def add(self, pattern): + """Adds a new keyval conversion rule. + + Args: + pattern: regex with capture groups for matching given sets of model + variables. We terminate all regexes with '$' to force complete matches. + + Returns: + Translation function decorator for associating with the provided + pattern. + """ + + def register_translation_fn_decorator(fn): + # We force a complete match by adding end-of-string match. + self.rules.append((re.compile(pattern + '$'), fn)) + return fn + + return register_translation_fn_decorator + + def apply(self, flatdict, **opts): + """Applies rules to a flattened dictionary. + + Args: + flatdict: flat-key dictionary of variables. + **opts: additional config options for translation rules supplied at + application time. + + Returns: + Checkpoint data with translated key/values in flat-key dict format. + """ + new_dict = {} + unmatched = {} + for k, v in flatdict.items(): + matched = False + for rule_pat, rule_fn in self.rules: + if rule_pat.match(k): + groups = rule_pat.match(k).groups() + new_k, new_v = rule_fn(opts, k, v, *groups) + if new_k is not None: + new_dict[new_k] = new_v + matched = True + break + if not matched: + unmatched[k] = v + + # We force every key-value pair in checkpoint to have a rule associated with + # it. + if unmatched: + raise ValueError('Unmapped tensor keys exist: %s' % unmatched) + + return new_dict + + +# Create a translation rule set for importing T5 & T5.1.1 model checkpoints. +# ----------------------------------------------------------------------------- +t5_importer = CheckpointTranslator() + +# Name mappings. +SLOT_MAP = {'_slot_vc': 'v_col', '_slot_vr': 'v_row', '_slot_v': 'v'} +TOWER_MAP = {'transformer': 'decoder'} + + +@t5_importer.add(r'global_step') +def global_step(opts, key, val): + del opts, key + return 'state/step', val.astype(np.int32).get() if isinstance( + val, LazyArray) else val + + +@t5_importer.add(r'shared/embedding(\w*)') +def shared_embeddings(opts, key, val, slot): + del opts, key + prefix = 'state/param_states' if slot else 'target' + suffix = '/' + SLOT_MAP[slot] if slot else '' + newkey = f'{prefix}/token_embedder/embedding{suffix}' + return newkey, val + + +@t5_importer.add(r'(encoder|decoder|transformer)/embedding(\w*)') +def separate_embeddings(opts, key, val, encdec, slot): + del opts, key + prefix = 'state/param_states' if slot else 'target' + suffix = '/' + SLOT_MAP[slot] if slot else '' + encdec = TOWER_MAP.get(encdec, encdec) + newkey = f'{prefix}/{encdec}/token_embedder/embedding{suffix}' + return newkey, val + + +# In the Mesh TensorFlow T5 code, relative_attention_bias always occurs in layer +# 0 because SelfAttention precedes other sublayers within the same block. +@t5_importer.add( + r'(encoder|decoder|transformer)/block_(\d+)/layer_000/SelfAttention/relative_attention_bias(\w*)' +) +def rel_embeddings(opts, key, val, encdec, blocknum, slot): + """Process relpos bias assuming that they are not shared across layers.""" + del opts, key + prefix = 'state/param_states' if slot else 'target' + suffix = '/' + SLOT_MAP[slot] if slot else '' + blocknum = int(blocknum) + encdec = TOWER_MAP.get(encdec, encdec) + # At this point, we can't determine whether the relpos bias was shared across + # layers or not. We first assume that it was not shared. During post + # processing, we remove the layers_0 scope if it was shared. + newkey = f'{prefix}/{encdec}/layers_{blocknum}/relpos_bias/rel_embedding{suffix}' + return newkey, val + + +@t5_importer.add( + r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/(SelfAttention|EncDecAttention)/(q|k|v|o)(\w*)' +) +def attention_layers(opts, key, val, encdec, blocknum, attntype, qkvo, slot): + """Process attention layers.""" + del opts, key + prefix = 'state/param_states' if slot else 'target' + suffix = '/' + SLOT_MAP[slot] if slot else '' + blocknum = int(blocknum) + encdec = TOWER_MAP.get(encdec, encdec) + matrix = {'q': 'query', 'k': 'key', 'v': 'value', 'o': 'out'}[qkvo] + + if encdec == 'encoder': + attntype = 'attention' + else: + attntype = { + 'SelfAttention': 'self_attention', + 'EncDecAttention': 'encoder_decoder_attention' + }[attntype] + newkey = f'{prefix}/{encdec}/layers_{blocknum}/{attntype}/{matrix}/kernel{suffix}' + return newkey, val + + +@t5_importer.add( + r'(encoder|decoder|transformer)/block_(\d+)/layer_\d+/DenseReluDense/(wi|wo)(?:_(\d+))?/kernel(\w*)' +) +def mlpblock(opts, key, val, encdec, blocknum, io_name, io_num, slot): + """Process MLP blocks.""" + del opts, key + prefix = 'state/param_states' if slot else 'target' + suffix = '/' + SLOT_MAP[slot] if slot else '' + blocknum = int(blocknum) + encdec = TOWER_MAP.get(encdec, encdec) + io_num = f'_{io_num}' if io_num else '' + newkey = f'{prefix}/{encdec}/layers_{blocknum}/mlp/{io_name}{io_num}/kernel{suffix}' + return newkey, val + + +@t5_importer.add( + r'(encoder|decoder|transformer)/block_(\d+)/layer_(\d+)/(?:layer|rms)_norm/scale(\w*)' +) +def layernorms(opts, key, val, encdec, blocknum, lyrnum, slot): + """Process layer norms assuming that they are pre-layernorms.""" + del opts, key + prefix = 'state/param_states' if slot else 'target' + suffix = '/' + SLOT_MAP[slot] if slot else '' + lyrnum = int(lyrnum) + + if encdec == 'transformer': + layernorm_type = ['pre_self_attention_layer_norm', + 'pre_mlp_layer_norm'][lyrnum] + + elif encdec == 'encoder': + layernorm_type = ['pre_attention_layer_norm', 'pre_mlp_layer_norm'][lyrnum] + else: # decoder + layernorm_type = [ + 'pre_self_attention_layer_norm', 'pre_cross_attention_layer_norm', + 'pre_mlp_layer_norm' + ][lyrnum] + + encdec = TOWER_MAP.get(encdec, encdec) + newkey = f'{prefix}/{encdec}/layers_{int(blocknum)}/{layernorm_type}/scale{suffix}' + return newkey, val + + +@t5_importer.add( + r'(encoder|decoder|transformer)/(?:final_layer|rms)_norm/scale(\w*)') +def final_layernorms(opts, key, val, encdec, slot): + """Process final layer norms.""" + del opts, key + prefix = 'state/param_states' if slot else 'target' + suffix = '/' + SLOT_MAP[slot] if slot else '' + norm = { + 'encoder': 'encoder_norm', + 'decoder': 'decoder_norm', + 'transformer': 'decoder_norm' + }[encdec] + encdec = TOWER_MAP.get(encdec, encdec) + newkey = f'{prefix}/{encdec}/{norm}/scale{suffix}' + return newkey, val + + +@t5_importer.add(r'(?:decoder|transformer)/logits/kernel(\w*)') +def final_logits(opts, key, val, slot): + del opts, key + prefix = 'state/param_states' if slot else 'target' + suffix = '/' + SLOT_MAP[slot] if slot else '' + newkey = f'{prefix}/decoder/logits_dense/kernel{suffix}' + return newkey, val + + +def _add_missing_param_states(t5_data): + """Add dummy slots that Flax Adafactor requires but TF does not.""" + updates = {} + for k in t5_data: + if k.startswith('target'): + state_leaf = 'state/param_states' + k[len('target'):] + updates[state_leaf + '/m'] = np.zeros((1,), np.float32) + if state_leaf + '/v' in t5_data: + updates[state_leaf + '/v_row'] = np.zeros((1,), np.float32) + updates[state_leaf + '/v_col'] = np.zeros((1,), np.float32) + elif state_leaf + '/v_row' in t5_data: + updates[state_leaf + '/v'] = np.zeros((1,), np.float32) + t5_data.update(**updates) + return t5_data + + +def _maybe_correct_relpos_bias(t5_data): + """Correct the relpos_bias format if it is shared across layers.""" + max_layer_ind = 0 + for k, v in t5_data.items(): + match = re.search(r'layers_(\d+)/relpos_bias', k) + if match: + layer_ind = int(match.groups()[0]) + max_layer_ind = max(max_layer_ind, layer_ind) + + modified_dict = {} + if max_layer_ind == 0: + # Relative position biases are shared across layers + for k, v in t5_data.items(): + new_k = re.sub(r'layers_\d+/relpos_bias', 'relpos_bias', k) + modified_dict[new_k] = v + else: + # Relative position biases are unique in each layer. No more processing is + # necessary. + modified_dict = t5_data + + return modified_dict + + +# Load checkpoint, translate, and update flax optimizer and model. +# ----------------------------------------------------------------------------- +def load_tf_ckpt(path): + """Load a TF checkpoint as a flat dictionary of numpy arrays.""" + ckpt_reader = tf.train.load_checkpoint(path) + ckpt_shape_map = ckpt_reader.get_variable_to_shape_map() + ckpt_dtype_map = ckpt_reader.get_variable_to_dtype_map() + datamap = { # pylint: disable=g-complex-comprehension + k: LazyThreadPoolArray( + s, + jnp.dtype(ckpt_dtype_map[k].as_numpy_dtype), + lambda x=k: ckpt_reader.get_tensor(x)) + for k, s in ckpt_shape_map.items() + } + return datamap + + +def _update_state_dict(state_dict: Mapping[str, Any], + t5_data: MutableMapping[str, LazyArray], + strict: bool = True) -> Mapping[str, Any]: + """Update flax optimizer for T5 model. + + Args: + state_dict: Optimizer to update with T5 parameters. + t5_data: T5 model parameters, typically loaded from a checkpoint. + strict: If True requires that optimizer and t5_data mappings contain the + same set of names (variables). If False, updating will succeed even if + t5_data contains variables not in the optimizer. If the optimizer has + variables not in t5_data, this function will still fail. + + Returns: + Updated optimizer. + """ + flat_state_dict = traverse_util.flatten_dict(state_dict, sep='/') + + # Remove parameters from the checkpoint not found in the optimizer (this + # allows us to load checkpoints that contain more parameters than our current + # model). + if not strict: + for k in list(t5_data): + if k not in flat_state_dict: + t5_data.pop(k) + + # Shape check. + for k, v in t5_data.items(): + if flat_state_dict[k].shape != v.shape: + raise ValueError( + f'Variable {k} has shape {v.shape} != {flat_state_dict[k].shape}') + flat_state_dict = t5_data + state_dict = traverse_util.unflatten_dict( + {tuple(k.split('/')): v for k, v in flat_state_dict.items()}) + return state_dict + + +def restore_from_t5_checkpoint( + state_dict: Mapping[str, Any], + path: str, + lazy_parameters: bool = False, + strict: bool = True, + translator: Optional[CheckpointTranslator] = None) -> Mapping[str, Any]: + """Load T5 checkpoint and update Adafactor optimizer and T5 model from it. + + We require that the final translated checkpoint structure exactly matches + that of the Flax Adafactor + Transformer data, up to shape agreement of + the leaves. + + Args: + state_dict: Flax Adafactor Optimizer for T5 transformer encoder-decoder. + path: a path to checkpoint file or directory. + lazy_parameters: whether to leave the parameters as LazyArrays to preserve + memory. + strict: If True requires that optimizer and t5_data mappings contain the + same set of names (variables). If False, updating will succeed even if + t5_data contains variables not in the optimizer. If the optimizer has + variables not in t5_data, this function will still fail. + translator: The mapping rules for conversion. If None, then default T5 + conversion rules will be used. + + Returns: + Adafactor optimizer updated with parameters and optimizer state from + T5 checkpoint. + """ + if translator is None: + translator = t5_importer + ckpt_data = load_tf_ckpt(path) + t5_data = translator.apply(ckpt_data) + t5_data = _add_missing_param_states(t5_data) + t5_data = _maybe_correct_relpos_bias(t5_data) + state_dict = _update_state_dict(state_dict, t5_data, strict=strict) + if not lazy_parameters: + state_dict = jax.tree_map( + lambda x: x.get() if isinstance(x, LazyArray) else x, state_dict) + return state_dict diff --git a/t5x/checkpoint_importer_test.py b/t5x/checkpoint_importer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2b40fb84d1ab1b6e240b247961f9287340398b1c --- /dev/null +++ b/t5x/checkpoint_importer_test.py @@ -0,0 +1,81 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.checkpoint_importer.""" + +import json +import os + +from absl import flags +from absl.testing import absltest +import jax +import numpy as np +from t5x import checkpoint_importer +import tensorflow as tf + + +class CheckpointImporterTest(absltest.TestCase): + + def test_rel_embeddings_shared_layers(self): + # This represents a ckpt where the Mesh TensorFlow's + # transformer_layers.SelfAttention.relative_attention_type = "bias_shared", + # i.e., the same relative attention parameters are shared by all layers + # within the (en|de)coder. + ckpt_data = { + 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': + 1, + 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': + 2, + 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v': + 3, + } + t5_data = checkpoint_importer.t5_importer.apply(ckpt_data) + t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data) + expected = { + 'target/encoder/relpos_bias/rel_embedding': 1, + 'target/decoder/relpos_bias/rel_embedding': 2, + 'state/param_states/decoder/relpos_bias/rel_embedding/v': 3, + } + self.assertEqual(t5_data, expected) + + def test_rel_embeddings_per_layer(self): + # This represents a ckpt where the Mesh TensorFlow's + # transformer_layers.SelfAttention.relative_attention_type = "bias", i.e., + # each layer has its own relative attention parameters. + ckpt_data = { + 'encoder/block_000/layer_000/SelfAttention/relative_attention_bias': + 1, + 'encoder/block_001/layer_000/SelfAttention/relative_attention_bias': + 2, + 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias': + 3, + 'decoder/block_000/layer_000/SelfAttention/relative_attention_bias_slot_v': + 4, + 'decoder/block_011/layer_000/SelfAttention/relative_attention_bias': + 5 + } + t5_data = checkpoint_importer.t5_importer.apply(ckpt_data) + t5_data = checkpoint_importer._maybe_correct_relpos_bias(t5_data) + expected = { + 'target/encoder/layers_0/relpos_bias/rel_embedding': 1, + 'target/encoder/layers_1/relpos_bias/rel_embedding': 2, + 'target/decoder/layers_0/relpos_bias/rel_embedding': 3, + 'state/param_states/decoder/layers_0/relpos_bias/rel_embedding/v': 4, + 'target/decoder/layers_11/relpos_bias/rel_embedding': 5, + } + self.assertEqual(t5_data, expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/checkpoint_utils.py b/t5x/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef15ab7c10c511d454221abb87916e8e1677519 --- /dev/null +++ b/t5x/checkpoint_utils.py @@ -0,0 +1,91 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint helper functions for managing checkpoints. + +Supports marking checkpoints as pinned to exclude them from the checkpointer +removal process. +""" + +import os + +from absl import logging + +from tensorflow.io import gfile + +# PINNED file in the checkpoint directory indicates that the checkpoint should +# not be removed during the automatic pruning of old checkpoints. +_PINNED_CHECKPOINT_FILENAME = 'PINNED' + + +def pinned_checkpoint_filepath(ckpt_dir: str) -> str: + """Full path of the pinned checkpoint file.""" + return os.path.join(ckpt_dir, _PINNED_CHECKPOINT_FILENAME) + + +def is_pinned_checkpoint(ckpt_dir: str) -> bool: + """Returns whether the checkpoint is pinned, and should NOT be removed.""" + pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) + if gfile.exists(pinned_ckpt_file): + return True + return False + + +def pin_checkpoint(ckpt_dir: str, txt: str = '1') -> None: + """Pin a checkpoint so it does not get deleted by the normal pruning process. + + Creates a PINNED file in the checkpoint directory to indicate the checkpoint + should be excluded from the deletion of old checkpoints. + + Args: + ckpt_dir: The checkpoint step dir that is to be always kept. + txt: Text to be written into the checkpoints ALWAYS_KEEP me file. + """ + pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) + with gfile.GFile(pinned_ckpt_file, 'w') as f: + logging.debug('Write %s file : %s.', pinned_ckpt_file, txt) + f.write(txt) + + +def unpin_checkpoint(ckpt_dir: str) -> None: + """Removes the pinned status of the checkpoint so it is open for deletion.""" + if not is_pinned_checkpoint(ckpt_dir): + logging.debug('%s is not PINNED. Nothing to do here.', ckpt_dir) + return + try: + pinned_ckpt_file = pinned_checkpoint_filepath(ckpt_dir) + logging.debug('Remove %s file.', pinned_ckpt_file) + gfile.rmtree(pinned_ckpt_file) + except IOError: + logging.exception('Failed to unpin %s', ckpt_dir) + + +def remove_checkpoint_dir(ckpt_dir: str) -> None: + """Removes the checkpoint dir if it is not pinned.""" + if not is_pinned_checkpoint(ckpt_dir): + logging.info('Deleting checkpoint: %s', ckpt_dir) + gfile.rmtree(ckpt_dir) + else: + logging.info('Keeping pinned checkpoint: %s', ckpt_dir) + + +def remove_dataset_checkpoint(ckpt_dir: str, train_ds_prefix: str) -> None: + """Removes dataset checkpoints if the checkpoint is not pinned.""" + if not is_pinned_checkpoint(ckpt_dir): + train_ds_pattern = os.path.join(ckpt_dir, train_ds_prefix + '*') + logging.info('Deleting dataset checkpoint: %s', train_ds_pattern) + for file in gfile.glob(train_ds_pattern): + gfile.remove(file) + else: + logging.info('Keeping pinned checkpoint: %s', ckpt_dir) diff --git a/t5x/checkpoint_utils_test.py b/t5x/checkpoint_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..36d7583f3e1e439af31481c4d5fe613a5e8d3202 --- /dev/null +++ b/t5x/checkpoint_utils_test.py @@ -0,0 +1,149 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.checkpoint_utils.""" + +import os +import traceback + +from absl.testing import absltest +from t5x import checkpoint_utils +from tensorflow.io import gfile + +TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") + + +class CheckpointsUtilsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.checkpoints_dir = self.create_tempdir() + self.ckpt_dir_path = self.checkpoints_dir.full_path + self.pinned_ckpt_file = os.path.join(self.ckpt_dir_path, "PINNED") + self.checkpoints_dir.create_file("checkpoint") + # Create a `train_ds` file representing the dataset checkpoint. + train_ds_basename = "train_ds-00000-of-00001" + self.train_ds_file = os.path.join(self.ckpt_dir_path, train_ds_basename) + self.checkpoints_dir.create_file(train_ds_basename) + + def test_always_keep_checkpoint_file(self): + self.assertEqual( + "/path/to/ckpt/dir/PINNED", + checkpoint_utils.pinned_checkpoint_filepath("/path/to/ckpt/dir")) + + def test_is_pinned_checkpoint_false_by_default(self): + # Ensure regular checkpoint without PINNED file. + self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) + + # Validate checkpoints are not pinned by default. + self.assertFalse(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path)) + + def test_is_pinned_checkpoint(self): + # Ensure the checkpoint directory as pinned. + pinned_ckpt_testdata = os.path.join(TESTDATA, "pinned_ckpt_dir") + pinned_file = os.path.join(pinned_ckpt_testdata, "PINNED") + self.assertTrue(gfile.exists(pinned_file)) + + # Test and validate. + self.assertTrue(checkpoint_utils.is_pinned_checkpoint(pinned_ckpt_testdata)) + + def test_is_pinned_missing_ckpt(self): + self.assertFalse( + checkpoint_utils.is_pinned_checkpoint( + os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist"))) + + def test_pin_checkpoint(self): + # Ensure directory isn't already pinned. + self.assertFalse(gfile.exists(self.pinned_ckpt_file)) + + # Test. + checkpoint_utils.pin_checkpoint(self.ckpt_dir_path) + + # Validate. + self.assertTrue(gfile.exists(self.pinned_ckpt_file)) + with open(self.pinned_ckpt_file) as f: + self.assertEqual("1", f.read()) + + def test_pin_checkpoint_txt(self): + checkpoint_utils.pin_checkpoint(self.ckpt_dir_path, "TEXT_IN_PINNED") + self.assertTrue(os.path.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) + with open(self.pinned_ckpt_file) as f: + self.assertEqual("TEXT_IN_PINNED", f.read()) + + def test_unpin_checkpoint(self): + # Mark the checkpoint directory as pinned. + self.checkpoints_dir.create_file("PINNED") + self.assertTrue(checkpoint_utils.is_pinned_checkpoint(self.ckpt_dir_path)) + + # Test. + checkpoint_utils.unpin_checkpoint(self.ckpt_dir_path) + + # Validate the "PINNED" checkpoint file got removed. + self.assertFalse(gfile.exists(os.path.join(self.ckpt_dir_path, "PINNED"))) + + def test_unpin_checkpoint_does_not_exist(self): + missing_ckpt_path = os.path.join(self.ckpt_dir_path, "ckpt_does_not_exist") + self.assertFalse(gfile.exists(missing_ckpt_path)) + + # Test. Assert does not raise error. + try: + checkpoint_utils.unpin_checkpoint(missing_ckpt_path) + except IOError: + # TODO(b/172262005): Remove traceback.format_exc() from the error message. + self.fail("Unpin checkpoint failed with: %s" % traceback.format_exc()) + + def test_remove_checkpoint_dir(self): + # Ensure the checkpoint directory is setup. + assert gfile.exists(self.ckpt_dir_path) + + # Test. + checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path) + + # Validate the checkpoint directory got removed. + self.assertFalse(gfile.exists(self.ckpt_dir_path)) + + def test_remove_checkpoint_dir_pinned(self): + # Mark the checkpoint directory as pinned so it does not get removed. + self.checkpoints_dir.create_file("PINNED") + + # Test. + checkpoint_utils.remove_checkpoint_dir(self.ckpt_dir_path) + + # Validate the checkpoint directory still exists. + self.assertTrue(gfile.exists(self.ckpt_dir_path)) + + def test_remove_dataset_checkpoint(self): + # Ensure the checkpoint directory is setup. + assert gfile.exists(self.ckpt_dir_path) + + # Test. + checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds") + + # Validate the checkpoint directory got removed. + self.assertFalse(gfile.exists(self.train_ds_file)) + self.assertTrue(gfile.exists(self.ckpt_dir_path)) + + def test_remove_dataset_checkpoint_pinned(self): + # Mark the checkpoint directory as pinned so it does not get removed. + self.checkpoints_dir.create_file("PINNED") + + # Test. + checkpoint_utils.remove_dataset_checkpoint(self.ckpt_dir_path, "train_ds") + + # Validate the checkpoint directory still exists. + self.assertTrue(gfile.exists(self.train_ds_file)) + self.assertTrue(gfile.exists(self.ckpt_dir_path)) + +if __name__ == "__main__": + absltest.main() diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..091900d839bfd2c8d3c0157d5bb893bd30610703 --- /dev/null +++ b/t5x/checkpoints.py @@ -0,0 +1,1678 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for reading and writing sharded checkpoints. + +The checkpointing utilities here can be used in two ways. The first is to use +the `Checkpointer` class. This requires having an optimizer and various +partitioning utilities setup, but allows for reading and writing of partitioned +parameters. It also allows different hosts to read different parameter +partitions in a multi-host setup, which results in much faster reads. This is +normally used during training where you have already created an optimizer based +on a config. + +The second way is to use the `load_t5x_checkpoint` function. This doesn't +require an optimizer to get given up front so it is useful for things like +debugging and analysis of learned weights. However, this means that we cannot do +partitioned reads so loading will be slower than that `Checkpointer` class. +""" +import asyncio +import dataclasses +import functools +import os +import re +import subprocess +import time +from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple + +from absl import logging +from flax import serialization +from flax import traverse_util +import jax +import jax.config +from jax.experimental import global_device_array as gda_lib +from jax.experimental import multihost_utils +from jax.experimental.gda_serialization import serialization as gda_serialization +import jax.numpy as jnp +import numpy as np +import orbax.checkpoint +from t5x import checkpoint_importer +from t5x import checkpoint_utils +from t5x import optimizers +from t5x import partitioning +from t5x import state_utils +from t5x import train_state as train_state_lib +import tensorflow as tf +from tensorflow.io import gfile +import tensorstore as ts +import typing_extensions +from tensorboard.backend.event_processing import directory_watcher +from tensorboard.backend.event_processing import event_file_loader +from tensorboard.backend.event_processing import io_wrapper + +PartitionSpec = partitioning.PartitionSpec +PyTreeDef = type(jax.tree_structure(None)) +LazyArray = checkpoint_importer.LazyArray +LazyAwaitableArray = checkpoint_importer.LazyAwaitableArray +LazyThreadPoolArray = checkpoint_importer.LazyThreadPoolArray + +# Version 3 is used since 2021-06-10, compared to version 2 the only change is +# that `bfloat16` arrays are written in Tensorstore using its native `bfloat16` +# support instead of casting them to `uint16`. +VERSION = 3 +# Desired chunk size is 64MiB. +# This is large enough to keep CNS happy but small enough to support a wide +# range of partitionings. +_DESIRED_CHUNK_SIZE_BYTES = 64 * 1024 * 1024 +# TODO(levskaya, adarob): how should we handle stacked/fused variables?? +_TRAIN_DS_PREFIX = 'train_ds' + + +def _choose_chunk_shape(write_shape: Sequence[int], + target_elements: int) -> List[int]: + """Chooses a chunk shape that evenly divides write_shape. + + The chunk shape is chosen such that the total number of elements is less than + or equal to `target_elements`, but is otherwise as large as possible. + + This uses a greedy algorithm that attempts to split the largest dimensions + first. + + Args: + write_shape: Write shape for which to choose a chunk shape. + target_elements: Desired number of elements in chosen chunk shape. Must be + >= 1. + + Returns: + List of length `len(write_shape)` specifying the chosen chunk shape. + """ + assert target_elements >= 1 + rank = len(write_shape) + + # `dim_factors[i]` is the list of divisors of `write_shape[i]` + dim_factors = [ + [i for i in range(1, size + 1) if size % i == 0] for size in write_shape + ] + + # The current chunk shape is: + # [dim_factors[i][-1] for i in range(rank)] + + def get_total_elements(): + """Returns the number of elements in the current chunk shape.""" + total_elements = 1 + for i in range(rank): + total_elements *= dim_factors[i][-1] + return total_elements + + # Reduce the current chunk shape until the desired number of elements is + # reached. + while get_total_elements() > target_elements: + # Greedily reduce the largest dimension. This is not guaranteed to bring us + # the closest to `target_elements`, but is simple to implement and should + # work well enough. + dim_to_reduce = -1 + dim_to_reduce_size = 1 + for i in range(rank): + size = dim_factors[i][-1] + if size > dim_to_reduce_size: + dim_to_reduce_size = size + dim_to_reduce = i + # Can only fail to choose `dim_to_reduce` if all dimensions have size of 1. + # But that cannot happen since `target_elements >= 1`. + assert dim_to_reduce_size > 1 + dim_factors[dim_to_reduce].pop() + return [dim_factors[i][-1] for i in range(rank)] + + +@dataclasses.dataclass +class _ParameterInfo: + """Information needed to read/write and slice a partitioned parameter.""" + # The unique parameter name. + name: str + # The shape of the parameter. + shape: Tuple[int] + # The TensoreStore Spec containing the minimal information for read/write. + ts_spec: Optional[ts.Spec] + # The LocalChunkInfo for the part of the parameter local to this host. + local_chunk_info: Optional[partitioning.LocalChunkInfo] + # PartitionSpec mesh axes + axes: Optional[partitioning.PartitionSpec] = None + + +orbax.checkpoint.utils.register_ts_spec_for_serialization() + + +def _run_future_tree(future_tree): + """Block until all futures are resolved on this host.""" + future_leaves, treedef = jax.tree_flatten(future_tree) + + # TODO(adarob): Use asyncio.run in py3.7+. + loop = asyncio.get_event_loop() + leaves = loop.run_until_complete(asyncio.gather(*future_leaves)) + return jax.tree_unflatten(treedef, leaves) + + +def all_steps(checkpoints_dir: str) -> Sequence[int]: + """Returns list of available step numbers in ascending order.""" + glob_pattern = os.path.join(checkpoints_dir, 'checkpoint_*', 'checkpoint') + checkpoint_paths = gfile.glob(glob_pattern) + re_pattern = re.compile(r'.*/checkpoint_(\d+)/checkpoint$') + matches = [re_pattern.match(ckpt) for ckpt in checkpoint_paths] + return sorted(int(match.group(1)) for match in matches if match) + + +def all_dataset_checkpoint_steps(checkpoints_dir: str) -> Sequence[int]: + """Returns available dataset checkpoint step numbers in ascending order.""" + glob_pattern = os.path.join(checkpoints_dir, 'checkpoint_*', + f'{_TRAIN_DS_PREFIX}-*') + train_ds_paths = gfile.glob(glob_pattern) + re_pattern = re.compile(r'.*/checkpoint_(\d+)/.*$') + matches = [re_pattern.match(path) for path in train_ds_paths] + return sorted(set(int(match.group(1)) for match in matches if match)) + + +def latest_step(checkpoints_dir: str) -> Optional[int]: + """Returns latest step number or None if no checkpoints exist.""" + steps = all_steps(checkpoints_dir) + if not steps: + return None + return steps[-1] + + +def _get_local_data(x): + if isinstance(x, gda_lib.GlobalDeviceArray): + return x.local_data(0) + else: + return x + + +def get_checkpoint_dir(checkpoints_dir: str, step: int) -> str: + """Returns path to a checkpoint dir given a parent directory and step.""" + return os.path.join(checkpoints_dir, f'checkpoint_{step}') + + +def _cast(target: PyTreeDef, dtype: jnp.dtype): + """Cast arrays in target to dtype.""" + + def maybe_cast(x): + if isinstance(x, (int, str)): + # Ignore common non-array types that shouldn't be cast. + return x + elif x.dtype == dtype: + return x + elif isinstance(x, jax.ShapeDtypeStruct): + return jax.ShapeDtypeStruct(x.shape, dtype) + elif isinstance(x, gda_lib.GlobalDeviceArray): + raise ValueError('GDA cast not supported.') + else: + return x.astype(dtype) + + return jax.tree_map(maybe_cast, target) + + +def _update_ts_path_from_relative_to_absolute( + ckpt_dir: str, ts_spec_dict: MutableMapping[str, Any]): + """Update (in-place) the path and gcs bucket (if applicable) in a TS Spec.""" + + # Handle `gs://` paths. + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_dir, re.DOTALL) + if m is not None: + if ts_spec_dict['kvstore']['driver'] != 'gcs': + raise ValueError(f'Incorrect TensorStore Spec. ' + f'Expects kvstore driver to be "gcs" for {ckpt_dir}. ' + f'Got {ts_spec_dict}') + bucket = m.group(1) + ckpt_dir = m.group(2) + ts_spec_dict['kvstore']['bucket'] = bucket + + # Update the path with `ckpt_dir` + + if 'path' in ts_spec_dict['kvstore']: + # tensorstore>=0.1.14 format + ts_spec_dict['kvstore']['path'] = os.path.join( + ckpt_dir, ts_spec_dict['kvstore']['path']) + elif 'path' in ts_spec_dict: + # tensorstore<0.1.14 format + ts_spec_dict['path'] = os.path.join(ckpt_dir, ts_spec_dict['path']) + else: + raise ValueError( + 'Incorrect TensorStore Spec. Expects "path" to be a key of spec or ' + f'`spec["kvstore"]`. Got {ts_spec_dict}') + + +def _maybe_update_ts_from_file_to_gcs(ckpt_contents): + """Updates the TensorStore driver from gfile to gcs.""" + + def _gfile_to_gcs_driver(arr_or_ts_spec_dict): + """Converts the ts.Spec dict using gfile driver to gcs driver.""" + if not isinstance(arr_or_ts_spec_dict, dict): + return arr_or_ts_spec_dict + + if arr_or_ts_spec_dict['kvstore']['driver'] in ('file', 'gfile'): + ts_spec_dict = arr_or_ts_spec_dict + path = ts_spec_dict['kvstore'].pop('path') + # This will be updated to the actual bucket in `_read_ts`. + ts_spec_dict['kvstore'] = { + 'bucket': 't5x-dummy-bucket', + 'driver': 'gcs', + 'path': path + } + else: + if arr_or_ts_spec_dict['kvstore']['driver'] != 'gcs': + raise ValueError('Unsupported TensoreStore driver. Got ' + f'{arr_or_ts_spec_dict["kvstore"]["driver"]}.') + ts_spec_dict = arr_or_ts_spec_dict + + return ts_spec_dict + + def _is_leaf(value): + return not isinstance( + value, dict) or set(value.keys()) >= {'driver', 'kvstore', 'metadata'} + + return jax.tree_map(_gfile_to_gcs_driver, ckpt_contents, is_leaf=_is_leaf) + + +def _maybe_update_ts_from_gcs_to_file(ckpt_contents): + """Updates the TensorStore driver to gfile or file if different.""" + + # if saved in gcs, change to file + def _gcs_to_file_driver(arr_or_ts_spec_dict): + if not isinstance(arr_or_ts_spec_dict, dict): + return arr_or_ts_spec_dict + + if arr_or_ts_spec_dict['kvstore']['driver'] == 'gcs': + ts_spec_dict = arr_or_ts_spec_dict + path = ts_spec_dict['kvstore'].pop('path') + driver = 'file' + ts_spec_dict['kvstore'] = {'path': path, 'driver': driver} + elif arr_or_ts_spec_dict['kvstore']['driver'] == 'gfile': + ts_spec_dict = arr_or_ts_spec_dict + driver = 'file' + ts_spec_dict['kvstore']['driver'] = driver + elif arr_or_ts_spec_dict['kvstore']['driver'] == 'file': + ts_spec_dict = arr_or_ts_spec_dict + else: + raise ValueError('Unsupported TensoreStore driver. Got ' + f'{arr_or_ts_spec_dict["kvstore"]["driver"]}.') + + return ts_spec_dict + + def _is_leaf(value): + return not isinstance( + value, dict) or set(value.keys()) >= {'driver', 'kvstore', 'metadata'} + + return jax.tree_map(_gcs_to_file_driver, ckpt_contents, is_leaf=_is_leaf) + + +class _BytesConditionVariable(object): + """Wraps a condition variable to control concurrency based on bytes.""" + + def __init__(self, num_bytes): + self._max_bytes = num_bytes + self._num_bytes = num_bytes + self._cv = asyncio.Condition(lock=asyncio.Lock()) + + async def wait_for_bytes(self, n_bytes): + async with self._cv: + await self._cv.wait_for(lambda: self._num_bytes > n_bytes) + self._num_bytes -= n_bytes + assert self._num_bytes >= 0 + + async def return_bytes(self, n_bytes): + async with self._cv: + self._num_bytes += n_bytes + assert self._num_bytes <= self._max_bytes + self._cv.notify_all() + + +class SaveStateTransformationFn(typing_extensions.Protocol): + + def __call__(self, state_dict: PyTreeDef, + parameter_infos: PyTreeDef) -> Tuple[PyTreeDef, PyTreeDef]: + """Transforms the state and param info, e.g., by remapping parameters. + + Args: + state_dict: State in the current model. + parameter_infos: PyTree containing `_ParameterInfo` objects. + + Returns: + A tuple whose first element is the result of transforming `state_dict` and + whose second element is the result of transforming `parameter_infos`. + """ + + +class RestoreStateTransformationFn(typing_extensions.Protocol): + + def __call__(self, + state_dict: PyTreeDef, + target_state_dict: PyTreeDef, + *, + is_resuming: bool = False) -> PyTreeDef: + """Transforms the given checkpoint state, e.g., by remapping parameters. + + Args: + state_dict: State to transform, which could be from a previous version of + the model. + target_state_dict: State in the current model. + is_resuming: `True` iff this restore call is due to a job resuming after + being temporarily stopped due to, for example, a preemption. This is + useful when there is restore logic that should run when restoring from + some pre-existing checkpoint, but that should not run again when + resuming from a newly-written checkpoint. + + Returns: + The result of transforming the `state_dict`. + """ + + +class Checkpointer(object): + """Handles saving and restoring potentially-sharded T5X checkpoints. + + Checkpoints are stored using a combination of msgpack (via flax.serialization) + and TensorStore. + + Parameters (and other objects) that are not partitioned are written to the + msgpack binary directly (by host 0). Partitioned parameters are each written + to their own TensorStore, with each host writing their portion to the same + TensorStore in parallel. If a partition is written on multiple hosts, the + partition is further sharded across these replicas to avoid additional + overhead. In place of the paramater, a `tensorstore.Spec` is written to the + msgpack (by host 0) as a reference to be used during restore. Note that the + path of the array being written is relative. This makes the checkpoints + portable. In other words, even if the checkpoint files are moved to a new + directory, they can still be loaded. Because the path is relative, the + checkpoint directory information has to be dynamically provided. This is done + by `_update_ts_path_from_relative_to_absolute`. + + For TensorStore driver using Google Cloud Storage (GCS) Key-Value Storage + Layer, the GCS bucket information is necessary. When a checkpoint is written + using the gcs driver, we don't want to hardcode the bucket information in the + resulting file in order to maintain the portability. Therefore, we use a dummy + bucket name of "t5x-dummy-bucket". When reading or writing the checkpoint, the + bucket information is parsed from the checkpoint directory and the bucket + information is dynamically updated. + + Attributes: + checkpoints_dir: a path to a directory to save checkpoints in and restore + them from. + keep: an optional maximum number of checkpoints to keep. If more than this + number of checkpoints exist after a save, the oldest ones will be + automatically deleted to save space. + restore_dtype: optional dtype to cast targets to after restoring. + save_dtype: dtype to cast targets to before saving. + keep_dataset_checkpoints: an optional maximum number of data iterators to + keep. If more than this number of data iterators exist after a save, the + oldest ones will be automatically deleted to save space. + """ + + def __init__(self, + train_state: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, + checkpoints_dir: str, + dataset_iterator: Optional[tf.data.Iterator] = None, + *, + keep: Optional[int] = None, + save_dtype: jnp.dtype = np.float32, + restore_dtype: Optional[jnp.dtype] = None, + use_gda: Optional[bool] = False, + keep_dataset_checkpoints: Optional[int] = None): + """Checkpointer constructor. + + Args: + train_state: A train state to be used to determine the structure of the + parameter tree, and the *full* (non-partitioned) parameter shapes and + dtypes. Saved and restored train states must match this structure. + partitioner: the partitioner to use for determining the local chunks + mapping or to perform params partitioning on restore. + checkpoints_dir: a path to a directory to save checkpoints in and restore + them from. + dataset_iterator: an optional iterator to save/restore. + keep: an optional maximum number of checkpoints to keep. If more than this + number of checkpoints exist after a save, the oldest ones will be + automatically deleted to save space. + save_dtype: dtype to cast targets to before saving. + restore_dtype: optional dtype to cast targets to after restoring. If None, + no parameter casting is performed. + use_gda: if True, enabled gda_lib.GlobalDeviceArray. Note: this is + currently an experimental feature under development. + keep_dataset_checkpoints: an optional maximum number of data iterators to + keep. If more than this number of data iterators exist after a save, the + oldest ones will be automatically deleted to save space. + """ + self._train_state = train_state + self._partitioner = partitioner + self.checkpoints_dir = checkpoints_dir + self.keep = keep + self.keep_dataset_checkpoints = keep_dataset_checkpoints + # Immutable due to use in `_get_parameter_infos` + self._save_dtype = save_dtype + self.restore_dtype = restore_dtype + self._dataset_ckpt = ( + tf.train.Checkpoint(ds=dataset_iterator) if dataset_iterator else None) + self._use_gda = use_gda + if self._use_gda: + logging.info('Checkpointing using GDA format is enabled.') + + data_layout = partitioner.get_data_layout() + self._dataset_ckpt_name = ( + f'{_TRAIN_DS_PREFIX}-' + f'{data_layout.shard_id:03}-of-{data_layout.num_shards:03}') + self._should_write_dataset_ckpt = ( + dataset_iterator and data_layout.is_first_host_in_replica_set) + + self._parameter_infos = self._get_parameter_infos() + + asyncio.set_event_loop(asyncio.new_event_loop()) + + def _get_state_dict_for_save(self, + state_dict: Dict[str, Any], + lazy_load: bool = True) -> Mapping[str, Any]: + """Gets the optimizer state dict.""" + + def _lazy_load_device_array(arr): + if isinstance(arr, jax.xla.DeviceArray): + return LazyThreadPoolArray(arr.shape, arr.dtype, lambda: np.array(arr)) + return arr + + if lazy_load: + state_dict = jax.tree_map(_lazy_load_device_array, state_dict) + return state_dict + + def _get_parameter_infos(self): + """Generates the state dict of _ParameterInfos for the Optimizer. + + We generate a state dict (matching the shape of the optimizer state dict) + that stores a _ParameterInfo for each parameter array. + + The _ParameterInfo contains the TensorStore spec for the parameter array and + the LocalChunkInfo describing the slice of the array local to this host. + + Returns: + The state dict of _ParameterInfo objects. + """ + + def _get_param_info(name: str, arr: Any, axes: partitioning.PartitionSpec): + # If a node in your model is None it is probably a param_state that is not + # used because of a MultiOptimizer. We don't want to have any parameter + # info for it because it shouldn't be saved or restored. + if arr is None: + return None + # Pass-through empty dict leaves, which occur with optax EmptyState(). + if isinstance(arr, dict) and not arr: + return {} + + if axes is None: + return _ParameterInfo( + name=name, + shape=arr.shape, + ts_spec=None, + local_chunk_info=None, + axes=None) + + if self._use_gda and isinstance(arr, gda_lib.GlobalDeviceArray): + local_chunk_info = None + metadata = gda_serialization._get_metadata(arr) # pylint: disable=protected-access + del metadata['dtype'] + else: + local_chunk_info = self._partitioner.get_local_chunk_info( + arr.shape, axes) + write_shape = [ + si if sl == slice(None) else sl.stop - sl.start + for si, sl in zip(arr.shape, local_chunk_info.slice) + ] + # TODO(levskaya, adarob): how should we handle stacked/fused variables?? + chunk_shape = _choose_chunk_shape( + write_shape, + target_elements=_DESIRED_CHUNK_SIZE_BYTES / arr.dtype.itemsize) + + metadata = { + 'compressor': { + 'id': 'gzip' + }, + 'shape': arr.shape, + 'chunks': np.array(chunk_shape), + } + + if self.checkpoints_dir.startswith('gs://'): + spec = { + 'driver': 'zarr', + 'dtype': jnp.dtype(arr.dtype).name, + 'kvstore': { + 'driver': 'gcs', + # We always write with a dummy bucket and dynamically update the + # bucket information. This makes the checkpoint files portable + # and not bind to the bucket that it was originally written to. + 'bucket': 't5x-dummy-bucket', + }, + 'path': name.replace('/', '.'), + 'metadata': metadata, + } + else: + spec = { + 'driver': 'zarr', + 'dtype': jnp.dtype(arr.dtype).name, + 'kvstore': { + 'driver': 'file', + 'path': name.replace('/', '.') + }, + 'metadata': metadata, + } + + return _ParameterInfo( + name, + shape=arr.shape, + ts_spec=ts.Spec(spec), + local_chunk_info=local_chunk_info, + axes=axes) + + # Create a tree of param names as the keys on the path to each leaf + # separated by "/". + param_names = traverse_util.unflatten_dict({ + k: '/'.join(k) for k in traverse_util.flatten_dict( + self._train_state.state_dict(), keep_empty_nodes=True) + }) + + return jax.tree_map( + _get_param_info, param_names, + self._get_state_dict_for_save(self._train_state.state_dict()), + self._partitioner.get_mesh_axes(self._train_state).state_dict()) + + def _get_checkpoint_dir(self, step: int) -> str: + return get_checkpoint_dir(self.checkpoints_dir, step) + + def all_steps(self) -> Sequence[int]: + """Returns list of available step numbers in ascending order.""" + return all_steps(self.checkpoints_dir) + + def all_dataset_checkpoint_steps(self) -> Sequence[int]: + """Returns list of available step numbers in ascending order.""" + return all_dataset_checkpoint_steps(self.checkpoints_dir) + + def latest_step(self) -> Optional[int]: + """Returns latest step number or None if no checkpoints exist.""" + return latest_step(self.checkpoints_dir) + + def _remove_old_dataset_checkpoints(self): + """Deletes old dataset checkpoints if there are more than allowed.""" + if self.keep_dataset_checkpoints: + existing_steps = self.all_dataset_checkpoint_steps() + to_remove = len(existing_steps) - self.keep_dataset_checkpoints + if to_remove > 0: + for step in existing_steps[:to_remove]: + checkpoint_utils.remove_dataset_checkpoint( + self._get_checkpoint_dir(step), _TRAIN_DS_PREFIX) + + def _remove_old_checkpoints(self): + """Deletes oldest checkpoints if there are more than keep_checkpoints.""" + if not self.keep: + return + existing_steps = self.all_steps() + to_remove = len(existing_steps) - self.keep + if to_remove <= 0: + return + + for step in existing_steps[:to_remove]: + checkpoint_utils.remove_checkpoint_dir(self._get_checkpoint_dir(step)) + + def save(self, + train_state: train_state_lib.TrainState, + state_transformation_fns: Sequence[SaveStateTransformationFn] = (), + *, + concurrent_gb: int = 128): + """Saves a checkpoint for the given train state. + + Args: + train_state: the train state to save. May contain a combination of + LazyArray objects and arrays (e.g., np.ndarray, jax.DeviceArray) + state_transformation_fns: Transformations to apply, in order, to the state + before writing. + concurrent_gb: the approximate number of gigabytes of partitionable + parameters to process in parallel. Useful to preserve RAM. + """ + step = train_state.step + step = step.get() if isinstance(step, LazyArray) else step + step = _get_local_data(step) + # Integer, to avoid side effects in the checkpoint path. + step = int(step) + + # Share a timestamp across devices. + timestamp = multihost_utils.broadcast_one_to_all(np.int32(time.time())) + + final_dir = os.path.join(self.checkpoints_dir, f'checkpoint_{step}') + tmp_dir = final_dir + f'.tmp-{timestamp}' + + if gfile.exists(final_dir): + logging.info( + 'Skipping save checkpoint for step %d (directory %s already exists)', + step, final_dir) + return + + logging.info('Saving checkpoint for step %d to %s', step, tmp_dir) + + if jax.process_index() == 0: + gfile.makedirs(tmp_dir) + # Block all hosts until directory is ready. + multihost_utils.sync_global_devices(f'checkpointer:make_dir:{tmp_dir}') + + written_state_dict = self._write_state_to_tensorstore( + tmp_dir, train_state, concurrent_gb, state_transformation_fns) + + if self._should_write_dataset_ckpt: + logging.info("Writing dataset iterator state to '%s'.", + self._dataset_ckpt_name) + try: + self._dataset_ckpt.write(os.path.join(tmp_dir, self._dataset_ckpt_name)) + except tf.errors.FailedPreconditionError as e: + logging.error( + 'Input pipeline must be stateless in order to checkpoint. Cache ' + 'stateful steps offline or disable iterator checkpointing.') + raise e + + # Block until complete on all hosts. + multihost_utils.sync_global_devices( + f'checkpointer:tensorstore_write_complete:{tmp_dir}') + + if jax.process_index() == 0: + written_state_dict = jax.tree_map(_get_local_data, written_state_dict) + + # Write msgpack file in host 0 only + msgpack_bytes = serialization.to_bytes({ + 'version': VERSION, + 'optimizer': written_state_dict + }) + with gfile.GFile(os.path.join(tmp_dir, 'checkpoint'), 'wb') as fp: + fp.write(msgpack_bytes) + + # Finalize checkpoint directory. + if final_dir.startswith('gs://'): + subprocess.run(['gsutil', '-m', 'mv', tmp_dir, final_dir], + stdout=subprocess.DEVNULL, + check=True) + else: + gfile.rename(tmp_dir, final_dir) + logging.info('Saved checkpoint for step %d to %s', step, final_dir) + + # Remove old checkpoints, if necessary. + self._remove_old_checkpoints() + self._remove_old_dataset_checkpoints() + + # Block until complete on all hosts. + multihost_utils.sync_global_devices( + f'checkpointer:write_complete:{final_dir}') + + def _write_state_to_tensorstore( + self, + ckpt_dir: str, + train_state: train_state_lib.TrainState, + concurrent_gb: int, + state_transformation_fns: Sequence[SaveStateTransformationFn], + ) -> Mapping[str, Any]: + """Writes extracted state from train state to Tensorstore.""" + concurrent_bytes = concurrent_gb * 10**9 + bytes_cv = _BytesConditionVariable(concurrent_bytes) + + async def _write_array(maybe_arr: Any, + param_info: Optional[_ParameterInfo], + cast: bool = False): + """Maybe write to TensorStore, returning object to write to msgpack. + + Args: + maybe_arr: array or LazyArray to be written + param_info: ParameterInfo object. If None (or if param_info.ts_spec is + None), the array will be immediately returned without writing to + tensorstore. This is because array is None or is not partitioned, and + should be written separately. + cast: if True, performs cast operation using self._save_dtype. + + Returns: + Tensorstore spec corresponding to the written array. + """ + if param_info is None or param_info.ts_spec is None: + # Write to the msgpack file on host 0. + if isinstance(maybe_arr, LazyArray): + return await maybe_arr.get_async() + return maybe_arr + + # Only write each chunk of a parameter from one host + if self._use_gda or param_info.local_chunk_info.replica_id == 0: + arr = maybe_arr + + # Wait until memory is available. + if isinstance(arr, gda_lib.GlobalDeviceArray): + n_bytes = sum([ + shard.data.nbytes + for shard in arr.local_shards + if shard.replica_id == 0 + ]) + else: + n_bytes = arr.nbytes + if n_bytes > concurrent_bytes: + logging.warning( + 'Temporarily increasing the concurrency limits from %d bytes to ' + '%d bytes to fit %s.', concurrent_bytes, n_bytes, param_info.name) + n_bytes = concurrent_bytes + await bytes_cv.wait_for_bytes(n_bytes) + + if isinstance(maybe_arr, LazyArray): + arr = await arr.get_async() + elif not isinstance(arr, np.ndarray) and not isinstance( + arr, gda_lib.GlobalDeviceArray): + # Cast jax.DeviceArray to np.ndarray. + arr = np.array(maybe_arr, dtype=maybe_arr.dtype) + + tmp_ts_spec_dict = param_info.ts_spec.to_json() + + if cast: + # Set desired destination dtype. + tmp_ts_spec_dict['dtype'] = jnp.dtype(self._save_dtype).name + + param_info.ts_spec = ts.Spec(tmp_ts_spec_dict) + + # Path and gcs bucket (if applicable) information is updated in-place. + _update_ts_path_from_relative_to_absolute(ckpt_dir, tmp_ts_spec_dict) + + if cast: + # Set up casting spec. + tmp_ts_spec_dict = { + 'base': tmp_ts_spec_dict, + 'driver': 'cast', + 'dtype': jnp.dtype(arr.dtype).name, # dtype before cast + } + + if self._use_gda: + await gda_serialization.async_serialize(arr, tmp_ts_spec_dict) + else: + t = await ts.open( + tmp_ts_spec_dict, + create=True, + open=True, + context=ts.Context({'file_io_concurrency': { + 'limit': 128 + }})) + await t[param_info.local_chunk_info.slice].write(arr) + + await bytes_cv.return_bytes(n_bytes) + + # N.B. we return the original ts_spec (before + # `_update_ts_path_from_relative_to_absolute` was called). This is because + # we'd like to keep the path as relative, i.e., it doesn't hardcode the + # directory that the checkpoint was originally written. This makes the + # checkpoints portable. + return param_info.ts_spec + + transformed_state_dict, transformed_parameter_infos = ( + self._transform_state_and_infos(train_state.state_dict(), + self._parameter_infos, + state_transformation_fns)) + + state_dict_for_save = self._get_state_dict_for_save(transformed_state_dict) + + def _cast_arr_if_not_partitioned(maybe_arr, param_info): + if param_info is None or param_info.ts_spec is None: + return _cast(maybe_arr, self._save_dtype) + return maybe_arr + + state_dict_for_save['target'] = jax.tree_multimap( + _cast_arr_if_not_partitioned, state_dict_for_save['target'], + transformed_parameter_infos['target']) + future_written_state = {} + for k in state_dict_for_save.keys(): + # ensure that only 'target' is cast + future_written_state[k] = jax.tree_multimap( + functools.partial(_write_array, cast=(k == 'target')), + state_dict_for_save[k], transformed_parameter_infos[k]) + + # Block until complete on this host. + written_state_dict = _run_future_tree(future_written_state) + + # Block until complete on all hosts. + multihost_utils.sync_global_devices( + f'checkpointer:ts_write_complete:{ckpt_dir}') + + return written_state_dict + + def _transform_state_and_infos( + self, + state_dict: PyTreeDef, + parameter_infos: PyTreeDef, + state_transformation_fns: Sequence[SaveStateTransformationFn], + ) -> Tuple[PyTreeDef, PyTreeDef]: + """Applies transformations to the state dict and parameter infos PyTrees.""" + for fn in state_transformation_fns: + state_dict, parameter_infos = fn(state_dict, parameter_infos) + return state_dict, parameter_infos + + def restore( + self, + step: Optional[int] = None, + path: Optional[str] = None, + state_transformation_fns: Sequence[RestoreStateTransformationFn] = (), + fallback_state: Optional[Mapping[str, Any]] = None, + lazy_parameters: bool = False) -> train_state_lib.TrainState: + """Restores the host-specific parameters in an Optimizer. + + Either `step` or `path` can be specified, but not both. If neither are + specified, restores from the latest checkpoint in the checkpoints directory. + + Args: + step: the optional step number to restore from. + path: an optional absolute path to a checkpoint file to restore from. + state_transformation_fns: Transformations to apply, in order, to the state + after reading. + fallback_state: a state dict of an optimizer to fall back to for loading + params that do not exist in the checkpoint (after applying all + `state_transformation_fns`), but do exist in `Checkpointer.optimizer`. + The union of `fallback_state` and state loaded from the checkpoint must + match `Checkpointer.optimizer`. + lazy_parameters: whether to load the parameters as LazyArrays to preserve + memory. + + Returns: + The restored train state. + + Raises: + ValueError if both `step` and `path` are specified. + ValueError if checkpoint at `path` or `step` does not exist. + ValueError if `step` and `path` are not specified and no checkpoint is + found in the checkpoints directory. + """ + if lazy_parameters and self._partitioner.params_on_devices: + raise ValueError('Lazy Parameters cannot be copied to devices, please ' + 'set partitioner.params_on_devices=False.') + if step is not None and path is not None: + raise ValueError('At most one of `step` or `path` may be provided.') + if path: + ckpt_path = path + else: + if step is None: + step = self.latest_step() + if not step: + raise ValueError(f'No checkpoints found in {self.checkpoints_dir}.') + ckpt_path = self._get_checkpoint_dir(step) + + if gfile.isdir(ckpt_path): + ckpt_dir = ckpt_path + ckpt_path = os.path.join(ckpt_path, 'checkpoint') + else: + ckpt_dir = os.path.dirname(ckpt_path) + + if not gfile.exists(ckpt_path) or gfile.isdir(ckpt_path): + raise ValueError(f'Path is not a valid T5X checkpoint: {ckpt_path}') + + logging.info('Restoring from checkpoint: %s', ckpt_path) + + with gfile.GFile(ckpt_path, 'rb') as fp: + # TODO(adarob): Use threaded reading as in flax.checkpoints. + raw_contents = fp.read() + if raw_contents.startswith(b'model_checkpoint_path'): + raise ValueError( + 'Attempting to restore a TensorFlow checkpoint as a native T5X ' + 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: ' + + ckpt_path) + + # `ckpt_contents['optimizer']` is a pytree with a realized np.array for + # leaves (params or states) written as msgpack and a ts.Spec (in a dict) + # for leaves written by TensorStore. + ckpt_contents = serialization.msgpack_restore(raw_contents) + + # If reading a ckpt that was written with gfile driver but the current + # session uses the gcs driver, convert the ckpt's driver to gcs. + if ckpt_dir.startswith('gs://'): + ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) + # If a ckpt was saved in gcs and is being loaded locally, then convert the + # driver to file or gfile. If the ckpt was not saved in gcs, do not change. + else: + ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents) + + ckpt_state_dict = self._get_optimizer_state_dict(ckpt_contents, + state_transformation_fns) + + # The state dict may contain TensorStore specs that need to be read. + dummy_spec = ts.Spec({'driver': 'zarr', 'kvstore': {'driver': 'memory'}}) + + # `dummy_written_state_dict` is a pytree with a `dummy_spec` for leaves + # (params or states) written as msgpack and a ts.Spec (in a dict) for leaves + # written by TensorStore. + dummy_written_state_dict = jax.tree_map( + lambda x: x.ts_spec or dummy_spec, + self._parameter_infos, + ) + + if fallback_state is None: + restore_parameter_infos = self._parameter_infos + else: + # If `fallback_state` was specified, restore only the subset + # of parameters matched by `self._get_optimizer_state_dict`. The + # rest will be provided by `fallback_state`. + dummy_written_state_dict = state_utils.intersect_state( + dummy_written_state_dict, ckpt_state_dict) + restore_parameter_infos = state_utils.intersect_state( + self._parameter_infos, ckpt_state_dict) + + restore_parameter_infos_flat = state_utils.flatten_state_dict( + restore_parameter_infos) + for key in restore_parameter_infos_flat.keys(): + logging.info('Restoring key from ckpt: %s', key) + + # NB: `serialization.from_state_dict` doesn't check whether the shapes match + # at the leaf level. Non-partitioned leaves (e.g., optimizer states) can + # load arrays with inconsistent shapes. + # `written_state_dict` is a pytree with a realized np.array for leaves + # (params or states) written as msgpack and a `ts.Spec` for leaves written + # by TensorStore. + written_state_dict = serialization.from_state_dict(dummy_written_state_dict, + ckpt_state_dict) + state_dict = self._read_state_from_tensorstore( + ckpt_path, + written_state_dict, + restore_parameter_infos=restore_parameter_infos, + lazy_parameters=lazy_parameters) + + # If `fallback_state` was specified, then fill the missing parameters. + if fallback_state is not None: + state_dict = state_utils.merge_state(state_dict, fallback_state) + + for key in state_utils.flatten_state_dict(state_dict).keys(): + if key not in restore_parameter_infos_flat: + logging.info('Not restoring key from ckpt: %s', key) + + if self._dataset_ckpt: + logging.info("Restoring dataset iterator from '%s'.", + self._dataset_ckpt_name) + self._dataset_ckpt.read(os.path.join( + ckpt_dir, self._dataset_ckpt_name)).assert_consumed() + + return self._restore_train_state(state_dict) + + def _restore_train_state( + self, + state_dict: optimizers.OptimizerStateType) -> train_state_lib.TrainState: + """Restores a TrainState from an Optimizer state_dict.""" + train_state = self._train_state.restore_state(state_dict) + + if not self._use_gda and self._partitioner.params_on_devices: + logging.info('Moving params to devices.') + train_state_axes = self._partitioner.get_mesh_axes(train_state) + train_state = self._partitioner.move_params_to_devices( + train_state, train_state_axes) + + return train_state + + def _create_lazy_awaitable_array( + self, param_info: _ParameterInfo, maybe_ts_spec: Any, ckpt_path: str, + restore_dtype: Optional[jnp.dtype]) -> LazyAwaitableArray: + """Creates LazyArray from tensorstore. + + Does not materialize the array immediately. + + Args: + param_info: Information about how to read the parameter, host based sliced + reads and the like. + maybe_ts_spec: The tensorstore spec to read the parameter or some other + object. If this is an array then we will do a host based sliced read on + it (provided the param_info says to). Anything else we just return. + ckpt_path: A base location to use when resolving the relative paths in the + tensorstore spec. + restore_dtype: type to restore as. None indicates that no cast is + requested. + + Returns: + LazyArray object. + """ + mesh = None + axes = None + if self._use_gda: + mesh = self._partitioner.mesh + axes = param_info.axes + get_fn = functools.partial( + _read_ts, + param_info, + maybe_ts_spec, + ckpt_path=ckpt_path, + restore_dtype=restore_dtype, + mesh=mesh, + axes=axes) + return LazyAwaitableArray.from_tensor_store_spec_or_array( + maybe_ts_spec, get_fn, dtype=restore_dtype) + + def _read_state_from_tensorstore( + self, + ckpt_path: str, + written_state_dict: Mapping[str, Any], + restore_parameter_infos: Optional[Mapping[str, Any]] = None, + lazy_parameters: bool = False, + ) -> Mapping[str, Any]: + """Sets up lazy reads from Tensorstore and returns them as a state_dict.""" + if restore_parameter_infos is None: + restore_parameter_infos = self._parameter_infos + + # Replace TensorStore Specs with the lazy array values. + state_dict = {} + for k in written_state_dict.keys(): + # ensure that only 'target' is cast + restore_dtype = self.restore_dtype if k == 'target' else None + state_dict[k] = jax.tree_multimap( + functools.partial( + self._create_lazy_awaitable_array, + ckpt_path=ckpt_path, + restore_dtype=restore_dtype), restore_parameter_infos[k], + written_state_dict[k]) + + if not lazy_parameters: + future_state_dict = jax.tree_map(lambda x: x.get_async(), state_dict) + state_dict = _run_future_tree(future_state_dict) + + if self.restore_dtype is not None: + state_dict['target'] = _cast(state_dict['target'], self.restore_dtype) + + return state_dict + + def restore_from_tf_checkpoint( + self, + path_or_dir: str, + strict: bool = True, + translator: Optional[checkpoint_importer.CheckpointTranslator] = None + ) -> train_state_lib.TrainState: + """Restore from a TensorFlow-based T5 checkpoint.""" + full_state_dict = checkpoint_importer.restore_from_t5_checkpoint( + self._train_state.state_dict(), + path_or_dir, + lazy_parameters=False, + strict=strict, + translator=translator) + + def _partition_parameter(maybe_arr: Any, param_info: _ParameterInfo): + if isinstance(maybe_arr, np.ndarray) and param_info: + arr = maybe_arr + if param_info.shape is not None and arr.shape != param_info.shape: + raise ValueError( + f'Shape of `{param_info.name}` in checkpoint {arr.shape} does ' + f'not match expected {param_info.shape}.') + if param_info.local_chunk_info: + arr = arr[param_info.local_chunk_info.slice] + return arr + return maybe_arr + + state_dict = jax.tree_multimap(_partition_parameter, full_state_dict, + self._parameter_infos) + if self.restore_dtype is not None: + state_dict['target'] = _cast(state_dict['target'], self.restore_dtype) + + return self._restore_train_state(state_dict) + + def convert_from_tf_checkpoint( + self, + path_or_dir: str, + *, + state_transformation_fns: Sequence[SaveStateTransformationFn] = (), + concurrent_gb: int = 16, + translator: Optional[checkpoint_importer.CheckpointTranslator] = None): + """Convert from a TensorFlow-based T5 checkpoint.""" + full_state_dict = checkpoint_importer.restore_from_t5_checkpoint( + self._train_state.state_dict(), + path_or_dir, + lazy_parameters=True, + translator=translator) + train_state = self._train_state.restore_state(full_state_dict) + self.save( + train_state, + state_transformation_fns=state_transformation_fns, + concurrent_gb=concurrent_gb) + + def _get_optimizer_state_dict( + self, ckpt_contents: PyTreeDef, + state_transformation_fns: Sequence[RestoreStateTransformationFn]): + return _get_optimizer_state_dict(ckpt_contents, + self._train_state.state_dict(), + state_transformation_fns) + + +class CheckpointerConstructor(typing_extensions.Protocol): + """A function that returns a checkpoints.Checkpointer. + + This type annotation allows users to partially bind args to the constructors + of Checkpointer subclasses without triggering type errors. + """ + + def __call__(self, + train_state: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, + checkpoints_dir: str, + dataset_iterator: Optional[tf.data.Iterator] = None, + *, + keep: Optional[int] = None, + save_dtype: jnp.dtype = np.float32, + restore_dtype: Optional[jnp.dtype] = None, + use_gda: Optional[bool] = False, + keep_dataset_checkpoints: Optional[int] = None) -> Checkpointer: + """Checkpointer constructor. + + Args: + train_state: A train state to be used to determine the structure of the + parameter tree, and the *full* (non-partitioned) parameter shapes and + dtypes. Saved and restored train states must match this structure. + partitioner: the partitioner to use for determining the local chunks + mapping or to perform params partitioning on restore. + checkpoints_dir: a path to a directory to save checkpoints in and restore + them from. + dataset_iterator: an optional iterator to save/restore. + keep: an optional maximum number of checkpoints to keep. If more than this + number of checkpoints exist after a save, the oldest ones will be + automatically deleted to save space. + save_dtype: dtype to cast targets to before saving. + restore_dtype: optional dtype to cast targets to after restoring. If None, + no parameter casting is performed. + use_gda: if True, enabled gda_lib.GlobalDeviceArray. Note: this is + currently an experimental feature under development. + keep_dataset_checkpoints: an optional maximum number of data iterators to + keep. If more than this number of data iterators exist after a save, the + oldest ones will be automatically deleted to save space. + """ + pass + + +class SaveBestCheckpointer(Checkpointer): + """A Checkpointer class that keeps checkpoints based on 'best' metrics. + + This extends the standard Checkpointer to garbage collect checkpoints based on + metric values, instead of step recency. It uses Tensorboard summary files to + determine best values for a given user configured metric name. Events are read + and parsed using Tensorboard's event_processing packages. + + The metric name must be of the form `{run_name}/{tag_name}`. For example, + 'train/accuracy' or 'inference_eval/glue_cola_v002/eval/accuracy'. + + A few important features of this checkpointer: + + - Fallback behavior. It is not possible to verify whether metric names are + valid during initialization, since some metrics may get written out after + some time (e.g., during an evaluation). As such, when user provided metric + names are not found, this checkpointer can be configured for two fall back + strategies: (1) if `keep_checkpoints_without_metrics` is False, we use to + the "most recent checkpoint" strategy from the standard checkpointer, (2) + if `keep_checkpoints_without_metrics` is True, we keep all checkpoints until + metrics become available (potentially indefinitely if summary files have + been deleted or corrupted). + + - The number of checkpoints to keep is always increased by 1. Since its + crucial to always keep the latest checkpoint (for recovery purposes) we + always store the latest checkpoint plus `keep` number of best checkpoints. + + - It is assumed that Tensorboard summaries (event) files share a common root + directory with `checkpoint_dir`, which is the directory passed to the + the logdir crawler that searches for event files. + + Attributes: + checkpoints_dir: a path to a directory to save checkpoints in and restore + them from. + keep: an optional maximum number of checkpoints to keep. If more than this + number of checkpoints exist after a save, the oldest ones will be + automatically deleted to save space. + restore_dtype: optional dtype to cast targets to after restoring. + save_dtype: dtype to cast targets to before saving. + metric_name_to_monitor: Name of metric to monitor. Must be in the format + {run_name}/{tag_name} (e.g., 'train/accuracy', + 'inference_eval/glue_cola_v002/eval/accuracy'). + metric_mode: Mode to use to compare metric values. One of 'max' or 'min'. + keep_checkpoints_without_metrics: Whether to always keep (or delete) + checkpoints for which a metric value has not been found. + force_keep_period: When removing checkpoints, skip those who step is + divisible by force_keep_period (step % force_keep_period == 0). + use_gda: Enables GDA (see Checkpointer). + keep_dataset_checkpoints: an optional maximum number of data iterators to + keep. If more than this number of data iterators exist after a save, the + oldest ones will be automatically deleted to save space. + """ + + def __init__(self, + train_state: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, + checkpoints_dir: str, + dataset_iterator: Optional[tf.data.Iterator] = None, + *, + keep: Optional[int] = None, + save_dtype: jnp.dtype = np.float32, + restore_dtype: Optional[jnp.dtype] = None, + metric_name_to_monitor: str = 'train/accuracy', + metric_mode: str = 'max', + keep_checkpoints_without_metrics: bool = True, + force_keep_period: Optional[int] = None, + use_gda: bool = False, + keep_dataset_checkpoints: Optional[int] = None): + super().__init__( + train_state, + partitioner, + checkpoints_dir, + dataset_iterator, + keep=keep, + save_dtype=save_dtype, + restore_dtype=restore_dtype, + use_gda=use_gda, + keep_dataset_checkpoints=keep_dataset_checkpoints) + if metric_mode not in ('max', 'min'): + raise ValueError('Unsupported `metric_mode`: %s' % metric_mode) + + # Metric run and tag names are derived from metric_name_to_monitor and are + # filled in _try_fill_metric_run_and_tag_names(). + self._metric_run: Optional[str] = None + self._metric_tag: Optional[str] = None + self._metric_name_to_monitor = metric_name_to_monitor + self._metric_mode = metric_mode + self._keep_checkpoints_without_metrics = keep_checkpoints_without_metrics + self._force_keep_period = force_keep_period + logging.info('Using SaveBestCheckpointer to keep %s best (%s) metric %s', + keep, metric_mode, metric_name_to_monitor) + + def _populate_metrics_for_steps(self, + steps: Iterable[int]) -> Mapping[int, float]: + """Iterate through summary event files and return metrics for `steps`.""" + metrics_by_step = {} + for subdir in io_wrapper.GetLogdirSubdirectories(self.checkpoints_dir): + rpath = os.path.relpath(subdir, self.checkpoints_dir) + # Skip runs that do not match user-specified metric. + if ((not self._metric_run and not self._try_fill_metric_run_and_tag_names( + (rpath,))) or self._metric_run != rpath): + logging.info('Skipping events in %s', subdir) + continue + + logging.info('Looking for events in %s', subdir) + loader = directory_watcher.DirectoryWatcher( + subdir, event_file_loader.EventFileLoader, + io_wrapper.IsTensorFlowEventsFile) + for event in loader.Load(): + # Skip metric collection of events for unavailable checkpoints or for + # unmonitored tags. + if (event.step not in steps or not event.summary.value or + event.summary.value[0].tag != self._metric_tag): + continue + metric_value = tf.make_ndarray(event.summary.value[0].tensor) + metrics_by_step[event.step] = metric_value + + return metrics_by_step + + def _try_fill_metric_run_and_tag_names(self, run_keys: Iterable[str]) -> bool: + """Extract metric run and tag names by matching one of the `run_keys`. + + This function tries to greedily split user-provided metric_name_to_monitor + into {run} and {tag} components. It does so by trying to match all available + {run}/{tag} names in the provided run_keys. If successful, populates + self._metric_run and self._metric_tag. + + Args: + run_keys: Set of run keys to test for. + + Returns: + Whether metric name prefix matches one of the run keys, and, as a + side-effect, populates self._metric_run and self._metric_tag. + """ + metric_run, metric_tag = None, None + + # Query existing events for different run and tags to match with user + # provided metric name. + m = self._metric_name_to_monitor.split('/') + possible_run_names = ['/'.join(m[:i]) for i in range(1, len(m))] + for key in run_keys: + for possible_run_name in possible_run_names: + if key == possible_run_name: + metric_run = possible_run_name + metric_tag = self._metric_name_to_monitor[len(metric_run) + 1:] + break + + if metric_run and metric_tag: + self._metric_run, self._metric_tag = metric_run, metric_tag + return True + return False + + def _filter_out_force_keep_period_steps(self, existing_steps): + """Filter out steps that are divisible by keep_period excluding the last.""" + if not existing_steps: + return existing_steps + + # Don't filter out the last step. + last_step = existing_steps.pop() + existing_steps = [ + s for s in existing_steps if s % self._force_keep_period != 0 + ] + return existing_steps + [last_step] + + def _remove_old_checkpoints(self): + """Deletes checkpoints if there are more than keep_checkpoints.""" + if not self.keep: + return + + existing_steps = self.all_steps() + if self._force_keep_period: + # Ignore checkpoints whose step is divisible by the keep period. + existing_steps = self._filter_out_force_keep_period_steps(existing_steps) + + # Artificially add 1 to `keep` since we always keep the latest checkpoint. + if len(existing_steps) <= self.keep + 1: + return + + # Synchronous fetch of new events for existing_steps. + metrics_by_step = self._populate_metrics_for_steps(existing_steps) + logging.info('SaveBestcheckpointer: collected metrics %s', metrics_by_step) + + # Re-sort existing_steps by metric values while always keeping the latest + # checkpoint. + latest_checkpoint = existing_steps[-1] + existing_steps = existing_steps[:-1] + + if self._keep_checkpoints_without_metrics: + existing_steps = list( + filter(lambda s: s in metrics_by_step, existing_steps)) + + to_remove = len(existing_steps) - self.keep + if to_remove <= 0: + return + + # For any remaining steps without metrics, we assign a low/high value which + # will make them candidate for removal. If no metrics are found this sorting + # should preserve current order (oldest first). + not_found_value = float('-inf' if self._metric_mode == 'max' else 'inf') + existing_steps = sorted( + existing_steps, + key=lambda step: metrics_by_step.get(step, not_found_value), + reverse=(self._metric_mode != 'max')) + existing_steps.append(latest_checkpoint) + + for step in existing_steps[:to_remove]: + checkpoint_utils.remove_checkpoint_dir(self._get_checkpoint_dir(step)) + + +def _get_optimizer_state_dict( + ckpt_contents: PyTreeDef, optimizer_state: Mapping[str, Any], + state_transformation_fns: Sequence[RestoreStateTransformationFn]): + """Extracts optimizer state dict contents and applies assignment map.""" + version = ckpt_contents.get('version', 0) + if version == 0: + # This is a standard Flax checkpoint and may require remapping below. + ckpt_optimizer_state = ckpt_contents + else: + ckpt_optimizer_state = ckpt_contents['optimizer'] + + if version >= 2: + for fn in state_transformation_fns: + ckpt_optimizer_state = fn(ckpt_optimizer_state, optimizer_state) + return ckpt_optimizer_state + else: + raise ValueError('Checkpoint versions earlier than 2 are not supported. ' # pylint: disable=unreachable + f'Got version: {version}') + + +async def _read_ts(param_info: _ParameterInfo, + maybe_tspec: Any, + ckpt_path: str, + restore_dtype: Optional[jnp.dtype] = None, + mesh: Optional[gda_lib.Shape] = None, + axes: Optional[gda_lib.MeshAxes] = None): + """Read from a tensorstore. + + If both `mesh` and `axes` are provided, the method will attempt to restore the + array as a GlobalDeviceArray. + + Note: + We use param_infos as the first argument because this function is only used + in `jax.tree_multimap` calls. In a tree multimap if the leaf of the first + tree is `None` then is is ignored, even if the second tree has a subtree + at that point. This means that when we are using something like a + MultiOptimizer we can set the parameter info for a variable to `None` and + we can skip processing it, even if the checkpoint has a subtree with things + like optimizer state variables in it. + + Args: + param_info: Information about how to read the parameter, host based sliced + reads and the like. + maybe_tspec: The tensorstore spec to read the parameter or some other + object. If this is an array then we will do a host based sliced read on it + (provided the param_info says to). Anything else we just return. + ckpt_path: A base location to use when resolving the relative paths in the + tensorstore spec. + restore_dtype: type to restore as. None indicates that no cast is requested. + mesh: Mesh object for GDA restoration. + axes: MeshAxes object for GDA restoration. + + Returns: + The array. Depending on the value `maybe_tspec` it might be read from + tensorstore, or it might be returned as is. Depending on the values in + param_info (specifically the `local_chunk_info`) it might be the full value + or a specific slice. + """ + # If saved as a numpy array, but a partitioned read is requested, return a + # slice of the array for that host. Otherwise, return the whole thing. + if isinstance(maybe_tspec, np.ndarray) and param_info: + if param_info.local_chunk_info: + arr = maybe_tspec + return arr[param_info.local_chunk_info.slice] + else: + return maybe_tspec + # If we have anything else that isn't a tensorstore spec just return it. + elif not isinstance(maybe_tspec, ts.Spec): + return maybe_tspec + + tmp_ts_spec_dict = maybe_tspec.to_json() + # Remove non-required params so that we can open Tensorstore + # that was created with a different set of params. + del tmp_ts_spec_dict['metadata']['chunks'] + del tmp_ts_spec_dict['metadata']['compressor'] + + # Convert the relative path in the spec to a path based on the checkpoint + # location. Path and gcs bucket (if applicable) information is updated + # in-place. + _update_ts_path_from_relative_to_absolute( + os.path.dirname(ckpt_path), tmp_ts_spec_dict) + + if param_info.shape is not None: + ts_spec_arr_shape = tuple(tmp_ts_spec_dict['metadata']['shape']) + # Check that the shapes of the array on disk match the expected shape based + # on the optimizer that is being restored. + if ts_spec_arr_shape != param_info.shape: + raise ValueError(f'Shape of `{param_info.name}` in checkpoint ' + f'{ts_spec_arr_shape} does not match expected ' + f'{param_info.shape}.') + + if ('dtype' in tmp_ts_spec_dict and tmp_ts_spec_dict['dtype'] + == 'uint16') or ('dtype' in tmp_ts_spec_dict['metadata'] and + tmp_ts_spec_dict['metadata']['dtype'] == ' Optional[_ParameterInfo]: + """Create _ParameterInfo that results in a full read.""" + # tspec is only None for `param_states` where the associated variable + # is not updated by any optimizers. By setting the parameter info for + # this to None, we can later short circut processing these subtrees + # during loading. + if maybe_tspec is None: + return None + local_chunk_info = None + tspec = None + if isinstance(maybe_tspec, ts.Spec): + tspec = maybe_tspec + local_chunk_info = partitioning.LocalChunkInfo( + slice=(slice(None, None),), replica_id=0) + return _ParameterInfo( + name='', # We don't ever use the name. + shape=tuple(tspec.to_json()['metadata']['shape']) if tspec else None, + # We just believe the spec in the file. + ts_spec=tspec, + local_chunk_info=local_chunk_info, + axes=None) + + +def find_checkpoint(path: str, step: Optional[int] = None) -> str: + """Find the checkpoint file based on paths and steps. + + Args: + path: The location of the checkpoint. Can point to the `model_dir`, the + checkpoint dir with a step, or the actual checkpoint file. + step: The step to load. Only used if you are pointing to the `model_dir` + + Raises: + ValueError if the checkpoint file can't be found. + + Returns: + The path to the checkpoint file. + """ + # If you aren't pointing at the msgpack checkpoint file + if gfile.isdir(path): + # If you didn't specify a step + if step is None: + # Try to get the most recent step. + step = latest_step(path) + # If you found a step then you were pointing at model_dir, set the path to + # the msgpack file in the checkpoint dir. + if step: + path = get_checkpoint_dir(path, step) + # You gave a step, use it. + else: + path = get_checkpoint_dir(path, step) + # Whether you supplied a step, found a step, or were already pointing at the + # step, you are not pointing at a step directory, so now point to the + # msgpack file. + path = os.path.join(path, 'checkpoint') + # You weren't point to a dir so you were pointing at the msgpack file. + # Check that we found a checkpoint file. + if not gfile.exists(path) or gfile.isdir(path): + raise ValueError(f'Path is not a valid checkpoint: {path}') + return path + + +def load_t5x_checkpoint( + path: str, + step: Optional[int] = None, + state_transformation_fns: Sequence[RestoreStateTransformationFn] = (), + remap: bool = True, + restore_dtype: Optional[jnp.dtype] = None, + lazy_parameters: bool = False) -> PyTreeDef: + """Load a T5X checkpoint without pre-defining the optimizer. + + Note: + This only works for T5X checkpoints, not TF checkpoints. + + Args: + path: The location of the checkpoint. + step: The checkpoint from which step should be loaded. + state_transformation_fns: Transformations to apply, in order, to the state + after reading. + remap: Whether to rename the checkpoint variables to the newest version. + restore_dtype: optional dtype to cast targets to after restoring. If None, + no parameter casting is performed. + lazy_parameters: whether to load the parameters as LazyArrays to preserve + memory. + + Returns: + A nested dictionary of weights and parameter states from the checkpoint. + """ + path = find_checkpoint(path, step) + logging.info('Restoring from checkpoint: %s', path) + + # The msgpack file will have all the info we need about the parameter layout. + with gfile.GFile(path, 'rb') as fp: + ckpt_contents = serialization.msgpack_restore(fp.read()) + + # If reading a ckpt that was written with gfile driver but the current + # session uses the gcs driver, convert the ckpt's driver to gcs. + if path.startswith('gs://'): + ckpt_contents = _maybe_update_ts_from_file_to_gcs(ckpt_contents) + # If a ckpt was saved in gcs and is being loaded locally, then convert the + # driver to file or gfile. If the ckpt was not saved in gcs, do not change. + else: + ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents) + + # Remap that variable names to the most recent formatting. + if remap: + ckpt_optimizer_state = _get_optimizer_state_dict(ckpt_contents, {}, + state_transformation_fns) + # If we aren't remapping names we at least need to index into the checkpoint + # file blob to make sure we are only dealing with the optimizer state. + else: + # Grab a subsection of the file depending on the version. + version = ckpt_contents.get('version', 0) + if version == 0: + ckpt_optimizer_state = ckpt_contents + else: + ckpt_optimizer_state = ckpt_contents['optimizer'] + + # Replace all dicts of tensorstore specs with actual `ts.Spec`s. + # When a checkpoint was trained using a MultiOptimizer, some of the parameter + # states may be set to `None` (when a parameter was untouched by any + # optimizer). We still needs references to these in our state so we keep + # empty nodes. + ckpt_optimizer_state_with_specs = ( + state_utils.flatten_state_dict( + ckpt_optimizer_state, keep_empty_nodes=True)) + ckpt_optimizer_state_with_specs = { + k: ts.Spec(v) if isinstance(v, dict) else v + for k, v in ckpt_optimizer_state_with_specs.items() + } + + # Create fake parameter info that results in reading the whole variable. + param_infos = { + k: fake_param_info(v) for k, v in ckpt_optimizer_state_with_specs.items() + } + + ckpt_optimizer_state_with_specs = traverse_util.unflatten_dict( + ckpt_optimizer_state_with_specs, sep='/') + param_infos = traverse_util.unflatten_dict(param_infos, sep='/') + + def _create_lazy_awaitable_array( + param_info: _ParameterInfo, maybe_ts_spec: Any, ckpt_path: str, + restore_dtype: Optional[jnp.dtype]) -> LazyAwaitableArray: + get_fn = functools.partial( + _read_ts, + param_info, + maybe_ts_spec, + ckpt_path=ckpt_path, + restore_dtype=restore_dtype) + return LazyAwaitableArray.from_tensor_store_spec_or_array( + maybe_ts_spec, get_fn, dtype=restore_dtype) + + state_dict = jax.tree_multimap( + functools.partial( + _create_lazy_awaitable_array, + ckpt_path=path, + restore_dtype=restore_dtype), param_infos, + ckpt_optimizer_state_with_specs) + + if not lazy_parameters: + future_state_dict = jax.tree_map(lambda x: x.get_async(), state_dict) + state_dict = _run_future_tree(future_state_dict) + + if restore_dtype is not None: + state_dict['target'] = _cast(state_dict['target'], restore_dtype) + return state_dict diff --git a/t5x/checkpoints_test.py b/t5x/checkpoints_test.py new file mode 100644 index 0000000000000000000000000000000000000000..834ebcde49ac8501c3f2fb5c866cf53f6f6ef967 --- /dev/null +++ b/t5x/checkpoints_test.py @@ -0,0 +1,1744 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.checkpoints.""" +import concurrent.futures +import functools +import itertools +import os +from typing import Any, Mapping + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +from flax import serialization +from flax import traverse_util +from flax.metrics import tensorboard +import jax +import jax.numpy as jnp +import numpy as np +from t5x import checkpoints +from t5x import optimizers +from t5x import partitioning +from t5x import state_utils +from t5x import test_utils +from t5x import train_state as train_state_lib +from t5x import utils +import tensorflow as tf +from tensorflow.io import gfile +import tensorstore as ts + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +mock = absltest.mock +PartitionSpec = partitioning.PartitionSpec +FLAGS = flags.FLAGS +LazyArray = checkpoints.LazyArray + +TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'testdata') + +FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState + + +def make_train_state( + *, + step: int, + params: Mapping[str, Any], + param_states: Mapping[str, Any], + flax_optimizer_def: optimizers.OptimizerDefType = optimizers.sgd(0.1) +) -> FlaxOptimTrainState: + """Helper to construct a train state for testing.""" + optimizer = optimizers.Optimizer( + flax_optimizer_def, + state=optimizers.OptimizerState(step=step, param_states=param_states), + target=params) + return FlaxOptimTrainState(optimizer) + + +def make_train_state_multi_optimizer(params: Mapping[str, Any], + param_states: Mapping[str, Any], + step: int) -> FlaxOptimTrainState: + """Helper to construct a train state with multi optimizer for testing.""" + optimizer = optimizers.Optimizer( + optimizers.MultiOptimizer([ + (traverse_util.ModelParamTraversal( + lambda path, _: 'kernel' not in path), optimizers.sgd(0.1)), + ]), + state=optimizers.OptimizerState(step=step, param_states=param_states), + target=params) + return FlaxOptimTrainState(optimizer) + + +def update_train_state_step(train_state: FlaxOptimTrainState, + step: int) -> FlaxOptimTrainState: + """Helper to update the step inside TrainState.""" + state_dict = train_state.state_dict() + state_dict['state']['step'] = step + return train_state.restore_state(state_dict) + + +class CheckpointChunkShapeTest(absltest.TestCase): + + def test_simple(self): + self.assertEqual([4096, 4096], + checkpoints._choose_chunk_shape([4096, 4096], 4096 * 4096)) + + self.assertEqual([4096, 4096], + checkpoints._choose_chunk_shape([8192, 8192], 4096 * 4096)) + + self.assertEqual([4096, 2731], + checkpoints._choose_chunk_shape([8192, 8193], 4096 * 4096)) + + self.assertEqual([4096], checkpoints._choose_chunk_shape([8192], 4096)) + + self.assertEqual([2731], checkpoints._choose_chunk_shape([8193], 4096)) + + +class CheckpointsTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.train_state = make_train_state( + step=np.int32(42), + params={ + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + 'kernel': np.arange(32, dtype=np.float32).reshape((2, 16)) + }, + param_states={ + 'bias': np.int32(1), + 'kernel': np.array([1, 2], np.uint8) + }) + self.train_state_multi_optimizer = make_train_state_multi_optimizer( + step=np.int32(42), + params={ + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + 'kernel': np.arange(32, dtype=np.float32).reshape((2, 16)) + }, + param_states={ + 'bias': np.int32(1), + 'kernel': None + }) + self.default_mesh_axes = make_train_state( + step=None, + params={ + 'bias': PartitionSpec('model', None), + 'kernel': PartitionSpec(None, 'model') + }, + param_states={ + 'bias': None, + 'kernel': None + }) + + self.ds = tf.data.Dataset.range(1024) + + self.checkpoints_dir = self.create_tempdir() + self.tmp_dir = self.checkpoints_dir.full_path + + fake_checkpoints = self.create_tempdir() + self.fake_checkpoints = fake_checkpoints.full_path + self.steps = (0, 100, 200) + for step in self.steps: + step_dir = fake_checkpoints.mkdir(f'checkpoint_{step}') + step_dir.create_file('checkpoint') + + @mock.patch('jax._src.lib.xla_bridge.process_index') + @mock.patch('jax.devices') + @mock.patch('jax.local_devices') + def get_partitioner(self, + process_index, + host_count, + num_partitions, + local_devices_fn, + devices_fn, + process_index_fn, + params_on_devices: bool = True, + mesh_axes=None): + host_count_to_layout = { + 1: (2, 2, 1, 2), + 2: (4, 2, 1, 2), + 4: (4, 4, 1, 2), + 8: (4, 8, 1, 2), + 16: (8, 8, 1, 2), + 32: (8, 16, 1, 2) + } + devices = test_utils.make_devices(*host_count_to_layout[host_count]) + devices_fn.return_value = devices + local_devices = [d for d in devices if d.process_index == 0] + local_devices_fn.return_value = local_devices + process_index_fn.return_value = process_index + num_partitions_to_mps = { + 1: (1, 1, 1, 1), + 2: (1, 1, 1, 2), + 4: (2, 1, 1, 2), + 16: (4, 2, 1, 2) + } + mesh = partitioning.get_mesh( + model_parallel_submesh=num_partitions_to_mps[num_partitions], + input_devices=devices, + input_local_devices=local_devices) + mesh_axes = mesh_axes or self.default_mesh_axes + local_chunker = partitioning.LocalChunker(mesh) + + class TestPartitioner(partitioning.BasePartitioner): + + def __init__(self): + self.move_params_to_devices_calls = 0 + super().__init__( + num_partitions, None, params_on_devices=params_on_devices) + + @property + def _local_chunker(self): + return local_chunker + + @property + def _mesh(self): + return mesh + + def partition(self, + fn, + in_axis_resources, + out_axis_resources, + static_argnums=(), + donate_argnums=()): + raise NotImplementedError + + def compile(self, partitioned_fn, *args): + raise NotImplementedError + + def move_params_to_devices(self, train_state, train_state_axes): + assert params_on_devices + return train_state + + def get_mesh_axes(self, train_state): + return mesh_axes + + return TestPartitioner() + + # pylint:disable=no-value-for-parameter + @mock.patch( + 'jax.experimental.multihost_utils.sync_global_devices', return_value=None) + @mock.patch('time.time', return_value=0) + @mock.patch('jax.host_count') + @mock.patch('jax.process_index') + def call_host_checkpointer(self, + process_index, + host_count, + partitioner, + fn, + save_dtype, + ds_iter, + mock_process_index, + mock_host_count, + unused_mock_host_time, + unused_mock_sync_devices, + restore_dtype=np.float32): + mock_process_index.return_value = process_index + mock_host_count.return_value = host_count + + checkpointer = checkpoints.Checkpointer( + self.train_state, + partitioner, + self.tmp_dir, + ds_iter, + save_dtype=save_dtype, + restore_dtype=restore_dtype) + return fn(checkpointer) + + # pylint:disable=no-value-for-parameter + @mock.patch( + 'jax.experimental.multihost_utils.sync_global_devices', return_value=None) + @mock.patch('time.time', return_value=0) + @mock.patch('jax.host_count') + @mock.patch('jax.process_index') + def call_host_multioptimizer_checkpointer(self, process_index, host_count, + partitioner, fn, save_dtype, + ds_iter, mock_process_index, + mock_host_count, + unused_mock_host_time, + unused_mock_sync_devices): + mock_process_index.return_value = process_index + mock_host_count.return_value = host_count + + checkpointer = checkpoints.Checkpointer( + self.train_state_multi_optimizer, + partitioner, + self.tmp_dir, + ds_iter, + save_dtype=save_dtype) + return fn(checkpointer) + + def test_get_parameter_infos(self): + train_state = make_train_state( + params={ + 'bias': np.ones((8192, 8192), np.float32), + 'kernel': np.ones((2, 16), np.float32) + }, + param_states={ + 'bias': np.int32(1), + 'kernel': np.array([1, 2]) + }, + step=np.int32(42)) + # host 3 of a 4x4 with mesh 'model' dim == 16 + partitioner = self.get_partitioner(3, 4, 16) + checkpointer = checkpoints.Checkpointer(train_state, partitioner, + self.tmp_dir) + + expected_parameter_infos = { + 'state': { + 'step': + checkpoints._ParameterInfo( + name='state/step', shape=(), ts_spec=None, local_chunk_info=None, axes=None), + 'param_states': { + 'bias': + checkpoints._ParameterInfo( + name='state/param_states/bias', + shape=(), + ts_spec=None, + local_chunk_info=None, axes=None), + 'kernel': + checkpoints._ParameterInfo( + name='state/param_states/kernel', + shape=(2,), + ts_spec=None, + local_chunk_info=None, axes=None) + } + }, + 'target': { + 'bias': + checkpoints._ParameterInfo( + name='target/bias', + shape=(8192, 8192), + ts_spec=ts.Spec({ + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { # pylint:disable=duplicate-key + 'driver': 'file', + 'path': 'target.bias', + }, + 'metadata': { + 'chunks': [4096, 4096], + 'compressor': { + 'id': 'gzip' + }, + 'shape': [8192, 8192], + }, + }), + local_chunk_info=partitioning.LocalChunkInfo( + slice=(slice(4096, 8192, None), slice(None, None, + None)), + replica_id=1), axes=PartitionSpec('model', None)), + 'kernel': + checkpoints._ParameterInfo( + name='target/kernel', + shape=(2, 16), + ts_spec=ts.Spec({ + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { # pylint:disable=duplicate-key + 'driver': 'file', + 'path': 'target.kernel', + }, + 'metadata': { + 'chunks': [2, 8], + 'compressor': { + 'id': 'gzip' + }, + 'shape': [2, 16], + }, + }), + local_chunk_info=partitioning.LocalChunkInfo( + slice=(slice(None, None, None), slice(8, 16, None)), + replica_id=1), axes=PartitionSpec(None, 'model')) + } + } # pyformat: disable + jax.tree_multimap(self.assertEqual, checkpointer._get_parameter_infos(), + expected_parameter_infos) + + def test_get_multioptimizer_parameter_infos(self): + train_state = make_train_state( + step=np.int32(42), + params={ + 'bias': np.ones((8192, 8192), jnp.bfloat16), + 'kernel': np.ones((2, 16), np.float32) + }, + param_states={ + 'bias': np.int32(1), + # The parameter state for Kernel is `None` as if we have a + # multioptimizer that is not updating this parameter. + 'kernel': None + }) + # host 3 of a 4x4 with mesh 'model' dim == 16 + partitioner = self.get_partitioner(3, 4, 16) + checkpointer = checkpoints.Checkpointer(train_state, partitioner, + self.tmp_dir) + kernel_state_info = ( + checkpointer._get_parameter_infos()['state']['param_states']['kernel']) + self.assertIsNone(kernel_state_info) + + def test_all_steps(self): + partitioner = self.get_partitioner(0, 1, 1) + checkpointer = self.call_host_checkpointer(0, 1, partitioner, lambda c: c, + np.float32, None) + + self.assertIsNone(checkpointer.latest_step()) + for step in ['0', '42', '10', '999.tmp-0', '100']: + d = os.path.join(checkpointer.checkpoints_dir, f'checkpoint_{step}') + gfile.makedirs(d) + ckpt = os.path.join(d, 'checkpoint') + with gfile.GFile(ckpt, 'w') as f: + f.write('') + self.assertSequenceEqual( + checkpoints.all_steps(checkpointer.checkpoints_dir + '/'), + [0, 10, 42, 100]) + + def test_all_latest_step(self): + partitioner = self.get_partitioner(0, 1, 1) + checkpointer = self.call_host_checkpointer(0, 1, partitioner, lambda c: c, + np.float32, None) + + self.assertIsNone(checkpointer.latest_step()) + + for step in ['0', '42', '10', '999.tmp-0', '100']: + d = os.path.join(checkpointer.checkpoints_dir, f'checkpoint_{step}') + gfile.makedirs(d) + ckpt = os.path.join(d, 'checkpoint') + with gfile.GFile(ckpt, 'w') as f: + f.write('') + + self.assertSequenceEqual(checkpointer.all_steps(), [0, 10, 42, 100]) + self.assertEqual(checkpointer.latest_step(), 100) + + # Remove checkpoint file for step 100 (but leave directory). + gfile.remove(ckpt) + self.assertSequenceEqual(checkpointer.all_steps(), [0, 10, 42]) + self.assertEqual(checkpointer.latest_step(), 42) + + def test_all_latest_step_public(self): + self.assertIsNone(checkpoints.latest_step(self.tmp_dir)) + + for step in ['0', '42', '10', '999.tmp-0', '100']: + d = os.path.join(self.tmp_dir, f'checkpoint_{step}') + gfile.makedirs(d) + ckpt = os.path.join(d, 'checkpoint') + with gfile.GFile(ckpt, 'w') as f: + f.write('') + + self.assertSequenceEqual( + checkpoints.all_steps(self.tmp_dir), [0, 10, 42, 100]) + self.assertEqual(checkpoints.latest_step(self.tmp_dir), 100) + + # Remove checkpoint file for step 100 (but leave directory). + gfile.remove(ckpt) + self.assertSequenceEqual(checkpoints.all_steps(self.tmp_dir), [0, 10, 42]) + self.assertEqual(checkpoints.latest_step(self.tmp_dir), 42) + + def validate_restore(self, + host_count, + num_partitions, + step=42, + checkpoint_dataset=False, + expected_restore_dtype=np.float32, + lazy_parameters=False, + disable_partitioning=False): + params = self.train_state.params + param_states = self.train_state.param_states + + for i in range(host_count): + partitioner = self.get_partitioner( + i, + host_count, + num_partitions, + params_on_devices=not lazy_parameters, + mesh_axes=jax.tree_map(lambda x: None, self.default_mesh_axes) + if disable_partitioning else None) + ds_shard_id = partitioner.get_data_layout().shard_id + + bias_slice = partitioner.get_local_chunk_info(params['bias'].shape, + ('model', None)).slice + kernel_slice = partitioner.get_local_chunk_info(params['kernel'].shape, + (None, 'model')).slice + + ds_iter = iter(self.ds) + + actual_train_state = self.call_host_checkpointer( + i, + host_count, + partitioner, + lambda c: c.restore( # pylint: disable=g-long-lambda + step=step, + lazy_parameters=lazy_parameters), + np.float32, + ds_iter if checkpoint_dataset else None, + restore_dtype=expected_restore_dtype) + if lazy_parameters: + actual_train_state = jax.tree_map(lambda x: x.get(), actual_train_state) + self.assertEqual(actual_train_state._optimizer.optimizer_def, + self.train_state._optimizer.optimizer_def) + + self.assertEqual(actual_train_state.step, step) + self.assertEqual(actual_train_state.step.dtype, np.int32) + self.assertEqual(actual_train_state._optimizer.state.step.dtype, np.int32) + jax.tree_multimap(np.testing.assert_array_equal, + actual_train_state.param_states, param_states) + self.assertEqual(actual_train_state.param_states['kernel'].dtype, + np.uint8) + self.assertSameElements(actual_train_state.params, ('bias', 'kernel')) + self.assertTrue( + all( + jax.tree_leaves( + jax.tree_map(lambda x: x.dtype == expected_restore_dtype, + actual_train_state.params)))) + np.testing.assert_equal(actual_train_state.params['bias'], + params['bias'][bias_slice]) + np.testing.assert_equal(actual_train_state.params['kernel'], + params['kernel'][kernel_slice]) + if checkpoint_dataset: + # The next value from the restored iterator should equal the + # replica set id. + self.assertEqual(next(ds_iter).numpy(), ds_shard_id) + + def validate_multioptimizer_restore(self, + host_count, + num_partitions, + step=42, + checkpoint_dataset=False, + expected_restore_dtype=np.float32): + params = self.train_state_multi_optimizer.params + param_states = self.train_state_multi_optimizer.param_states + + for i in range(host_count): + partitioner = self.get_partitioner(i, host_count, num_partitions) + ds_shard_id = partitioner.get_data_layout().shard_id + + bias_slice = partitioner.get_local_chunk_info(params['bias'].shape, + ('model', None)).slice + kernel_slice = partitioner.get_local_chunk_info(params['kernel'].shape, + (None, 'model')).slice + + ds_iter = iter(self.ds) + + actual_train_state = self.call_host_multioptimizer_checkpointer( + i, host_count, partitioner, lambda c: c.restore(step=step), + np.float32, ds_iter if checkpoint_dataset else None) + actual_optimizer = actual_train_state._optimizer # pylint: disable=protected-access + actual_step = actual_train_state.step + actual_params = actual_train_state.params + actual_param_states = actual_train_state.param_states + self.assertEqual( + actual_optimizer.optimizer_def, + self.train_state_multi_optimizer._optimizer.optimizer_def) + self.assertEqual(actual_optimizer.state.step.dtype, np.int32) + jax.tree_map(lambda x: self.assertEqual(x.dtype, expected_restore_dtype), + actual_optimizer.target) + self.assertEqual(actual_step, step) + self.assertEqual(actual_step.dtype, np.int32) + jax.tree_multimap(np.testing.assert_array_equal, actual_param_states, + param_states) + self.assertSameElements(actual_params, ('bias', 'kernel')) + self.assertTrue( + all( + jax.tree_leaves( + jax.tree_map(lambda x: x.dtype == expected_restore_dtype, + actual_params)))) + np.testing.assert_equal(actual_params['bias'], params['bias'][bias_slice]) + np.testing.assert_equal(actual_params['kernel'], + params['kernel'][kernel_slice]) + if checkpoint_dataset: + # The next value from the restored iterator should equal the + # replica set id. + self.assertEqual(next(ds_iter).numpy(), ds_shard_id) + + def validate_save(self, + host_count, + num_partitions, + step=42, + save_dtype=np.float32, + checkpoint_dataset=False, + multi_optimizer=False, + disable_partitioning=False): + if multi_optimizer: + params = self.train_state_multi_optimizer.params + param_states = self.train_state_multi_optimizer.param_states + optimizer_def = self.train_state_multi_optimizer._optimizer.optimizer_def + else: + params = self.train_state.params + param_states = self.train_state.param_states + optimizer_def = self.train_state._optimizer.optimizer_def + # Update these on each save. + step = np.int32(step) + expected_bias = np.zeros((4, 1), save_dtype) + expected_kernel = np.zeros((2, 16), save_dtype) + + bias_tspec = { + 'driver': 'zarr', + 'kvstore': { + 'driver': 'file', + 'path': f'{self.tmp_dir}/checkpoint_{step}.tmp-0/target.bias', + } + } + kernel_tspec = { + 'driver': 'zarr', + 'kvstore': { + 'driver': 'file', + 'path': f'{self.tmp_dir}/checkpoint_{step}.tmp-0/target.kernel', + } + } + + # Test save. + # Each host sets its partition to its host number + 1. + # Go in reverse since host 0 renames the directory. + for i in reversed(range(host_count)): + partitioner = self.get_partitioner( + i, + host_count, + num_partitions, + mesh_axes=jax.tree_map(lambda x: None, self.default_mesh_axes) + if disable_partitioning else None) + data_layout = partitioner.get_data_layout() + num_ds_shards = data_layout.num_shards + ds_shard_id = data_layout.shard_id + chunk_id_for_shard = partitioner.get_local_chunk_info( + jnp.ones((num_ds_shards,)), ['data']).replica_id + + bias_chunk = partitioner.get_local_chunk_info(params['bias'].shape, + ('model', None)) + kernel_chunk = partitioner.get_local_chunk_info(params['kernel'].shape, + (None, 'model')) + + ds_iter = iter(self.ds) + + # pylint:disable=cell-var-from-loop + def _save_ckpt(checkpointer): + # Set the checkpoint so that the next value on restore will be the + # replica set id. + for _ in range(ds_shard_id): + next(ds_iter) + + train_state = make_train_state( + step=step, + params={ + 'bias': params['bias'][bias_chunk.slice], + 'kernel': params['kernel'][kernel_chunk.slice] + }, + param_states=param_states, + flax_optimizer_def=optimizer_def) + checkpointer.save(train_state) + + # pylint:enable=cell-var-from-loop + + self.call_host_checkpointer(i, host_count, partitioner, _save_ckpt, + save_dtype, + ds_iter if checkpoint_dataset else None) + + if disable_partitioning: + continue + + # Read the current TensorStore. + if i == 0: + # Host 0 moves the files. + bias_tspec['kvstore']['path'] = ( + bias_tspec['kvstore']['path'].replace('.tmp-0', '')) + kernel_tspec['kvstore']['path'] = ( + kernel_tspec['kvstore']['path'].replace('.tmp-0', '')) + + if checkpoint_dataset: + ckpt_dir = f'{self.tmp_dir}/checkpoint_{step}' + if i != 0: + ckpt_dir += '.tmp-0' + ds_ckpt_glob = gfile.glob(ckpt_dir + '/train_ds-' + + f'{ds_shard_id:03}-of-{num_ds_shards:03}*') + if chunk_id_for_shard == 0: + self.assertLen(ds_ckpt_glob, 2) + else: + self.assertEmpty(ds_ckpt_glob) + + # only replica_id=0 is saved for each array chunk + if bias_chunk.replica_id == 0: + current_bias = ts.open(bias_tspec).result().read().result().view( + save_dtype) + expected_bias[bias_chunk.slice] = (params['bias'][bias_chunk.slice]) + np.testing.assert_equal(current_bias, expected_bias) + + if kernel_chunk.replica_id == 0: + current_kernel = ts.open(kernel_tspec).result().read().result().view( + save_dtype) + expected_kernel[kernel_chunk.slice] = ( + params['kernel'][kernel_chunk.slice]) + np.testing.assert_equal(current_kernel, expected_kernel) + + with gfile.GFile(f'{self.tmp_dir}/checkpoint_{step}/checkpoint', 'rb') as f: + ckpt_contents = serialization.msgpack_restore(f.read()) + self.assertEqual(ckpt_contents['version'], checkpoints.VERSION) + jax.tree_multimap(np.testing.assert_allclose, + ckpt_contents['optimizer']['state']['param_states'], + param_states) + self.assertEqual(ckpt_contents['optimizer']['state']['step'].dtype, + np.int32) + if disable_partitioning: + # Parameters should also be in the msgpack checkpoint file. + jax.tree_multimap( + np.testing.assert_allclose, ckpt_contents['optimizer']['target'], + jax.tree_map(lambda arr: arr.astype(save_dtype), params)) + + # Jax tree maps ignore Nones so actually check this value is None + if multi_optimizer: + self.assertIsNone( + ckpt_contents['optimizer']['state']['param_states']['kernel']) + + # (host_count, num_partitions) + TOPOLOGIES = [ + (1, 1), # 1 host, 1 partition + (1, 2), # 1 host, 2 partitions + (2, 1), # 2 hosts, 1 partition + (2, 2), # 2 hosts, 2 partitions + (4, 4), # 4 hosts, 4 partitions + (4, 1), # 4 hosts, 1 partition + (4, 2), # 4 hosts, 2 partitions + (8, 2), # 8 hosts, 2 partitions + ] + + DTYPES = [ + jnp.int32, jnp.float32, jnp.bfloat16, jnp.uint32, jnp.int64, jnp.float64 + ] + + @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES)) + def test_save_restore(self, save_topology, restore_topology): + self.validate_save(*save_topology) + self.validate_restore(*restore_topology) + + @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES)) + def test_save_restore_lazy(self, save_topology, restore_topology): + self.validate_save(*save_topology) + self.validate_restore(*restore_topology, lazy_parameters=True) + + @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES)) + def test_save_multioptimizer_restore(self, save_topology, restore_topology): + self.validate_save(*save_topology) + self.validate_multioptimizer_restore(*restore_topology) + + @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES)) + def test_multioptimizer_save_multioptimizer_restore(self, save_topology, + restore_topology): + self.validate_save(*save_topology, multi_optimizer=True) + self.validate_multioptimizer_restore(*restore_topology) + + def test_load_t5x_checkpoint(self): + self.validate_save(1, 1) + ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir) + jax.tree_multimap(np.testing.assert_array_equal, + self.train_state.state_dict(), ckpt) + + def test_load_t5x_checkpoint_of_multioptimizer(self): + self.validate_save(1, 1, multi_optimizer=True) + ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir) + jax.tree_multimap(np.testing.assert_array_equal, + self.train_state_multi_optimizer.state_dict(), ckpt) + # Jax tree maps ignore Nones so actually check this value is None + self.assertIsNone(ckpt['state']['param_states']['kernel']) + + def test_load_t5x_checkpoint_lazy(self): + self.validate_save(1, 1) + ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir) + lazy_ckpt = checkpoints.load_t5x_checkpoint( + self.tmp_dir, lazy_parameters=True) + lazy_loaded_ckpt = jax.tree_map(lambda x: x.get(), lazy_ckpt) + jax.tree_multimap(np.testing.assert_array_equal, ckpt, lazy_loaded_ckpt) + + def test_load_t5x_checkpoint_of_multioptimizer_lazy(self): + self.validate_save(1, 1, multi_optimizer=True) + ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir) + lazy_ckpt = checkpoints.load_t5x_checkpoint( + self.tmp_dir, lazy_parameters=True) + lazy_loaded_ckpt = jax.tree_map(lambda x: x.get(), lazy_ckpt) + jax.tree_multimap(np.testing.assert_array_equal, ckpt, lazy_loaded_ckpt) + # Jax tree maps ignore Nones so actually check this value is None + self.assertIsNone(lazy_loaded_ckpt['state']['param_states']['kernel']) + + @parameterized.parameters(TOPOLOGIES) + def test_save_restore_dataset(self, *topology): + # Note that we must use the same number of replica sets on save/restore. + self.validate_save(*topology, checkpoint_dataset=True) + self.validate_restore(*topology, checkpoint_dataset=True) + + @parameterized.parameters(itertools.product(DTYPES, DTYPES)) + def test_save_as_type(self, save_dtype, restore_dtype): + self.validate_save(1, 1, save_dtype=save_dtype) + self.validate_restore(1, 1, expected_restore_dtype=restore_dtype) + + @parameterized.parameters(TOPOLOGIES) + def test_reload_wrong_shape(self, *restore_topology): + self.validate_save(1, 1) + self.train_state = make_train_state( + step=np.int32(42), + params={ + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + 'kernel': np.arange(32, dtype=np.float32).reshape((4, 8)) + }, + param_states={ + 'bias': np.int32(1), + 'kernel': np.array([1, 2]) + }) + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Shape of `target/kernel` in checkpoint (2, 16) does not match ' + 'expected (4, 8).'): + self.validate_restore(*restore_topology) + + @parameterized.parameters(TOPOLOGIES) + def test_save_partitioned_restore_non_partitioned(self, *restore_topology): + # Save with default partitioning. + self.validate_save(2, 2) + # Restore without partitioning. + self.validate_restore(*restore_topology, disable_partitioning=True) + + @parameterized.parameters(TOPOLOGIES) + def test_save_non_partitioned_restore_partitioned(self, *restore_topology): + # Save without partitioning. + self.validate_save(2, 1, disable_partitioning=True) + # Restore with partitioning. + self.validate_restore(*restore_topology) + + @parameterized.parameters(TOPOLOGIES) + def test_save_non_partitioned_restore_non_partitioned(self, + *restore_topology): + # Save without partitioning. + self.validate_save(2, 1, disable_partitioning=True) + # Restore with partitioning. + self.validate_restore(*restore_topology, disable_partitioning=True) + + @mock.patch('time.time', return_value=0) + def test_keep(self, unused_mock_time): + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + checkpointer = checkpoints.Checkpointer( + train_state, no_partitions_partitioner, self.tmp_dir, keep=2) + + checkpointer.save(update_train_state_step(train_state, 42)) + self.assertSequenceEqual(checkpointer.all_steps(), [42]) + + checkpointer.save(update_train_state_step(train_state, 43)) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43]) + + checkpointer.save(update_train_state_step(train_state, 44)) + self.assertSequenceEqual(checkpointer.all_steps(), [43, 44]) + + checkpointer.keep = 1 + checkpointer.save(update_train_state_step(train_state, 45)) + self.assertSequenceEqual(checkpointer.all_steps(), [45]) + + checkpointer.keep = 3 + checkpointer.save(update_train_state_step(train_state, 46)) + self.assertSequenceEqual(checkpointer.all_steps(), [45, 46]) + + @mock.patch('time.time', return_value=0) + def test_keep_pinned(self, unused_mock_time): + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + checkpointer = checkpoints.Checkpointer( + train_state, no_partitions_partitioner, self.tmp_dir, keep=1) + + checkpointer.save(update_train_state_step(train_state, 42)) + self.assertSequenceEqual(checkpointer.all_steps(), [42]) + + # Mark the checkpoint as pinned by creating the ALWAYS KEEP file. + ckpt_dir = self.checkpoints_dir.mkdir(f'checkpoint_{42}') + ckpt_dir.create_file('PINNED') + + checkpointer.save(update_train_state_step(train_state, 43)) + + # Assert both the pinned and the most recent checkpoints are saved. + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43]) + + checkpointer.save(update_train_state_step(train_state, 44)) + + # Assert the non-pinned checkpoint gets deleted, but the pinned and the most + # recent one are still saved. + self.assertSequenceEqual(checkpointer.all_steps(), [42, 44]) + + @mock.patch('time.time', return_value=0) + def test_keep_dataset_checkpoints(self, unused_mock_time): + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + dataset_iterator = iter(tf.data.Dataset.range(10)) + checkpointer = checkpoints.Checkpointer( + train_state, + no_partitions_partitioner, + self.tmp_dir, + dataset_iterator=dataset_iterator, + keep=2, + keep_dataset_checkpoints=1) + + checkpointer.save(update_train_state_step(train_state, 42)) + self.assertSequenceEqual(checkpointer.all_steps(), [42]) + self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [42]) + + checkpointer.save(update_train_state_step(train_state, 43)) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43]) + self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [43]) + + checkpointer.save(update_train_state_step(train_state, 44)) + self.assertSequenceEqual(checkpointer.all_steps(), [43, 44]) + self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [44]) + + checkpointer.keep = 1 + checkpointer.save(update_train_state_step(train_state, 45)) + self.assertSequenceEqual(checkpointer.all_steps(), [45]) + self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [45]) + + checkpointer.keep = 3 + checkpointer.save(update_train_state_step(train_state, 46)) + self.assertSequenceEqual(checkpointer.all_steps(), [45, 46]) + self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), [46]) + + @mock.patch('time.time', return_value=0) + def test_keep_dataset_checkpoints_pinned(self, unused_mock_time): + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + dataset_iterator = iter(tf.data.Dataset.range(10)) + checkpointer = checkpoints.Checkpointer( + train_state, + no_partitions_partitioner, + self.tmp_dir, + dataset_iterator=dataset_iterator, + keep=1, + keep_dataset_checkpoints=1) + + checkpointer.save(update_train_state_step(train_state, 42)) + self.assertSequenceEqual(checkpointer.all_steps(), [42]) + + # Mark the checkpoint as pinned by creating the ALWAYS KEEP file. + ckpt_dir = self.checkpoints_dir.mkdir(f'checkpoint_{42}') + ckpt_dir.create_file('PINNED') + + checkpointer.save(update_train_state_step(train_state, 43)) + + # Assert both the pinned and the most recent checkpoints are saved. + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43]) + self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), + [42, 43]) + + checkpointer.save(update_train_state_step(train_state, 44)) + + # Assert the non-pinned checkpoint gets deleted, but the pinned and the most + # recent one are still saved. + self.assertSequenceEqual(checkpointer.all_steps(), [42, 44]) + self.assertSequenceEqual(checkpointer.all_dataset_checkpoint_steps(), + [42, 44]) + + @mock.patch('time.time', return_value=0) + def test_keep_with_save_best_checkpointer(self, unused_mock_time): + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + + checkpointer = checkpoints.SaveBestCheckpointer( + train_state, + no_partitions_partitioner, + self.tmp_dir, + keep=2, + metric_name_to_monitor='train/accuracy', + metric_mode='max', + keep_checkpoints_without_metrics=False) + + # Test that without a valid set of metrics deletion falls back to oldest + # step (since keep_checkpoints_without_metrics is set to False). + checkpointer.save(update_train_state_step(train_state, 41)) + self.assertSequenceEqual(checkpointer.all_steps(), [41]) + checkpointer.save(update_train_state_step(train_state, 42)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42]) + checkpointer.save(update_train_state_step(train_state, 43)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43]) + checkpointer.save(update_train_state_step(train_state, 44)) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44]) + + # Now create some metrics for steps 42, 43 and 44. + summary_writer = tensorboard.SummaryWriter( + os.path.join(self.tmp_dir, 'train')) + summary_writer.scalar('accuracy', 0.9, 42) + summary_writer.scalar('accuracy', 0.8, 43) + summary_writer.scalar('accuracy', 0.7, 44) + + # Verify that both the newest (without a metrics) and best accuracy + # checkpoints are kept. + checkpointer.save(update_train_state_step(train_state, 45)) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 45]) + + # Change mode to `min` and check that the checkpoints with highest accuracy + # are removed. + checkpointer._metric_mode = 'min' + + # Add metrics to newly created checkpoint as well as a new checkpoint. + summary_writer.scalar('accuracy', 0.95, 45) + checkpointer.save(update_train_state_step(train_state, 46)) + summary_writer.scalar('accuracy', 0.99, 46) + checkpointer.save(update_train_state_step(train_state, 47)) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 47]) + + @mock.patch('time.time', return_value=0) + def test_keep_pinned_save_best_checkpointer(self, unused_mock_time): + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + + checkpointer = checkpoints.SaveBestCheckpointer( + train_state, + no_partitions_partitioner, + self.tmp_dir, + keep=2, + metric_name_to_monitor='train/accuracy', + metric_mode='max', + keep_checkpoints_without_metrics=False) + + summary_writer = tensorboard.SummaryWriter( + os.path.join(self.tmp_dir, 'train')) + + checkpointer.save(update_train_state_step(train_state, 42)) + summary_writer.scalar('accuracy', 0.9, 42) + checkpointer.save(update_train_state_step(train_state, 43)) + summary_writer.scalar('accuracy', 0.7, 43) + checkpointer.save(update_train_state_step(train_state, 44)) + summary_writer.scalar('accuracy', 0.8, 44) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44]) + + # Mark checkpoint 43 as always keep. + ckpt_dir = self.checkpoints_dir.mkdir(f'checkpoint_{43}') + always_keep_ckpt_43 = ckpt_dir.create_file('PINNED') + + # Verify that the pinned checkpoint 43 is always saved even though it does + # not have the best metrics, and keep = 2. + checkpointer.save(update_train_state_step(train_state, 45)) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44, 45]) + checkpointer.save(update_train_state_step(train_state, 46)) + summary_writer.scalar('accuracy', 0.6, 46) + + # Remove the ALWAYS KEEP file for checkpoint 43. + gfile.rmtree(always_keep_ckpt_43.full_path) + + # Checkpoint 43 should get deleted in the next update since it is not + # pinned and does not have the best metrics. + checkpointer.save(update_train_state_step(train_state, 47)) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 44, 47]) + + @mock.patch('time.time', return_value=0) + def test_keep_pinned_save_best_checkpointer_missing_metrics( + self, unused_mock_time): + """Test for `keep_checkpoints_without_metrics` behavior.""" + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + + # Use SaveBestCheckpointer with default keep_checkpoints_without_metrics. + checkpointer = checkpoints.SaveBestCheckpointer( + train_state, + no_partitions_partitioner, + self.tmp_dir, + keep=1, + metric_name_to_monitor='train/accuracy', + metric_mode='max') + + # Pre-create metrics for only some of the steps. + summary_writer = tensorboard.SummaryWriter( + os.path.join(self.tmp_dir, 'train')) + summary_writer.scalar('accuracy', 0.5, 43) + summary_writer.scalar('accuracy', 0.4, 44) + summary_writer.scalar('accuracy', 0.8, 45) + summary_writer.scalar('accuracy', 0.3, 46) + + # Verify that we keep checkpoints for 41 and 42 even without metrics. + checkpointer.save(update_train_state_step(train_state, 41)) + checkpointer.save(update_train_state_step(train_state, 42)) + checkpointer.save(update_train_state_step(train_state, 43)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43]) + + # Mark 41 and 43 checkpoints as pinned / to not be removed. + ckpt_dir_41 = self.checkpoints_dir.mkdir(f'checkpoint_{41}') + ckpt_dir_41.create_file('PINNED') + ckpt_dir_43 = self.checkpoints_dir.mkdir(f'checkpoint_{43}') + ckpt_dir_43.create_file('PINNED') + + # Checkpoints 41 and 43 should always be kept because they are pinned. + checkpointer.save(update_train_state_step(train_state, 44)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 44]) + # Checkpoint 44 should get deleted on next save. 43 is saved inspite of + # it's low accuracy because it is pinned. + checkpointer.save(update_train_state_step(train_state, 45)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 45]) + + @mock.patch('time.time', return_value=0) + def test_save_best_checkpointer_from_restart(self, unused_mock_time): + """Emulate restart/preempt condition.""" + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + + # First, create a checkpointer that saves all checkpoints. + checkpointer = checkpoints.Checkpointer( + train_state, no_partitions_partitioner, self.tmp_dir, keep=None) + + # Create a series of checkpoints. Create many checkpoints to stress test + # event collection (some methods employ lossy/sampling collection). + for i in range(100): + checkpointer.save(update_train_state_step(train_state, i)) + self.assertSequenceEqual(checkpointer.all_steps(), list(range(100))) + + # Now create some metrics for all steps, with high metrics on specific + # steps. + summary_writer = tensorboard.SummaryWriter( + os.path.join(self.tmp_dir, 'train')) + for i in range(100): + if i in (42, 53): + summary_writer.scalar('accuracy', i * 0.01, i) + else: + summary_writer.scalar('accuracy', i * 0.001, i) + + # Replace checkpointer with SaveBest variant. + checkpointer = checkpoints.SaveBestCheckpointer( + train_state, + no_partitions_partitioner, + self.tmp_dir, + keep=2, + metric_name_to_monitor='train/accuracy', + metric_mode='max') + + # Verify that pre-existing metrics are read and the appropriate checkpoints + # are deleted. + checkpointer.save(update_train_state_step(train_state, 101)) + self.assertSequenceEqual(checkpointer.all_steps(), [42, 53, 101]) + + def test_save_best_checkpointer_force_keep_period(self): + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + + checkpointer = checkpoints.SaveBestCheckpointer( + train_state, + no_partitions_partitioner, + self.tmp_dir, + keep=2, + metric_name_to_monitor='train/accuracy', + metric_mode='max', + keep_checkpoints_without_metrics=False, + force_keep_period=3) + + summary_writer = tensorboard.SummaryWriter( + os.path.join(self.tmp_dir, 'train')) + + # save checkpoints 0..9 with increasing accuracy + dict_actual_steps = {} + for c in range(10): + checkpointer.save(update_train_state_step(train_state, c)) + summary_writer.scalar('accuracy', c / 100, c) + dict_actual_steps[c] = checkpointer.all_steps() + + # Check when the last step=8 is not divisible by the keep_period=3 + actual_steps_8 = dict_actual_steps[8] + expected_steps_8 = [0, 3, 5, 6, 7, 8] + self.assertSequenceEqual(actual_steps_8, expected_steps_8) + + # Check when the last step=9 is divisible by the keep_period=3 + actual_steps_9 = dict_actual_steps[9] + expected_steps_9 = [0, 3, 6, 7, 8, 9] + self.assertSequenceEqual(actual_steps_9, expected_steps_9) + + @mock.patch('time.time', return_value=0) + def test_save_best_checkpointer_missing_metrics(self, unused_mock_time): + """Test for `keep_checkpoints_without_metrics` behavior.""" + no_partitions_partitioner = self.get_partitioner(0, 1, 1) + train_state = self.train_state + + # Replace checkpointer with SaveBest variant. + checkpointer = checkpoints.SaveBestCheckpointer( + train_state, + no_partitions_partitioner, + self.tmp_dir, + keep=1, + metric_name_to_monitor='train/accuracy', + metric_mode='max') + + # Pre-create metrics for only some of the steps. + summary_writer = tensorboard.SummaryWriter( + os.path.join(self.tmp_dir, 'train')) + summary_writer.scalar('accuracy', 0.6, 43) + summary_writer.scalar('accuracy', 0.5, 44) + summary_writer.scalar('accuracy', 0.4, 45) + + # Verify that we always keep checkpoints for 41 and 42 (no metrics) and that + # number to keep applies to other checkpoints. + checkpointer.save(update_train_state_step(train_state, 41)) + self.assertSequenceEqual(checkpointer.all_steps(), [41]) + checkpointer.save(update_train_state_step(train_state, 42)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42]) + checkpointer.save(update_train_state_step(train_state, 43)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43]) + checkpointer.save(update_train_state_step(train_state, 44)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 44]) + # Checkpoint 44 should get deleted on next save. + checkpointer.save(update_train_state_step(train_state, 45)) + self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 45]) + + # When switching keep_checkpoints_without_metrics to False, we should see + # checkpoints 41 and 42 also be deleted. + checkpointer._keep_checkpoints_without_metrics = False + checkpointer.save(update_train_state_step(train_state, 46)) + self.assertSequenceEqual(checkpointer.all_steps(), [43, 46]) + + def test_assignment_map(self): + self.validate_save(1, 1) + # Change optimizer + optimizer = optimizers.Optimizer( + optimizers.sgd(0.1), + state=optimizers.OptimizerState( + step=np.int32(42), + param_states={ + 'bias': np.int32(1), + 'kernel': np.array([1, 2], np.uint8) + }), + target={ + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + 'layer1': { + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + 'kernel': np.arange(32, dtype=np.float32).reshape((2, 16)) + }, + 'layer2': { + 'bias': np.arange(32, dtype=np.float32).reshape((2, 16)), + 'kernel': np.arange(32, dtype=np.float32).reshape((2, 16)) + } + }) + self.train_state = FlaxOptimTrainState(optimizer) + + actual_train_state = self.call_host_checkpointer( + 0, + 1, + self.get_partitioner( + 0, 1, 1, mesh_axes=jax.tree_map(lambda x: None, self.train_state)), + lambda c: c.restore( # pylint:disable=g-long-lambda + step=42, + state_transformation_fns=[ + functools.partial( + state_utils.apply_assignment_map, + assignment_map=[('target/layer2/bias', 'target/kernel'), + ('target/layer\\d/(.*)', 'target/\\1')]) + ]), + np.float32, + None) + self.assertEqual(actual_train_state.step, 42) + self.assertEqual(actual_train_state._optimizer.optimizer_def, + self.train_state._optimizer.optimizer_def) + jax.tree_multimap(np.testing.assert_array_equal, + actual_train_state.param_states, + self.train_state.param_states) + jax.tree_multimap(np.testing.assert_array_equal, actual_train_state.params, + self.train_state.params) + + def test_assignment_map_unused(self): + self.validate_save(1, 1) + with self.assertRaisesWithLiteralMatch( + ValueError, + "Unused patterns in `assignment_map`: {'target/layer\\d/(.*)'}"): + self.call_host_checkpointer( + 0, + 1, + self.get_partitioner(0, 1, 1), + lambda c: c.restore( # pylint:disable=g-long-lambda + step=42, + state_transformation_fns=[ + functools.partial( + state_utils.apply_assignment_map, + assignment_map=[('target/layer\\d/(.*)', 'target/\\1')]) + ]), + np.float32, + None) + + def test_assignment_map_noexists(self): + self.validate_save(1, 1) + with self.assertRaisesWithLiteralMatch( + ValueError, + "Parameter 'target/layer/bias' does not exist in restore checkpoint. " + "Must be one of: ['state/param_states/bias', " + "'state/param_states/kernel', 'state/step', 'target/bias', " + "'target/kernel']"): + self.call_host_checkpointer( + 0, + 1, + self.get_partitioner(0, 1, 1), + lambda c: c.restore( # pylint:disable=g-long-lambda + step=42, + state_transformation_fns=[ + functools.partial( + state_utils.apply_assignment_map, + assignment_map=[('target/(.*)', 'target/layer/\\1')]) + ]), + np.float32, + None) + + def test_assignment_map_partial_restore(self): + self.validate_save(1, 1) + # Change optimizer + optimizer = optimizers.Optimizer( + optimizers.sgd(0.1), + state=optimizers.OptimizerState( + step=np.int32(42), + param_states={ + 'bias': np.int32(1), + 'kernel': np.array([1, 2], np.uint8) + }), + target={ + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + 'layer1': { + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + 'kernel': np.arange(32, dtype=np.float32).reshape((2, 16)) + }, + 'layer2': { + 'bias': np.arange(32, dtype=np.float32).reshape((2, 16)), + 'kernel': np.arange(32, dtype=np.float32).reshape((2, 16)) + } + }) + self.train_state = FlaxOptimTrainState(optimizer) + + actual_train_state = self.call_host_checkpointer( + 0, + 1, + self.get_partitioner( + 0, 1, 1, mesh_axes=jax.tree_map(lambda x: None, self.train_state)), + lambda c: c.restore( # pylint:disable=g-long-lambda + step=42, + state_transformation_fns=[ + functools.partial( + state_utils.apply_assignment_map, + assignment_map=[ + # Restore only the target kernels. + (r'target/layer(\d+)/kernel', r'target/kernel'), + (r'target.*bias', None), + (r'state.*', None)]) + ], + fallback_state={ + # Initialize biases and optimizer state "from scratch" + 'target': { + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + 'layer1': { + 'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)), + }, + 'layer2': { + 'bias': np.arange(32, dtype=np.float32).reshape((2, 16)), + } + }, + 'state': { + 'step': 1337, # Note: original optimizer is step=42 + 'param_states': { + 'bias': 1, + 'kernel': np.array([1, 2], np.uint8) + } + } + }), + np.float32, + None) + self.assertEqual(actual_train_state._optimizer.optimizer_def, + self.train_state._optimizer.optimizer_def) + self.assertEqual(actual_train_state.step, 1337) # note: from-scratch + jax.tree_multimap(np.testing.assert_array_equal, + actual_train_state.param_states, + self.train_state.param_states) + jax.tree_multimap(np.testing.assert_array_equal, actual_train_state.params, + self.train_state.params) + + def verify_restore_checkpoint_from_path( + self, + path, + model, + decoder_only=False, + partitioner_class=partitioning.PjitPartitioner): + partitioner = partitioner_class(num_partitions=1) + + input_features = {'decoder_input_tokens': tf.zeros([2, 8])} + if not decoder_only: + input_features['encoder_input_tokens'] = tf.zeros([2, 8]) + train_ds = tf.data.Dataset.from_tensors(input_features) + + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=model.optimizer_def, + init_fn=model.get_initial_variables, + input_shapes={k: v.shape for k, v in train_ds.element_spec.items()}, + partitioner=partitioner) + + restored = list( + train_state_initializer.from_checkpoints( + [utils.RestoreCheckpointConfig(mode='specific', path=path)])) + self.assertLen(restored, 1) + return restored[0] + + def test_checkpointer_in_threaded_env(self): + """Tests use of asyncio in checkpointer works with non-main threads.""" + executor = concurrent.futures.thread.ThreadPoolExecutor(max_workers=1) + save = executor.submit(self.validate_save, 1, 1) + save.result() + restore = executor.submit(self.validate_restore, 1, 1) + restore.result() + + def test_find_checkpoint(self): + # `model_dir` with no step + self.assertEqual( + checkpoints.find_checkpoint(self.fake_checkpoints), + os.path.join(self.fake_checkpoints, f'checkpoint_{self.steps[-1]}', + 'checkpoint')) + # `model_dir` with step + step = 100 + self.assertEqual( + checkpoints.find_checkpoint(self.fake_checkpoints, step), + os.path.join(self.fake_checkpoints, f'checkpoint_{step}', 'checkpoint')) + # checkpoint_dir + self.assertEqual( + checkpoints.find_checkpoint( + os.path.join(self.fake_checkpoints, f'checkpoint_{step}')), + os.path.join(self.fake_checkpoints, f'checkpoint_{step}', 'checkpoint')) + # checkpoint_dir with step + with self.assertRaises(ValueError): + _ = checkpoints.find_checkpoint( + os.path.join(self.fake_checkpoints, f'checkpoint_{step}'), 1000), + # checkpoint_file + path = os.path.join(self.fake_checkpoints, f'checkpoint_{step}', + 'checkpoint') + self.assertEqual(checkpoints.find_checkpoint(path), path) + # checkpoint_file with step + self.assertEqual(checkpoints.find_checkpoint(path, 1000), path) + # Error with step + with self.assertRaises(ValueError): + checkpoints.find_checkpoint(self.fake_checkpoints, 1000) + # Error + with self.assertRaises(ValueError): + checkpoints.find_checkpoint( + os.path.join(self.fake_checkpoints, 'checkpoint')) + + def test_restore_tf_as_t5x(self): + checkpoint_path = os.path.join(TESTDATA, 'mtf_tiny_t5') + partitioner = self.get_partitioner(0, 1, 1) + with self.assertRaisesRegex( + ValueError, + 'Attempting to restore a TensorFlow checkpoint as a native T5X ' + 'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: .*'): + self.call_host_checkpointer(0, 1, partitioner, + lambda c: c.restore(path=checkpoint_path), + np.float32, None) + + def test_restore_from_invalid_path(self): + with self.assertRaisesRegex(ValueError, + r'Path is not a valid T5X checkpoint: .*'): + self.verify_restore_checkpoint_from_path(TESTDATA, + test_utils.get_t5_test_model()) + + with self.assertRaisesRegex(ValueError, + r'Path is not a valid T5X checkpoint: .*'): + self.verify_restore_checkpoint_from_path( + os.path.join(TESTDATA, 'checkpoint'), test_utils.get_t5_test_model()) + + def test_save_lazy_optimizer(self): + # Call save one to get the parameters onto disk + self.validate_save(1, 1) + # Load the parameters in a lazy way + partitioner = self.get_partitioner(0, 1, 1, params_on_devices=False) + step = 42 + train_state = self.call_host_checkpointer( + 0, + 1, + partitioner, + lambda c: c.restore( # pylint: disable=g-long-lambda + step=step, lazy_parameters=True), + np.float32, + None) + # Increment the step so we can save it + new_step = train_state.step.get() + 1 + state_dict = train_state.state_dict() + state_dict['state']['step'] = new_step + train_state = train_state.restore_state(state_dict) + + # Save the train state that is made of lazy parameters. + self.call_host_checkpointer( + 0, 1, partitioner, + lambda c: c.save(train_state=train_state, concurrent_gb=2), np.float32, + None) + + # Load what we just saved to inspect values + loaded_train_state = checkpoints.load_t5x_checkpoint( + self.tmp_dir, step=new_step) + # Make sure the parameters are the same. + train_state = jax.tree_map( + lambda x: x.get() # pylint: disable=g-long-lambda + if isinstance(x, LazyArray) else x, + train_state) + jax.tree_multimap(np.testing.assert_allclose, train_state.state_dict(), + loaded_train_state) + + def test_update_ts_from_gfile_to_gcs(self): + ckpt_contents = { + 'version': 3, + 'optimizer': { + 'target': { + 'unsharded_param': np.ones((5, 5), dtype=np.int32), + 'sharded_param': { + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { + 'driver': 'file', + 'path': 'target.sharded_param' + }, + 'metadata': { + 'chunks': [768, 768], + 'compressor': { + 'id': 'gzip', + 'level': 1 + }, + 'shape': [768, 768] + } + } + } + } + } + + expected = { + 'version': 3, + 'optimizer': { + 'target': { + # np.ndarray should not change + 'unsharded_param': np.ones((5, 5), dtype=np.int32), + 'sharded_param': { + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { + 'bucket': 't5x-dummy-bucket', + 'driver': 'gcs', + 'path': 'target.sharded_param' + }, + 'metadata': { + 'chunks': [768, 768], + 'compressor': { + 'id': 'gzip', + 'level': 1 + }, + 'shape': [768, 768] + } + } + } + } + } + actual = checkpoints._maybe_update_ts_from_file_to_gcs(ckpt_contents) + jax.tree_multimap(np.testing.assert_array_equal, actual, expected) + + def test_update_ts_from_gcs_to_file(self): + ckpt_contents = { + 'version': 3, + 'optimizer': { + 'target': { + # np.ndarray should not change + 'unsharded_param': np.ones((5, 5), dtype=np.int32), + 'sharded_param': { + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { + 'bucket': 't5x-dummy-bucket', + 'driver': 'gcs', + 'path': 'target.sharded_param' + }, + 'metadata': { + 'chunks': [768, 768], + 'compressor': { + 'id': 'gzip', + 'level': 1 + }, + 'shape': [768, 768] + }, + } + } + } + } + + driver = 'file' + expected = { + 'version': 3, + 'optimizer': { + 'target': { + 'unsharded_param': np.ones((5, 5), dtype=np.int32), + 'sharded_param': { + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { + 'driver': driver, + 'path': 'target.sharded_param' + }, + 'metadata': { + 'chunks': [768, 768], + 'compressor': { + 'id': 'gzip', + 'level': 1 + }, + 'shape': [768, 768] + } + } + } + } + } + + actual = checkpoints._maybe_update_ts_from_gcs_to_file(ckpt_contents) + jax.tree_multimap(np.testing.assert_array_equal, actual, expected) + + def assert_update_ts_path_from_relative_to_absolute(self, ts_spec_dict, + expected, ckpt_dir): + """Tests that `ts_spec_dict` gets updated with `ckpt_dir` to `expected`.""" + + # Test with normalization (corresponds to tensorstore>=0.1.14) + normalized_ts_spec_dict = ts.Spec(ts_spec_dict).to_json() + checkpoints._update_ts_path_from_relative_to_absolute( + ckpt_dir, normalized_ts_spec_dict) + normalized_ts_spec_dict = ts.Spec(normalized_ts_spec_dict).to_json() + normalized_expected = ts.Spec(expected).to_json() + jax.tree_multimap(np.testing.assert_array_equal, normalized_ts_spec_dict, + normalized_expected) + + # Test without normalization (corresponds to tensorstore<0.1.14) + checkpoints._update_ts_path_from_relative_to_absolute( + ckpt_dir, ts_spec_dict) + jax.tree_multimap(np.testing.assert_array_equal, ts_spec_dict, expected) + + def test_update_ts_path_from_relative_to_absolute_gfile(self): + ts_spec_dict = { + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { + 'driver': 'file', + 'path': 'target.encoder.layers_0.attention.query.kernel' + }, + 'metadata': { + 'chunks': [768, 768], + 'compressor': { + 'id': 'gzip', + 'level': 1 + }, + 'shape': [768, 768] + } + } + + expected = { + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { + 'driver': 'file', + # Path becomes absolute. + 'path': '/dir1/dir2/target.encoder.layers_0.attention.query.kernel' + }, + 'metadata': { + 'chunks': [768, 768], + 'compressor': { + 'id': 'gzip', + 'level': 1 + }, + 'shape': [768, 768] + } + } + ckpt_dir = '/dir1/dir2' + + self.assert_update_ts_path_from_relative_to_absolute( + ts_spec_dict, expected, ckpt_dir) + + def test_update_ts_path_from_relative_to_absolute_gcs(self): + ts_spec_dict = { + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { + 'bucket': 't5x-dummy-bucket', + 'driver': 'gcs' + }, + 'metadata': { + 'chunks': [768, 768], + 'compressor': { + 'id': 'gzip', + 'level': 1 + }, + 'shape': [768, 768] + }, + 'path': 'target.encoder.layers_0.attention.query.kernel', + 'transform': { + 'input_exclusive_max': [[768], [768]], + 'input_inclusive_min': [0, 0] + } + } + + expected = { + 'driver': 'zarr', + 'dtype': 'float32', + 'kvstore': { + 'bucket': 'test-bucket', # bucket should be changed. + 'driver': 'gcs' + }, + 'metadata': { + 'chunks': [768, 768], + 'compressor': { + 'id': 'gzip', + 'level': 1 + }, + 'shape': [768, 768] + }, + # Path becomes absolute without the "gs://bucket" portion stripped. + 'path': 'dir1/dir2/target.encoder.layers_0.attention.query.kernel', + 'transform': { + 'input_exclusive_max': [[768], [768]], + 'input_inclusive_min': [0, 0] + } + } + + ckpt_dir = 'gs://test-bucket/dir1/dir2' + + self.assert_update_ts_path_from_relative_to_absolute( + ts_spec_dict, expected, ckpt_dir) + + def test_restore_tf_checkpoint(self): + self.verify_restore_checkpoint_from_path( + os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0'), + test_utils.get_t5_test_model( + emb_dim=32, head_dim=64, num_heads=2, mlp_dim=64)) + + def test_restore_tf_checkpoint_wrong_config(self): + with self.assertRaisesRegex(ValueError, r'Variable .* has shape .* != .*'): + self.verify_restore_checkpoint_from_path( + os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0'), + test_utils.get_t5_test_model()) + + def test_convert_tf_checkpoint(self): + checkpoint_path = os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0') + + # Minimal setup to create an optimizer with the matching config. + model = test_utils.get_t5_test_model( + emb_dim=32, head_dim=64, num_heads=2, mlp_dim=64) + + partitioner = partitioning.PjitPartitioner(num_partitions=1) + + def initialize_params_fn(rng): + initial_variables = model.get_initial_variables( + rng=rng, + input_shapes={ + 'encoder_input_tokens': (2, 512), + 'decoder_input_tokens': (2, 114), + }) + return FlaxOptimTrainState.create(model.optimizer_def, initial_variables) + + train_state = jax.eval_shape(initialize_params_fn, jax.random.PRNGKey(0)) + checkpointer = checkpoints.Checkpointer(train_state, partitioner, + self.tmp_dir) + _ = checkpointer.convert_from_tf_checkpoint(checkpoint_path) + + def test_load_matched(self): + checkpoint = os.path.join(TESTDATA, 'test_t5_tiny.checkpoint_0') + train_state = self.verify_restore_checkpoint_from_path( + checkpoint, test_utils.get_t5_test_model()) + state_dict = train_state._optimizer.state_dict() + ckpt = checkpoints.load_t5x_checkpoint(checkpoint) + jax.tree_multimap(np.testing.assert_array_equal, state_dict, ckpt) + + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/configs/__init__.py b/t5x/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85dd7a38f30639b377a504c2c0295e2b8955cea9 --- /dev/null +++ b/t5x/configs/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This empty file is needed for loading the gin files in this directory.""" diff --git a/t5x/configs/runs/__init__.py b/t5x/configs/runs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/configs/runs/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/configs/runs/eval.gin b/t5x/configs/runs/eval.gin new file mode 100644 index 0000000000000000000000000000000000000000..278b92e7ca51d4a12785b4befb11d85aea400e2c --- /dev/null +++ b/t5x/configs/runs/eval.gin @@ -0,0 +1,68 @@ +# Defaults for eval.py. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to evaluate on +# - CHECKPOINT_PATH: The model checkpoint to evaluate +# - EVAL_OUTPUT_DIR: The dir to write results to. +# +# +# Commonly overridden options: +# +# - DatasetConfig.split +# - DatasetConfig.batch_size +# - DatasetConfig.use_cached +# - RestoreCheckpointConfig.mode +# - PjitPartitioner.num_partitions +from __gin__ import dynamic_registration + +import __main__ as eval_script +import seqio +from t5x import partitioning +from t5x import utils + + +# Must be overridden +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +CHECKPOINT_PATH = %gin.REQUIRED +EVAL_OUTPUT_DIR = %gin.REQUIRED +TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. + +# DEPRECATED: Import the this module in your gin file. +MIXTURE_OR_TASK_MODULE = None + +eval_script.evaluate: + model = %MODEL # imported from separate gin file + dataset_cfg = @utils.DatasetConfig() + partitioner = @partitioning.PjitPartitioner() + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + output_dir = %EVAL_OUTPUT_DIR + inference_evaluator_cls = @seqio.Evaluator + +partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +seqio.Evaluator: + logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] + num_examples = None # Use all examples in the dataset. + use_memory_cache = True + +utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'test' + batch_size = 32 + shuffle = False + seed = 42 + use_cached = False + pack = False + use_custom_packing_ops = False + module = %MIXTURE_OR_TASK_MODULE + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' diff --git a/t5x/configs/runs/finetune.gin b/t5x/configs/runs/finetune.gin new file mode 100644 index 0000000000000000000000000000000000000000..f482a0cf8ff9ebad9d464bfef3bd6bfc3d886110 --- /dev/null +++ b/t5x/configs/runs/finetune.gin @@ -0,0 +1,149 @@ +# Defaults for finetuning with train.py. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - MIXTURE_OR_TASK_NAME +# - TASK_FEATURE_LENGTHS +# - TRAIN_STEPS # includes pretrain steps +# - MODEL_DIR # automatically set when using xm_launch +# - INITIAL_CHECKPOINT_PATH +# +# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag. +# +# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt +# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps. +# +# Commonly overridden options: +# - DROPOUT_RATE +# - BATCH_SIZE +# - PjitPartitioner.num_partitions +# - Trainer.num_microbatches +# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess +# on the fly. Most common tasks are cached, hence this is set to True by +# default. + +from __gin__ import dynamic_registration + +import __main__ as train_script +import seqio +from t5x import gin_utils +from t5x import partitioning +from t5x import utils +from t5x import trainer + +# Must be overridden +MODEL_DIR = %gin.REQUIRED +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +MIXTURE_OR_TASK_MODULE = %gin.REQUIRED +TRAIN_STEPS = %gin.REQUIRED +INITIAL_CHECKPOINT_PATH = %gin.REQUIRED + +# Commonly overridden +DROPOUT_RATE = 0.1 +USE_CACHED_TASKS = True +BATCH_SIZE = 128 + +# Sometimes overridden +EVAL_STEPS = 20 + +# Convenience overrides. +EVALUATOR_USE_MEMORY_CACHE = True +EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset. +JSON_WRITE_N_RESULTS = None # Write all inferences. +# HW RNG is faster than SW, but has limited determinism. +# Most notably it is not deterministic across different +# submeshes. +USE_HARDWARE_RNG = False +# None always uses faster, hardware RNG +RANDOM_SEED = None + +# DEPRECATED: Import the this module in your gin file. +MIXTURE_OR_TASK_MODULE = None + +train_script.train: + model = %MODEL # imported from separate gin file + model_dir = %MODEL_DIR + train_dataset_cfg = @train/utils.DatasetConfig() + train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() + infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() + checkpoint_cfg = @utils.CheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + trainer_cls = @trainer.Trainer + total_steps = %TRAIN_STEPS + eval_steps = %EVAL_STEPS + eval_period = 1000 + random_seed = %RANDOM_SEED + use_hardware_rng = %USE_HARDWARE_RNG + summarize_config_fn = @gin_utils.summarize_gin_config + inference_evaluator_cls = @seqio.Evaluator + +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +seqio.Evaluator: + logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] + num_examples = %EVALUATOR_NUM_EXAMPLES + use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE + +seqio.JSONLogger: + write_n_results = %JSON_WRITE_N_RESULTS + +train/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'train' + batch_size = %BATCH_SIZE + shuffle = True + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS + pack = True + module = %MIXTURE_OR_TASK_MODULE + +train_eval/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'validation' + batch_size = %BATCH_SIZE + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS + pack = True + module = %MIXTURE_OR_TASK_MODULE + +infer_eval/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = None # compute max + split = 'validation' + batch_size = %BATCH_SIZE + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS + pack = False + module = %MIXTURE_OR_TASK_MODULE + +utils.CheckpointConfig: + restore = @utils.RestoreCheckpointConfig() + save = @utils.SaveCheckpointConfig() +utils.RestoreCheckpointConfig: + path = %INITIAL_CHECKPOINT_PATH + mode = 'specific' + dtype = 'float32' +utils.SaveCheckpointConfig: + period = 5000 + dtype = 'float32' + keep = None # keep all checkpoints + save_dataset = False # don't checkpoint dataset state + +trainer.Trainer: + num_microbatches = None + learning_rate_fn = @utils.create_learning_rate_scheduler() +utils.create_learning_rate_scheduler: + factors = 'constant' + base_learning_rate = 0.001 + warmup_steps = 1000 diff --git a/t5x/configs/runs/infer.gin b/t5x/configs/runs/infer.gin new file mode 100644 index 0000000000000000000000000000000000000000..0918d2f4843d698cf27787c62f7e09cf81c1e835 --- /dev/null +++ b/t5x/configs/runs/infer.gin @@ -0,0 +1,71 @@ +# Defaults for infer.py. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference +# - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features +# to. +# - CHECKPOINT_PATH: The model checkpoint to use for inference +# - INFER_OUTPUT_DIR: The dir to write results to. +# +# +# Commonly overridden options: +# +# - infer.mode +# - infer.checkpoint_period +# - infer.shard_id +# - infer.num_shards +# - DatasetConfig.split +# - DatasetConfig.batch_size +# - DatasetConfig.use_cached +# - RestoreCheckpointConfig.is_tensorflow +# - RestoreCheckpointConfig.mode +# - PjitPartitioner.num_partitions +from __gin__ import dynamic_registration + +import __main__ as infer_script +from t5x import partitioning +from t5x import utils + +# Must be overridden +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +CHECKPOINT_PATH = %gin.REQUIRED +INFER_OUTPUT_DIR = %gin.REQUIRED + +# DEPRECATED: Import the this module in your gin file. +MIXTURE_OR_TASK_MODULE = None + +infer_script.infer: + mode = 'predict' + model = %MODEL # imported from separate gin file + output_dir = %INFER_OUTPUT_DIR + dataset_cfg = @utils.DatasetConfig() + partitioner = @partitioning.PjitPartitioner() + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + checkpoint_period = 100 + shard_id = 0 + num_shards = 1 + +partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + module = %MIXTURE_OR_TASK_MODULE + task_feature_lengths = %TASK_FEATURE_LENGTHS + use_cached = False + split = 'test' + batch_size = 32 + shuffle = False + seed = 0 + pack = False + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' + dtype = 'bfloat16' diff --git a/t5x/configs/runs/infer_from_tfexample_file.gin b/t5x/configs/runs/infer_from_tfexample_file.gin new file mode 100644 index 0000000000000000000000000000000000000000..5d62b27555ecfef3cd801098fe640ac09eff744c --- /dev/null +++ b/t5x/configs/runs/infer_from_tfexample_file.gin @@ -0,0 +1,90 @@ +# Defaults for infer.py if using a TFExample file as input. +# +# +# The features from each TFExample are tokenized using the model's vocabulary. +# By default, the inputs feature is assumed to be keyed as 'inputs', but this +# can be overridden with `create_task_from_tfexample_file.inputs_key`. +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - TF_EXAMPLE_FILE_PATHS: The path to read TF Examples from. +# - TF_EXAMPLE_FILE_TYPE: The type of file to read TF Examples from. Currently +# supported: 'tfrecord', 'recordio', 'sstable'. +# - FEATURE_LENGTHS: The maximum length per feature in the TF Examples. +# - CHECKPOINT_PATH: The model checkpoint to use for inference +# - INFER_OUTPUT_DIR: The dir to write results to. +# +# +# Commonly overridden options: +# +# - infer.mode +# - infer.checkpoint_period +# - infer.shard_id +# - infer.num_shards +# - create_task_from_tfexample_file.inputs_key +# - create_task_from_tfexample_file.targets_key +# - DatasetConfig.split +# - DatasetConfig.batch_size +# - RestoreCheckpointConfig.mode +# - PjitPartitioner.num_partitions +from __gin__ import dynamic_registration + +import __main__ as infer_script +import seqio +from t5x import models +from t5x import partitioning +from t5x import utils + +# Must be overridden +TF_EXAMPLE_FILE_PATHS = %gin.REQUIRED +TF_EXAMPLE_FILE_TYPE = %gin.REQUIRED +FEATURE_LENGTHS = %gin.REQUIRED +CHECKPOINT_PATH = %gin.REQUIRED +INFER_OUTPUT_DIR = %gin.REQUIRED + +infer_script.infer: + mode = 'predict' + model = %MODEL # imported from separate gin file + output_dir = %INFER_OUTPUT_DIR + dataset_cfg = @utils.DatasetConfig() + partitioner = @partitioning.PjitPartitioner() + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + checkpoint_period = 100 + shard_id = 0 + num_shards = 1 + +partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +utils.DatasetConfig: + mixture_or_task_name = @infer_script.create_task_from_tfexample_file() + task_feature_lengths = %FEATURE_LENGTHS + split = 'infer' + batch_size = 32 + shuffle = False + seed = 0 + pack = False + +infer_script.create_task_from_tfexample_file: + paths = %TF_EXAMPLE_FILE_PATHS + file_type = %TF_EXAMPLE_FILE_TYPE + inputs_key = 'inputs' + targets_key = None + features = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} + +# Plumbing to extract the vocabulary directly from MODEL. This is needed to +# tokenize the features from the TFExample we aren't provided with vocabularies +# via a Task. +inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() +models.get_input_vocabulary.model = %MODEL +outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() +models.get_output_vocabulary.model = %MODEL + +utils.RestoreCheckpointConfig: + mode = 'specific' + path = %CHECKPOINT_PATH + dtype = 'bfloat16' + diff --git a/t5x/configs/runs/precompile.gin b/t5x/configs/runs/precompile.gin new file mode 100644 index 0000000000000000000000000000000000000000..b53ac17d6e3b60da003ec2f5bb30625bd395091a --- /dev/null +++ b/t5x/configs/runs/precompile.gin @@ -0,0 +1,58 @@ +# Defaults for precompile mode in main.py. +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - MIXTURE_OR_TASK_NAME +# - TASK_FEATURE_LENGTHS +# - TRAIN_STEPS +# - MODEL_DIR: # automatically set when using xm_launch +# +# Commonly overridden options: +# +# - USE_CACHED_TASKS +# - BATCH_SIZE +# - PjitPartitioner.num_partitions +from __gin__ import dynamic_registration + +import __main__ as train_script +import seqio +from t5x import gin_utils +from t5x import partitioning +from t5x import utils +from t5x import trainer + +MODEL_DIR = %gin.REQUIRED +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED + +# Commonly overridden +USE_CACHED_TASKS = True +BATCH_SIZE = 128 + +# None always uses faster, hardware RNG +RANDOM_SEED = None + +train_script.precompile: + model = %MODEL # imported from separate gin file + model_dir = %MODEL_DIR + train_dataset_cfg = @train/utils.DatasetConfig() + partitioner = @partitioning.PjitPartitioner() + random_seed = %RANDOM_SEED + +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + backend = "tpu" + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +train/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'train' + batch_size = %BATCH_SIZE + shuffle = True + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS + pack = True diff --git a/t5x/configs/runs/pretrain.gin b/t5x/configs/runs/pretrain.gin new file mode 100644 index 0000000000000000000000000000000000000000..de1286467d277237dd06102c2b07cbdd6859d4df --- /dev/null +++ b/t5x/configs/runs/pretrain.gin @@ -0,0 +1,108 @@ +# Defaults for pretraining with train.py. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - MIXTURE_OR_TASK_NAME +# - TASK_FEATURE_LENGTHS +# - TRAIN_STEPS +# - MODEL_DIR: # automatically set when using xm_launch +# +# Commonly overridden options: +# +# - train/DatasetConfig.batch_size +# - train_eval/DatasetConfig.batch_size +# - PjitPartitioner.num_partitions +# - Trainer.num_microbatches +# - DROPOUT_RATE +from __gin__ import dynamic_registration + +import __main__ as train_script +from t5x import gin_utils +from t5x import partitioning +from t5x import utils +from t5x import trainer + +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +TASK_FEATURE_LENGTHS = %gin.REQUIRED +TRAIN_STEPS = %gin.REQUIRED +MODEL_DIR = %gin.REQUIRED +BATCH_SIZE = 128 +USE_CACHED_TASKS = True + +# DEPRECATED: Import the this module in your gin file. +MIXTURE_OR_TASK_MODULE = None +SHUFFLE_TRAIN_EXAMPLES = True + +# HW RNG is faster than SW, but has limited determinism. +# Most notably it is not deterministic across different +# submeshes. +USE_HARDWARE_RNG = False +# None always uses faster, hardware RNG +RANDOM_SEED = None + +# Can be overridden with `train.*`.` +train_script.train: + model = %MODEL # imported from separate gin file + model_dir = %MODEL_DIR + train_dataset_cfg = @train/utils.DatasetConfig() + train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() + infer_eval_dataset_cfg = None + checkpoint_cfg = @utils.CheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + trainer_cls = @trainer.Trainer + total_steps = %TRAIN_STEPS + eval_steps = 20 + eval_period = 1000 + random_seed = %RANDOM_SEED + use_hardware_rng = %USE_HARDWARE_RNG + summarize_config_fn = @gin_utils.summarize_gin_config + +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +train/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'train' + batch_size = %BATCH_SIZE + shuffle = %SHUFFLE_TRAIN_EXAMPLES + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS + pack = True + module = %MIXTURE_OR_TASK_MODULE + +train_eval/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'validation' + batch_size = %BATCH_SIZE + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS + pack = True + module = %MIXTURE_OR_TASK_MODULE + +utils.CheckpointConfig: + restore = @utils.RestoreCheckpointConfig() + save = @utils.SaveCheckpointConfig() +utils.RestoreCheckpointConfig: + path = [] # initialize from scratch +utils.SaveCheckpointConfig: + period = 1000 + dtype = 'float32' + keep = None # keep all checkpoints + save_dataset = False # don't checkpoint dataset state + +trainer.Trainer: + num_microbatches = None + learning_rate_fn = @utils.create_learning_rate_scheduler() + +utils.create_learning_rate_scheduler: + factors = 'constant * rsqrt_decay' + base_learning_rate = 1.0 + warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. diff --git a/t5x/contrib/__init__.py b/t5x/contrib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0eda1ed07ac0093ac4430d87343dd3410d3da456 --- /dev/null +++ b/t5x/contrib/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This empty file is needed for packaging the contrib modules.""" diff --git a/t5x/contrib/moe/README.md b/t5x/contrib/moe/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6da7f167af44ef2ee8b7c663279cc881d4c28014 --- /dev/null +++ b/t5x/contrib/moe/README.md @@ -0,0 +1,46 @@ +# Mixture of Experts + + +This repo contains overrides and configs for training sparse Mixture of Experts +(MoE) models with T5X. The existing setups and examples all use [Flaxformer](https://github.com/google/flaxformer). + +## Training standard MoE architectures + +If you are looking train a T5X variant of a popular Mesh Tensorflow MoE model +(e.g. [Switch Transformer](https://arxiv.org/abs/2101.03961) or [Sparsely-Gated Mixture-of-Experts](https://arxiv.org/abs/1701.06538)) or adapt existing +MoE models, then the easiest way to get started is to plug one of the +[(Flaxformer) model gin configs](https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models) +into the [T5X Quickstart guide](https://github.com/google-research/t5x). To customize the default MoE models, you can override aspects of the underlying [(Flaxformer) architecture gin config](https://github.com/google/flaxformer/blob/main/flaxformer/t5x/configs/moe/architectures/moe.gin). + +## Using MoE in your existing model + +Alternatively, if you already have your own existing T5X/Flaxformer model +architecture and wish to add MoE layers, you can directly use the +[Flaxformer MoeLayer](https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/moe_layers.py#L67). +Currently, the MoeLayer is constrained to use +[Flaxformer MlpBlock(s)](https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/components/dense.py#L185) +as experts. As a point of reference: MoeLayer(s) are integrated with the Flaxformer T5 +architecture through the +[SparseEncoder](https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/moe_architecture.py#L36) +and +[SparseDecoder](https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/moe_architecture.py#L162). +These classes allow us to interleave sparse MoE and dense MLP blocks through the +`sparse_layout` attribute. + +## Expert routing mechanisms + +A number of routing mechanisms are supported: + +* Switch routing (or top-1 "tokens choose" routing) based on the + [Switch Transformer](https://arxiv.org/abs/2101.03961) +* General Top-k "tokens choose" routing of the form used in + [Sparsely-Gated Mixture-of-Experts](https://arxiv.org/abs/1701.06538), + [Vision MoE](https://arxiv.org/abs/2106.05974), + [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906) + and many other MoE works +* "Experts choose" routing introduced in + [Mixture-of-Experts with Expert Choice Routing](https://arxiv.org/abs/2202.09368) + +See the +[Flaxformer router codebase](https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py) for details. + diff --git a/t5x/contrib/moe/__init__.py b/t5x/contrib/moe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0f44439ef9e64b0f885e5bdff0dbd717c1f139 --- /dev/null +++ b/t5x/contrib/moe/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Import API modules.""" + +import t5x.contrib.moe.adafactor_utils +import t5x.contrib.moe.models +import t5x.contrib.moe.partitioning +import t5x.contrib.moe.trainer +import t5x.contrib.moe.training_utils + +# Version number. +from t5x.version import __version__ diff --git a/t5x/contrib/moe/adafactor_utils.py b/t5x/contrib/moe/adafactor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b8fd572828478ec6eaba505ae3c31cb288f88f1b --- /dev/null +++ b/t5x/contrib/moe/adafactor_utils.py @@ -0,0 +1,32 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adafactor logical rules for Mixture of Experts models.""" + +from flax import core as flax_core +from t5x import adafactor + +FactorDim = adafactor.FactorDim +FrozenDict = flax_core.FrozenDict + + +def logical_factor_rules() -> FrozenDict: + """Logical factor rules for Mixture of Experts.""" + rules = flax_core.unfreeze(adafactor.standard_logical_factor_rules()) + rules.update({ + 'expert': FactorDim.BATCH, + 'expert_mlp': FactorDim.COLUMN, + 'unmodeled': FactorDim.NONE + }) + return flax_core.freeze(rules) diff --git a/t5x/contrib/moe/configs/__init__.py b/t5x/contrib/moe/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85dd7a38f30639b377a504c2c0295e2b8955cea9 --- /dev/null +++ b/t5x/contrib/moe/configs/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This empty file is needed for loading the gin files in this directory.""" diff --git a/t5x/contrib/moe/configs/runs/__init__.py b/t5x/contrib/moe/configs/runs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85dd7a38f30639b377a504c2c0295e2b8955cea9 --- /dev/null +++ b/t5x/contrib/moe/configs/runs/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This empty file is needed for loading the gin files in this directory.""" diff --git a/t5x/contrib/moe/configs/runs/continue_pretrain.gin b/t5x/contrib/moe/configs/runs/continue_pretrain.gin new file mode 100644 index 0000000000000000000000000000000000000000..797b478dc8a061602afcc8c59d4bedc4724d20bc --- /dev/null +++ b/t5x/contrib/moe/configs/runs/continue_pretrain.gin @@ -0,0 +1,26 @@ +# Continue a Mixture of Experts pre-training run. +# +# See t5x/contrib/moe/configs/runs/pretrain.gin for instructions. +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - NUM_EXPERTS +# - NUM_MODEL_PARTITIONS (1 if no model parallelism) +# - MIXTURE_OR_TASK_NAME +# - TASK_FEATURE_LENGTHS +# - TRAIN_STEPS +# - INITIAL_CHECKPOINT_PATH +# - MODEL_DIR + +from __gin__ import dynamic_registration + +from t5x import utils + +include 't5x/contrib/moe/configs/runs/pretrain.gin' + +utils.RestoreCheckpointConfig: + mode = 'specific' + path = %INITIAL_CHECKPOINT_PATH + dtype = 'float32' diff --git a/t5x/contrib/moe/configs/runs/eval.gin b/t5x/contrib/moe/configs/runs/eval.gin new file mode 100644 index 0000000000000000000000000000000000000000..50cde96f2b7a228026ff7ce67affd716216467e8 --- /dev/null +++ b/t5x/contrib/moe/configs/runs/eval.gin @@ -0,0 +1,41 @@ +# Evaluate a Mixture of Experts model. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - NUM_EXPERTS +# - NUM_MODEL_PARTITIONS (1 if no model parallelism) +# - MIXTURE_OR_TASK_NAME +# - CHECKPOINT_PATH +# - EVAL_OUTPUT_DIR +# +# Commonly overridden options (see also t5x/configs/runs/eval.gin): +# +# - DROPOUT_RATE +# - BATCH_SIZE + +from __gin__ import dynamic_registration + +import __main__ as eval_script + +from t5x.contrib.moe import partitioning as moe_partitioning +from t5x import utils + +include 't5x/configs/runs/eval.gin' + +NUM_EXPERTS = %gin.REQUIRED +NUM_MODEL_PARTITIONS = %gin.REQUIRED + +# We use the MoE partitioner. +eval_script.evaluate.partitioner = @moe_partitioning.MoePjitPartitioner() +moe_partitioning.MoePjitPartitioner: + num_experts = %NUM_EXPERTS + num_partitions = %NUM_MODEL_PARTITIONS + logical_axis_rules = @moe_partitioning.standard_logical_axis_rules() +moe_partitioning.standard_logical_axis_rules: + num_experts = %NUM_EXPERTS + num_partitions = %NUM_MODEL_PARTITIONS + +utils.DatasetConfig.batch_size = %BATCH_SIZE diff --git a/t5x/contrib/moe/configs/runs/finetune.gin b/t5x/contrib/moe/configs/runs/finetune.gin new file mode 100644 index 0000000000000000000000000000000000000000..aa89384ae0883bdd2c650f9cd4008e02e7f1b918 --- /dev/null +++ b/t5x/contrib/moe/configs/runs/finetune.gin @@ -0,0 +1,60 @@ +# Fine-tune a Mixture of Experts model. +# +# This file allows for fine-tuning with data, expert and model parallelism. To +# use model parallelism, set NUM_MODEL_PARTITIONS > 1. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - NUM_EXPERTS +# - NUM_MODEL_PARTITIONS (1 if no model parallelism) +# - MIXTURE_OR_TASK_NAME +# - TASK_FEATURE_LENGTHS +# - TRAIN_STEPS # includes pretrain steps +# - MODEL_DIR +# - INITIAL_CHECKPOINT_PATH +# +# Commonly overridden options (see also t5x/configs/runs/finetune.gin): +# +# - DROPOUT_RATE +# - BATCH_SIZE +# - Trainer.num_microbatches + +from __gin__ import dynamic_registration + +import __main__ as train_script + +from t5x.contrib.moe import partitioning as moe_partitioning +from t5x.contrib.moe import trainer as moe_trainer +from t5x import utils + +include 't5x/configs/runs/finetune.gin' + +NUM_EXPERTS = %gin.REQUIRED +NUM_MODEL_PARTITIONS = %gin.REQUIRED + +# We use the MoE partitioner. +train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner() +moe_partitioning.MoePjitPartitioner: + num_experts = %NUM_EXPERTS + num_partitions = %NUM_MODEL_PARTITIONS + logical_axis_rules = @moe_partitioning.standard_logical_axis_rules() +moe_partitioning.standard_logical_axis_rules: + num_experts = %NUM_EXPERTS + num_partitions = %NUM_MODEL_PARTITIONS + +# And the MoE trainer. +train_script.train.trainer_cls = @moe_trainer.MoeTrainer +moe_trainer.MoeTrainer: + num_microbatches = None + learning_rate_fn = @utils.create_learning_rate_scheduler() + num_experts = %NUM_EXPERTS +utils.create_learning_rate_scheduler: + factors = 'constant' + base_learning_rate = 0.001 + warmup_steps = 1000 + +# Checkpoint slightly more often than fine-tuning defaults. +utils.SaveCheckpointConfig.period = 2000 diff --git a/t5x/contrib/moe/configs/runs/infer.gin b/t5x/contrib/moe/configs/runs/infer.gin new file mode 100644 index 0000000000000000000000000000000000000000..ffee10c2b5131437deb86fd9119dec32d7646189 --- /dev/null +++ b/t5x/contrib/moe/configs/runs/infer.gin @@ -0,0 +1,42 @@ +# Run inference with a Mixture of Experts model. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - NUM_EXPERTS +# - NUM_MODEL_PARTITIONS (1 if no model parallelism) +# - MIXTURE_OR_TASK_NAME +# - TASK_FEATURE_LENGTHS +# - CHECKPOINT_PATH +# - INFER_OUTPUT_DIR +# +# Commonly overridden options (see also t5x/configs/runs/infer.gin): +# +# - DROPOUT_RATE +# - BATCH_SIZE + +from __gin__ import dynamic_registration + +import __main__ as infer_script + +from t5x.contrib.moe import partitioning as moe_partitioning +from t5x import utils + +include 't5x/configs/runs/infer.gin' + +NUM_EXPERTS = %gin.REQUIRED +NUM_MODEL_PARTITIONS = %gin.REQUIRED + +# We use the MoE partitioner. +infer_script.infer.partitioner = @moe_partitioning.MoePjitPartitioner() +moe_partitioning.MoePjitPartitioner: + num_experts = %NUM_EXPERTS + num_partitions = %NUM_MODEL_PARTITIONS + logical_axis_rules = @moe_partitioning.standard_logical_axis_rules() +moe_partitioning.standard_logical_axis_rules: + num_experts = %NUM_EXPERTS + num_partitions = %NUM_MODEL_PARTITIONS + +utils.DatasetConfig.batch_size = %BATCH_SIZE diff --git a/t5x/contrib/moe/configs/runs/pretrain.gin b/t5x/contrib/moe/configs/runs/pretrain.gin new file mode 100644 index 0000000000000000000000000000000000000000..9b5917d02df0d5261b87a009da68c1a772b3c500 --- /dev/null +++ b/t5x/contrib/moe/configs/runs/pretrain.gin @@ -0,0 +1,60 @@ +# Pre-train a Mixture of Experts model. +# +# This file allows for pre-training with data, expert and model parallelism. To +# use model parallelism, set NUM_MODEL_PARTITIONS > 1. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - NUM_EXPERTS +# - NUM_MODEL_PARTITIONS (1 if no model parallelism) +# - MIXTURE_OR_TASK_NAME +# - TASK_FEATURE_LENGTHS +# - TRAIN_STEPS +# - MODEL_DIR +# +# Commonly overridden options (see also t5x/configs/runs/pretrain.gin): +# +# - BATCH_SIZE +# - Trainer.num_microbatches +# - DROPOUT_RATE + +from __gin__ import dynamic_registration + +import __main__ as train_script + +from t5x.contrib.moe import partitioning as moe_partitioning +from t5x.contrib.moe import trainer as moe_trainer +from t5x import utils + +include 't5x/configs/runs/pretrain.gin' + +NUM_EXPERTS = %gin.REQUIRED +NUM_MODEL_PARTITIONS = %gin.REQUIRED + +# We use the MoE partitioner. +train_script.train.partitioner = @moe_partitioning.MoePjitPartitioner() +moe_partitioning.MoePjitPartitioner: + num_experts = %NUM_EXPERTS + num_partitions = %NUM_MODEL_PARTITIONS + logical_axis_rules = @moe_partitioning.standard_logical_axis_rules() +moe_partitioning.standard_logical_axis_rules: + num_experts = %NUM_EXPERTS + num_partitions = %NUM_MODEL_PARTITIONS + +# And the MoE trainer. +train_script.train.trainer_cls = @moe_trainer.MoeTrainer +moe_trainer.MoeTrainer: + num_microbatches = None + learning_rate_fn = @utils.create_learning_rate_scheduler() + num_experts = %NUM_EXPERTS +utils.create_learning_rate_scheduler: + factors = 'constant * rsqrt_decay' + base_learning_rate = 1.0 + warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. + +# Keep slightly fewer checkpoints than pre-training defaults. +utils.SaveCheckpointConfig.period = 5000 +utils.SaveCheckpointConfig.keep = 20 \ No newline at end of file diff --git a/t5x/contrib/moe/models.py b/t5x/contrib/moe/models.py new file mode 100644 index 0000000000000000000000000000000000000000..42f6b8597991686ebdbbc2d6f1a443ad8e954cf7 --- /dev/null +++ b/t5x/contrib/moe/models.py @@ -0,0 +1,251 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides model subclasses with Mixture of Experts support.""" + +import dataclasses +from typing import Callable, Mapping, Optional, Sequence, Tuple, Union + +import clu.metrics as clu_metrics +from flax import core as flax_core +from flax import linen as nn +from flax import optim +from flax import traverse_util +from flax.core import scope as flax_scope +import jax.numpy as jnp +import seqio +from t5x import decoding +from t5x import losses +from t5x import metrics as metrics_lib +from t5x import models + +AveragePerStep = metrics_lib.AveragePerStep +DecodeFnCallable = models.DecodeFnCallable +FrozenVariableDict = flax_scope.FrozenVariableDict +MetricsMap = metrics_lib.MetricsMap +PyTreeDef = models.PyTreeDef +Sum = metrics_lib.Sum + + +@dataclasses.dataclass() +class ExpertMetrics: + """Metrics for analyzing diversity among experts in mixture of experts models. + + Attributes: + auxiliary_loss: Auxiliary load balancing loss. + router_z_loss: Router z-loss. Encourages router logits to remain small in an + effort to improve stability. + fraction_tokens_left_behind: Fraction of tokens NOT processed by any expert. + expert_usage: Fraction of total capacity, across all experts, used to + process tokens. Larger expert capacities or non-uniform token routing will + result in smaller expert usage values. + router_confidence: How confident the router is about the tokens that it has + routed. + """ + auxiliary_loss: float + router_z_loss: float + + fraction_tokens_left_behind: float + expert_usage: float + router_confidence: float + + +class MoeEncoderDecoderModel(models.EncoderDecoderModel): + """Subclass which propagates MoE auxiliary loss and metrics.""" + + def __init__( + self, + module: nn.Module, + input_vocabulary: seqio.Vocabulary, + output_vocabulary: seqio.Vocabulary, + optimizer_def: optim.OptimizerDef, + decode_fn: DecodeFnCallable = decoding.beam_search, + feature_converter_cls: Optional[Callable[..., + seqio.FeatureConverter]] = None, + label_smoothing: float = 0.0, + z_loss: float = 0.0, + loss_normalizing_factor: Optional[float] = None, + aux_loss_factor: float = 0., + router_z_loss_factor: float = 0.): + super().__init__( + module=module, + input_vocabulary=input_vocabulary, + output_vocabulary=output_vocabulary, + optimizer_def=optimizer_def, + decode_fn=decode_fn, + feature_converter_cls=feature_converter_cls, + label_smoothing=label_smoothing, + z_loss=z_loss, + loss_normalizing_factor=loss_normalizing_factor) + self.aux_loss_factor = aux_loss_factor + self.router_z_loss_factor = router_z_loss_factor + + def loss_fn( + self, params: models.PyTreeDef, batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jnp.ndarray]) -> Tuple[jnp.ndarray, MetricsMap]: + """Cross-entropy loss function with auxiliary MoE load balancing loss. + + Args: + params: Model parameters. + batch: Batch of training examples. + dropout_rng: Random number generator key for dropout. + + Returns: + - Model loss. + - Metrics. + """ + logits, state = self._compute_logits( + params, batch, dropout_rng, mutable=['intermediates']) + loss_normalizing_factor: Optional[Union[ + float, int, str, losses.SpecialLossNormalizingFactor]] + (loss_normalizing_factor, + weights) = losses.get_loss_normalizing_factor_and_weights( + self._loss_normalizing_factor, batch) + + targets = batch['decoder_target_tokens'] + total_loss, z_loss, _ = losses.compute_weighted_cross_entropy( + logits, + targets=targets, + weights=weights, + label_smoothing=self._label_smoothing, + z_loss=self._z_loss, + loss_normalizing_factor=loss_normalizing_factor) + + # Extract and add MoE losses to total loss. + diversity_metrics = _extract_diversity_metrics(state) + aux_loss, router_z_loss = _expert_losses(diversity_metrics, + self.aux_loss_factor, + self.router_z_loss_factor) + total_loss += aux_loss + router_z_loss + + metrics = self._compute_metrics( + logits=logits, + targets=targets, + mask=weights, + loss=total_loss, + z_loss=z_loss) + metrics.update( + _expert_metrics( + diversity_metrics, + total_loss, + z_loss, + aux_loss, + router_z_loss, + num_tokens=targets.size)) + + return total_loss, metrics + + +def _extract_diversity_metrics( + state: flax_scope.FrozenVariableDict) -> Sequence[ExpertMetrics]: + """Extract expert diversity metrics from sown state intermediates. + + Args: + state: Model state holding sown intermediate metrics. + + Returns: + Single diversity metrics instance per MoE layer. + + Raises: + ValueError if unable to extract any diversity metrics from model state. + """ + state_dict = traverse_util.flatten_dict(flax_core.unfreeze(state)) + diversity_metrics = [ + metric for path, metric in state_dict.items() + if path[-1] == 'diversity_metrics' + ] + if not diversity_metrics: + raise ValueError( + 'Unable to find any expert diversity metrics. Please check that MoE ' + 'metrics and losses are correctly sown.') + # Convert modeling library DiversityMetrics objects to local ExpertMetrics + # objects to avoid modeling library dependencies. + return [ + ExpertMetrics(metric.auxiliary_loss, metric.router_z_loss, + metric.fraction_tokens_left_behind, metric.expert_usage, + metric.router_confidence) for metric in diversity_metrics + ] + + +def _expert_losses(diversity_metrics: Sequence[ExpertMetrics], + auxiliary_loss_factor: float, + router_z_loss_factor: float) -> Tuple[float, float]: + """Summarizes per-layer MoE auxiliary losses. + + For auxiliary losses, we take the mean across MoE layers. + + Args: + diversity_metrics: Per-layer mixture of expert metrics. + auxiliary_loss_factor: Factor by which to scale auxiliary load balancing + loss for mixture of experts models. The raw auxiliary losses will be + summed and then scaled by this factor. + router_z_loss_factor: Factor by which to scale router z-loss for mixture of + experts models. + + Returns: + - Load balancing loss. + - Router z-loss. + """ + aux_loss = auxiliary_loss_factor * jnp.array( + [m.auxiliary_loss for m in diversity_metrics], dtype=jnp.float32).mean() + router_z_loss = router_z_loss_factor * jnp.array( + [m.router_z_loss for m in diversity_metrics], dtype=jnp.float32).mean() + return aux_loss, router_z_loss + + +def _expert_metrics(diversity_metrics: Sequence[ExpertMetrics], + total_loss: float, z_loss: float, auxiliary_loss: float, + router_z_loss: float, num_tokens: int) -> MetricsMap: + """Summarizes per-layer expert metrics for the entire model. + + The return metrics map will also contain overrides for the cross entropy loss + metrics to account for the MoE losses. + + Args: + diversity_metrics: Per-layer mixture of expert metrics. + total_loss: Total model loss. + z_loss: Output logits z-loss (not MoE specific). + auxiliary_loss: Auxiliary load balancing loss for MoE models. + router_z_loss: Router z-loss for MoE models. + num_tokens: Total number of target tokens. + + Returns: + Expert diversity metrics. + """ + cross_ent_loss = total_loss - z_loss - auxiliary_loss - router_z_loss + return { + 'experts/auxiliary_loss': + AveragePerStep.from_model_output(auxiliary_loss), + 'experts/router_z_loss': + AveragePerStep.from_model_output(router_z_loss), + 'experts/fraction_tokens_left_behind': + AveragePerStep.from_model_output( + jnp.array( + [m.fraction_tokens_left_behind for m in diversity_metrics], + dtype=jnp.float32).mean()), + 'experts/expert_usage': + AveragePerStep.from_model_output( + jnp.array([m.expert_usage for m in diversity_metrics], + dtype=jnp.float32).mean()), + 'experts/router_confidence': + AveragePerStep.from_model_output( + jnp.array([m.router_confidence for m in diversity_metrics], + dtype=jnp.float32).mean()), + # Override vanilla T5 cross entropy loss metrics with corrected loss that + # accounts for MoE losses. + 'cross_ent_loss': + metrics_lib.AveragePerStep(total=cross_ent_loss), + 'cross_ent_loss_per_all_target_tokens': + clu_metrics.Average(total=jnp.sum(cross_ent_loss), count=num_tokens) + } diff --git a/t5x/contrib/moe/models_test.py b/t5x/contrib/moe/models_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f571c7443bae1fa9ddbded41fc785b2fb8f475ce --- /dev/null +++ b/t5x/contrib/moe/models_test.py @@ -0,0 +1,144 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for models.""" + +from unittest import mock + +from absl.testing import absltest +from clu import metrics as clu_metrics_lib +from flax import core as flax_core +import jax.numpy as jnp +import numpy as np +from t5x import metrics as metrics_lib + +from t5x.contrib.moe import models + +Accuracy = clu_metrics_lib.Accuracy +AveragePerStep = metrics_lib.AveragePerStep +ExpertMetrics = models.ExpertMetrics +FrozenDict = flax_core.frozen_dict.FrozenDict + + +class ModelsTest(absltest.TestCase): + + def test_expert_losses(self): + diversity_metrics = [ + ExpertMetrics( + auxiliary_loss=1., + router_z_loss=0., + fraction_tokens_left_behind=0.5, + expert_usage=0.5, + router_confidence=0.5), + ExpertMetrics( + auxiliary_loss=2., + router_z_loss=1., + fraction_tokens_left_behind=0.5, + expert_usage=0.5, + router_confidence=0.5) + ] + aux_loss, router_z_loss = models._expert_losses( + diversity_metrics, auxiliary_loss_factor=0.1, router_z_loss_factor=10) + + self.assertEqual(aux_loss, 0.15) + self.assertEqual(router_z_loss, 5.) + + def test_expert_metrics(self): + diversity_metrics = [ + ExpertMetrics( + auxiliary_loss=1., + router_z_loss=0., + fraction_tokens_left_behind=1., + expert_usage=0.7, + router_confidence=0.5), + ExpertMetrics( + auxiliary_loss=2., + router_z_loss=1., + fraction_tokens_left_behind=0.5, + expert_usage=0.5, + router_confidence=0.5) + ] + actual_metrics = models._expert_metrics( + diversity_metrics, + total_loss=100., + z_loss=1., + auxiliary_loss=3., + router_z_loss=7., + num_tokens=2) + actual_metrics = metrics_lib.set_step_metrics_num_steps(actual_metrics, 1) + actual_computed_metrics = { + k: v.compute() for k, v in actual_metrics.items() + } + + expected_metrics = { + 'cross_ent_loss': 89.0, + 'cross_ent_loss_per_all_target_tokens': 44.5, + 'experts/auxiliary_loss': 3., + 'experts/expert_usage': 0.6, + 'experts/fraction_tokens_left_behind': 0.75, + 'experts/router_confidence': 0.5, + 'experts/router_z_loss': 7. + } + self.assertEqual(actual_computed_metrics, expected_metrics) + + def test_extract_from_non_expert_model(self): + empty_state = FrozenDict({'intermediates': {}}) + with self.assertRaisesRegex(ValueError, + 'Unable to find any expert diversity metrics.'): + models._extract_diversity_metrics(empty_state) + + def test_model(self): + encoder_input_tokens = jnp.ones((2, 3)) + decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) + decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) + decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) + logits = jnp.arange(0, 24).reshape((2, 4, 3)) + params = {'foo': jnp.zeros(3)} + + mock_transformer = mock.Mock() + mock_transformer.apply.return_value = logits + mock_transformer.dtype = jnp.float32 + + batch = { + 'encoder_input_tokens': encoder_input_tokens, + 'decoder_input_tokens': decoder_input_tokens, + 'decoder_target_tokens': decoder_target_tokens, + 'decoder_loss_weights': decoder_loss_weights + } + + def mock_init(self): + self.module = mock_transformer + + with mock.patch.object( + models.MoeEncoderDecoderModel, '__init__', new=mock_init): + model = models.MoeEncoderDecoderModel() + result = model.score_batch(params, batch) + + mock_transformer.apply.assert_called_with({'params': params}, + encoder_input_tokens, + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + encoder_positions=None, + decoder_positions=None, + decode=False, + enable_dropout=False, + rngs=None, + mutable=False) + np.testing.assert_allclose(result, [-3.2228181, -1.8152122], rtol=1e-5) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/contrib/moe/partitioning.py b/t5x/contrib/moe/partitioning.py new file mode 100644 index 0000000000000000000000000000000000000000..f3da16dd653d5bf179fc0c566a663c691868cc7e --- /dev/null +++ b/t5x/contrib/moe/partitioning.py @@ -0,0 +1,446 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pjit partitioner with Mixture of Experts overrides.""" + +from typing import Any, Callable, Optional, Sequence, Union + +from absl import logging +from flax import core as flax_core +import jax +import numpy as np + +from t5x import adafactor +from t5x import partitioning as t5x_partitioning +from t5x import train_state as train_state_lib + +from t5x.contrib.moe import training_utils + +DataLayout = t5x_partitioning.DataLayout +FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState +HardwareMesh = t5x_partitioning.HardwareMesh +InferenceState = train_state_lib.InferenceState +LogicalAxisRules = t5x_partitioning.LogicalAxisRules +PartitionSpec = t5x_partitioning.PartitionSpec +Pytree = Any +TrainState = train_state_lib.TrainState + + +class MoePjitPartitioner(t5x_partitioning.PjitPartitioner): + """Pjit partitioner with overrides for Mixture of Experts support. + + This MoE partitioner has two overrides relative to the default partitioner: + (1) It prepends an 'expert' axis to all MoE optimizer state terms, so that + they are sharded along the 'expert' axis; see get_logical_axes(). + (2) In cases where model parallelism is used and the number of experts is less + than the number of devices, we treat the 'model' axis as a secondary data + axis. This allows us to decouple expert parallelism ('data' mesh axis) + from data parallelism ('data' and 'model' axes). + """ + + def __init__(self, + num_experts: int, + num_partitions: Optional[int] = None, + model_parallel_submesh: Optional[HardwareMesh] = None, + params_on_devices: bool = True, + logical_axis_rules: Optional[LogicalAxisRules] = None, + state_filter_fn: Optional[Callable[[str], bool]] = None): + """Configures the partitioner. + + Args: + num_experts: Total number of experts across all devices. + num_partitions: Specifies the size of the model parallel submesh to be + automatically selected for the current topology. See + `model_parallel_submesh` for details on how this submesh is used. + Mutually exclusive with `model_parallel_submesh`. + model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh + model-parallel device tile -- an axis of accelerator parallelism + orthogonal to data parallelism. See t5x/partitioning.py for details. + This argument is mutually exclusive with `num_partitions`. + params_on_devices: Whether to keep the params on devices. If False, params + stay in the host memory. + logical_axis_rules: A priority-ordered sequence of KV tuples that maps + logical axis names to either `None` (not sharded), 'model' (to shard + across the model-parallel submesh), or 'data' (to shard across the + data-parallel submesh). + state_filter_fn: Function to identify which optimizer state axis rules + should be overridden to be sharded along the 'expert' axis. If None + (default), Adafactor expert sharding overrides are used. + """ + # If True, treat 'model' axis as secondary data axis. + self.two_data_axes = _override_model_axis(num_experts, num_partitions, + model_parallel_submesh) + if self.two_data_axes: + # Override num_partitions to repurpose the 'model' axis as a secondary + # data axis, along which only the batch is sharded. Experts will be + # replicated along this secondary data axis. + num_partitions = jax.device_count() // num_experts + + # Override user specified model parallel submesh. Rely on T5X partitioning + # to determine new submesh from updated `num_partitions`. + logging.info( + 'Overriding user specified `model_parallel_submesh`=%s to support ' + 'expert parallelism for updated `num_partitions`=%d', + model_parallel_submesh, num_partitions) + model_parallel_submesh = None + + super().__init__( + num_partitions=num_partitions, + model_parallel_submesh=model_parallel_submesh, + params_on_devices=params_on_devices, + logical_axis_rules=logical_axis_rules) + + self._state_filter_fn = state_filter_fn + + def get_data_layout(self, + batch_size: Optional[int] = None, + host_index: Optional[int] = None) -> DataLayout: + """Returns filled `DataLayout` based on the partitioned model layout. + + Overrides default data layout in case were both mesh axes ('data' and + 'model') are treated as data axes. + + Args: + batch_size: If set, indicates the requested batch size. If not set, the + batch size is inferred from the layout. + host_index: Indicates the host index to use for the calculations, if not + set - use JAX-provided one. Should be in [0, num_hosts) interval and the + order should match the order of corresponding CPU devices in + `jax.devices()`. + + Returns: + Filled `DataLayout` structure. + """ + if self.two_data_axes: + if host_index is not None: + raise NotImplementedError('Explicit host_index is not yet implemented.') + mesh_size = self._local_chunker.global_mesh.shape[ + 'data'] * self._local_chunker.global_mesh.shape['model'] + batch_size = batch_size or mesh_size + if batch_size % mesh_size: + raise ValueError( + f'Batch size ({batch_size}) must be divisible by corresponding ' + f'mesh size ({mesh_size}).') + num_shards = self._local_chunker.num_chunks['data'] + if batch_size % num_shards: + raise ValueError( + f'Batch size ({batch_size}) must be divisible by number of ' + f'replicas ({num_shards}).') + replica_id = self._local_chunker.get_local_chunk_info( + (batch_size,), ('data', 'model')).replica_id + return DataLayout( + batch_size=batch_size, + shard_id=self._local_chunker.chunk_ids['data'], + num_shards=num_shards, + is_first_host_in_replica_set=(replica_id == 0)) + else: + return super().get_data_layout(batch_size, host_index) + + def get_logical_axes( + self, train_state: Union[FlaxOptimTrainState, InferenceState] + ) -> Union[FlaxOptimTrainState, InferenceState]: + """Returns a copy of TrainState with Optional[AxisNames] as leaves. + + Overrides the default logical axes by prepending the 'expert' axis to any + MoE optimizer state terms (identified by self._state_filter_fn) so they are + correctly sharded along the 'expert' axis. + + Args: + train_state: Object holding all relevant training of inference state. + + Returns: + State object matching structure of input train_state but with axis names + as leaves. + """ + logical_axes = train_state.as_logical_axes() + + if isinstance(logical_axes, InferenceState): + # InferenceState does not contain any optimizer state, so we skip all + # expert partitioning overrides. + return logical_axes + else: + train_state: FlaxOptimTrainState + + state_filter_fn = ( + self._state_filter_fn or _infer_state_filter_fn(train_state)) + if state_filter_fn is None: + # No state updates required. + return logical_axes + + prepend_expert = lambda x: PartitionSpec( # pylint: disable=g-long-lambda + 'expert',) + x if x else PartitionSpec('expert',) + optimizer_axes = logical_axes._optimizer # pylint: disable=protected-access + state_dict = flax_core.unfreeze(optimizer_axes.state_dict()) + state_dict['state']['param_states'] = training_utils.tree_map_with_names( + prepend_expert, state_dict['state']['param_states'], state_filter_fn) + + return train_state.restore_state(state_dict) + + def partition( + self, + fn: Callable, # pylint: disable=g-bare-generic + in_axis_resources: Pytree, + out_axis_resources: Pytree, + static_argnums: Union[int, Sequence[int]] = (), + donate_argnums: Union[int, Sequence[int]] = () + ) -> t5x_partitioning.PjittedFnWithContext: + """Partitions the computation using pjit. + + Overrides the default pjit partitioning in cases where expert and data axes + are decoupled -- wherein we treat the 'model' axis as a secondary data axis. + + Args: + fn: Function to partition. + in_axis_resources: Pytree of structure matching that of arguments to `fn`, + with all actual arguments replaced by resource assignment + specifications. + out_axis_resources: Like `in_axis_resources`, but specifies resource + assignment for function outputs. + static_argnums: Specifies which positional arguments to treat as static + (compile-time constant) in the partitioned function. + donate_argnums: Specifies which argument buffers are "donated" to the + computation. + + Returns: + A partitioned version of the input function. + """ + if self.two_data_axes: + # Both axes are used for data parallelism in this case, so we override the + # partition specs. + in_axis_resources = _override_partition_specs(in_axis_resources) + out_axis_resources = _override_partition_specs(out_axis_resources) + + pjitted = t5x_partitioning.pjit( + fn, + in_axis_resources=in_axis_resources, + out_axis_resources=out_axis_resources, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + backend=self._backend) + + return t5x_partitioning.PjittedFnWithContext(pjitted, self.mesh, + self._logical_axis_rules) + + +def standard_logical_axis_rules( + num_experts: int, + num_partitions: Optional[int] = None, + model_parallel_submesh: Optional[HardwareMesh] = None, + activation_partitioning_dims: int = 1, + parameter_partitioning_dims: int = 1, + additional_rules: Optional[LogicalAxisRules] = None): + """Returns partitioning rules for MoE models. + + The partitioning rules vary based on whether the expert and data axes need to + be decoupled; see also MoePjitPartitioner for details of when expert and data + axes need to be decouple. + + 2D parameter sharding (`parameter_partitioning_dims=2`) is not supported. + Sharding parameters along the 'data' axis will interfere with expert + parallelism, because experts are also partitioned along the 'data' axis. + + Args: + num_experts: Total number of experts across all devices. + num_partitions: Size of the model parallel submesh. Model parallelism is + only used if num_model_partitions > 1. Ignored if model_parallel_submesh + is specified. + model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh + model-parallel device tile -- an axis of accelerator parallelism + orthogonal to data parallelism. Model parallelism is only used if + np.prod(model_parallel_submesh) > 1. Mutually exclusive with + `num_partitions`. + activation_partitioning_dims: Enables 2-D activation sharding when set to 2. + parameter_partitioning_dims: Enables 2-D parameter sharding when set to 2. + additional_rules: Additional rules (a sequence of tuples) that will be + appended to the standard rules. + + Returns: + Sequence of logical axis rules. + + Raises: + ValueError if parameter_partitioning_dims=2. + """ + if parameter_partitioning_dims == 2: + raise ValueError('2D parameter sharding (`parameter_partitioning_dims=2`) ' + 'is not supported for MoE.') + + default_rules = t5x_partitioning.standard_logical_axis_rules( + activation_partitioning_dims, parameter_partitioning_dims) + moe_rules = [ + ('expert', 'data'), # Shard experts along the data axis + ('expert_mlp', 'model'), # Expert MLPs partitioned along model axis + ('expert_group', None), # Replicated axis for all-to-all constraints + ('expert_replicas', None), # Experts replicated along this axis + ('unmodeled', None), # Replicated weights + ] + standard_rules = list(default_rules) + moe_rules + if additional_rules: + standard_rules.extend(additional_rules) + + if _override_model_axis(num_experts, num_partitions, model_parallel_submesh): + overridden_rules = [] + for logical_axis, mesh_axis in standard_rules: + if logical_axis == 'batch': + # Because we now treat the 'model' axis as a second data axis, we want + # to shard batches across both axes. + overridden_mesh_axis = ('data', 'model') + elif logical_axis == 'expert_replicas': + # "model" axis is repurposed as a second data axis, along which experts + # are replicated. + overridden_mesh_axis = 'model' + elif mesh_axis == 'model': + # Any weights ordinarily partitioned along the model axis, should be + # explicitly replicated. + overridden_mesh_axis = None + else: + overridden_mesh_axis = mesh_axis + overridden_rules.append((logical_axis, overridden_mesh_axis)) + + return overridden_rules + + else: + return standard_rules + + +def data_partition_spec(two_data_axes: bool) -> PartitionSpec: + """Returns data partitioning spec. + + Args: + two_data_axes: If True, use 'model' axis as secondary data axis. Otherwise, + only use 'data' axis for data sharding. + + Returns: + Mesh dependent partition spec. + """ + if two_data_axes: + # Use 'model' axis as secondary data axis. Shard batches across both axes. + return PartitionSpec(('data', 'model'),) + else: + return PartitionSpec('data',) + + +def _override_model_axis( + num_experts: int, num_partitions: Optional[int], + model_parallel_submesh: Optional[HardwareMesh]) -> bool: + """Returns true iff there is no model parallelism & num experts < num devices. + + Args: + num_experts: Total number of experts across all devices. + num_partitions: Size of the model parallel submesh. Model parallelism is + only used if num_model_partitions > 1. Mutually exclusive with + `model_parallel_submesh`. + model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh + model-parallel device tile -- an axis of accelerator parallelism + orthogonal to data parallelism. Model parallelism is only used if + np.prod(model_parallel_submesh) > 1. Mutually exclusive with + `num_partitions`. + + Returns: + True if there is no model parallelism & num experts < num devices; False + otherwise. + """ + if (num_partitions is None) == (model_parallel_submesh is None): + raise ValueError( + 'One, and only one, of {num_partitions, model_parallel_submesh} must ' + 'be specified. Received: %s and %s' % + (num_partitions, model_parallel_submesh)) + + if num_experts == 0 or jax.device_count() <= num_experts: + # No expert replication required. No need to override model mesh axis. + return False + + return ((num_partitions is not None and num_partitions <= 1) or + (model_parallel_submesh is not None and + np.prod(model_parallel_submesh) <= 1)) + + +def _override_partition_specs(resources: Pytree): + """Override axis resources for two data axes setup. + + In the two data axes setup, we treat the 'model' axis as a secondary data + axis. To this end, we override any hardcoded, raw partition specs: + - PartitionSpec('data',) -> PartitionSpec(('data', 'model'),) + - PartitionSpec('model',) -> None + There is no need to override any params or optimizer state as these will + inherit the correct specs from the logical axis rules; see + standard_logical_axis_rules(). + + Args: + resources: Axis resource assignment specifications. + + Returns: + Axis resources with partition specs overridden to use 'model' as secondary + 'data' axis. + """ + + def _maybe_overridde_spec(axis_resource: Pytree): + """Overrides "data" and "model" partition specs; leaves others unchanged.""" + if axis_resource == PartitionSpec('data',): + # Shard all batches across both axes. + return PartitionSpec(('data', 'model'),) + elif axis_resource == PartitionSpec('model',): + # No model parallelism. + return None + else: + return axis_resource + + if resources is None: + return resources + elif not isinstance(resources, Sequence): + return _maybe_overridde_spec(resources) + else: + overridden_resources = [] + for resource in resources: + overridden_resources.append(_maybe_overridde_spec(resource)) + return tuple(overridden_resources) + + +def _infer_state_filter_fn( + train_state: FlaxOptimTrainState) -> Optional[Callable[[str], bool]]: + """Infers relevant regex matching sharded expert model state for optimizer. + + Only the Adafactor optimizer is currently supported. + + The model state generally inherits the correct partitioning specs from the + model parameters, except in cases where the kernel is factored (`v_col` and + `v_row` terms); see derive_logical_axes(): + https://github.com/google-research/t5x/blob/main/t5x/adafactor.py#L591. For + those cases, we use the state_filter_fn to identify the factored kernel terms + that need to be partitioned along the expert axis. + + Args: + train_state: Object holding optimizer and optimizer state (parameters). + + Returns: + Function to identify which model state is sharded along 'expert' axis. + + Raises: + ValueError if optimizer (on train state) is not an Adafactor optimizer. + """ + optimizer = train_state._optimizer # pylint: disable=protected-access + optimizer_def = optimizer.optimizer_def + + # TODO(jamesleethorp): Revisit once other T5X optimizers are available. + if not isinstance(optimizer_def, adafactor.Adafactor): + raise ValueError('Inferred MoE overrides are currently only available for ' + f'the Adafactor optimizer. Received: {optimizer_def}') + + if optimizer_def.hyper_params.factored: + # Factored kernel terms (`v_col` and `v_row`) need to be identified for + # expert sharding. + return training_utils.match_fn(r'.*expert.*/kernel/v_.*') + else: + # Non-factored kernel terms (`v`) inherit the correct specs, so no state + # updates will be required. + return None diff --git a/t5x/contrib/moe/partitioning_test.py b/t5x/contrib/moe/partitioning_test.py new file mode 100644 index 0000000000000000000000000000000000000000..14019b25713061735fa2763e85645b1fabb77c03 --- /dev/null +++ b/t5x/contrib/moe/partitioning_test.py @@ -0,0 +1,337 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for partitioning.""" + +from typing import Any + +from absl.testing import absltest +from flax import core as flax_core +from flax import optim +from flax.linen import partitioning as flax_partitioning +import jax +import numpy as np +from t5x import train_state as train_state_lib + +from t5x.contrib.moe import partitioning as moe_partitioning +from t5x.contrib.moe import training_utils + +mock = absltest.mock + +AxisMetadata = flax_partitioning.AxisMetadata +DataLayout = moe_partitioning.DataLayout +FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState +InferenceState = train_state_lib.InferenceState +PartitionSpec = moe_partitioning.PartitionSpec +PRNGKey = Any + + +class LogicalAdam(optim.Adam): + """Subclass of Adam optimizer with T5X logical axis partitioning support.""" + + def derive_logical_axes(self, optimizer_state, param_logical_axes): + """Derives optimizer logical partitioning from model logical partitions.""" + del param_logical_axes # Return fixed axes for test + optimizer_logical_axes = { + 'state': { + 'param_states': { + 'logits_dense': { + 'grad_ema': None, + 'grad_sq_ema': None + }, + 'mlp': { + 'wo': { + 'kernel': { + 'grad_ema': PartitionSpec('embed', 'mlp'), + 'grad_sq_ema': None + } + } + } + }, + 'step': None + }, + 'target': { + 'logits_dense': PartitionSpec('vocab', 'embed'), + 'mlp': { + 'wo': { + 'kernel': PartitionSpec('embed', 'mlp'), + }, + }, + } + } + return optimizer_state.restore_state(optimizer_logical_axes) + + +def create_optimizer(): + """Creates simple Adam optimizer.""" + target = { + 'logits_dense': np.ones((16, 16), np.float32), + 'mlp': { + 'wo': { + 'kernel': np.ones((32, 16), np.float32) + } + } + } + return LogicalAdam(learning_rate=1e-4).create(target) + + +class PartitioningTest(absltest.TestCase): + + def test_default_data_layout(self): + # No expert replication required. Use default data layout. + partitioner = moe_partitioning.MoePjitPartitioner( + num_experts=8, num_partitions=1) + self.assertFalse(partitioner.two_data_axes) + self.assertEqual( + partitioner.get_data_layout(batch_size=32), + DataLayout( + batch_size=32, + shard_id=0, + num_shards=1, + is_first_host_in_replica_set=True)) + + def test_two_data_axis_layout_override(self): + partitioner = moe_partitioning.MoePjitPartitioner( + num_experts=8, num_partitions=1) + # Force override case to check layout is valid. + partitioner.two_data_axes = True + partitioner._data_axis = ('data', 'model') + self.assertEqual( + partitioner.get_data_layout(batch_size=8), + DataLayout( + batch_size=8, + shard_id=0, + num_shards=1, + is_first_host_in_replica_set=True)) + + def test_logical_axes_for_moe_partitioner_no_overrides(self): + partitioner = moe_partitioning.MoePjitPartitioner( + num_experts=8, + num_partitions=1, + state_filter_fn=training_utils.match_fn(r'no_state_matching')) + + optimizer = create_optimizer() + train_state = FlaxOptimTrainState( + optimizer, + params_axes={ + 'logits_dense_axes': AxisMetadata(names=('vocab', 'embed')), + 'mlp': { + 'wo': { + 'kernel_axes': AxisMetadata(names=('embed', 'mlp')) + } + } + }) + + logical_axes = partitioner.get_logical_axes(train_state) + + # No updates to state. Should match what derive_logical_axes() returns. + jax.tree_map(self.assertIsNone, logical_axes.param_states['logits_dense']) + self.assertEqual(logical_axes.param_states['mlp']['wo']['kernel'].grad_ema, + PartitionSpec('embed', 'mlp')) + self.assertIsNone( + logical_axes.param_states['mlp']['wo']['kernel'].grad_sq_ema) + + self.assertEqual( + logical_axes.params, { + 'logits_dense': PartitionSpec('vocab', 'embed'), + 'mlp': { + 'wo': { + 'kernel': PartitionSpec('embed', 'mlp') + } + } + }) + + def test_logical_axes_for_moe_partitioner_with_overrides(self): + partitioner = moe_partitioning.MoePjitPartitioner( + num_experts=8, + num_partitions=1, + state_filter_fn=training_utils.match_fn(r'.*mlp.*')) + + optimizer = create_optimizer() + train_state = FlaxOptimTrainState( + optimizer, + params_axes={ + 'logits_dense_axes': AxisMetadata(names=('vocab', 'embed')), + 'mlp': { + 'wo': { + 'kernel_axes': AxisMetadata(names=('embed', 'mlp')) + } + } + }) + + logical_axes = partitioner.get_logical_axes(train_state) + + jax.tree_map(self.assertIsNone, logical_axes.param_states['logits_dense']) + # 'mlp' params should be prepended with 'expert' spec because + # state_filter_fn matches '.*mlp.*'. + self.assertEqual(logical_axes.param_states['mlp']['wo']['kernel'].grad_ema, + PartitionSpec('expert', 'embed', 'mlp')) + self.assertEqual( + logical_axes.param_states['mlp']['wo']['kernel'].grad_sq_ema, + PartitionSpec('expert',)) + + self.assertEqual( + logical_axes.params, { + 'logits_dense': PartitionSpec('vocab', 'embed'), + 'mlp': { + 'wo': { + 'kernel': PartitionSpec('embed', 'mlp') + } + } + }) + + def test_inference_state_logical_axes(self): + partitioner = moe_partitioning.MoePjitPartitioner( + num_experts=8, num_partitions=1) + + model_variables = flax_core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + } + }, + }) + train_state = InferenceState.create(model_variables) + logical_axes = partitioner.get_logical_axes(train_state) + + # No expert axis overrides to InferenceState. Partition specs should match + # input axis metadata. + self.assertEqual( + logical_axes, + InferenceState( + step=None, + params=flax_core.FrozenDict({ + 'dense': { + 'bias': PartitionSpec('embed',), + 'kernel': PartitionSpec('vocab', 'embed'), + }, + }))) + + @mock.patch('jax.device_count') + def test_overridden_logical_axis_rules(self, device_count: int): + device_count.return_value = 4 + # Fewer experts than devices --> modified axis rules with two 'batch' axes. + self.assertEqual( + moe_partitioning.standard_logical_axis_rules( + num_experts=1, + num_partitions=1, + model_parallel_submesh=None, + additional_rules=[('additional', 'model'), + ('expert_magic', 'data')]), + [ + ('batch', ('data', 'model')), # Shard batch over entire mesh + # No sharding of weights over model axis. + ('vocab', None), + ('embed', None), + ('mlp', None), + ('heads', None), + ('kv', None), + ('joined_kv', None), + ('relpos_buckets', None), + ('abspos_buckets', None), + ('length', None), + ('layers', None), + ('stack', None), + ('mlp_activations', None), + ('expert', 'data'), # Shard experts over "first" data axis only + ('expert_mlp', None), + ('expert_group', None), + # Experts replicated along "second" data axis + ('expert_replicas', 'model'), + ('unmodeled', None), + ('additional', None), + ('expert_magic', 'data'), + ]) + + def test_default_logical_axis(self): + # Model parallelism used --> default logical axis rules. + self.assertEqual( + moe_partitioning.standard_logical_axis_rules( + num_experts=1, + num_partitions=2, + model_parallel_submesh=None, + additional_rules=[('additional', 'model')]), + [ + ('batch', 'data'), # Shard batch over single data axis + # Default model annotations used. + ('vocab', 'model'), + ('embed', None), + ('mlp', 'model'), + ('heads', 'model'), + ('kv', None), + ('joined_kv', 'model'), + ('relpos_buckets', None), + ('abspos_buckets', None), + ('length', None), + ('layers', None), + ('stack', None), + ('mlp_activations', None), + ('expert', 'data'), # Shard experts along data axis + ('expert_mlp', 'model'), + ('expert_group', None), + ('expert_replicas', None), + ('unmodeled', None), + ('additional', 'model'), + ]) + + def test_2d_parameter_sharding_unsupported(self): + with self.assertRaisesRegex(ValueError, 'is not supported for MoE.'): + moe_partitioning.standard_logical_axis_rules( + num_experts=4, num_partitions=1, parameter_partitioning_dims=2) + + def test_data_partition_spec(self): + self.assertEqual( + moe_partitioning.data_partition_spec(two_data_axes=False), + PartitionSpec('data',)) + self.assertEqual( + moe_partitioning.data_partition_spec(two_data_axes=True), + PartitionSpec(('data', 'model'),)) + + @mock.patch('jax.device_count') + def test_when_to_override_model_axis(self, device_count: int): + device_count.return_value = 4 + + # More experts than devices. + self.assertFalse( + moe_partitioning._override_model_axis( + num_experts=8, num_partitions=1, model_parallel_submesh=None)) + + # Fewer experts than devices. + self.assertTrue( + moe_partitioning._override_model_axis( + num_experts=1, num_partitions=1, model_parallel_submesh=None)) + + # Model parallelism used. + self.assertFalse( + moe_partitioning._override_model_axis( + num_experts=1, num_partitions=2, model_parallel_submesh=None)) + + def test_axis_resource_overrides(self): + input_resources = (PartitionSpec('data'), PartitionSpec('model'), None, + PartitionSpec('unrecognized')) + overridden_resources = moe_partitioning._override_partition_specs( + input_resources) + # "data" -> ("data", "model"). "model" -> None. + self.assertEqual(overridden_resources, (PartitionSpec( + ('data', 'model'),), None, None, PartitionSpec('unrecognized',))) + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/contrib/moe/trainer.py b/t5x/contrib/moe/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..930618473181726007637cb7377b390ee1ce1613 --- /dev/null +++ b/t5x/contrib/moe/trainer.py @@ -0,0 +1,138 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trainer with Mixture of Experts support.""" + +from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING + +import cached_property +from t5x import models +from t5x import train_state as train_state_lib +from t5x import trainer +from t5x.contrib.moe import partitioning +from t5x.contrib.moe import training_utils + +BatchType = trainer.BatchType +LearningRateCallable = trainer.LearningRateCallable +MetricMapType = trainer.MetricMapType +PartitionSpec = partitioning.PartitionSpec +PartitionedTrainCallable = trainer.PartitionedTrainCallable +Rng = trainer.Rng + +if TYPE_CHECKING: # See b/163639353 + cached_property = property # pylint: disable=invalid-name +else: + cached_property = cached_property.cached_property + + +class MoeTrainer(trainer.Trainer): + """T5X trainer with overrides for Mixture of Experts support.""" + + def __init__( + self, + model: models.BaseModel, + train_state: train_state_lib.TrainState, + partitioner: partitioning.MoePjitPartitioner, + eval_names: Sequence[str], + summary_dir: Optional[str], + train_state_axes: Any, + rng: Rng, + learning_rate_fn: LearningRateCallable, + num_microbatches: Optional[int], + num_experts: int, + sharded_match_fn: Optional[Callable[ + [str], bool]] = training_utils.match_fn(r'.*expert.*'), + weight_metrics_computer: Optional[trainer.WeightMetricsComputer] = None): + """Trainer constructor. + + Args: + model: the instantiation of `BaseModel` to train. + train_state: a train state with parameters and optimizer state. + partitioner: the partitioner to use. + eval_names: names of evaluation datasets, which must match the keys of the + mapping passed to `eval`. + summary_dir: optional directory to write TensorBoard metrics to. + train_state_axes: partitioning info for the optimizer to be used. + rng: jax PRNGKey seed for random operations, to be combined with step + number for a deterministic RNG. + learning_rate_fn: returns the learning rate given the current step. + num_microbatches: the number of microbatches to use, or None for direct + training. + num_experts: Global number of experts. Used to scale sharded parameter + gradients. + sharded_match_fn: Filter function for distinguishing sharded (MoE) + parameters from replicated parameters. Used to identify the sharded + parameter gradients that need to be rescaled under pjit training. + weight_metrics_computer: A WeightMetricsComputer instance, or None, to + decide what metrics, if any, to log about weights and weight updates + during training. + """ + super().__init__( + model=model, + train_state=train_state, + partitioner=partitioner, + eval_names=eval_names, + summary_dir=summary_dir, + train_state_axes=train_state_axes, + rng=rng, + learning_rate_fn=learning_rate_fn, + num_microbatches=num_microbatches, + weight_metrics_computer=weight_metrics_computer) + + self._num_experts = num_experts + self._sharded_match_fn = sharded_match_fn + self.data_partition_spec = partitioning.data_partition_spec( + partitioner.two_data_axes) + + @cached_property + def _partitioned_train_step(self) -> PartitionedTrainCallable: + """Same as a regular T5X train step, but scales expert parameter gradients. + + We must scale expert parameter gradients by the number of experts to account + for pjit's implicit averaging over partitioned parameter gradients. + + Returns: + Partitioned train step function. + """ + + def train_with_lr(train_state: train_state_lib.TrainState, + batch: BatchType): + grad_accum, metrics, flax_mutables = ( + trainer.accumulate_grads_microbatched( + self._model, + train_state, + batch, + self._get_step_rng(train_state.step), + self._num_microbatches, + data_partition_spec=self.data_partition_spec)) + + # Only difference between this train step and regular T5X train step: + scaled_grads = training_utils.scale_sharded_grads( + grad_accum, self._sharded_match_fn, scale_factor=self._num_experts) + + new_train_state, metrics = trainer.apply_grads( + train_state, + scaled_grads, + metrics, + self._learning_rate_fn(train_state.step), + self._weight_metrics_computer, + other_state_variables={'flax_mutables': flax_mutables} + if flax_mutables else None) + return new_train_state, metrics + + return self._partitioner.partition( + train_with_lr, + in_axis_resources=(self._train_state_axes, self.data_partition_spec), + out_axis_resources=(self._train_state_axes, None), + donate_argnums=(0,)) diff --git a/t5x/contrib/moe/trainer_test.py b/t5x/contrib/moe/trainer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f7dcc94c9305ecc10c9eaac1150b49effea055b6 --- /dev/null +++ b/t5x/contrib/moe/trainer_test.py @@ -0,0 +1,166 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trainer.""" + +import contextlib + +from absl.testing import absltest +from flax import optim +import jax +import numpy as np +from t5x import metrics as metrics_lib +from t5x import models as models_lib +from t5x import train_state as train_state_lib +from t5x.contrib.moe import partitioning +from t5x.contrib.moe import trainer as trainer_lib +import tensorflow as tf + +mock = absltest.mock +jax.config.parse_flags_with_absl() + + +# Make `log_elapsed_time` a no-op to simplify mocking of `time.time()`. +@contextlib.contextmanager +def fake_log_elapsed_time(_): + yield + + +jax._src.dispatch.log_elapsed_time = fake_log_elapsed_time + + +def fake_accum_grads(model, optimizer, batch, rng, num_microbatches, + data_partition_spec): + del model, num_microbatches, rng, data_partition_spec + # Add `i` to each optimzer value. + i = batch['i'].sum() + grad_accum = jax.tree_map(lambda x: i, optimizer) + # Add j to each metric. + j = batch['j'].sum() + metrics = { + 'loss': metrics_lib.Sum.from_model_output(j), + 'accuracy': metrics_lib.Sum.from_model_output(j) + } + return grad_accum, metrics, None + + +def fake_apply_grads(optimizer, + grad_accum, + metrics, + learning_rate, + weight_metrics_computer, + other_state_variables=None): + del weight_metrics_computer + del other_state_variables + metrics['learning_rate'] = metrics_lib.Sum.from_model_output(learning_rate) + optimizer = jax.tree_multimap(lambda x, g: x + g, optimizer, grad_accum) + return optimizer, metrics + + +class MoeTrainerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.init_optimizer = optim.Optimizer( + optim.GradientDescent(), + state=optim.OptimizerState( + step=0, param_states={ + 'expert_bias': 0, + 'kernel': 0 + }), + target={ + 'expert_bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + }) + self.init_train_state = train_state_lib.FlaxOptimTrainState( + self.init_optimizer) + train_state_axes = jax.tree_map(lambda x: None, self.init_train_state) + model_dir = self.create_tempdir().full_path + + mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} + self.dataset = tf.data.Dataset.range(6).map(mapfn).batch( + 2, drop_remainder=True) + + num_experts = 10 + self.test_trainer = trainer_lib.MoeTrainer( + model=mock.create_autospec(models_lib.BaseModel, instance=True), + train_state=self.init_train_state, + partitioner=partitioning.MoePjitPartitioner( + num_experts=num_experts, num_partitions=1), + eval_names=['task1', 'task2'], + summary_dir=model_dir, + train_state_axes=train_state_axes, + rng=np.ones(2, np.uint32), + learning_rate_fn=lambda step: 2 * step, + num_microbatches=None, + num_experts=num_experts) + + @mock.patch('time.time') + @mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_accum_grads) + @mock.patch('t5x.trainer.apply_grads', fake_apply_grads) + @mock.patch('absl.logging.log', lambda *_: None) # avoids time.time() calls + def _test_train(self, precompile, mock_time=None): + trainer = self.test_trainer + initial_rng = trainer._base_rng + + if precompile: + mock_time.side_effect = [0, 1] + trainer.compile_train(next(self.dataset.as_numpy_iterator())) + trainer._compiled_train_step = mock.Mock( + side_effect=trainer._compiled_train_step) + + trainer._partitioned_train_step = mock.Mock( + side_effect=trainer._partitioned_train_step) + + # train start, logging, train end, logging + mock_time.side_effect = [1, 5] + num_steps = 2 + trainer.train(self.dataset.as_numpy_iterator(), num_steps) + + # Base rng must remain the same. + np.testing.assert_array_equal(trainer._base_rng, initial_rng) + + expected_optimizer = optim.Optimizer( + self.init_optimizer.optimizer_def, + state=optim.OptimizerState( + step=[6], + param_states={ + 'expert_bias': 60, # 10 * (0+1+2+3) = 60 + 'kernel': 6 # 0+1+2+3 = 6 + }), + target={ + 'expert_bias': 60 * np.ones(4), + 'kernel': 6 * np.ones((2, 4)) + }) + expected_train_state = train_state_lib.FlaxOptimTrainState( + expected_optimizer) + jax.tree_multimap(np.testing.assert_allclose, trainer.train_state, + expected_train_state) + + if precompile: + self.assertEqual(trainer._compiled_train_step.call_count, num_steps) + trainer._partitioned_train_step.assert_not_called() + else: + self.assertIsNone(trainer._compiled_train_step) + self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) + + def test_train_noprecompile(self): + self._test_train(False) + + def test_train_precompile(self): + self._test_train(True) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/contrib/moe/training_utils.py b/t5x/contrib/moe/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a181c570e3cd17bec8bfebf905973a997f4932d --- /dev/null +++ b/t5x/contrib/moe/training_utils.py @@ -0,0 +1,138 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensions to Jax/Flax core functions for Mixture of Experts training. + +""" + +import dataclasses +import re +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union + +import flax +import jax +import numpy as np +from t5x import train_state + +# Type Stubs +ParamTree = Any +PyTreeDef = Any +Gradients = Union[flax.core.FrozenDict, train_state.TrainState] + + +def match_fn(prefix: Optional[str]) -> Callable[[str], bool]: + """Creates a function returning true iff a string matches the prefix. + + Args: + prefix: Regex prefix to match. If none, then return match function will not + match any strings. + + Returns: + Prefix match function. + """ + if not prefix: + return lambda name: False + params_regex = re.compile(f'^{prefix}') + return lambda name: params_regex.match(name) is not None + + +def scale_sharded_grads(grads: Gradients, + sharded_match_fn: Optional[Callable[[str], bool]], + scale_factor: float) -> Gradients: + """Scales sharded grads, identified by sharded_match_fn, by scale_factor. + + Args: + grads: Parameter gradients. + sharded_match_fn: Filter function for distinguishing sharded parameters from + replicated parameters. + scale_factor: Amount by which to scale sharded parameter gradients. + + Returns: + Gradients matching input, expect with sharded parameter gradients rescaled. + """ + if sharded_match_fn: + names_and_grads, tree_def = _tree_flatten_with_names(grads) + scaled_grads = [ + grad * scale_factor if sharded_match_fn(name) else grad + for name, grad in names_and_grads + ] + return tree_def.unflatten(scaled_grads) + else: + return grads + + +def tree_map_with_names(f, param_tree, match_name_fn=lambda name: True): + """Like jax.tree_map but with a filter on the leaf path name. + + Args: + f: The function to be applied to each parameter in `param_tree`. + param_tree: The tree of parameters `f` should be applied to. + match_name_fn: This function is called with each tree leave's path name, + which has a path-like format ('a/b/c'), and decides whether `f` should be + applied to that leaf or the leaf should be kept as-is. + + Returns: + A tree identical in structure to `param_tree` but with the leaves the + result of calling `f` on them in the cases where `match_name_fn` returns + True for that leaf's path name. + """ + names_and_vals, tree_def = _tree_flatten_with_names(param_tree) + vals = [f(v) if match_name_fn(name) else v for name, v in names_and_vals] + return tree_def.unflatten(vals) + + +def _tree_flatten_with_names( + tree: ParamTree) -> Tuple[Sequence[Tuple[str, Any]], PyTreeDef]: + """Like jax.tree_flatten but also fetches leaf names. + + Specialized to parameter trees of the form {'key0': {'subkey0': Any}, ...}. + + Args: + tree: Tree of parameters to flatten. + + Returns: + - A list of leaf name and value pairs: [(name, value), ...]. + - A tree definition object representing the structure of the flattened tree. + """ + # PyTrees don't treat None values as leaves, so we explicitly declare them as + # such. + vals, tree_def = jax.tree_flatten(tree, is_leaf=lambda x: x is None) + + # 'Fake' token tree that is use to track jax internal tree traversal and + # adjust our custom tree traversal to be compatible with it. + tokens = range(len(vals)) + token_tree = tree_def.unflatten(tokens) + val_names, perm = zip(*_traverse_with_names(token_tree)) + inv_perm = np.argsort(perm) + + # Custom traversal should visit the same number of leaves. + if len(val_names) != len(vals): + raise ValueError(f'Pytree traversal detected {len(val_names)} names, ' + f'but {len(vals)} leafs.\nTreeDef is:\n{tree_def}') + + return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def + + +def _traverse_with_names( + param_tree: ParamTree) -> Iterable[Tuple[str, ParamTree]]: + """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" + if dataclasses.is_dataclass(param_tree): + param_tree = flax.serialization.to_state_dict(param_tree) + if isinstance(param_tree, (dict, flax.core.FrozenDict)): + keys = sorted(param_tree.keys()) + for key in keys: + for path, v in _traverse_with_names(param_tree[key]): + yield (key + '/' + path).rstrip('/'), v + else: + yield '', param_tree diff --git a/t5x/contrib/moe/training_utils_test.py b/t5x/contrib/moe/training_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..efff7c3a9c1445d0a986c26db046e90a27dea39c --- /dev/null +++ b/t5x/contrib/moe/training_utils_test.py @@ -0,0 +1,92 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for training_utils.""" + +import functools +import os +# Emulate 2 devices on CPU. Import before JAX. +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2' + +from absl.testing import absltest # pylint: disable=g-import-not-at-top +from flax import core as flax_core +import jax +from jax import numpy as jnp +import numpy as np + +from t5x.contrib.moe import training_utils + + +class MatchFnTest(absltest.TestCase): + + def test_regex_prefix(self): + match_fn = training_utils.match_fn(r'.*test.*') + self.assertTrue(match_fn('/test/something')) + self.assertTrue(match_fn('to/test/or/not/')) + self.assertFalse(match_fn('no/match')) + + def test_empty_prefix(self): + match_fn = training_utils.match_fn(None) + self.assertFalse(match_fn('/test/something')) + self.assertFalse(match_fn('to/test/or/not/')) + + +class ScaleShardedGradsTest(absltest.TestCase): + + def test_scale_sharded_grads(self): + grads = flax_core.freeze({ + 'encoder': { + 'expert_layer': jnp.ones((2, 3)), + 'regular_layer': jnp.ones((1, 2)) + } + }) + sharded_match_fn = training_utils.match_fn(r'.*expert.*') + scaled_grads = training_utils.scale_sharded_grads( + grads, sharded_match_fn, scale_factor=100.) + + expected_grads = flax_core.freeze({ + 'encoder': { + 'expert_layer': 100. * jnp.ones((2, 3)), + 'regular_layer': jnp.ones((1, 2)) + } + }) + jax.tree_map( + functools.partial(np.testing.assert_allclose, rtol=3e-7), scaled_grads, + expected_grads) + + +class TreeTest(absltest.TestCase): + + def test_tree_flatten_with_names(self): + tree = {'ff_0': {'kernel': 0, 'bias': 1}, 'ff_1': {'kernel': 2, 'bias': 3}} + names_and_values, _ = training_utils._tree_flatten_with_names(tree) + + expected_names_and_values = [('ff_0/bias', 1), ('ff_0/kernel', 0), + ('ff_1/bias', 3), ('ff_1/kernel', 2)] + self.assertEqual(names_and_values, expected_names_and_values) + + # Check that values match regular JAX tree_flatten. + self.assertEqual([x for _, x in names_and_values], + jax.tree_flatten(tree)[0]) + + def test_tree_map_with_names(self): + tree = {'a': 1, 'b': 2} + mapped_tree = training_utils.tree_map_with_names( + f=lambda x: -x, param_tree=tree, match_name_fn=lambda name: name == 'b') + + self.assertEqual(mapped_tree, {'a': 1, 'b': -2}) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/decoding.py b/t5x/decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..4e5e60afbe7c5d2546621e0fa91545cbfec3151f --- /dev/null +++ b/t5x/decoding.py @@ -0,0 +1,1136 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fast decoding routines for inference from a trained model.""" +import functools + +from typing import Any, Callable, Mapping, Optional, Tuple, Union +import flax +from flax import traverse_util +import jax +from jax import lax +from jax import random +import jax.numpy as jnp +import numpy as np + +PyTreeDef = type(jax.tree_structure(None)) +SamplingLoopState = Tuple[int, jnp.ndarray, Mapping[str, jnp.ndarray], + jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray] + +# Constants +# "Effective negative infinity" constant for masking in beam search. +NEG_INF = np.array(-1.0e7) + +# Temperatures lower than this are considered 0.0, which is handled specially +# with a conditional. This is to avoid numeric issues from exponentiating on +# 1.0/temperature when temperature is close to 0.0. +MIN_TEMPERATURE = np.array(1e-4) + +#------------------------------------------------------------------------------ +# Temperature Sampling +#------------------------------------------------------------------------------ +_dynamic_update_vector_slice_in_dim = jax.vmap( + lax.dynamic_update_slice_in_dim, in_axes=(0, 0, 0, None)) + + +def _is_tracer(value: Any): + return isinstance(value, jax.core.Tracer) + + +def temperature_sample( + inputs: jnp.ndarray, + cache: Mapping[str, jnp.ndarray], + tokens_to_logits: Callable[[jnp.ndarray, Mapping[str, jnp.ndarray]], + Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]], + eos_id: int, + decode_rng: Optional[jnp.ndarray] = None, + num_decodes: int = 1, + temperature: Union[float, jnp.ndarray] = 1.0, + topk: int = 1, + topp: float = 0.0, + cache_offset: int = 0, + initial_index: Optional[jnp.ndarray] = None, + max_decode_steps: Optional[Union[int, jnp.ndarray]] = None, + max_decode_steps_hard_limit: Optional[int] = None, + rescale_log_probs: bool = True, + state_callback_fn: Optional[Callable[[SamplingLoopState], + SamplingLoopState]] = None, + logit_callback_fn: Optional[Callable[[jnp.ndarray, SamplingLoopState], + jnp.ndarray]] = None +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Temperature sampling for language model generation. + + The temperature sampling is performed `num_decodes` times in a vectorized + manner by expanding the batch dimension. This is similar to how beam search + expands the batch dimension to process each batch element with multiple beams. + + This function dynamically updates the `inputs` array by sampling from the + model logits, which is provided by `tokens_to_logits` callable. The input + sequences are expanded at the end, populated and sliced by dropping the first + position. + + If `inputs` has non-zero entries, those values are not modified, i.e., + the sampled values for those positions are discarded. This simulates the + teacher forcing on the prefix positions. + + There are a few important observations related to this function. + + 1. The `inputs` is assumed to be a non-packed sequence. + + 2. If `initial_index=None`, then `inputs`[:, 0] is ignored. We will use 0 as a + BOS token to start the generation. This inherently assumes that `inputs` is + already shifted to the right by one position. If `initial_index=an_array`, + the token values at `inputs`[:, initial_index] are used as the token to + start the generation. + + 3. The loop index, i, is a vector of shape [batch_size]. When beginning + generation from scratch, each value will always have the same value. When + beginning with a partially filled cache, the loop index of different + elements can differ, via providing a value for `initial_index`. + + 3. Unless all batch elements generated the eos_id before reaching the end, we + always make `max_decode_len = inputs.shape[1]` number of calls to + `tokens_to_logits` when decoding from scratch and + `max_decode_len - jnp.minimum(initial_index)` number of calls when starting + from a partially filled cache. + + 4. Let `output` be the output sequences, i.e.,`sequences`[:, 1:]. Then + `output`[:, j] are the tokens generated when the while loop counter `i = + j`. Therefore, we generate the last token when `i = max_decode_len - 1` + and exit the while loop as all `i`s are incremented to `max_decode_len`. + + 5. Once `eos_id = 1` is generated, the subsequent predictions are all replaced + by padding token 0. + + 6. When using a partially filled cache, different batch elements can have + different lengths. This means an input that has a longer input will have + fewer steps until its `i` value reaches `max_decode_len` than an input with + a shorter input. We keep these longer examples alive, doing busy work + continually overwriting a new garbage token at the end of the sequence + until shorter examples finish. + + 7. When using a partially filled cache, providing a value for `initial_index`, + the attention cache index should be a vector of [batch_size]. + + We show three examples to illustrate how this function works. In addition to + input and output of the function, we also show two intermediate values: + `expanded_prompt_inputs` and `final_sequences`. Also for simplicity, the + examples are limited to `num_decodes = 1` usage and the `num_decodes` + dimension is omitted. + + ``` + Example 1: + inputs = [0, 5, 6, 1, 0] + expanded_prompt_inputs = [0, 5, 6, 1, 0, 0] + final_sequences = [0, 5, 6, 1, a, b] # before slicing. + output = [5, 6, 1, a, b] + where `a` is prediction while taking 1 as input and `b` is prediction while + taking `a` as input. + + Example 2 (early stopping): + inputs = [[0, 5, 1, 0, 0, 0, 0], + [0, 8, 0, 0, 0, 0, 0] + expanded_prompt_inputs = [[0, 5, 1, 0, 0, 0, 0, 0], + [0, 8, 0, 0, 0, 0, 0, 0] + final_sequences = [[0, 5, 1, a, b, c=1, 0, 0], + [0, 8, d, e, f=1, g=0, 0, 0]] + output = [[5, 1, a, b, c=1, 0, 0], + [8, d, e, f=1, g=0, 0, 0]] + + In this example, there are two sequences. Let's look at sequence 0. The + first generated token is `a`, which is in turn used to generate `b`. + Finally, `c = 1` is generated with the input `b`. Then the loop terminates + early because 1 is the `eos_id`. + + Now consider sequence 1. The when `f = 1` was generated, it is considered + done. Since sequence 0 is not done at this point, the next prediction, i.e., + `g` is zerod out. This continues until the end. + + Example 3 (prefilled cache): + inputs = [[0, 5, 2, 6, 1, 0], + [0, 8, 1, 0, 0, 0]] + expanded_prompt_inputs = [[0, 5, 2, 6, 1, 0, 0, 0], + [0, 8, 1, 0, 0, 0, 0, 0]] + max_decode_length = 6 + i = [4, 2] + input_tokens = [[1], + [1]] + output_tokens = [[a], + [b]] + expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, 0, 0], + [0, 8, 1, b, 0, 0, 0, 0]] + i = [5, 3] + input_tokens = [[a], + [b]] + output_tokens = [[c], + [d]] + expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, 0], + [0, 8, 1, b, d, 0, 0, 0]] + i = [6, 4] + input_tokens = [[c], + [d]] + output_tokens = [[y], + [e]] + expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, y], + [0, 8, 1, b, d, e, 0, 0]] + i = [6, 5] + input_tokens = [[z], + [e]] + output_tokens = [[z], + [f]] + expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, z], + [0, 8, 1, b, d, e, f, 0]] + i = [6, 6] + exit + outputs = [[5, 2, 6, 1, a, c], + [8, 1, b, d, e, f]] + + In this example, there are two sequences with different input lengths. Thus + the two caches had been filled to different positions. As we decode, the + first sequence hits the max decode length before the second. In order to + avoid prematurely ending decoding for the second sequence, the first + sequence continually overwrites the final token. + + Example 4 (prefilled cache and max decode steps): + inputs = [[0, 2, 0, 0, 0, 0, 0, 0], + [0, 3, 4, 0, 0, 0, 0, 0]] + expanded_prompt_inputs = [[0, 2, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 3, 4, 0, 0, 0, 0, 0, 0, 0]] + initial_indices = [1, 2] + max_decode_step = 2 + + Then `max_decode_len = [3, 4]`. + i = [1, 2] + input_tokens = [[2], + [4]] + output_tokens = [[a], + [b]] + expanded_prompt_inputs = [[0, 2, a, 0, 0, 0, 0, 0, 0, 0] + [0, 3, 4, b, 0, 0, 0, 0, 0, 0]] + i = [2, 3]] + input_tokens = [[a], + [b]] + output_tokens = [[c], + [d]] + expanded_prompt_inputs = [[0, 2, a, c, 0, 0, 0, 0, 0, 0] + [0, 3, 4, b, d, 0, 0, 0, 0, 0]] + This is the last while loop iteration with i == max_decode_len - 1. + outputs = [[2, a, c, 0, 0, 0, 0, 0] + [3, 4, b, d, 0, 0, 0, 0]] + ``` + + Args: + inputs: array: [batch_size, max_decode_len] int32 sequence of tokens. + cache: flax attention cache. + tokens_to_logits: fast autoregressive decoder function taking single token + slices and cache and returning next-token logits and updated cache. + eos_id: int: end-of-sentence token for target vocabulary. + decode_rng: JAX PRNGKey. + num_decodes: number of decoded sequences to be returned. + temperature: float: sampling temperature factor. As it approaches zero this + becomes equivalent to greedy sampling. + topk: integer: if nonzero only use the top-k logits to sample next token, if + zero don't use any cutoff and sample from full logits over vocabulary. + topp: float: if nonzero only use the smallest number of logits whose + cumulative sum of probs adds up to (at least) topp. Will raise ValueError + if it's nonzero when topk is nonzero. + cache_offset: axis offset for cache, arising from scanned layers. + initial_index: Optional[array]: [batch_size] int32 a vector of loop indexes + to start decoding at. + max_decode_steps: int: an optional maximum number of decoding steps. If + None, it will decode until the full input shape `inputs.shape[1]` is + filled. max_decode_steps begins counting after the prompt, so it will + decode at most len(prompt) + max_decode_steps tokens. + max_decode_steps_hard_limit: int: an optional fixed hard limit on + max_decode_steps. If this is set (not None and > 0), and max_decode_steps + is also set, then max_decode_steps will be clipped to this limit. The + value max_decode_steps can be an ndarray, but max_decode_steps_hard_limit + must be a Python integer or None. + rescale_log_probs: bool: whether to apply temperature, topp, and topk + rescaling to the log probs which are returned. If True, the log_probs will + include these transformations (for example, with topk=1, all log_probs + will be identically 0.0). If False, the log_probs will not be affected, + and topk/topp/temperature will not affect sequence probabilities. + state_callback_fn: Function that modifies the sampling loop state before + each step. This can be used to manipulate any part of the state either + on the accelerator or on the host using host callback. The function + should take a tuple of type SamplingLoopState as argument, and it + returns the updated state. See `decoding_test.py` for an example usage. + logit_callback_fn: Function that modifies the logits before each temperature + sampling step. The function should take arguments (logits, state) and it + should return the modified logits. See `decoding_test.py` for an example + usage. + + Returns: + A tuple (decodes, log_prob) where `decodes` is sampled sequences with shape + [batch_size, num_decodes, max_decode_len] sorted by `log_prob`, which is log + probability of each of the sampled sequences. + """ + if decode_rng is None: + decode_rng = jax.random.PRNGKey(0) + + if (max_decode_steps_hard_limit is not None and + max_decode_steps_hard_limit > 0 and max_decode_steps is not None): + max_decode_steps = jnp.minimum(max_decode_steps, + max_decode_steps_hard_limit) + + # [batch, len] -> [batch * num_decodes, len] + expanded_inputs = flat_batch_beam_expand(inputs, num_decodes) + expanded_cache = cache_map( + functools.partial( + flat_batch_beam_expand, beam_size=num_decodes, offset=cache_offset), + cache, + # When we start with a prefilled cache, the cache index is no longer a + # scalar that will broadcast across multiple decodes, it is a vector and + # needs to be updated to handle the multiple decodes. + apply_to_index=initial_index is not None) + if initial_index is not None: + initial_index = flat_batch_beam_expand(initial_index, num_decodes) + + # expanded_decodes: [batch * num_decodes, len] + # expanded_log_prob: [batch * num_decodes] + expanded_decodes, expanded_log_prob = _temperature_sample_single_trial( + expanded_inputs, + expanded_cache, + tokens_to_logits, + eos_id, + decode_rng, + temperature, + topk, + topp, + initial_index=initial_index, + max_decode_steps=max_decode_steps, + rescale_log_probs=rescale_log_probs, + state_callback_fn=state_callback_fn, + logit_callback_fn=logit_callback_fn) + + batch_size = inputs.shape[0] + # [batch * num_decodes, len] -> [batch, num_decodes, len] + decodes = unflatten_beam_dim(expanded_decodes, batch_size, num_decodes) + # [batch * num_decodes] -> [batch, num_decodes] + log_prob = unflatten_beam_dim(expanded_log_prob, batch_size, num_decodes) + + # Sort `decodes` and `log_prob` by increasing log probabilities of the sampled + # sequence. + # [batch, num_decodes, 1] + idxs = jnp.expand_dims(jnp.argsort(log_prob, axis=-1), axis=-1) + + # returns [batch, num_decodes, len], [batch, num_decodes] in sorted order. + return jnp.take_along_axis( + decodes, idxs, axis=1), jnp.take_along_axis( + log_prob, jnp.squeeze(idxs, axis=-1), axis=-1) + + +def _temperature_sample_single_trial( + inputs: jnp.ndarray, + cache: Mapping[str, jnp.ndarray], + tokens_to_logits: Callable[[jnp.ndarray, Mapping[str, jnp.ndarray]], + Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]], + eos_id: int, + prng_key: jnp.ndarray, + temperature: Union[float, jnp.ndarray] = 1.0, + topk: int = 20, + topp: Union[float, jnp.ndarray] = 0.0, + initial_index: Optional[jnp.ndarray] = None, + max_decode_steps: Optional[Union[int, jnp.ndarray]] = None, + rescale_log_probs: bool = True, + state_callback_fn: Optional[Callable[[SamplingLoopState], + SamplingLoopState]] = None, + logit_callback_fn: Optional[Callable[[jnp.ndarray, SamplingLoopState], + jnp.ndarray]] = None +) -> jnp.ndarray: + """A helper function for `temperature_sample`.""" + + # We can check the values of topp and topk only if they are not dynamic. + if not _is_tracer(topp) and topp and topk: + raise ValueError('At most one of `topp` or `topk` may be non-zero.') + + batch_size, max_decode_len = inputs.shape + + if max_decode_steps is not None: + # We can check the max_decode_steps bounds only if it is not dynamic. + if not _is_tracer(max_decode_steps) and max_decode_steps > inputs.shape[1]: + raise ValueError('Cannot decode more steps than the sequence length.') + + # The number of decode steps required to process the prefix is the number + # of non-zero tokens, since inputs[0] == 0 is the BOS token. + # `max_decode_len[j]` is the number of non-padding tokens in the jth element + # of the returned sequences capped at `len(inputs)`, assuming that the + # early stop doesn't occur. This is true with or without + # `max_decode_steps`. + # When the while loop index `i` for the `j`th element `i[j] = + # max_decode_len[j] - 1`, the generated token populate sequences[i[j]+1]]. + # Since sequences[:, 0] is BOS token, the generated token is + # `max_decode_len[j]`th non-padding tokens and hence `j`th element is + # ended. + max_decode_len = jnp.sum(inputs != 0, axis=1) + max_decode_steps + max_decode_len = jnp.minimum(inputs.shape[1], max_decode_len) + + # In the case of starting generation from a non-zero index, it is possible for + # one batch element to reach `max_decode_len` number of decoding steps before + # another. In order to let the last element decoder all the way to + # `max_decode_len` number of steps, we add a final garbage token to the end of + # the sequences. Any element that has reached `max_decode_len` before the rest + # of the elements will continually overwrite this token until all elements + # finish. + # [batch, length+1] -> [batch, length+2] + expanded_prompt_inputs = jnp.append( + inputs, jnp.zeros((batch_size, 2), dtype=inputs.dtype), axis=1) + end_marker = jnp.array(eos_id) + + temperature = jnp.asarray(temperature) + + # Initialize sampling loop state. + # initial loop PRNGKey + rng0 = prng_key + # the per batch-item holding current token in loop. + if initial_index is None: + # the per batch-item loop position counter. + i0 = jnp.zeros((batch_size), dtype=jnp.int32) + # the per batch-item holding current token in loop. + token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32) + else: + # the per batch-item loop position counter. + i0 = initial_index + # the per batch-item holding current token in loop. + # Select the token that the initial index is pointing to. + token0 = jnp.take_along_axis( + expanded_prompt_inputs, jnp.expand_dims(i0, axis=1), axis=1) + # per batch-item state bit indicating if sentence has finished. + ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_) + # (batch, length+2) array containing prefix prompt tokens for sampling loop + # as well as the generated output of newly sampled tokens. + sequences0 = expanded_prompt_inputs + log_prob0 = jnp.zeros((batch_size,), dtype=jnp.float32) + # Sampling loop state is stored in a simple tuple. + sampling_loop_init_state = (i0, sequences0, cache, token0, ended0, rng0, + log_prob0) + # Initial eos count to be used to determine whether eos is "generated". Many + # inputs follow the format bos, inputs..., eos, targets..., eos. By counting + # the number of eos tokens we can detect when a new one is added, instead of + # just finding the one that probably ends the inputs. + # [batch, 1] + initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True) + + def sampling_loop_cond_fn(state: SamplingLoopState) -> bool: + """Sampling loop termination condition.""" + (_, _, _, _, ended, _, _) = state + + # Have all sampled sequences reached an end marker? + # Different elements in the batch can be at different loop indices, if any + # of our examples are not at the end, keep going. + all_sequences_ended = jnp.all(ended) + return ~all_sequences_ended + + def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState: + """Sampling loop state update.""" + + if state_callback_fn is not None: + state = state_callback_fn(state) + + i, sequences, cache, cur_token, ended, rng, log_prob = state + # Split RNG for sampling. + rng1, rng2 = random.split(rng) + # Call fast-decoder model on current tokens to get next-position logits. + logits, new_cache = tokens_to_logits(cur_token, cache) + # Sample next token from logits. + + if logit_callback_fn is not None: + logits = logit_callback_fn(logits, state) + + def sample_logits_with_nonzero_temperature(logits): + scaled_logits = logits / jnp.maximum(temperature, MIN_TEMPERATURE) + if topk: + # Get top-k logits and their indices, sample within these top-k tokens. + topk_logits, _ = lax.top_k(scaled_logits, topk) + cutoff_logit = topk_logits[:, -1, None] + scaled_logits = jnp.where(scaled_logits < cutoff_logit, + jnp.full_like(scaled_logits, NEG_INF), + scaled_logits) + + # When topp is dynamic, we always use it since we cannot check + # non-zeroness (but it will have no effect if topp is 0.0). + if _is_tracer(topp) or topp: + logits_sorted = jnp.sort( + scaled_logits, axis=-1)[:, ::-1] # sort descending + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1) + cutoff_index = jnp.sum(sorted_cum_probs < topp, axis=-1, keepdims=True) + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + scaled_logits = jnp.where(scaled_logits < cutoff_logit, + jnp.full_like(scaled_logits, NEG_INF), + scaled_logits) + + # [batch] + next_token = random.categorical(rng1, scaled_logits).astype(jnp.int32) + + # log probability of the current token conditioned on the previously + # sampled and prefix tokens. + # [batch, vocab] -> [batch, vocab] + if rescale_log_probs: + log_probs = jax.nn.log_softmax(scaled_logits) + else: + log_probs = jax.nn.log_softmax(logits) + # [batch, vocab] -> [batch] + next_log_prob = jnp.squeeze( + jnp.take_along_axis( + log_probs, jnp.expand_dims(next_token, axis=1), axis=-1), + axis=-1) + + return (next_token, next_log_prob) + + def sample_logits_with_zero_temperature(logits): + # For zero temperature, we always want the greedy output, regardless + # of the values of topk and topp. + + next_token = jnp.argmax(logits, -1).astype(jnp.int32) + + if rescale_log_probs: + next_log_prob = jnp.zeros_like(next_token, dtype=jnp.float32) + else: + log_probs = jax.nn.log_softmax(logits) + next_log_prob = jnp.squeeze( + jnp.take_along_axis( + log_probs, jnp.expand_dims(next_token, axis=1), axis=-1), + axis=-1) + + return (next_token, next_log_prob) + + # Perform sampling with temperature + (next_token, + next_log_prob) = lax.cond(temperature > MIN_TEMPERATURE, + sample_logits_with_nonzero_temperature, + sample_logits_with_zero_temperature, logits) + + # When different batch elements are at different points in the loop counter, + # it is possible that an element that started at a higher index will reach + # `max_decode_len` before other elements. When this happens we need to make + # sure this element continuous overwrites our new garbage collection index. + # Here we clamp `i` to `max_decode_len`. This will cause the a write to + # `max_decode_len + 1` which is the final index in `sequences`. Subsequent + # loop body executions will also get their value clamped causing continual + # overwriting of the final garbage position until all examples are finished. + i = jnp.minimum(i, max_decode_len) + + # Only use sampled tokens if we're past provided prefix tokens. + # Select the next token from sequences. + # [batch] + next_input_token = jnp.squeeze( + jnp.take_along_axis(sequences, jnp.expand_dims(i + 1, axis=1), axis=1), + axis=1) + # Check if the next token is padding (a target) or non-padding (an input). + # Mask will have `1` for targets and `0` for inputs. + out_of_prompt = (next_input_token == 0) + # Select the sampled next token for targets and the actual next token for + # inputs (teacher forcing). + # [batch] + next_token = ( + next_token * out_of_prompt + next_input_token * ~out_of_prompt) + + # only add probability if outside prefix region + # [batch] -> [batch] + next_log_prob = log_prob + (next_log_prob * out_of_prompt) * jnp.squeeze( + ~ended, axis=-1).astype(jnp.int32) + + # [batch] -> [batch, 1] + next_token = jnp.expand_dims(next_token, axis=-1) + + # If end-marker reached for batch item, only emit padding tokens. + # [batch, 1] * [batch, 1] -> [batch, 1] + next_token_or_endpad = next_token * ~ended + # Add current sampled tokens to recorded sequences. + one_hot = jax.nn.one_hot(i + 1, sequences.shape[1], dtype=sequences.dtype) + new_sequences = sequences * (1 - one_hot) + next_token_or_endpad * one_hot + # new_sequences = dynamic_update_vector_slice_in_dim(sequences, + # next_token_or_endpad, + # i + 1, + # 0) + # Count eos tokens in the sequences and compare to the initial count + # [batch, 1] + cur_eos_count = jnp.sum(new_sequences == end_marker, axis=-1, keepdims=True) + # [batch, 1] + + # Have we reached max decoding length? + # We generally index into sequences[:, i + 1], and sequences.shape[1] = + # max_decode_len + 2, therefore i == max_decode_len - 1 will write to + # sequences[-2] which is our last valid location. i == max_decode_len will + # write to sequences[-1] which is our garbage collection token. Thus `i` + # should be strictly less than max_decode_len. + has_additional_eos = cur_eos_count > initial_eos_count + ended |= has_additional_eos | jnp.expand_dims( + i >= max_decode_len - 1, axis=1) + + return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2, + next_log_prob) + + # Run sampling loop and collect final state. + final_state = lax.while_loop(sampling_loop_cond_fn, sampling_loop_body_fn, + sampling_loop_init_state) + + # Pick part of the state corresponding to the sampled sequences. + final_sequences = final_state[1] + log_prob = final_state[-1] + # Drop the first position because they are dummy bos tokens. Drop the new + # garbage collection token at the end too. + return final_sequences[:, 1:-1], log_prob + + +#------------------------------------------------------------------------------ +# BEAM Sampling +#------------------------------------------------------------------------------ + + +def brevity_penalty(alpha: float, length: int) -> jnp.ndarray: + """Brevity penalty function for beam search penalizing short sequences. + + Args: + alpha: float: brevity-penalty scaling parameter. + length: int: length of considered sequence. + + Returns: + Brevity penalty score as jax scalar. + """ + return jnp.power(((5.0 + length) / 6.0), alpha) + + +# Beam handling utility functions: + + +def cache_map(fn, cache, apply_to_index: bool = False): + """Maps function over that caches, even multiple caches in various layers. + + Args: + fn: The function to apply. + cache: The cache to apply it to. + apply_to_index: Whether to apply the function to the cache index. + + Returns: + The result of applying `fn` to the cache. + """ + frozen = isinstance(cache, flax.core.FrozenDict) + if frozen: + cache = flax.core.unfreeze(cache) + flat_cache = traverse_util.flatten_dict(cache) + if apply_to_index: + keyvals = flat_cache + else: + keyvals = {k: v for k, v in flat_cache.items() if k[-1] != 'cache_index'} + # Exclude cached relative position bias from beam expansion, etc. + # Also excludes scalar index in absolute position embedder from expansion. + # TODO(levskaya): generalize cache_map to accept a list of leaf names to + # map over, instead of doing this ad-hoc. + exclusion_list = ['cached_bias', 'position_embedder_index'] + keyvals = {k: v for k, v in keyvals.items() if k[-1] not in exclusion_list} + + keyvals = jax.tree_map(fn, keyvals) + flat_cache.update(keyvals) + new_cache = traverse_util.unflatten_dict(flat_cache) + if frozen: + new_cache = flax.core.freeze(new_cache) + return new_cache + + +def add_beam_dim(x: jnp.ndarray, + beam_size: int, + offset: int = 0) -> jnp.ndarray: + """Creates new beam dimension in non-scalar array and tiles into it.""" + x = jnp.expand_dims(x, axis=offset + 1) + tile_dims = [1] * x.ndim + tile_dims[offset + 1] = beam_size + return jnp.tile(x, tile_dims) + + +def flatten_beam_dim(x: jnp.ndarray, offset: int = 0) -> jnp.ndarray: + """Flattens the first two dimensions of a non-scalar array.""" + xshape = list(x.shape) + b_sz = xshape.pop(offset) + xshape[offset] *= b_sz + return x.reshape(xshape) + + +def unflatten_beam_dim(x: jnp.ndarray, + batch_size: int, + beam_size: int, + offset: int = 0) -> jnp.ndarray: + """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" + assert batch_size * beam_size == x.shape[offset] + xshape = list(x.shape) + newshape = xshape[:offset] + [batch_size, beam_size] + xshape[offset + 1:] + return x.reshape(newshape) + + +def flat_batch_beam_expand(x: jnp.ndarray, + beam_size: int, + offset: int = 0) -> jnp.ndarray: + """Expands the each batch item by beam_size in batch_dimension.""" + return flatten_beam_dim(add_beam_dim(x, beam_size, offset), offset) + + +def cache_gather_beams(nested: PyTreeDef, + beam_indices: jnp.ndarray, + batch_size: int, + old_beam_size: int, + new_beam_size: int, + one_hot: bool = True, + offset: int = 0) -> jnp.ndarray: + """Gathers the cache beam slices indexed by beam_indices into new beam array. + + Args: + nested: cache pytree. + beam_indices: array of beam_indices + batch_size: size of batch. + old_beam_size: size of _old_ beam dimension. + new_beam_size: size of _new_ beam dimension. + one_hot: whether to perform gathers by one-hot contraction or directly. + offset: cache axis offset from scanned layers. + + Returns: + New pytree with new beam arrays. + [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] + """ + assert offset in (0, 1), 'general offsets not supported' + if one_hot: + # Gather via one-hot contraction, needed for SPMD partitioning. + oh_beam_indices = jax.nn.one_hot( + beam_indices, old_beam_size, dtype=jnp.int32) + if offset == 0: + + def gather_fn(x): + return jnp.einsum('beo,bo...->be...', oh_beam_indices, + x).astype(x.dtype) + else: + + def gather_fn(x): + return jnp.einsum('beo,lbo...->lbe...', oh_beam_indices, + x).astype(x.dtype) + + return cache_map(gather_fn, nested) + + else: + # True gather via fancy indexing. + batch_indices = jnp.reshape( + jnp.arange(batch_size * new_beam_size) // new_beam_size, + (batch_size, new_beam_size)) + if offset == 0: + + def gather_fn(x): + return x[batch_indices, beam_indices] + else: + + def gather_fn(x): + return x[:, batch_indices, beam_indices] + + return cache_map(gather_fn, nested) + + +def gather_beams(nested: PyTreeDef, + beam_indices: jnp.ndarray, + batch_size: int, + old_beam_size: int, + new_beam_size: int, + one_hot: bool = True) -> jnp.ndarray: + """Gathers the beam slices indexed by beam_indices into new beam array. + + Args: + nested: pytree of arrays or scalars (the latter ignored). + beam_indices: array of beam_indices + batch_size: size of batch. + old_beam_size: size of _old_ beam dimension. + new_beam_size: size of _new_ beam dimension. + one_hot: whether to perform gathers by one-hot contraction or directly. + + Returns: + New pytree with new beam arrays. + [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] + """ + if one_hot: + # Gather via one-hot contraction, needed for SPMD partitioning. + oh_beam_indices = jax.nn.one_hot( + beam_indices, old_beam_size, dtype=jnp.int32) + + def gather_fn(x): + return jnp.einsum('beo,bo...->be...', oh_beam_indices, x).astype(x.dtype) + + return jax.tree_map(gather_fn, nested) + else: + # True gather via fancy indexing. + batch_indices = jnp.reshape( + jnp.arange(batch_size * new_beam_size) // new_beam_size, + (batch_size, new_beam_size)) + + def gather_fn(x): + return x[batch_indices, beam_indices] + + return jax.tree_map(gather_fn, nested) + + +def top_k_two_stage(x, k): + """Wrapper around lax.top_k with low-batch optimization. + + Args: + x: tensor with shape f32[batch, num_samples]. + k: integer indicating how many top values to return. + + Returns: + Largest k values and indices with shape (f32[batch, k], s32[batch, k]). + """ + + batch, num_samples = x.shape + num_lanes = 128 + if (isinstance(batch, int) and batch <= 8 and + num_samples > 8 * num_lanes * k): + # At small batch, when num_samples is sufficiently large, optimize + # execution on TPU by doing TopK in two stages. Reshaping 'x' to fill + # lanes reduces tensor padding in TopK call. + if num_samples % num_lanes != 0: + # Pad input tensor to multiples of num_lanes. + num_samples_rounded_up = num_samples + ( + num_lanes - num_samples % num_lanes) + x = jnp.pad( + x, ((0, 0), (0, num_samples_rounded_up - num_samples)), + mode='constant', + constant_values=np.NINF) + num_samples = num_samples_rounded_up + # Reshape input tensor to fill lanes. + num_samples_sublanes = int(num_samples / num_lanes) + x_reshaped = jnp.reshape(x, (batch * num_lanes, num_samples_sublanes)) + # First stage top_k. + vals, indices = lax.top_k(x_reshaped, k) + indices = jnp.reshape(indices, (batch, num_lanes, k)) + index_offsets = jnp.reshape(num_samples_sublanes * jnp.arange(num_lanes), + (1, num_lanes, 1)) + indices = jnp.reshape( + jnp.add(index_offsets, indices), (batch, num_lanes * k)) + vals = jnp.reshape(vals, (batch, num_lanes * k)) + # Second stage top_k. + vals_s2, indices_s2 = lax.top_k(vals, k) + indices_s2 = jnp.take_along_axis(indices, indices_s2, axis=1) + return vals_s2, indices_s2 + else: + # Use default TopK implementation. + return lax.top_k(x, k) + + +def gather_topk_beams(nested: PyTreeDef, score_or_log_prob: jnp.ndarray, + batch_size: int, new_beam_size: int) -> jnp.ndarray: + """Gathers the top-k beam slices given by score_or_log_prob array. + + Args: + nested: pytree of arrays or scalars (the latter ignored). + score_or_log_prob: [batch_size, old_beam_size] array of values to sort by + for top-k selection of beam slices. + batch_size: int: size of batch. + new_beam_size: int: size of _new_ top-k selected beam dimension + + Returns: + New pytree with new beam arrays containing top k new_beam_size slices. + [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] + """ + _, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size) + topk_indices = jnp.flip(topk_indices, axis=1) + return gather_beams(nested, topk_indices, batch_size, + score_or_log_prob.shape[1], new_beam_size) + + +# Beam search state: + + +@flax.struct.dataclass +class BeamState: + """Holds beam search state data.""" + # The position of the decoding loop in the length dimension. + cur_index: jnp.DeviceArray # scalar int32: current decoded length index + # The active sequence log probabilities and finished sequence scores. + live_logprobs: jnp.DeviceArray # float32: [batch_size, beam_size] + finished_scores: jnp.DeviceArray # float32: [batch_size, beam_size] + # The current active-beam-searching and finished sequences. + live_seqs: jnp.DeviceArray # int32: [batch_size, beam_size, max_decode_len] + finished_seqs: jnp.DeviceArray # int32: [batch_size, beam_size, + # max_decode_len] + # Records which of the 'finished_seqs' is occupied and not a filler slot. + finished_flags: jnp.DeviceArray # bool: [batch_size, beam_size] + # The current state of the autoregressive decoding caches. + cache: PyTreeDef # Any pytree of arrays, e.g. flax attention Cache object + + +def beam_init(batch_size: int, + beam_size: int, + max_decode_len: int, + cache: Mapping[str, jnp.ndarray], + offset: int = 0) -> BeamState: + """Initializes the beam search state data structure.""" + cur_index0 = jnp.array(0) + live_logprobs0 = jnp.tile( + jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) + finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF + live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) + # add beam dimension to attention cache pytree elements + beam_cache0 = cache_map(lambda x: add_beam_dim(x, beam_size, offset), cache) + return BeamState( + cur_index=cur_index0, + live_logprobs=live_logprobs0, + finished_scores=finished_scores0, + live_seqs=live_seqs0, + finished_seqs=finished_seqs0, + finished_flags=finished_flags0, + cache=beam_cache0) + + +# Beam search routine: + + +def beam_search(inputs: jnp.ndarray, + cache: Mapping[str, jnp.ndarray], + tokens_to_logits: Callable[ + [jnp.ndarray, Mapping[str, jnp.ndarray]], + Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]], + eos_id: int, + num_decodes: int = 4, + alpha: float = 0.6, + max_decode_len: Optional[int] = None, + decode_rng: Optional[jnp.ndarray] = None, + cache_offset: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Beam search for transformer machine translation. + + If `inputs` has non-zero entries, those values are not modified, i.e., + the sampled values for those positions are discarded. This simulates the + teacher forcing on the prefix positions. + + Args: + inputs: array: [batch_size, length] int32 sequence of tokens. + cache: flax attention cache. + tokens_to_logits: fast autoregressive decoder function taking single token + slices and cache and returning next-token logits and updated cache. + eos_id: int: id of end-of-sentence token for target vocabulary. + num_decodes: number of decoded sequences to be returned. This is equivalent + to the number of beams used in the beam search. + alpha: float: scaling factor for brevity penalty. + max_decode_len: int: an optional maximum length of decoded sequence. If + None, it uses `inputs.shape[1]` as `max_decode_len`. + decode_rng: Unused decoder RNG seed. + cache_offset: axis offset for cache, arising from scanned layers. + + Returns: + Tuple of: + [batch_size, beam_size, max_decode_len] top-scoring sequences + [batch_size, beam_size] beam-search scores. + """ + del decode_rng + # We liberally annotate shape information for clarity below. + + beam_size = num_decodes + + batch_size = inputs.shape[0] + end_marker = jnp.array(eos_id) + if max_decode_len is None: + max_decode_len = inputs.shape[1] + # We start with a dummy token in the beginning so extend the maximum length. + max_decode_len += 1 + + # initialize beam search state + beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len, + cache, cache_offset) + + def beam_search_loop_cond_fn(state: BeamState) -> bool: + """Beam search loop termination condition.""" + # Have we reached max decoding length? + # Because we mutate the "i+1" position, we stop one token before the end. + not_at_end = (state.cur_index < max_decode_len - 1) + + # Is no further progress in the beam search possible? + # Get the best possible scores from alive sequences. + min_brevity_penalty = brevity_penalty(alpha, max_decode_len) + best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty + # Get the worst scores from finished sequences. + worst_finished_scores = jnp.min( + state.finished_scores, axis=1, keepdims=True) + # Mask out scores from slots without any actual finished sequences. + worst_finished_scores = jnp.where(state.finished_flags, + worst_finished_scores, NEG_INF) + # If no best possible live score is better than current worst finished + # scores, the search cannot improve the finished set further. + search_terminated = jnp.all(worst_finished_scores > best_live_scores) + + # If we're not at the max decode length, and the search hasn't terminated, + # continue looping. + return not_at_end & (~search_terminated) + + def beam_search_loop_body_fn(state: BeamState) -> BeamState: + """Beam search loop state update function.""" + # Collect the current position slice along length to feed the fast + # autoregressive decoder model. Flatten the beam dimension into batch + # dimension for feeding into the model. + # --> [batch * beam, 1] + flat_ids = flatten_beam_dim( + lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index), + (batch_size, beam_size, 1))) + # Flatten beam dimension into batch to be compatible with model. + # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} + flat_cache = cache_map( + functools.partial(flatten_beam_dim, offset=cache_offset), state.cache) + + # Call fast-decoder model on current tokens to get next-position logits. + # --> [batch * beam, vocab] + flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache) + + # unflatten beam dimension + # [batch * beam, vocab] --> [batch, beam, vocab] + logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) + # Unflatten beam dimension in attention cache arrays + # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} + new_cache = cache_map( + lambda x: unflatten_beam_dim(x, batch_size, beam_size, cache_offset), + new_flat_cache) + + # Gather log probabilities from logits + candidate_log_probs = jax.nn.log_softmax(logits) + # Add new logprobs to existing prefix logprobs. + # --> [batch, beam, vocab] + log_probs = ( + candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) + + # We'll need the vocab size, gather it from the log probability dimension. + vocab_size = log_probs.shape[-1] + + # Each item in batch has beam_size * vocab_size candidate sequences. + # For each item, get the top 2*k candidates with the highest log- + # probabilities. We gather the top 2*K beams here so that even if the best + # K sequences reach EOS simultaneously, we have another K sequences + # remaining to continue the live beam search. + beams_to_keep = 2 * beam_size + # Flatten beam and vocab dimensions. + flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size)) + # Gather the top 2*K scores from _all_ beams. + # --> [batch, 2*beams], [batch, 2*beams] + topk_log_probs, topk_indices = top_k_two_stage( + flat_log_probs, k=beams_to_keep) + + # Append the most probable 2*K token IDs to the top 2*K sequences + # Recover token id by modulo division. + topk_ids = topk_indices % vocab_size + # Force decode `inputs` into topk_ids up until PAD. When `inputs` is all + # PADs this is a no-op. + next_input_token = jnp.expand_dims( + inputs, axis=1).astype(jnp.int32)[:, :, state.cur_index + 1] + out_of_prompt = (next_input_token == 0) + + # When forcing prompts, update log probabilities to `0` for the top of the + # beam and -INF for the rest, effectively keeping only one beam alive. + # --> [batch, 2*beams] + inside_prompt_log_probs = jnp.concatenate([ + jnp.zeros((batch_size, 1), dtype=topk_log_probs.dtype), + jnp.full_like(topk_log_probs[:, :beams_to_keep - 1], NEG_INF) + ], + axis=1) + topk_log_probs = ( + topk_log_probs * out_of_prompt + + inside_prompt_log_probs * ~out_of_prompt) + + topk_ids = topk_ids * out_of_prompt + next_input_token * ~out_of_prompt + + # Expand id array for broadcasting + # --> [batch, 2*beams, 1] + topk_ids = jnp.expand_dims(topk_ids, axis=2) + + # Recover the beam index by floor division. + topk_beam_indices = topk_indices // vocab_size + # Gather 2*k top beams. + # --> [batch, 2*beams, length] + topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size, + beam_size, beams_to_keep) + # Update sequences for the 2*K top-k new sequences. + # --> [batch, 2*beams, length] + topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids, + (0, 0, state.cur_index + 1)) + + # Update LIVE (in-progress) sequences: + # Did any of these sequences reach an end marker? + # --> [batch, 2*beams] + newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) + # To prevent these newly finished sequences from being added to the LIVE + # set of active beam search sequences, set their log probs to a very large + # negative value. + new_log_probs = topk_log_probs + newly_finished * NEG_INF + # Determine the top k beam indices (from top 2*k beams) from log probs. + # --> [batch, beams] + _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) + new_topk_indices = jnp.flip(new_topk_indices, axis=1) + # Gather the top k beams (from top 2*k beams). + # --> [batch, beams, length], [batch, beams] + top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs], + new_topk_indices, + batch_size, 2 * beam_size, + beam_size) + + # Determine the top k beam indices from the original set of all beams. + # --> [batch, beams] + top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices, + batch_size, 2 * beam_size, beam_size) + # With these, gather the top k beam-associated caches. + # --> {[batch, beams, ...], ...} + top_alive_cache = cache_gather_beams(new_cache, top_alive_indices, + batch_size, beam_size, beam_size, True, + cache_offset) + + # Update FINISHED (reached end of sentence) sequences: + # Calculate new seq scores from log probabilities. + new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) + # Mask out the still unfinished sequences by adding large negative value. + # --> [batch, 2*beams] + new_scores += (~newly_finished) * NEG_INF + + # Combine sequences, scores, and flags along the beam dimension and compare + # new finished sequence scores to existing finished scores and select the + # best from the new set of beams. + finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] + [state.finished_seqs, topk_seq], + axis=1) + finished_scores = jnp.concatenate( # --> [batch, 3*beams] + [state.finished_scores, new_scores], axis=1) + finished_flags = jnp.concatenate( # --> [batch, 3*beams] + [state.finished_flags, newly_finished], axis=1) + # --> [batch, beams, length], [batch, beams], [batch, beams] + top_finished_seq, top_finished_scores, top_finished_flags = ( + gather_topk_beams([finished_seqs, finished_scores, finished_flags], + finished_scores, batch_size, beam_size)) + + return BeamState( + cur_index=state.cur_index + 1, + live_logprobs=top_alive_log_probs, + finished_scores=top_finished_scores, + live_seqs=top_alive_seq, + finished_seqs=top_finished_seq, + finished_flags=top_finished_flags, + cache=top_alive_cache) + + # Run while loop and get final beam search state. + final_state = lax.while_loop(beam_search_loop_cond_fn, + beam_search_loop_body_fn, beam_search_init_state) + + # Account for the edge-case where there are no finished sequences for a + # particular batch item. If so, return live sequences for that batch item. + # --> [batch] + none_finished = jnp.any(final_state.finished_flags, axis=1) + # --> [batch, beams, length] + finished_seqs = jnp.where(none_finished[:, None, None], + final_state.finished_seqs, final_state.live_seqs) + # --> [batch, beams] + finished_scores = jnp.where(none_finished[:, + None], final_state.finished_scores, + final_state.live_logprobs) + + # Drop the first dummy 0 token. + return finished_seqs[:, :, 1:], finished_scores diff --git a/t5x/decoding_test.py b/t5x/decoding_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f874504e048aa5bffe9958dffa7e3d41730aa749 --- /dev/null +++ b/t5x/decoding_test.py @@ -0,0 +1,943 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.decoding.""" + +import functools +from typing import Mapping, Tuple +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import api +from jax.experimental import host_callback as hcb +import jax.numpy as jnp +import numpy as np +from t5x import decoding + +EOS_ID = 1 +NEG_INF = decoding.NEG_INF + + +class DecodeTest(parameterized.TestCase): + + def test_temperature_sample_uneven_prefix(self): + + def token_to_logits(ids, cache): + del ids + del cache + # Always sample id 2 for batch element 0 and id 3 for element 1. + logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], + dtype=np.float32) + return logits, {} + + inputs = np.array([[0, 5, 7, 1, 0, 0], [0, 6, 1, 0, 0, 0]]) + sampled_sequences, _ = decoding._temperature_sample_single_trial( + inputs, {}, + token_to_logits, + EOS_ID, + jax.random.PRNGKey(0), + topk=0, + initial_index=np.array([3, 2])) + expected = np.array([[5, 7, 1, 2, 2, 2], [6, 1, 3, 3, 3, 3]]) + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_temperature_sample_no_prefix(self): + batch, max_decode_len = 2, 3 + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # Always sample id 2 for batch element 0 and id 3 for element 1. + logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], + dtype=np.float32) + return logits, {} + + inputs = np.zeros((batch, max_decode_len), dtype=np.int32) + sampled_sequences, _ = decoding._temperature_sample_single_trial( + inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0) + + expected = [[2, 2, 2], [3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_temperature_sample_prefix(self): + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # Always sample id 2 for batch element 0 and id 3 for element 1. + logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], + dtype=np.float32) + return logits, {} + + # batch element 0 has length 3 prefix and element 1 has length 2. + inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) + sampled_sequences, _ = decoding._temperature_sample_single_trial( + inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=0) + + expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_temperature_sample_with_zero_temperature(self): + batch, max_decode_len = 2, 3 + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # Use very large logits that are close to one another. + logits = np.array( + [[1700.47, 1700.48, 1700.51, 1700.45], [3.2, 4.8, -5.3, 5.6]], + dtype=np.float32) + return logits, {} + + inputs = np.zeros((batch, max_decode_len), dtype=np.int32) + sampled_sequences, _ = decoding._temperature_sample_single_trial( + inputs, {}, + token_to_logits, + EOS_ID, + jax.random.PRNGKey(0), + topk=4, + temperature=0.0) + + expected = [[2, 2, 2], [3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_temperature_sample_prefix_ending_with_eos(self): + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # Always sample id 2 for batch element 0 and id 3 for element 1. + logits = np.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], + dtype=np.float32) + return logits, {} + + # batch element 0 has length 4 prefix (including the initial dummy token and + # the last eos) and element 1 has length 3. + inputs = np.array([[0, 5, 6, 1, 0], [0, 8, 1, 0, 0]], dtype=np.int32) + sampled_sequences, _ = decoding._temperature_sample_single_trial( + inputs, {}, token_to_logits, EOS_ID, jax.random.PRNGKey(0), topk=1) + + expected = [[5, 6, 1, 2, 2], [8, 1, 3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_temperature_sample_with_state_callback(self): + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # A distribution with roughly all probability mass in sample id 3 + logits = np.array([[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], + dtype=np.float32) + return logits, {} + + def state_callback_fn(state): + i, sequences, cache, cur_token, ended, rng, log_prob = state + + def callback_fn(current_index_and_sequences): + """Add EOS token after first time token id 3 has been sampled.""" + current_index, sequences = current_index_and_sequences + sequences = np.array(sequences) + for i in range(len(current_index)): + if sequences[i, current_index[i]] == 3: + sequences[i, current_index[i] + 1] = EOS_ID + return sequences + + sequences = hcb.call( + callback_fn, (i, sequences), + result_shape=api.ShapeDtypeStruct(sequences.shape, sequences.dtype)) + return i, sequences, cache, cur_token, ended, rng, log_prob + + inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) + sampled_sequences, _ = decoding._temperature_sample_single_trial( + inputs, {}, + token_to_logits, + EOS_ID, + jax.random.PRNGKey(0), + topk=0, + temperature=0.0, + state_callback_fn=state_callback_fn) + + expected = [[5, 6, 7, 3, EOS_ID], [8, 9, 3, EOS_ID, 0]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_temperature_sample_with_logit_callback(self): + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # uniform distribution over targets from model + logits = np.array([[-1e7, -1e7, -1e7, -1e7], [-1e7, -1e7, -1e7, -1e7]], + dtype=np.float32) + return logits, {} + + def logit_callback_fn(logits, state): + del state # unused + # Rewrite logits to always sample id 2 for batch element 0 and + # id 3 for element 1. + logits[0, 2] = 0 + logits[1, 3] = 0 + return logits + + # batch element 0 has length 3 prefix and element 1 has length 2. + inputs = np.array([[0, 5, 6, 7, 0], [0, 8, 9, 0, 0]], dtype=np.int32) + sampled_sequences, _ = decoding._temperature_sample_single_trial( + inputs, {}, + token_to_logits, + EOS_ID, + jax.random.PRNGKey(0), + topk=0, + temperature=0.0, + logit_callback_fn=logit_callback_fn) + + expected = [[5, 6, 7, 2, 2], [8, 9, 3, 3, 3]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_temperature_sample_prefix_ending_with_eos_early_stop(self): + batch, max_decode_len = 2, 7 + rng0 = jax.random.PRNGKey(0) + + ret = [np.array([2, 3]) for _ in range(max_decode_len)] + # Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of + # `decoding._temperature_sample_single_trial`. + ret[3] = np.array([2, 1]) + # Sequence 0 outputs EOS=1 when i = 4. + ret[4] = np.array([1, 3]) + ret = jax.numpy.array(ret) + + def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument + """Ignores logit and returns only based on the rng_input.""" + rng = rng0 + k = 0 + # Mimic the rng split done in `decoding.sample_loop_body_fn`. + for j in range(max_decode_len): + rng1, rng = jax.random.split(rng) + # We want to sift out `j` for which rng1 == rng_input + # rngs are a pair of ints. So sum the bool and divide by 2. + k += j * (rng1 == rng_input).sum() // 2 + # `k` at this point is equal to the while loop variable `i` of the caller. + return ret[k] + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # These values are not used in this test because random.categorical is + # directly mocked. + dummy_logits = np.zeros((batch, 4), dtype=np.float32) + return dummy_logits, {} + + inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], + dtype=np.int32) + with mock.patch.object(jax.random, 'categorical', new=mocked_categorical): + sampled_sequences, _ = decoding._temperature_sample_single_trial( + inputs, {}, token_to_logits, EOS_ID, rng0, topk=0) + + expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] + np.testing.assert_array_equal(expected, sampled_sequences) + + def test_greedy_decoding_topk_sample_log_probs(self): + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # Sample [2, 3] with probability [0.6, 0.4]. + logits = np.array([[-1e7, -1e7, -0.510825624, -0.916290732]], + dtype=np.float32) + return logits, {} + + inputs = np.array([[0, 2, 2, 2, 0]], dtype=np.int32) + sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial( + inputs, {}, + token_to_logits, + EOS_ID, + jax.random.PRNGKey(0), + topk=1, + rescale_log_probs=True) + + expected_sequence = [[2, 2, 2, 2, 2]] + expected_log_probs = [0.0] + np.testing.assert_array_equal(expected_sequence, sampled_sequences) + np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) + + inputs = np.array([[0, 2, 2, 3, 0]], dtype=np.int32) + sampled_sequences, sampled_log_probs = decoding._temperature_sample_single_trial( + inputs, {}, + token_to_logits, + EOS_ID, + jax.random.PRNGKey(0), + topk=1, + rescale_log_probs=False) + + expected_sequence = [[2, 2, 3, 2, 2]] + expected_log_probs = [-1.02165125] + np.testing.assert_array_equal(expected_sequence, sampled_sequences) + np.testing.assert_array_almost_equal(expected_log_probs, sampled_log_probs) + + def test_temperature_sample_log_prob(self): + batch, max_decode_len = 2, 7 + rng0 = jax.random.PRNGKey(0) + + ret = [np.array([2, 3]) for _ in range(max_decode_len)] + # Sequence 1 outputs EOS=1 when i = 3 where `i` is the while loop counter of + # `decoding._temperature_sample_single_trial`. + ret[3] = np.array([2, 1]) + # Sequence 0 outputs EOS=1 when i = 4. + ret[4] = np.array([1, 3]) + ret = jax.numpy.array(ret) + + # TODO(hwchung): refactor this. + def mocked_categorical(rng_input, logits): # pylint: disable=unused-argument + """Ignores logit and returns only based on the rng_input.""" + rng = rng0 + k = 0 + # Mimic the rng split done in `decoding.sample_loop_body_fn`. + for j in range(max_decode_len): + rng1, rng = jax.random.split(rng) + # We want to sift out `j` for which rng1 == rng_input + # rngs are a pair of ints. So sum the bool and divide by 2. + k += j * (rng1 == rng_input).sum() // 2 + # `k` at this point is equal to the while loop variable `i` of the caller. + return ret[k] + + logits = np.random.randn(batch, 4) + token_to_logits = lambda ids, cache: (logits, {}) + inputs = np.array([[0, 5, 1, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0]], + dtype=np.int32) + with mock.patch.object(jax.random, 'categorical', new=mocked_categorical): + sampled_sequences, log_prob = decoding._temperature_sample_single_trial( + inputs, {}, token_to_logits, EOS_ID, rng0, topk=0) + + log_probs = jax.nn.log_softmax(logits) + expected = [[5, 1, 2, 2, 1, 0, 0], [8, 3, 3, 1, 0, 0, 0]] + expected_log_prob = [ + log_probs[0, 2] + log_probs[0, 2] + log_probs[0, 1], + log_probs[1, 3] + log_probs[1, 3] + log_probs[1, 1] + ] + expected_log_prob = np.array(expected_log_prob) + np.testing.assert_array_equal(expected, sampled_sequences) + np.testing.assert_allclose(expected_log_prob, log_prob, atol=1e-5) + + def test_temperature_sample_num_decodes(self): + num_decodes = 3 + rng0 = jax.random.PRNGKey(0) + inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) + + with mock.patch.object(decoding, + '_temperature_sample_single_trial') as mocked: + # expanded_decodes: [batch * num_decodes, max_decode_len] + expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3], + [8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]]) + # expanded_log_prob: [batch * num_decodes] + expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) + mocked.return_value = expanded_decodes, expanded_log_prob + + decodes, scores = decoding.temperature_sample( + inputs, {}, mock.Mock(), EOS_ID, rng0, num_decodes=num_decodes) + + expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0], + [0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]]) + # Test that the actual decode function is called with the expanded values. + np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) + + np.testing.assert_array_equal(decodes, + [[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], + [[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]]) + np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) + + def test_temperature_sample_num_decodes_with_initial_index(self): + num_decodes = 3 + rng0 = jax.random.PRNGKey(0) + inputs = np.array([[0, 5, 1, 0], [0, 8, 7, 0]], dtype=np.int32) + initial_index = np.array([1, 2], dtype=np.int32) + + with mock.patch.object(decoding, + '_temperature_sample_single_trial') as mocked: + with mock.patch.object(decoding, 'cache_map') as mocked_cache_map: + # expanded_decodes: [batch * num_decodes, max_decode_len] + expanded_decodes = np.array([[5, 1, 4, 4], [5, 1, 5, 5], [5, 1, 3, 3], + [8, 7, 5, 5], [8, 7, 3, 3], [8, 7, 4, 4]]) + # expanded_log_prob: [batch * num_decodes] + expanded_log_prob = np.array([-2.3, -1.3, -3.6, -0.5, -2.5, -1.9]) + mocked.return_value = expanded_decodes, expanded_log_prob + + decodes, scores = decoding.temperature_sample( + inputs, {}, + mock.Mock(), + EOS_ID, + rng0, + num_decodes=num_decodes, + initial_index=initial_index) + + expanded_inputs = jnp.array([[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0], + [0, 8, 7, 0], [0, 8, 7, 0], [0, 8, 7, 0]]) + expanded_initial_index = np.array([1, 1, 1, 2, 2, 2], dtype=np.int32) + # Test that the actual decode function is called with the expanded + # values. + np.testing.assert_array_equal(mocked.call_args[0][0], expanded_inputs) + np.testing.assert_array_equal(mocked.call_args[1]['initial_index'], + expanded_initial_index) + # Test that the function was applied to the index in the cache map + self.assertTrue(mocked_cache_map.call_args[1]['apply_to_index']) + + np.testing.assert_array_equal(decodes, + [[[5, 1, 3, 3], [5, 1, 4, 4], [5, 1, 5, 5]], + [[8, 7, 3, 3], [8, 7, 4, 4], [8, 7, 5, 5]]]) + np.testing.assert_allclose(scores, [[-3.6, -2.3, -1.3], [-2.5, -1.9, -0.5]]) + + @parameterized.named_parameters( + dict( + testcase_name='no_initial_index', + initial_index=None, + expected_calls=6, + ), + dict( + testcase_name='initial_index', + initial_index=np.array([1, 2], dtype=np.int32), + expected_calls=4, + ), + dict( + testcase_name='lower_initial_index', + initial_index=np.array([1, 1], dtype=np.int32), + expected_calls=5, # we decode 4 tokens out of the prompt + ), + ) + def test_temperature_sample_max_decode_steps_with_initial_index( + self, initial_index, expected_calls): + max_decode_steps = 4 + rng0 = jax.random.PRNGKey(0) + inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], + dtype=np.int32) + + token_to_logits = mock.Mock() + token_to_logits.return_value = (np.array( + [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) + + # to unroll while loop + with jax.disable_jit(): + decodes, scores = decoding.temperature_sample( + inputs, {}, + token_to_logits, + EOS_ID, + rng0, + initial_index=initial_index, + topk=4, + max_decode_steps=max_decode_steps) + + self.assertLen(token_to_logits.call_args_list, expected_calls) + + expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0], + [2, 2, 3, 3, 3, 3, 0, 0]]) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.], [0.]]) + + def test_temperature_sample_max_decode_steps_endpad(self): + max_decode_steps = 4 + rng0 = jax.random.PRNGKey(0) + inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 2, 2, 2, 2, 0], + [0, 2, 2, 2, 0, 0, 0, 0]], + dtype=np.int32) + initial_index = np.array([1, 6, 0]) + + token_to_logits = mock.Mock() + token_to_logits.return_value = (np.array( + [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], + dtype=np.float32), {}) + + # to unroll while loop + with jax.disable_jit(): + decodes, scores = decoding.temperature_sample( + inputs, {}, + token_to_logits, + EOS_ID, + rng0, + initial_index=initial_index, + topk=4, + max_decode_steps=max_decode_steps) + + # `inputs[2]` starts from index 0. So it requires 3 calls to + # `token_to_logits` to exit the prompt (these generated tokens are + # overridden) and 4 more calls to fill the rest. `inputs[0]` only need 4 + # calls. In the last 3 calls, it generates but MUST NOT populate the + # sequences because it is already ended. + self.assertLen(token_to_logits.call_args_list, 7) + expected_output = np.array( + [[2, 3, 3, 3, 3, 0, 0, 0], [2, 2, 2, 2, 2, 2, 3, 3], + [2, 2, 2, 3, 3, 3, 3, 0]], + dtype=np.int32) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_allclose(scores, [[0.], [0.], [0.]]) + + def test_temperature_sample_max_decode_steps_docstring_ex4(self): + max_decode_steps = 2 + rng0 = jax.random.PRNGKey(0) + inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 3, 4, 0, 0, 0, 0, 0]], + dtype=np.int32) + initial_index = np.array([1, 2]) + + token_to_logits = mock.Mock() + token_to_logits.return_value = (np.array( + [[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) + + # to unroll while loop + with jax.disable_jit(): + decodes, _ = decoding.temperature_sample( + inputs, {}, + token_to_logits, + EOS_ID, + rng0, + initial_index=initial_index, + topk=4, + max_decode_steps=max_decode_steps) + self.assertLen(token_to_logits.call_args_list, 2) + expected_output = np.array( + [[2, 2, 2, 0, 0, 0, 0, 0], [3, 4, 3, 3, 0, 0, 0, 0]], dtype=np.int32) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + + def test_temperature_sample_max_decode_steps_hard_limit(self): + max_decode_steps = 10 + max_decode_steps_hard_limit = 4 + rng0 = jax.random.PRNGKey(0) + inputs = np.array([[0, 2, 0, 0, 0, 0, 0, 0], [0, 2, 2, 0, 0, 0, 0, 0]], + dtype=np.int32) + + token_to_logits = mock.Mock() + token_to_logits.return_value = (np.array( + [[-1e7, -1e7, -1e7, 0], [-1e7, -1e7, -1e7, 0]], dtype=np.float32), {}) + + # to unroll while loop + with jax.disable_jit(): + decodes, scores = decoding.temperature_sample( + inputs, {}, + token_to_logits, + EOS_ID, + rng0, + topk=4, + max_decode_steps=max_decode_steps, + max_decode_steps_hard_limit=max_decode_steps_hard_limit) + + expected_output = np.array([[2, 3, 3, 3, 3, 0, 0, 0], + [2, 2, 3, 3, 3, 3, 0, 0]]) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.], [0.]]) + + def test_temperature_sample_topp(self): + rng0 = jax.random.PRNGKey(0) + inputs = np.zeros((1, 20), dtype=np.int32) + + token_to_logits = mock.Mock() + + # logits correspond to (0.3, 0, 0.1, 0.6) + token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]], + dtype=np.float32), {}) + + decodes, scores = decoding.temperature_sample( + inputs, {}, token_to_logits, EOS_ID, rng0, topp=0.55, + topk=0) # anything under 0.6 will trigger deterministic decoding. + + expected_output = np.array([[3] * 20]) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.]]) + + # temperature is applied first, so the distribution becomes + # (0.27, 0, 0.069, 0.65), so if topp is 0.63, it should become greedy. + decodes, scores = decoding.temperature_sample( + inputs, {}, + token_to_logits, + EOS_ID, + rng0, + temperature=0.8, + topp=0.63, + topk=0) + + expected_output = np.array([[3] * 20]) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.]]) + + def test_dynamic_topp_max_decode_steps(self): + rng0 = jax.random.PRNGKey(0) + inputs = np.zeros((1, 20), dtype=np.int32) + + token_to_logits = mock.Mock() + + # logits correspond to (0.3, 0, 0.1, 0.6) + token_to_logits.return_value = (np.array([[-1.2, -1e7, -2.3, -0.51]], + dtype=np.float32), {}) + + def dynamic_decode_fn(inputs, temperature, topp, max_decode_steps): + return decoding.temperature_sample( + inputs, {}, + token_to_logits, + EOS_ID, + rng0, + temperature=temperature, + topp=topp, + topk=0, + max_decode_steps=max_decode_steps) + + dynamic_decode_fn_jit = jax.jit(dynamic_decode_fn) + + decodes, scores = dynamic_decode_fn_jit(inputs, 0.8, 0.63, 10) + + expected_output = np.array([[3] * 10 + [0] * 10]) + expected_output = jnp.expand_dims(expected_output, 1) + + np.testing.assert_array_equal(decodes, expected_output) + np.testing.assert_array_equal(scores, [[0.]]) + + def test_topp_log_probs(self): + rng0 = jax.random.PRNGKey(0) + inputs = np.zeros((1, 1), dtype=np.int32) + + token_to_logits = mock.Mock() + + # logits correspond to (0.3, 0, 0.1, 0.6) + token_to_logits.return_value = (np.array([[-1.2, NEG_INF, -2.3, -0.51]], + dtype=np.float32), {}) + + with jax.disable_jit(): + # this lets us see logits after topp and topk are applied + with mock.patch.object(jax.random, 'categorical') as mocked: + mocked.return_value = jnp.array([0], dtype=jnp.int32) + decodes, _ = decoding.temperature_sample( + inputs, {}, + token_to_logits, + EOS_ID, + rng0, + temperature=1.4, + topp=0.7, + topk=0) + + self.assertLen(token_to_logits.call_args_list, 1) + np.testing.assert_array_equal(decodes, jnp.asarray([[[0]]])) + + np.testing.assert_array_almost_equal( + mocked.call_args_list[0][0][1], + jnp.asarray([[-0.85714293, NEG_INF, NEG_INF, -0.36428571]])) + + def test_add_beam_dim(self): + x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) + y = decoding.add_beam_dim(x, beam_size=3) + self.assertEqual(y.shape, (2, 3, 4)) + np.testing.assert_array_equal([[[0, 5, 1, 0], [0, 5, 1, 0], [0, 5, 1, 0]], + [[0, 8, 6, 9], [0, 8, 6, 9], [0, 8, 6, 9]]], + y) + + def test_flat_batch_beam_expand(self): + x = np.array([[0, 5, 1, 0], [0, 8, 6, 9]], dtype=np.int32) + np.testing.assert_array_equal( + [[0, 5, 1, 0], [0, 5, 1, 0], [0, 8, 6, 9], [0, 8, 6, 9]], + decoding.flat_batch_beam_expand(x, beam_size=2)) + + def test_top_k_two_stage(self): + + def _test_top_k(batch_size, k): + # Pick sufficiently large seq_len. + seq_len = 2047 * k * batch_size + seq = np.arange(seq_len) + np.random.shuffle(seq) + x = jnp.reshape(seq, (batch_size, int(seq_len / batch_size))).astype( + jnp.float32) + np.testing.assert_almost_equal( + decoding.top_k_two_stage(x, k), jax.lax.top_k(x, k), decimal=5) + + # Test small batch cases (batch={1,8}, k=16). + _test_top_k(1, 16) + _test_top_k(8, 16) + # Test large batch cases (batch={9,32}, k=11). + _test_top_k(9, 11) + _test_top_k(32, 11) + + def test_cache_map(self): + cache = { + 'layers_0': { + 'cached_key': jnp.ones([3, 6]), + 'cached_values': jnp.ones([3, 6]), + 'cache_index': jnp.ones([ + 3, + ]), + }, + 'layers_1': { + 'self_attention': { + 'cached_key': jnp.ones([2, 7]), + 'cached_values': jnp.ones([5, 8]), + 'cache_index': jnp.array(1), + }, + 'encoder_decoder_attention': { + 'cached_key': jnp.ones([10, 12, 2]), + 'cached_values': jnp.ones([4, 7, 2]), + 'cache_index': jnp.ones([4, 5, 6]), + } + }, + } + + fn = functools.partial(jnp.add, 4) + + gold_cache = { + 'layers_0': { + 'cached_key': fn(jnp.ones([3, 6])), + 'cached_values': fn(jnp.ones([3, 6])), + 'cache_index': jnp.ones([ + 3, + ]), + }, + 'layers_1': { + 'self_attention': { + 'cached_key': fn(jnp.ones([2, 7])), + 'cached_values': fn(jnp.ones([5, 8])), + 'cache_index': jnp.array(1), + }, + 'encoder_decoder_attention': { + 'cached_key': fn(jnp.ones([10, 12, 2])), + 'cached_values': fn(jnp.ones([4, 7, 2])), + 'cache_index': jnp.ones([4, 5, 6]), + } + } + } + + jax.tree_multimap(np.testing.assert_array_equal, + decoding.cache_map(fn, cache), gold_cache) + + def test_cache_map_with_index(self): + cache = { + 'layers_0': { + 'cached_key': jnp.ones([3, 6]), + 'cached_values': jnp.ones([3, 6]), + 'cache_index': jnp.ones([ + 3, + ]), + }, + 'layers_1': { + 'relpos_bias': { + 'cached_bias': jnp.ones([1, 5, 3]), + }, + 'self_attention': { + 'cached_key': jnp.ones([2, 7]), + 'cached_values': jnp.ones([5, 8]), + 'cache_index': jnp.array(1), + }, + 'encoder_decoder_attention': { + 'cached_key': jnp.ones([10, 12, 2]), + 'cached_values': jnp.ones([4, 7, 2]), + 'cache_index': jnp.ones([4, 5, 6]), + } + }, + 'position_embedder': { + 'position_embedder_index': jnp.array(-1), + }, + } + + fn = functools.partial(jnp.add, 8) + + gold_cache = { + 'layers_0': { + 'cached_key': fn(jnp.ones([3, 6])), + 'cached_values': fn(jnp.ones([3, 6])), + 'cache_index': fn(jnp.ones([ + 3, + ])), + }, + 'layers_1': { + 'relpos_bias': { + 'cached_bias': jnp.ones([1, 5, 3]), + }, + 'self_attention': { + 'cached_key': fn(jnp.ones([2, 7])), + 'cached_values': fn(jnp.ones([5, 8])), + 'cache_index': fn(jnp.array(1)), + }, + 'encoder_decoder_attention': { + 'cached_key': fn(jnp.ones([10, 12, 2])), + 'cached_values': fn(jnp.ones([4, 7, 2])), + 'cache_index': fn(jnp.ones([4, 5, 6])), + } + }, + 'position_embedder': { + 'position_embedder_index': jnp.array(-1), + }, + } + + jax.tree_multimap(np.testing.assert_array_equal, + decoding.cache_map(fn, cache, apply_to_index=True), + gold_cache) + + def test_beam_search(self): + # Toy problem, we have 4 states, A, B, START, END, (plus PAD). + # Scores are given by a first-order Markov model. + batch_size = 2 + beam_size = 2 + # PAD doesn't matter for this test, but part of the contract for beam_search + # is giving the PAD token id 0. + states = ['PAD', 'A', 'B', 'START-', '-END'] + num_states = len(states) + decode_length = 7 + + # Edge potentials (written inside edges for diagonals): + # 1 -1 1 -1 + # A ---- A ---- A ---- A ---- A + # 0 \ -1 \ 1 \ -1 \ 1 0 + # START X X X X END + # 0 / -1 / 1 / -1 / 1 0 + # B ---- B ---- B ---- B ---- B + # 1 -1 1 -1 + + # put the above edge potentials in a 3-tensor + ab_edge_potentials = np.asarray([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]], + [[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) + # now we have to add on the START, END states + # and PAD at 0 + edge_potentials = np.ones([6, 5, 5]) * NEG_INF + edge_potentials[1:5, 1:3, 1:3] = ab_edge_potentials + # START can go to either A or B for free at t0 + edge_potentials[0, 3, 1] = 0 + edge_potentials[0, 3, 2] = 0 + # either A or B can go to END for free at t5 + edge_potentials[5, 1, 4] = 0 + edge_potentials[5, 2, 4] = 0 + # PAD can go to anything for free (doesn't matter for this test) + edge_potentials[:, 0, :] = 0 + + edge_potentials = jnp.asarray(edge_potentials) + + # at time 0, we start with state=START=3 + logits0 = jnp.asarray([NEG_INF, NEG_INF, NEG_INF, 0, NEG_INF]) + + # add dummy flattened batch x beam dim for broadcasting + logits0 = jnp.expand_dims(logits0, axis=0) + edge_potentials = jnp.expand_dims(edge_potentials, axis=0) + + def tokens_to_logits( + token_indices: jnp.ndarray, state_cache: Mapping[str, jnp.ndarray] + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + cur_iter = state_cache['cur_iter'] + # grab edge potentials for the current timestep + cur_edge_potentials = jnp.take_along_axis( + edge_potentials, + jnp.reshape( + jnp.maximum(0, cur_iter[:, 0].astype(jnp.int32) - 1), + (batch_size * beam_size, 1, 1, 1)), + axis=1) + cur_edge_potentials = jnp.squeeze(cur_edge_potentials, axis=1) + # get "logits" from edge potentials for requested tokens (except at t0) + cur_logits = jnp.matmul( + jnp.reshape( + jax.nn.one_hot(token_indices, num_states, axis=1), + (batch_size * beam_size, 1, num_states)), cur_edge_potentials) + cur_logits = jnp.squeeze(cur_logits, axis=1) + # use our START-only logits for t0, otherwise use the edge potentials + logits_for_tokens = jnp.where(cur_iter == 0, logits0, cur_logits) + # update state in the cache + new_cache = state_cache.copy() + new_cache['cur_iter'] = cur_iter + 1 + return logits_for_tokens, new_cache + + init_cache = {} + init_cache['cur_iter'] = jnp.zeros((batch_size, 1)) + + top_scoring, _ = decoding.beam_search( + inputs=np.zeros([batch_size, decode_length]), + cache=init_cache, + tokens_to_logits=tokens_to_logits, + eos_id=4, + num_decodes=beam_size, + alpha=0.0, + max_decode_len=decode_length) + + # The two top scoring sequences should be a tie between + # START-AABBA-END + # and + # START-BBAAB-END + # (and greedy beam search will find both these with just two beams) + + top_scoring_strings = [ + ''.join(states[tok] + for tok in top_scoring[0, i, :]) + for i in range(beam_size) + ] + + expected = ['START-AABBA-END', 'START-BBAAB-END'] + np.testing.assert_array_equal(expected, top_scoring_strings) + + def test_beam_search_force_decode_prefix(self): + beam_size = 2 + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. + logits = np.repeat( + np.expand_dims( + np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], + [-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]], + dtype=np.float32), + axis=1), [beam_size], + axis=1) + logits = decoding.flatten_beam_dim(logits) + return logits, {} + + # batch element 0 has length 1 and element 1 has length 2. + inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) + rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32) + beam_search_sequences, decoding_scores = decoding.beam_search( + inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size, alpha=0) + + # Prefixes are forced depending on inputs. + # Beam search sequences and corresponding scores are in reverse order. + self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) + expected = np.array([[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]], + [[4, 5, 2, 3, 3], [4, 5, 3, 3, 3]]]) + np.testing.assert_array_equal(expected, beam_search_sequences) + + expected_scores = [] + batch_logits = np.array([[-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], + [-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4]], + dtype=np.float32) + for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs): + beam_expected_scores = [] + for beam in batch: + log_probs = jax.nn.log_softmax(logits) + # Add them directly since they are static. + beam_scores = [] + for token, prompt_token in zip(beam, prompt): + if prompt_token != 0: + beam_scores.append(0) + else: + beam_scores.append(log_probs[token]) + beam_expected_scores.append(sum(beam_scores)) + expected_scores.append(beam_expected_scores) + np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5) + + def test_beam_search_force_decode_no_prefix(self): + beam_size = 2 + + def token_to_logits(ids, cache): # pylint: disable=unused-argument + # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. + logits = np.repeat( + np.expand_dims( + np.array([[-1e7, -1e10, -0.1, -0.9], [-1e7, -1e10, -0.9, -0.1]], + dtype=np.float32), + axis=1), [beam_size], + axis=1) + logits = decoding.flatten_beam_dim(logits) + return logits, {} + + # No prefix is passed. + inputs = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=np.int32) + beam_search_sequences, decoding_scores = decoding.beam_search( + inputs, {}, token_to_logits, EOS_ID, num_decodes=beam_size) + + # Prefixes are forced depending on inputs. + # Beam search sequences and corresponding scores are in reverse order. + self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) + expected = np.array([[[3, 2, 2, 2, 2], [2, 2, 2, 2, 2]], + [[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]]) + np.testing.assert_array_equal(expected, beam_search_sequences) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/eval.py b/t5x/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6f85a6230202a6fa83fdbe39bf6e27cac420c6 --- /dev/null +++ b/t5x/eval.py @@ -0,0 +1,247 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=line-too-long +# pyformat: disable +r"""This script runs inference-evaluation on a T5X-compatible model. + +""" +# pyformat: enable +# pylint:enable=line-too-long + +import functools +import os +from typing import Optional, Sequence, Type + +# pylint:disable=g-import-not-at-top +# TODO(adarob): Re-enable once users are notified and tests are updated. +os.environ['FLAX_LAZY_RNG'] = 'no' +from absl import logging +from clu import metric_writers +import jax +from jax.experimental import multihost_utils +import seqio +from t5x import gin_utils +from t5x import models +from t5x import partitioning +from t5x import utils +from typing_extensions import Protocol + +# Automatically search for gin files relative to the T5X package. +_DEFAULT_GIN_SEARCH_PATHS = [ + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +] + + +class SummarizeConfigFn(Protocol): + + def __call__(self, model_dir: str, + summary_writer: Optional[metric_writers.SummaryWriter], + step: int) -> None: + ... + + +def evaluate( + *, + model: models.BaseTransformerModel, + dataset_cfg: utils.DatasetConfig, + restore_checkpoint_cfg: utils.RestoreCheckpointConfig, + partitioner: partitioning.BasePartitioner, + output_dir: str, + inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator, + summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, + fallback_init_rng: Optional[int] = None): + """Evaluation function. + + Args: + model: The model object to use for inference. + dataset_cfg: Specification for the dataset to infer based on. + restore_checkpoint_cfg: Specification for the model parameter checkpoint to + load. + partitioner: Partitioner for the model parameters and data across devices. + output_dir: Path to directory to write temporary files and final results. + inference_evaluator_cls: seqio.Evaluator class to use for inference + evaluation, potentially with bound configuration args. + summarize_config_fn: A function that takes in the model directory, an + optional SummaryWriter, and the step number, and writes a summary of the + configuration. SummaryWriter will be None in most cases. + fallback_init_rng: A random seed used for parameter initialization during + model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is + set to True. If None, parameter initialization is not allowed during model + loading and having fallback_to_scratch enabled will result in an error. + """ + logging.info('Process ID: %d', jax.process_index()) + if dataset_cfg.module: + utils.import_module(dataset_cfg.module) + batch_size = dataset_cfg.batch_size + + summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) + + ds_vocabs = utils.get_vocabulary(dataset_cfg) + if (ds_vocabs[0] != model.input_vocabulary or + ds_vocabs[1] != model.output_vocabulary): + raise ValueError(f'Model and Task vocabularies do not match:\n' + f' task={dataset_cfg.mixture_or_task_name}\n' + f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' + f' model.input_vocabulary={model.input_vocabulary}\n' + f' model.output_vocabulary={model.output_vocabulary}\n') + + # ---------------------------------------------------------------------------- + # SeqIO (inference-based) evaluation setup + # ---------------------------------------------------------------------------- + # Init evaluator to set up cached datasets + evaluator = inference_evaluator_cls( + mixture_or_task_name=dataset_cfg.mixture_or_task_name, + feature_converter=model.FEATURE_CONVERTER_CLS(pack=False), + eval_split=dataset_cfg.split, + use_cached=dataset_cfg.use_cached, + seed=dataset_cfg.seed, + sequence_length=dataset_cfg.task_feature_lengths, + log_dir=os.path.join(output_dir, 'inference_eval')) + if not evaluator.eval_tasks: + raise ValueError( + f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation.") + + # ---------------------------------------------------------------------------- + # T5X model loading. + # ---------------------------------------------------------------------------- + + # Initialize optimizer from the existing checkpoint. + input_shapes = { + k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items() + } + + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=None, # Do not load optimizer state. + init_fn=model.get_initial_variables, + input_shapes=input_shapes, + partitioner=partitioner) + train_state_axes = train_state_initializer.train_state_axes + # Log the variable shapes information and write to a file. + log_file = os.path.join(output_dir, 'model-info.txt') + utils.log_model_info(log_file, + train_state_initializer.global_train_state_shape, + partitioner) + + predict_fn = None + score_fn = None + + # Disable strictness since we are dropping the optimizer state. + restore_checkpoint_cfg.strict = False + + if fallback_init_rng is not None: + fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) + for train_state in train_state_initializer.from_checkpoints( + [restore_checkpoint_cfg], init_rng=fallback_init_rng): + + # Compile the model only once. + if not predict_fn: + predict_fn = utils.get_infer_fn( + infer_step=model.predict_batch, + batch_size=batch_size, + train_state_axes=train_state_axes, + partitioner=partitioner) + + predict_with_aux_fn = utils.get_infer_fn( + infer_step=model.predict_batch_with_aux, + batch_size=batch_size, + train_state_axes=train_state_axes, + partitioner=partitioner) + + score_fn = utils.get_infer_fn( + infer_step=model.score_batch, + batch_size=batch_size, + train_state_axes=train_state_axes, + partitioner=partitioner) + + # ---------------------------------------------------------------------------- + # Main training loop + # ---------------------------------------------------------------------------- + + # Run final evaluation (with decoding) on the full eval dataset. + all_metrics, _, _ = evaluator.evaluate( + compute_metrics=jax.process_index() == 0, + step=int(train_state.step), + predict_fn=functools.partial( + predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0)), + score_fn=functools.partial(score_fn, train_state=train_state), + predict_with_aux_fn=functools.partial( + predict_with_aux_fn, + train_state=train_state, + rng=jax.random.PRNGKey(0))) + all_metrics.result() # Ensure metrics are finished being computed. + # Wait until computations are done before continuing. + multihost_utils.sync_global_devices(f'step_{train_state.step}:complete') + + logging.info('Finished.') + + +if __name__ == '__main__': + from absl import app + from absl import flags + import gin + + FLAGS = flags.FLAGS + + jax.config.parse_flags_with_absl() + + flags.DEFINE_multi_string( + 'gin_file', + default=None, + help='Path to gin configuration file. Multiple paths may be passed and ' + 'will be imported in the given order, with later configurations ' + 'overriding earlier ones.') + + flags.DEFINE_multi_string( + 'gin_bindings', default=[], help='Individual gin bindings.') + + flags.DEFINE_list( + 'gin_search_paths', + default=['.'], + help='Comma-separated list of gin config path prefixes to be prepended ' + 'to suffixes given via `--gin_file`. If a file appears in. Only the ' + 'first prefix that produces a valid path for each suffix will be ' + 'used.') + + flags.DEFINE_string( + 'tfds_data_dir', None, + 'If set, this directory will be used to store datasets prepared by ' + 'TensorFlow Datasets that are not available in the public TFDS GCS ' + 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' + 'all `Task`s.') + + + def main(argv: Sequence[str]): + """Wrapper for pdb post mortems.""" + _main(argv) + + def _main(argv: Sequence[str]): + """True main function.""" + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + if FLAGS.tfds_data_dir: + seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) + + # Create gin-configurable version of `eval`. + evaluate_using_gin = gin.configurable(evaluate) + + gin_utils.parse_gin_flags( + # User-provided gin paths take precedence if relative paths conflict. + FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, + FLAGS.gin_file, + FLAGS.gin_bindings) + evaluate_using_gin() + + gin_utils.run(main) diff --git a/t5x/examples/__init__.py b/t5x/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac5693550488d38623ec8e5b56e3fc3de148d40 --- /dev/null +++ b/t5x/examples/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This empty file is needed to be recognized as a package by the setuptools.""" diff --git a/t5x/examples/decoder_only/layers.py b/t5x/examples/decoder_only/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8037981df2e15f6a80810d98627217c3da2bf655 --- /dev/null +++ b/t5x/examples/decoder_only/layers.py @@ -0,0 +1,1074 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dense attention classes and mask/weighting functions.""" + +# pylint: disable=attribute-defined-outside-init,g-bare-generic + +import dataclasses +import functools +import operator +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union + +from flax import linen as nn +import flax.core.variables as variables +from flax.linen import partitioning as nn_partitioning +from flax.training import common_utils +import jax +from jax import lax +from jax import random +import jax.numpy as jnp +import numpy as np + + +# from flax.linen.partitioning import param_with_axes, with_sharding_constraint +param_with_axes = nn_partitioning.param_with_axes +with_sharding_constraint = nn_partitioning.with_sharding_constraint + + +# Type annotations +Array = jnp.ndarray +DType = jnp.dtype +PRNGKey = jnp.ndarray +Shape = Iterable[int] +Activation = Callable[..., Array] +# Parameter initializers. +Initializer = Callable[[PRNGKey, Shape, DType], Array] + +default_embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0) + + +def dot_product_attention(query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0., + deterministic: bool = False, + dtype: DType = jnp.float32, + float32_logits: bool = False): + """Computes dot-product attention given query, key, and value. + + This is the core function for applying attention based on + https://arxiv.org/abs/1706.03762. It calculates the attention weights given + query and key and combines the values using the attention weights. + + Args: + query: queries for calculating attention with shape of `[batch, q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch, kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch, kv_length, + num_heads, v_depth_per_head]`. + bias: bias for the attention weights. This should be broadcastable to the + shape `[batch, num_heads, q_length, kv_length]` This can be used for + incorporating causal masks, padding masks, proximity bias, etc. + dropout_rng: JAX PRNGKey: to be used for dropout + dropout_rate: dropout rate + deterministic: bool, deterministic or not (to apply dropout) + dtype: the dtype of the computation (default: float32) + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + + Returns: + Output of shape `[batch, length, num_heads, v_depth_per_head]`. + """ + assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( + 'q, k, v batch dims must match.') + assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( + 'q, k, v num_heads must match.') + assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' + assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + + # Casting logits and softmax computation for float32 for model stability. + if float32_logits: + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) + + # `attn_weights`: [batch, num_heads, q_length, kv_length] + attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + + # Apply attention bias: masking, dropout, proximity bias, etc. + if bias is not None: + attn_weights = attn_weights + bias.astype(attn_weights.dtype) + + # Normalize the attention weights across `kv_length` dimension. + attn_weights = jax.nn.softmax(attn_weights).astype(dtype) + + # Apply attention dropout. + if not deterministic and dropout_rate > 0.: + keep_prob = 1.0 - dropout_rate + # T5 broadcasts along the "length" dim, but unclear which one that + # corresponds to in positional dimensions here, assuming query dim. + dropout_shape = list(attn_weights.shape) + dropout_shape[-2] = 1 + keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) + keep = jnp.broadcast_to(keep, attn_weights.shape) + multiplier = ( + keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) + attn_weights = attn_weights * multiplier + + # Take the linear combination of `value`. + return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) + + +class MultiHeadDotProductAttention(nn.Module): + """Multi-head dot-product attention. + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + """ + num_heads: int + head_dim: int + dtype: DType = jnp.float32 + dropout_rate: float = 0. + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal') + float32_logits: bool = False + + def update_cache_prefill( + self, key: Array, value: Array, cached_key: variables.Variable, + cached_value: variables.Variable, cache_index: variables.Variable, + prefill_lengths: Array + ) -> Tuple[Array, Array, Array, Array, Array, Array]: + """Update the autoregressive cache for multiple timesteps at once. + + This is useful for things like a prefix-lm where the encoder section of the + input is visible bidirectionally. The key and value for this section need to + be computed in a single shot, as a step by step approach would result in + causal attention. + + Args: + key: The calculated key used in attention. [batch..., length, num_heads, + features_per_head] + value: The calculated value used in attention. [batch..., length, + num_heads, features_per_head] + cached_key: The cache of previous keys. [batch..., num_heads, + features_per_head, length] + cached_value: The cache of previous values. [batch..., num_heads, + features_per_head, length] + cache_index: The timestep that we are currently calculating the key and + value for. [batch] + prefill_lengths: The number of timesteps we should fill in the cache. + [batch] + + Returns: + The key, value, and the last timestep we just filled in the cache. + We also return the new cache values for now because assigning to a + variable inside of a method doesn't work. These returns will be removed + eventually. + """ + # Make a reference to the data underlaying the variable for ease of + # use. + cache_index.value = prefill_lengths + # Note, the cache index is now a vector of batch size so that each example + # can start just after its prefix, which can be different lengths for + # different examples. + cur_index = cache_index.value + # Move the sequence dimension to the end to match the cache shapes. + key_cached = jnp.moveaxis(key, -3, -1) + value_cached = jnp.moveaxis(value, -3, -1) + # Reshape the index so the batch is at the beginning. The default + # broadcasting behavior is to add singleton dims to the front, but we need + # them at the end. + batch_first_index = jnp.reshape( + cur_index, (-1,) + tuple(1 for _ in range(cached_key.value.ndim - 1))) + # Calculate a mask that will set any position past the prefix to zero + # when applied to the key. + key_mask = ( + lax.broadcasted_iota(jnp.int32, cached_key.value.shape, + cached_key.value.ndim - 1) < batch_first_index) + value_mask = ( + lax.broadcasted_iota(jnp.int32, cached_value.value.shape, + cached_value.value.ndim - 1) < batch_first_index) + # Set the caches with the calculated key and values but hide anything + # past the prefix. + cached_key_value = key_cached * key_mask + cached_value_value = value_cached * value_mask + # TODO(hwchung): remove the return values once direct assignment to + # variables inside a method is possible. + return (key, value, cur_index, cached_key_value, cached_value_value, + prefill_lengths) + + def update_cache_decode( + self, key: Array, value: Array, cached_key: variables.Variable, + cached_value: variables.Variable, cache_index: variables.Variable + ) -> Tuple[Array, Array, Array, Array, Array, Array]: + """Update the next timestep in the autoregressive cache. + + This is used during step by step decoding where each key and value we get + are a single (the next) timestep. + + Args: + key: The calculated key used in attention. [batch..., 1, num_heads, + features_per_head] + value: The calculated value used in attention. [batch..., 1, num_heads, + features_per_head] + cached_key: The cache of previous keys. [batch..., num_heads, + features_per_head, length] + cached_value: The cache of previous values. [batch..., num_heads, + features_per_head, length] + cache_index: The timestep that we are currently calculating the key and + value for. [batch] if we are decoding after doing a prefill or [1] if we + are starting with step-by-step decoding. + + Returns: + The key, value, and the last timestep we just filled in the cache. Note: + this index is the last timestep we just fill, the actual value of the + `cache_index` is already increased to point to the next timestep to fill. + We also return the new cache values for now because assigning to a + variable inside of a method doesn't work. These returns will be removed + eventually. + """ + cache_length = cached_key.value.shape[-1] + # Create a OHE of the current index. NOTE: the index is increased + # below. + # Note: We reshape the index into a column vector so that it will work + # if the index is a scalar or a vector with different cache positions + # from different elements in a batch. + cur_index = jnp.reshape(cache_index.value, (-1,)) + one_hot_indices = jax.nn.one_hot(cur_index, cache_length, dtype=key.dtype) + # In order to update the key, value caches with the current key and + # value, we move the length axis to the back, similar to what we did + # for the cached ones above. + # Note these are currently the key and value of a single position, + # since we feed one position at a time. + one_token_key = jnp.moveaxis(key, -3, -1) + one_token_value = jnp.moveaxis(value, -3, -1) + # The one hot indices are now either [1, length] for a scalar index or + # [batch size, length] for examples where there are different lengths + # of prefixes. We need to add dims for num_heads and num_features as + # broadcasting doesn't work for the batched version. + one_hot_indices = jnp.expand_dims( + jnp.expand_dims(one_hot_indices, axis=1), axis=1) + # Update key, value caches with our new 1d spatial slices. + # We implement an efficient scatter into the cache via one-hot + # broadcast and addition. + # Key/Value have seq lengths of 1 while one_hot has a seq_length + # of length. key/value will broadcast their value to each timestep + # and the onehot will mask all but the correct timesteps. + key = cached_key.value + one_token_key * one_hot_indices + value = cached_value.value + one_token_value * one_hot_indices + cached_key_value = key + cached_value_value = value + cache_index_value = cache_index.value + 1 + # Move the keys and values back to their original shapes. + key = jnp.moveaxis(key, -1, -3) + value = jnp.moveaxis(value, -1, -3) + # TODO(hwchung): remove the return values once direct assignment to + # variables inside a method is possible. + return (key, value, cur_index, cached_key_value, cached_value_value, + cache_index_value) + + @nn.compact + def __call__(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + decode: bool = False, + deterministic: bool = False, + prefill: bool = False, + prefill_lengths: Optional[Array] = None) -> Array: + """Applies multi-head dot product attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. + + There are two modes: decoding and non-decoding (e.g., training). The mode is + determined by `decode`. + + During decoding mode, this method is called twice, by `init` and + `apply`. In the former, inputs_q: `[batch..., length, qkv_features]` and + inputs_kv: `[batch..., length, qkv_features]`. + + During apply, query, key and value all have the shape: `[batch * beam, 1, + qkv_features]` where the batch dimension is added to include multiple beams. + Note that the batch dimension is different during the `init` and `apply` + calls. This is because the cached variables are directly passed-in during + `apply` method. In other words, the cache variables such as `cached_key` are + initialized with `batch` dim, expanded by tiling in the beam search function + to `batch * beam` dimension, and passed to the `apply` method as part of a + variable dict. + + Args: + inputs_q: input queries of shape `[batch, q_length, embed]`. + inputs_kv: key/values of shape `[batch, kv_length, embed]`. + mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. + bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. + decode: whether to prepare and use an autoregressive cache. + deterministic: whether deterministic or not (to apply dropout) + prefill: whether to run a partial sequence to prefill the cache. + prefill_lengths: an array of shape [batch] denoting the length of each + partial sequence we are filling in the cache. + + Returns: + output of shape `[batch, q_length, embed]`. + """ + projection = functools.partial( + DenseGeneral, + axis=-1, + features=(self.num_heads, self.head_dim), + kernel_axes=('embed', 'joined_kv'), + dtype=self.dtype) + + # NOTE: T5 does not explicitly rescale the attention logits by + # 1/sqrt(depth_kq)! This is folded into the initializers of the + # linear transformations, which is equivalent under Adafactor. + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + query_init = lambda *args: self.kernel_init(*args) / depth_scaling + + # Project inputs_q to multi-headed q/k/v + # dimensions are then [batch, length, num_heads, head_dim] + query = projection(kernel_init=query_init, name='query')(inputs_q) + key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) + value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) + + query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) + key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) + value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) + + if prefill and decode: + raise ValueError('prefill and decode cannot both be true at the same' + 'time. If you are using a prefix LM with bidirectional ' + 'attention on the inputs, please make a call with ' + 'prefill=True that includes an attention mask that ' + 'covers your inputs first and then make your decoding ' + 'calls.') + if prefill or decode: + # Detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable('cache', 'cached_key') + # The key and value have dimension + # [batch..., length, num_heads, features_per_head], but we cache them as + # [batch..., num_heads, features_per_head, length] as a TPU fusion + # optimization. This also enable the "scatter via one-hot broadcast" + # trick, which means we do a one-hot broadcast instead of a scatter/gather + # operations, which gives a 3-4x speedup in practice. + swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) + cached_key = self.variable('cache', 'cached_key', jnp.zeros, + swap_dims(key.shape), key.dtype) + cached_value = self.variable('cache', 'cached_value', jnp.zeros, + swap_dims(value.shape), value.dtype) + cache_index = self.variable('cache', 'cache_index', + lambda: jnp.array(0, dtype=jnp.int32)) + if is_initialized: + # Here we are in "apply()". + *batch_dims, num_heads, features_per_head, length = ( + cached_key.value.shape) + if prefill: + if prefill_lengths is None: + # Figure out how far each element in the batch fills the cache based + # on the mask. We index each element in the batch, the first head + # dim (because this is always set to one), and the first query + # vector. If there is any prefix at all, the first element in the + # prefix would be part of it. + prefill_lengths = jnp.sum( + mask[:, 0, 0, :], axis=-1).astype(cache_index.value.dtype) + (key, value, cur_index, cached_key_value, cached_value_value, + cache_index_value) = self.update_cache_prefill( + key, value, cached_key, cached_value, cache_index, + prefill_lengths) + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + elif decode: + # Check the shape of the cached key against the input query. + expected_shape = tuple(batch_dims) + (1, num_heads, features_per_head) + if expected_shape != query.shape: + raise ValueError('Autoregressive cache shape error, ' + 'expected query shape %s instead got %s.' % + (expected_shape, query.shape)) + (key, value, cur_index, cached_key_value, cached_value_value, + cache_index_value) = self.update_cache_decode( + key, value, cached_key, cached_value, cache_index) + # Enforcing the Causal mask over previous positions and selecting only + # the bias value for the current index is only needed during decode + # mode where a single example is feed at a time. In prefill mode we + # uses these as provided, that same way it is done in a normal forward + # pass, like when computing logits during training. + + # Causal mask for cached decoder self-attention: our single query + # position should only attend to those key positions that have already + # been generated and cached, not the remaining zero elements. + + # (1, 1, length) represent (head dim, query length, key length) + # query length is 1 because during decoding we deal with one + # index. + # The same mask is applied to all batch elements and heads. + # + # Add trailing dims to the current index so it can either + # broadcast over the batch dim or it can just be batch size. + mask = combine_masks( + mask, + jnp.broadcast_to( + jnp.arange(length), + tuple(batch_dims) + + (1, 1, length)) <= jnp.reshape(cur_index, (-1, 1, 1, 1))) + # Grab the correct relative attention bias during decoding. This is + # only required during single step decoding. + if bias is not None: + # The bias is a full attention matrix, but during decoding we only + # have to take a slice of it. + # This is equivalent to `bias[..., cur_index:cur_index+1, :]`. If + # we are doing prefix decoding where `cur_index` is a vector the + # result will be `[batch, heads, 1, :]`. If `cur_index` is a scalar + # like in encdec decoding, the result will be `[1, heads, 1, :]`. + # We use a one-hot einsum rather than a slice to avoid introducing a + # Gather op that is currently lowered poorly by SPMD passes, adding + # expensive all-reduce and all-gather operations. + + bias = jnp.einsum( + 'bq, bhqk->bhk', + common_utils.onehot(cur_index, num_classes=length), bias) + bias = jnp.expand_dims(bias, 2) + + # Currently, updating a variable inside of a method is not handled + # in flax, so we return the actual values and assign them in the main + # compacted call for now. + # TODO(brianlester,levskaya): Move variable assignment inside of the + # cache update functions once variable references are tracked across + # transform boundaries. + cache_index.value = cache_index_value + cached_key.value = cached_key_value + cached_value.value = cached_value_value + + # Convert the boolean attention mask to an attention bias. + if mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + mask > 0, + jnp.full(mask.shape, 0.).astype(self.dtype), + jnp.full(mask.shape, -1e10).astype(self.dtype)) + else: + attention_bias = None + + # Add provided bias term (e.g. relative position embedding). + if bias is not None: + attention_bias = combine_biases(attention_bias, bias) + + dropout_rng = None + if not deterministic and self.dropout_rate > 0.: + dropout_rng = self.make_rng('dropout') + + # Apply attention. + x = dot_product_attention( + query, + key, + value, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout_rate, + deterministic=deterministic, + dtype=self.dtype, + float32_logits=self.float32_logits) + + # Back to the original inputs dimensions. + out = DenseGeneral( + features=inputs_q.shape[-1], # output dim is set to the input dim. + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=('joined_kv', 'embed'), + dtype=self.dtype, + name='out')( + x) + return out + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +#------------------------------------------------------------------------------ +# DenseGeneral for attention layers. +#------------------------------------------------------------------------------ +class DenseGeneral(nn.Module): + """A linear transformation (without bias) with flexible axes. + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + """ + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + dtype: DType = jnp.float32 + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'truncated_normal') + kernel_axes: Tuple[str, ...] = () + + @nn.compact + def __call__(self, inputs: Array) -> Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features + kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), + np.prod(features)) + kernel = param_with_axes( + 'kernel', + self.kernel_init, + kernel_param_shape, + jnp.float32, + axes=self.kernel_axes) + kernel = jnp.asarray(kernel, self.dtype) + kernel = jnp.reshape(kernel, kernel_shape) + + contract_ind = tuple(range(0, len(axis))) + return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) + + +def _convert_to_activation_function( + fn_or_string: Union[str, Callable]) -> Callable: + """Convert a string to an activation function.""" + if fn_or_string == 'linear': + return lambda x: x + elif isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + elif callable(fn_or_string): + return fn_or_string + else: + raise ValueError("don't know how to convert %s to an activation function" % + (fn_or_string,)) + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block. + + Attributes: + intermediate_dim: Shared dimension of hidden layers. + activations: Type of activations for each layer. Each element is either + 'linear', a string function name in flax.linen, or a function. + kernel_init: Kernel function, passed to the dense layers. + deterministic: Whether the dropout layers should be deterministic. + intermediate_dropout_rate: Dropout rate used after the intermediate layers. + dtype: Type for the dense layer. + """ + intermediate_dim: int = 2048 + activations: Sequence[Union[str, Callable]] = ('relu',) + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'truncated_normal') + intermediate_dropout_rate: float = 0.1 + dtype: Any = jnp.float32 + + @nn.compact + def __call__(self, inputs, decode: bool = False, deterministic: bool = False): + """Applies Transformer MlpBlock module.""" + # Iterate over specified MLP input activation functions. + # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. + activations = [] + for idx, act_fn in enumerate(self.activations): + dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' + x = DenseGeneral( + self.intermediate_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=('embed', 'mlp'), + name=dense_name)( + inputs) + x = _convert_to_activation_function(act_fn)(x) + activations.append(x) + + # Take elementwise product of above intermediate activations. + x = functools.reduce(operator.mul, activations) + # Apply dropout and final dense output projection. + x = nn.Dropout( + rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) # Broadcast along length. + x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) + output = DenseGeneral( + inputs.shape[-1], + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=('mlp', 'embed'), + name='wo')( + x) + return output + + +class Embed(nn.Module): + """A parameterized function from integers [0, n) to d-dimensional vectors. + + Attributes: + num_embeddings: number of embeddings. + features: number of feature dimensions for each embedding. + dtype: the dtype of the embedding vectors (default: float32). + embedding_init: embedding initializer. + one_hot: performs the gather with a one-hot contraction rather than a true + gather. This is currently needed for SPMD partitioning. + """ + num_embeddings: int + features: int + cast_input_dtype: Optional[DType] = None + dtype: DType = jnp.float32 + attend_dtype: Optional[DType] = None + embedding_init: Initializer = default_embed_init + one_hot: bool = False + embedding: Array = dataclasses.field(init=False) + + def setup(self): + self.embedding = param_with_axes( + 'embedding', + self.embedding_init, (self.num_embeddings, self.features), + jnp.float32, + axes=('vocab', 'embed')) + + def __call__(self, inputs: Array) -> Array: + """Embeds the inputs along the last dimension. + + Args: + inputs: input data, all dimensions are considered batch dimensions. + + Returns: + Output which is embedded input data. The output shape follows the input, + with an additional `features` dimension appended. + """ + if self.cast_input_dtype: + inputs = inputs.astype(self.cast_input_dtype) + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError('Input type must be an integer or unsigned integer.') + if self.one_hot: + iota = lax.iota(jnp.int32, self.num_embeddings) + one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) + output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) + else: + output = jnp.asarray(self.embedding, self.dtype)[inputs] + output = with_sharding_constraint(output, ('batch', 'length', 'embed')) + return output + + def attend(self, query: Array) -> Array: + """Attend over the embedding using a query array. + + Args: + query: array with last dimension equal the feature depth `features` of the + embedding. + + Returns: + An array with final dim `num_embeddings` corresponding to the batched + inner-product of the array of query vectors against each embedding. + Commonly used for weight-sharing between embeddings and logit transform + in NLP models. + """ + dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype + return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) + + +class RelativePositionBiases(nn.Module): + """Adds T5-style relative positional embeddings to the attention logits. + + Attributes: + num_buckets: Number of buckets to bucket distances between key and query + positions into. + max_distance: Maximum distance before everything is lumped into the last + distance bucket. + num_heads: Number of heads in the attention layer. Each head will get a + different relative position weighting. + dtype: Type of arrays through this module. + embedding_init: initializer for relative embedding table. + """ + num_buckets: int + max_distance: int + num_heads: int + dtype: Any + embedding_init: Callable[..., Array] = nn.linear.default_embed_init + + @staticmethod + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """Translate relative position to a bucket number for relative attention. + + The relative position is defined as memory_position - query_position, i.e. + the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are + invalid. + We use smaller buckets for small absolute relative_position and larger + buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative + positions <=-max_distance map to the same bucket. This should allow for + more graceful generalization to longer sequences than the model has been + trained on. + + Args: + relative_position: an int32 array + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ + ret = 0 + n = -relative_position + if bidirectional: + num_buckets //= 2 + ret += (n < 0).astype(np.int32) * num_buckets + n = np.abs(n) + else: + n = np.maximum(n, 0) + # now n is in the range [0, inf) + max_exact = num_buckets // 2 + is_small = (n < max_exact) + val_if_large = max_exact + ( + np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / + np.log(max_distance / max_exact) * + (num_buckets - max_exact)).astype(np.int32) + val_if_large = np.minimum(val_if_large, num_buckets - 1) + ret += np.where(is_small, n, val_if_large) + return ret + + @nn.compact + def __call__(self, qlen, klen, bidirectional=True, decode=False): + """Produce relative position embedding attention biases. + + Args: + qlen: attention query length. + klen: attention key length. + bidirectional: whether to allow positive memory-query relative position + embeddings. + decode: whether to cache relative position bias during autoregressive + decoding. + + Returns: + output: `(1, num_heads, q_len, k_len)` attention bias + """ + # bidirectional embeddings don't make sense when decoding (and break cache). + if decode and bidirectional: + raise ValueError( + 'bidirectional RelativePositionBiases are not supported when ' + '`decode=True`.') + + # We only cache the bias if the model was already initialized, i.e. if this + # module is called with `model.apply` and `decode = True`. We raise an error + # if called with `model.init` and `decode = True`, since this can cache + # incorrect positional embeddings produced by random parameters. + is_initialized = self.has_variable('params', 'rel_embedding') + if decode and not is_initialized: + raise ValueError( + 'decode-mode cannot be enabled during init. use model.apply to ' + 'initialize the decoding cache.') + + # Return pre-computed relative position bias in cache during decode steps. + if decode and self.has_variable('cache', 'cached_bias'): + cached_bias = self.get_variable('cache', 'cached_bias') + expected_bias_shape = (1, self.num_heads, qlen, klen) + if cached_bias.shape != expected_bias_shape: + raise ValueError(f'The cached relative position attention bias was ' + f'expected to have shape {expected_bias_shape} but ' + f'instead has the shape {cached_bias.shape}.') + return cached_bias + + # TODO(levskaya): should we be computing this w. numpy as a program + # constant? + context_position = np.arange(qlen, dtype=jnp.int32)[:, None] + memory_position = np.arange(klen, dtype=jnp.int32)[None, :] + relative_position = memory_position - context_position # shape (qlen, klen) + rp_bucket = self._relative_position_bucket( + relative_position, + bidirectional=bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance) + relative_attention_bias = param_with_axes( + 'rel_embedding', + self.embedding_init, (self.num_heads, self.num_buckets), + jnp.float32, + axes=('heads', 'relpos_buckets')) + + relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) + # Instead of using a slow gather, we create a leading-dimension one-hot + # array from rp_bucket and use it to perform the gather-equivalent via a + # contraction, i.e.: + # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). + # This is equivalent to relative_attention_bias[:, rp_bucket] + bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) + rp_bucket_one_hot = jnp.array( + rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) + # --> shape (qlen, klen, num_heads) + values = lax.dot_general( + relative_attention_bias, + rp_bucket_one_hot, + ( + ((1,), (0,)), # rhs, lhs contracting dims + ((), ()))) # no batched dims + # Add a singleton batch dimension. + # --> shape (1, num_heads, qlen, klen) + out = values[jnp.newaxis, ...] + + # Store computed relative position bias in cache after first calculation. + if decode: + _ = self.variable('cache', 'cached_bias', lambda: out) + + return out + + +#------------------------------------------------------------------------------ +# T5 Layernorm - no subtraction of mean or bias. +#------------------------------------------------------------------------------ +class LayerNorm(nn.Module): + """T5 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 + dtype: Any = jnp.float32 + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Applies layer normalization on the input.""" + x = jnp.asarray(x, jnp.float32) + features = x.shape[-1] + mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) + scale = param_with_axes( + 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',)) + + scale = jnp.asarray(scale, self.dtype) + return y * scale + + +#------------------------------------------------------------------------------ +# Mask-making utility functions. +#------------------------------------------------------------------------------ +def make_attention_mask(query_input: Array, + key_input: Array, + pairwise_fn: Callable = jnp.multiply, + extra_batch_dims: int = 0, + dtype: DType = jnp.float32) -> Array: + """Mask-making helper for attention weights. + + In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the + attention weights will be `[batch, heads, len_q, len_kv]` and this + function will produce `[batch, 1, len_q, len_kv]`. + + Args: + query_input: a batched, flat input of query_length size + key_input: a batched, flat input of key_length size + pairwise_fn: broadcasting elementwise comparison function + extra_batch_dims: number of extra batch dims to add singleton axes for, none + by default + dtype: mask return dtype + + Returns: + A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. + """ + # [batch, len_q, len_kv] + mask = pairwise_fn( + # [batch, len_q] -> [batch, len_q, 1] + jnp.expand_dims(query_input, axis=-1), + # [batch, len_q] -> [batch, 1, len_kv] + jnp.expand_dims(key_input, axis=-2)) + + # [batch, 1, len_q, len_kv]. This creates the head dim. + mask = jnp.expand_dims(mask, axis=-3) + mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) + return mask.astype(dtype) + + +def make_causal_mask(x: Array, + extra_batch_dims: int = 0, + dtype: DType = jnp.float32) -> Array: + """Make a causal mask for self-attention. + + In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights + will be `[batch, heads, len, len]` and this function will produce a + causal mask of shape `[batch, 1, len, len]`. + + Note that a causal mask does not depend on the values of x; it only depends on + the shape. If x has padding elements, they will not be treated in a special + manner. + + Args: + x: input array of shape `[batch, len]` + extra_batch_dims: number of batch dims to add singleton axes for, none by + default + dtype: mask return dtype + + Returns: + A `[batch, 1, len, len]` shaped causal mask for 1d attention. + """ + idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) + return make_attention_mask( + idxs, + idxs, + jnp.greater_equal, + extra_batch_dims=extra_batch_dims, + dtype=dtype) + + +def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): + """Combine attention masks. + + Args: + *masks: set of attention mask arguments to combine, some can be None. + dtype: final mask dtype + + Returns: + Combined mask, reduced by logical and, returns None if no masks given. + """ + masks = [m for m in masks if m is not None] + if not masks: + return None + assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( + f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + mask, *other_masks = masks + for other_mask in other_masks: + mask = jnp.logical_and(mask, other_mask) + return mask.astype(dtype) + + +def combine_biases(*masks: Optional[Array]): + """Combine attention biases. + + Args: + *masks: set of attention bias arguments to combine, some can be None. + + Returns: + Combined mask, reduced by summation, returns None if no masks given. + """ + masks = [m for m in masks if m is not None] + if not masks: + return None + assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( + f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + mask, *other_masks = masks + for other_mask in other_masks: + mask = mask + other_mask + return mask + + +def make_decoder_mask(decoder_target_tokens: Array, + dtype: DType, + decoder_causal_attention: Optional[Array] = None, + decoder_segment_ids: Optional[Array] = None) -> Array: + """Compute the self-attention mask for a decoder. + + Decoder mask is formed by combining a causal mask, a padding mask and an + optional packing mask. If decoder_causal_attention is passed, it makes the + masking non-causal for positions that have value of 1. + + A prefix LM is applied to a dataset which has a notion of "inputs" and + "targets", e.g., a machine translation task. The inputs and targets are + concatenated to form a new target. `decoder_target_tokens` is the concatenated + decoder output tokens. + + The "inputs" portion of the concatenated sequence can attend to other "inputs" + tokens even for those at a later time steps. In order to control this + behavior, `decoder_causal_attention` is necessary. This is a binary mask with + a value of 1 indicating that the position belonged to "inputs" portion of the + original dataset. + + Example: + + Suppose we have a dataset with two examples. + + ds = [{"inputs": [6, 7], "targets": [8]}, + {"inputs": [3, 4], "targets": [5]}] + + After the data preprocessing with packing, the two examples are packed into + one example with the following three fields (some fields are skipped for + simplicity). + + decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] + decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] + decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] + + where each array has [batch, length] shape with batch size being 1. Then, + this function computes the following mask. + + mask = [[[[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + + mask[b, 1, :, :] represents the mask for the example `b` in the batch. + Because mask is for a self-attention layer, the mask's shape is a square of + shape [query length, key length]. + + mask[b, 1, i, j] = 1 means that the query token at position i can attend to + the key token at position j. + + Args: + decoder_target_tokens: decoder output tokens. [batch, length] + dtype: dtype of the output mask. + decoder_causal_attention: a binary mask indicating which position should + only attend to earlier positions in the sequence. Others will attend + bidirectionally. [batch, length] + decoder_segment_ids: decoder segmentation info for packed examples. [batch, + length] + + Returns: + the combined decoder mask. + """ + masks = [] + # The same mask is applied to all attention heads. So the head dimension is 1, + # i.e., the mask will be broadcast along the heads dim. + # [batch, 1, length, length] + causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) + + # Positions with value 1 in `decoder_causal_attneition` can attend + # bidirectionally. + if decoder_causal_attention is not None: + # [batch, 1, length, length] + inputs_mask = make_attention_mask( + decoder_causal_attention, + decoder_causal_attention, + jnp.logical_and, + dtype=dtype) + masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) + else: + masks.append(causal_mask) + + # Padding mask. + masks.append( + make_attention_mask( + decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype)) + + # Packing mask + if decoder_segment_ids is not None: + masks.append( + make_attention_mask( + decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype)) + + return combine_masks(*masks, dtype=dtype) diff --git a/t5x/examples/decoder_only/layers_test.py b/t5x/examples/decoder_only/layers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b571d9d28afea97bd0e125399395b30e24f7ee43 --- /dev/null +++ b/t5x/examples/decoder_only/layers_test.py @@ -0,0 +1,756 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for attention classes.""" + +import dataclasses +from typing import Optional +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from flax import linen as nn +from flax.core import freeze +from flax.linen import partitioning as nn_partitioning +import jax +from jax import random +from jax.nn import initializers +import jax.numpy as jnp +import numpy as np +from t5x.examples.decoder_only import layers + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +Array = jnp.ndarray +AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name + + +class SelfAttention(layers.MultiHeadDotProductAttention): + """Self-attention special case of multi-head dot-product attention.""" + + @nn.compact + def __call__(self, + inputs_q: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + deterministic: bool = False): + return super().__call__( + inputs_q, inputs_q, mask, bias, deterministic=deterministic) + + +@dataclasses.dataclass(frozen=True) +class SelfAttentionArgs: + num_heads: int = 1 + batch_size: int = 2 + # qkv_features: int = 3 + head_dim: int = 3 + # out_features: int = 4 + q_len: int = 5 + features: int = 6 + dropout_rate: float = 0.1 + deterministic: bool = False + decode: bool = False + float32_logits: bool = False + + def __post_init__(self): + # If we are doing decoding, the query length should be 1, because are doing + # autoregressive decoding where we feed one position at a time. + assert not self.decode or self.q_len == 1 + + def init_args(self): + return dict( + num_heads=self.num_heads, + head_dim=self.head_dim, + dropout_rate=self.dropout_rate, + float32_logits=self.float32_logits) + + def apply_args(self): + inputs_q = jnp.ones((self.batch_size, self.q_len, self.features)) + mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) + bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) + return { + 'inputs_q': inputs_q, + 'mask': mask, + 'bias': bias, + 'deterministic': self.deterministic + } + + +class AttentionTest(parameterized.TestCase): + + def test_dot_product_attention_shape(self): + # This test only checks for shape but tries to make sure all code paths are + # reached. + dropout_rng = random.PRNGKey(0) + batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6 + + query = jnp.ones((batch_size, q_len, num_heads, qk_depth)) + key = jnp.ones((batch_size, kv_len, num_heads, qk_depth)) + value = jnp.ones((batch_size, kv_len, num_heads, v_depth)) + bias = jnp.ones((batch_size, num_heads, q_len, kv_len)) + + args = dict( + query=query, + key=key, + value=value, + bias=bias, + dropout_rng=dropout_rng, + dropout_rate=0.5, + deterministic=False, + ) + + output = layers.dot_product_attention(**args) + self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth)) + + def test_make_attention_mask_multiply_pairwise_fn(self): + decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]]) + attention_mask = layers.make_attention_mask( + decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32) + expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]) + self.assertEqual(attention_mask.shape, (2, 1, 3, 3)) + np.testing.assert_array_equal(attention_mask[0, 0], expected0) + np.testing.assert_array_equal(attention_mask[1, 0], expected1) + + def test_make_attention_mask_equal_pairwise_fn(self): + segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]]) + attention_mask = layers.make_attention_mask( + segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32) + # Padding is not treated in a special way. So they need to be zeroed out + # separately. + expected0 = jnp.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]]) + expected1 = jnp.array([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1]]) + self.assertEqual(attention_mask.shape, (2, 1, 6, 6)) + np.testing.assert_array_equal(attention_mask[0, 0], expected0) + np.testing.assert_array_equal(attention_mask[1, 0], expected1) + + def test_make_causal_mask_with_padding(self): + x = jnp.array([[7, 0, 0], [8, 5, 0]]) + y = layers.make_causal_mask(x) + self.assertEqual(y.shape, (2, 1, 3, 3)) + # Padding is not treated in a special way. So they need to be zeroed out + # separately. + expected_y = jnp.array([[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]], + jnp.float32) + np.testing.assert_allclose(y[0], expected_y) + np.testing.assert_allclose(y[1], expected_y) + + def test_make_causal_mask_extra_batch_dims(self): + x = jnp.ones((3, 3, 5)) + y = layers.make_causal_mask(x, extra_batch_dims=2) + self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5)) + + def test_make_causal_mask(self): + x = jnp.ones((1, 3)) + y = layers.make_causal_mask(x) + self.assertEqual(y.shape, (1, 1, 3, 3)) + expected_y = jnp.array([[[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]]], + jnp.float32) + np.testing.assert_allclose(y, expected_y) + + def test_combine_masks(self): + masks = [ + jnp.array([0, 1, 0, 1], jnp.float32), None, + jnp.array([1, 1, 1, 1], jnp.float32), + jnp.array([1, 1, 1, 0], jnp.float32) + ] + y = layers.combine_masks(*masks) + np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32)) + + def test_combine_biases(self): + masks = [ + jnp.array([0, 1, 0, 1], jnp.float32), None, + jnp.array([0, 1, 1, 1], jnp.float32), + jnp.array([0, 1, 1, 0], jnp.float32) + ] + y = layers.combine_biases(*masks) + np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32)) + + def test_make_decoder_mask_lm_unpacked(self): + decoder_target_tokens = jnp.array([6, 7, 3, 0]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32) + expected_mask = jnp.array([[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], + [0, 0, 0, 0]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_lm_packed(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_segment_ids=decoder_segment_ids) + expected_mask = jnp.array([[[[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_unpacked(self): + decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask = jnp.array( + [[[[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]], + dtype=jnp.float32) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_packed(self): + decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + expected_mask = jnp.array([[[[1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask0 = jnp.array([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], + [0, 0, 0, 0]]) + expected_mask1 = jnp.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], + [0, 0, 0, 0]]) + self.assertEqual(mask.shape, (2, 1, 4, 4)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + np.testing.assert_array_equal(mask[1, 0], expected_mask1) + + def test_make_decoder_mask_composite_causal_attention(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]) + + self.assertEqual(mask.shape, (1, 1, 7, 7)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + + def test_make_decoder_mask_composite_causal_attention_packed(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1]]) + + self.assertEqual(mask.shape, (1, 1, 9, 9)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + + @parameterized.parameters({'f': 20}, {'f': 22}) + def test_multihead_dot_product_attention(self, f): + # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim + b, q, h, d, k = 2, 3, 4, 5, 6 + + base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + np.random.seed(0) + inputs_q = np.random.randn(b, q, f) + inputs_kv = np.random.randn(b, k, f) + + # Projection: [b, q, f] -> [b, q, h, d] + # So the kernels have to be [f, h, d] + query_kernel = np.random.randn(f, h, d) + key_kernel = np.random.randn(f, h, d) + value_kernel = np.random.randn(f, h, d) + # `out` calculation: [b, q, h, d] -> [b, q, f] + # So kernel has to be [h, d, f] + out_kernel = np.random.randn(h, d, f) + + params = { + 'query': { + 'kernel': query_kernel.reshape(f, -1) + }, + 'key': { + 'kernel': key_kernel.reshape(f, -1) + }, + 'value': { + 'kernel': value_kernel.reshape(f, -1) + }, + 'out': { + 'kernel': out_kernel.reshape(-1, f) + } + } + y = layers.MultiHeadDotProductAttention(**args).apply( + {'params': freeze(params)}, inputs_q, inputs_kv) + + query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel) + key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel) + value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel) + logits = np.einsum('bqhd,bkhd->bhqk', query, key) + weights = nn.softmax(logits, axis=-1) + combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value) + y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel) + np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5) + + def test_multihead_dot_product_attention_caching(self): + # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim + b, h, d, k = 2, 3, 4, 5 + f = h * d + + base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + cache = { + 'cached_key': np.zeros((b, h, d, k)), + 'cached_value': np.zeros((b, h, d, k)), + 'cache_index': np.array(0) + } + inputs_q = np.random.randn(b, 1, f) + inputs_kv = np.random.randn(b, 1, f) + + # Mock dense general such that q, k, v projections are replaced by simple + # reshaping. + def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument + return x.reshape(b, -1, h, d) + + with mock.patch.object( + layers.DenseGeneral, '__call__', new=mock_dense_general): + _, mutated = layers.MultiHeadDotProductAttention(**args).apply( + {'cache': freeze(cache)}, + inputs_q, + inputs_kv, + decode=True, + mutable=['cache']) + updated_cache = mutated['cache'] + + # Perform the same mocked projection to generate the expected cache. + # (key|value): [b, 1, h, d] + key = mock_dense_general(None, inputs_kv) + value = mock_dense_general(None, inputs_kv) + + # cached_(key|value): [b, h, d, k] + cache['cached_key'][:, :, :, 0] = key[:, 0, :, :] + cache['cached_value'][:, :, :, 0] = value[:, 0, :, :] + cache['cache_index'] = np.array(1) + for name, array in cache.items(): + np.testing.assert_allclose(array, updated_cache[name]) + + def test_dot_product_attention(self): + # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim + b, q, h, d, k = 2, 3, 4, 5, 6 + np.random.seed(0) + query = np.random.randn(b, q, h, d) + key = np.random.randn(b, k, h, d) + value = np.random.randn(b, k, h, d) + bias = np.random.randn(b, h, q, k) + attn_out = layers.dot_product_attention(query, key, value, bias=bias) + logits = np.einsum('bqhd,bkhd->bhqk', query, key) + weights = jax.nn.softmax(logits + bias, axis=-1) + expected = np.einsum('bhqk,bkhd->bqhd', weights, value) + np.testing.assert_allclose(attn_out, expected, atol=1e-6) + + def test_multihead_dot_product_attention_prefill_caching(self): + # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim + b, h, d, k = 2, 3, 4, 5 + f = h * d + prefill_lengths = np.array([3, 1]) + + base_args = SelfAttentionArgs( + num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + cache = { + 'cached_key': np.zeros((b, h, d, k)), + 'cached_value': np.zeros((b, h, d, k)), + 'cache_index': np.array([0, 0]) + } + inputs_q = np.random.randn(b, k, f) + inputs_kv = np.random.randn(b, k, f) + + # Mock dense general such that q, k, v projections are replaced by simple + # reshaping. + def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument + return x.reshape(b, -1, h, d) + + with mock.patch.object( + layers.DenseGeneral, '__call__', new=mock_dense_general): + _, mutated = layers.MultiHeadDotProductAttention(**args).apply( + {'cache': freeze(cache)}, + inputs_q, + inputs_kv, + decode=False, + prefill=True, + prefill_lengths=prefill_lengths, + mutable=['cache']) + updated_cache = mutated['cache'] + + # Perform the same mocked projection to generate the expected cache. + # (key|value): [b, 1, h, d] + key = mock_dense_general(None, inputs_kv) + value = mock_dense_general(None, inputs_kv) + + # cached_(key|value): [b, h, d, k] + # Update the our gold cache with the key and values that are part of the + # prefix that we are prefilling the cache with. Explicit loops here avoid a + # confusing transpose. + for b, prefill_length in enumerate(prefill_lengths): + for i in range(prefill_length): + cache['cached_key'][b, :, :, i] = key[b, i, :, :] + cache['cached_value'][b, :, :, i] = value[b, i, :, :] + cache['cache_index'][b] = prefill_length + for name, array in cache.items(): + np.testing.assert_allclose(array, updated_cache[name]) + + +class EmbeddingTest(parameterized.TestCase): + + def test_embedder_raises_exception_for_incorrect_input_type(self): + """Tests that inputs are integers and that an exception is raised if not.""" + embed = layers.Embed(num_embeddings=10, features=5) + inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + bad_inputs = inputs.astype(np.float32) + with self.assertRaisesRegex( + ValueError, 'Input type must be an integer or unsigned integer.'): + _ = embed.apply(variables, bad_inputs) + + @parameterized.named_parameters( + { + 'testcase_name': 'with_ones', + 'init_fn': jax.nn.initializers.ones, + 'num_embeddings': 10, + 'features': 5, + 'matrix_sum': 5 * 10, + }, { + 'testcase_name': 'with_zeros', + 'init_fn': jax.nn.initializers.zeros, + 'num_embeddings': 10, + 'features': 5, + 'matrix_sum': 0, + }) + def test_embedding_initializes_correctly(self, init_fn, num_embeddings, + features, matrix_sum): + """Tests if the Embed class initializes with the requested initializer.""" + embed = layers.Embed( + num_embeddings=num_embeddings, + features=features, + embedding_init=init_fn) + inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + embedding_matrix = variables['params']['embedding'] + self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum) + + def test_embedding_matrix_shape(self): + """Tests that the embedding matrix has the right shape.""" + num_embeddings = 10 + features = 5 + embed = layers.Embed(num_embeddings=num_embeddings, features=features) + inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + embedding_matrix = variables['params']['embedding'] + self.assertEqual((num_embeddings, features), embedding_matrix.shape) + + def test_embedding_attend(self): + """Tests that attending with ones returns sum of embedding vectors.""" + features = 5 + embed = layers.Embed(num_embeddings=10, features=features) + inputs = np.array([[1]], dtype=np.int64) + variables = embed.init(jax.random.PRNGKey(0), inputs) + query = np.ones(features, dtype=np.float32) + result = embed.apply(variables, query, method=embed.attend) + expected = np.sum(variables['params']['embedding'], -1) + np.testing.assert_array_almost_equal(result, expected) + + +class DenseTest(parameterized.TestCase): + + def test_dense_general_no_bias(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 3)) + model = layers.DenseGeneral( + features=4, + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + self.assertEqual(y.shape, (1, 4)) + np.testing.assert_allclose(y, np.full((1, 4), 3.)) + + def test_dense_general_two_features(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 3)) + model = layers.DenseGeneral( + features=(2, 2), + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + # We transform the last input dimension to two output dimensions (2, 2). + np.testing.assert_allclose(y, np.full((1, 2, 2), 3.)) + + def test_dense_general_two_axes(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 2, 2)) + model = layers.DenseGeneral( + features=3, + axis=(-2, 2), # Note: this is the same as (1, 2). + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + # We transform the last two input dimensions (2, 2) to one output dimension. + np.testing.assert_allclose(y, np.full((1, 3), 4.)) + + def test_mlp_same_out_dim(self): + module = layers.MlpBlock( + intermediate_dim=4, + activations=('relu',), + kernel_init=nn.initializers.xavier_uniform(), + dtype=jnp.float32, + ) + inputs = np.array( + [ + # Batch 1. + [[1, 1], [1, 1], [1, 2]], + # Batch 2. + [[2, 2], [3, 1], [2, 2]], + ], + dtype=np.float32) + params = module.init(random.PRNGKey(0), inputs, deterministic=True) + self.assertEqual( + jax.tree_map(lambda a: a.tolist(), params), { + 'params': { + 'wi': { + 'kernel': [[ + -0.8675811290740967, 0.08417510986328125, + 0.022586345672607422, -0.9124102592468262 + ], + [ + -0.19464373588562012, 0.49809837341308594, + 0.7808468341827393, 0.9267289638519287 + ]], + }, + 'wo': { + 'kernel': [[0.01154780387878418, 0.1397249698638916], + [0.974980354309082, 0.5903260707855225], + [-0.05997943878173828, 0.616570234298706], + [0.2934272289276123, 0.8181164264678955]], + }, + }, + 'params_axes': { + 'wi': { + 'kernel_axes': AxisMetadata(names=('embed', 'mlp')), + }, + 'wo': { + 'kernel_axes': AxisMetadata(names=('mlp', 'embed')), + }, + }, + }) + result = module.apply(params, inputs, deterministic=True) + np.testing.assert_allclose( + result.tolist(), + [[[0.5237172245979309, 0.8508185744285583], + [0.5237172245979309, 0.8508185744285583], + [1.2344461679458618, 2.3844780921936035]], + [[1.0474344491958618, 1.7016371488571167], + [0.6809444427490234, 0.9663378596305847], + [1.0474344491958618, 1.7016371488571167]]], + rtol=1e-6, + ) + + +class RelativePositionBiasesTest(absltest.TestCase): + + def setUp(self): + self.num_heads = 3 + self.query_len = 5 + self.key_len = 7 + self.relative_attention = layers.RelativePositionBiases( + num_buckets=12, + max_distance=10, + num_heads=3, + dtype=jnp.float32, + ) + super(RelativePositionBiasesTest, self).setUp() + + def test_relative_attention_bidirectional_params(self): + """Tests that bidirectional relative position biases have expected params.""" + params = self.relative_attention.init( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True) + param_shapes = jax.tree_map(lambda x: x.shape, params) + self.assertEqual( + param_shapes, { + 'params': { + 'rel_embedding': (3, 12), + }, + 'params_axes': { + 'rel_embedding_axes': + AxisMetadata(names=('heads', 'relpos_buckets')), + } + }) + + def test_regression_relative_attention_bidirectional_values(self): + """Tests that bidirectional relative position biases match expected values. + + See top docstring note on matching T5X behavior for these regression tests. + """ + outputs, unused_params = self.relative_attention.init_with_output( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True) + self.assertEqual(outputs.shape, + (1, self.num_heads, self.query_len, self.key_len)) + self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) + self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) + self.assertAlmostEqual(outputs[0, 1, 4, 6], 0.14510104, places=5) + self.assertAlmostEqual(outputs[0, 2, 4, 6], -0.36783996, places=5) + + def test_relative_attention_unidirectional_params(self): + """Tests that unidirectional relative position biases have expected params.""" + params = self.relative_attention.init( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False) + param_shapes = jax.tree_map(lambda x: x.shape, params) + self.assertEqual( + param_shapes, { + 'params': { + 'rel_embedding': (3, 12), + }, + 'params_axes': { + 'rel_embedding_axes': + AxisMetadata(names=('heads', 'relpos_buckets')), + } + }) + + def test_regression_relative_attention_unidirectional_values(self): + """Tests that unidirectional relative position biases match expected values. + + See top docstring note on matching T5X behavior for these regression tests. + """ + outputs, unused_params = self.relative_attention.init_with_output( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False) + self.assertEqual(outputs.shape, + (1, self.num_heads, self.query_len, self.key_len)) + self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) + self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) + self.assertAlmostEqual(outputs[0, 1, 4, 6], -0.13101986, places=5) + self.assertAlmostEqual(outputs[0, 2, 4, 6], 0.39296466, places=5) + + def test_relative_attention_decode_cache_error_with_init(self): + """Tests that relative embedding init fails with decode == True.""" + with self.assertRaisesRegex( + ValueError, + 'decode-mode cannot be enabled during init. use model.apply to ' + 'initialize the decoding cache.'): + self.relative_attention.init( + jax.random.PRNGKey(0), + self.query_len, + self.key_len, + bidirectional=False, + decode=True) + + def test_relative_attention_decode_cache_errror_with_bidirectional(self): + """Tests that bidirectional relative embeddings fails when decoding.""" + params = self.relative_attention.init( + jax.random.PRNGKey(0), + self.query_len, + self.key_len, + bidirectional=False, + decode=False) + + with self.assertRaisesRegex( + ValueError, + 'bidirectional RelativePositionBiases are not supported when ' + '`decode=True`.'): + self.relative_attention.apply( + params, + self.query_len, + self.key_len, + bidirectional=True, + decode=True, + mutable=['cache']) + + def test_relative_attention_decode_cache(self): + """Tests that relative embeddings are correctly cached when decode=True.""" + + params = self.relative_attention.init( + jax.random.PRNGKey(0), + self.query_len, + self.key_len, + bidirectional=False, + decode=False) + + # during init, cache is not actually initialized. + self.assertNotIn('cache', params) + + outputs, state = self.relative_attention.apply( + params, + self.query_len, + self.key_len, + bidirectional=False, + decode=True, + mutable=['cache']) + + self.assertEqual(outputs.shape, + (1, self.num_heads, self.query_len, self.key_len)) + + self.assertIn('cached_bias', state['cache']) + + cached_bias = state['cache']['cached_bias'] + + self.assertAlmostEqual(cached_bias[0, 0, 0, 0], 0.55764728, places=5) + self.assertAlmostEqual(cached_bias[0, 1, 2, 1], -0.10935841, places=5) + self.assertAlmostEqual(cached_bias[0, 1, 4, 6], -0.13101986, places=5) + self.assertAlmostEqual(cached_bias[0, 2, 4, 6], 0.39296466, places=5) + + np.testing.assert_array_equal(outputs, state['cache']['cached_bias']) + + params_with_cache = { + **params, + **state, + } + + outputs, state = self.relative_attention.apply( + params_with_cache, + self.query_len, + self.key_len, + bidirectional=False, + decode=True, + mutable=['cache']) + + np.testing.assert_array_equal(cached_bias, state['cache']['cached_bias']) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/examples/decoder_only/models/base.gin b/t5x/examples/decoder_only/models/base.gin new file mode 100644 index 0000000000000000000000000000000000000000..d0bed734241f03e1066b357882d96d72d162ee48 --- /dev/null +++ b/t5x/examples/decoder_only/models/base.gin @@ -0,0 +1,59 @@ +# Decoder-only model (Base) with 134307072 parameters. +from __gin__ import dynamic_registration + +import seqio +from t5x import adafactor +from t5x import decoding +from t5x import models +from t5x.examples.decoder_only import network + +# ------------------- Loss HParam ---------------------------------------------- +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) +# the loss normalizing factor should be set to pretraining batch_size * +# target_token_length. +LOSS_NORMALIZING_FACTOR = None +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED + +# Vocabulary (shared by encoder and decoder) +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" + +# ------------------- Optimizer ------------------------------------------------ +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + logical_factor_rules = @adafactor.standard_logical_factor_rules() + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.DecoderOnlyModel() +models.DecoderOnlyModel: + module = @network.DecoderWrapper() + vocabulary = %VOCABULARY + optimizer_def = %OPTIMIZER + decode_fn = @decoding.temperature_sample + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +decoding.temperature_sample: + temperature = 1.0 + topk = 40 + +# ------------------- Network specification ------------------------------------ +network.DecoderWrapper.config = @network.TransformerConfig() +network.TransformerConfig: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 768 + num_heads = 12 + num_layers = 12 + head_dim = 64 + mlp_dim = 2048 + mlp_activations = ('gelu', 'linear') + dropout_rate = %DROPOUT_RATE + logits_via_embedding = True diff --git a/t5x/examples/decoder_only/models/xxl.gin b/t5x/examples/decoder_only/models/xxl.gin new file mode 100644 index 0000000000000000000000000000000000000000..bd4a6b5c541ddd89d7f80297c047bd9784a3cb59 --- /dev/null +++ b/t5x/examples/decoder_only/models/xxl.gin @@ -0,0 +1,11 @@ +# Decoder-only model (XXL) with 4762357760 parameters. + +include 't5x/examples/decoder_only/models/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.TransformerConfig: + emb_dim = 4096 + num_heads = 64 + num_layers = 24 + head_dim = 64 + mlp_dim = 10240 diff --git a/t5x/examples/decoder_only/network.py b/t5x/examples/decoder_only/network.py new file mode 100644 index 0000000000000000000000000000000000000000..c68164c38a0345147ac80c1260acc679b08aabff --- /dev/null +++ b/t5x/examples/decoder_only/network.py @@ -0,0 +1,241 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal decoder-only Transformer model.""" + +from typing import Any, Optional, Sequence + +from flax import linen as nn +from flax import struct +import jax.numpy as jnp +from t5x.examples.decoder_only import layers + + +@struct.dataclass +class TransformerConfig: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int + # Activation dtypes. + dtype: Any = jnp.float32 + emb_dim: int = 512 + num_heads: int = 8 + num_layers: int = 6 + head_dim: int = 64 + mlp_dim: int = 2048 + # Activation functions are retrieved from Flax. + mlp_activations: Sequence[str] = ('relu',) + dropout_rate: float = 0.1 + # If `True`, the embedding weights are used in the decoder output layer. + logits_via_embedding: bool = False + + +class DecoderLayer(nn.Module): + """Transformer decoder layer.""" + config: TransformerConfig + + @nn.compact + def __call__(self, + inputs: jnp.ndarray, + decoder_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False, + decode: bool = False, + max_decode_length: Optional[int] = None, + prefill: bool = False, + prefill_lengths: Optional[jnp.ndarray] = None): + """Applies decoder block module.""" + cfg = self.config + + # Relative position embedding as attention biases. + l = max_decode_length if decode and max_decode_length else inputs.shape[-2] + + # During decoding, this module will be called with `decode=True` first to + # initialize the decoder cache, including a cached relpos bias. The prefill + # codepath will call this once again with `decode=False`, which is slightly + # wasteful but generally harmless. During subsequent decode steps, this will + # be called with `decode=True` and will reuse the cached bias. This + # significantly improves performance during decoding with many decode steps. + decoder_bias = layers.RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_heads=cfg.num_heads, + dtype=cfg.dtype, + embedding_init=nn.initializers.variance_scaling(1.0, 'fan_avg', + 'uniform'), + name='relpos_bias')( + l, l, False, decode=decode) + + # `inputs` is layer input with a shape [batch, length, emb_dim]. + x = layers.LayerNorm( + dtype=cfg.dtype, name='pre_self_attention_layer_norm')( + inputs) + + # Self-attention block + x = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + name='self_attention')( + x, + x, + decoder_mask, + decoder_bias, + deterministic=deterministic, + decode=decode, + prefill=prefill, + prefill_lengths=prefill_lengths) + x = nn.Dropout( + rate=cfg.dropout_rate, + broadcast_dims=(-2,), + name='post_self_attention_dropout')( + x, deterministic=deterministic) + x = x + inputs + + # MLP block. + y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) + y = layers.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + name='mlp', + )(y, deterministic=deterministic) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,), name='post_mlp_dropout')( + y, deterministic=deterministic) + y = y + x + + return y + + +class Decoder(nn.Module): + """A stack of decoder layers.""" + config: TransformerConfig + + @nn.compact + def __call__(self, + decoder_input_tokens: jnp.ndarray, + decoder_target_tokens: jnp.ndarray, + decoder_segment_ids: Optional[jnp.ndarray] = None, + decoder_positions: Optional[jnp.ndarray] = None, + decoder_causal_attention: Optional[jnp.ndarray] = None, + *, + enable_dropout: bool = True, + decode: bool = False, + max_decode_length: Optional[int] = None, + prefill: Optional[bool] = None, + prefill_lengths: Optional[jnp.ndarray] = None): + """Applies LanguageModel on the inputs. + + For a decoder-only architecture with the notion of "prefix", e.g., a prefix + LM where the prefix corresponds to the "inputs" of a supervised dataset, we + perform the "prefill" operation to fill the autoregressive cache + corresponding to the prefix region in one go. Then the autoregressive + decoding starts after the prefix. This makes the decoding process more + efficient. In addition, it gives an option to use bidirectional attention in + the prefix region because the cache is filled simultaneously. + + Args: + decoder_input_tokens: input token to the decoder. + decoder_target_tokens: target token to the decoder. + decoder_segment_ids: decoder segmentation info for packed examples. + decoder_positions: decoder subsequence positions for packed examples. + decoder_causal_attention: a binary mask indicating the portion of the + sequence to apply bidirectional attention to instead of causal. As an + example, useful to specify the "inputs" portion of a concatenated + sequence for a prefix LM. + enable_dropout: enables dropout if set to True. + decode: whether to prepare and use an autoregressive cache as opposed to + using teacher-forcing. + max_decode_length: maximum sequence length to be decoded. + prefill: whether to run a partial sequence to prefill the cache. + prefill_lengths: an array of shape [batch] denoting the length of each + partial sequence we are filling in the cache. + + Returns: + logits array. + """ + cfg = self.config + deterministic = not enable_dropout + assert decoder_input_tokens.ndim == 2 # [batch, len] + + if decode: + decoder_mask = None + else: + decoder_mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=cfg.dtype, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + + embedding = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + attend_dtype=jnp.float32, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + one_hot=True, + name='token_embedder') + y = embedding(decoder_input_tokens.astype('int32')) + + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,), name='input_dropout')( + y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + for lyr in range(cfg.num_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + y = DecoderLayer( + config=cfg, name=f'layers_{lyr}')( + y, + decoder_mask=decoder_mask, + deterministic=deterministic, + decode=decode, + max_decode_length=max_decode_length, + prefill=prefill, + prefill_lengths=prefill_lengths) + + y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,), name='output_dropout')( + y, deterministic=deterministic) + + # [batch, length, emb_dim] -> [batch, length, vocab_size] + if cfg.logits_via_embedding: + # Use the transpose of embedding matrix for the logit transform. + logits = embedding.attend(y) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + else: + # Use a separate dense layer for the logit transform. + logits = layers.DenseGeneral( + cfg.vocab_size, + dtype=jnp.float32, # Use float32 for stabiliity. + kernel_axes=('embed', 'vocab'), + name='logits_dense')( + y) + return logits + + +# TODO(hwchung): remove this after figuring out the name scope issue. +class DecoderWrapper(nn.Module): + """Thin wrapper for the outer "decoder/" name scope.""" + + config: TransformerConfig + + def setup(self): + self.decoder = Decoder(self.config, name='decoder') + + def __call__(self, *args, **kwargs): + return self.decoder(*args, **kwargs) diff --git a/t5x/examples/decoder_only/network_test.py b/t5x/examples/decoder_only/network_test.py new file mode 100644 index 0000000000000000000000000000000000000000..abad8406971f42135c978182601b2e412f1aea4a --- /dev/null +++ b/t5x/examples/decoder_only/network_test.py @@ -0,0 +1,34 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for network.""" + +import os + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized + +import jax +import numpy as np +from t5x import test_utils + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +FLAGS = flags.FLAGS + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/examples/scalable_t5/README.md b/t5x/examples/scalable_t5/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6b457b9b71664d596644851dc0bc81127e1b0874 --- /dev/null +++ b/t5x/examples/scalable_t5/README.md @@ -0,0 +1,53 @@ +# Scalable T5 + +NB: This particular example is still WIP. We're investigating a slight training +regression compared to the "vanilla" T5 example. + +This directory is very similar to the vanilla T5X "T5" example, but demonstrates +a host of techniques needed to scale model training to giant models run on +large TPU or GPU cluster environments using XLA's SPMD capabilities. See the +notes for the main "t5" example for general details on setup and execution. + +__Note__: many of the APIs built on top of `pjit` by Flax and T5X for easier +model parallel programming are still experimental, and may change. + +## Intermediate variable annotations + +In larger models, with multi-axis model parallelism, it is typically necessary +to provide additional constraint annotations beyond those for the input and +output parameters for a function. We do this using a special version of the +`pjit` annotation function `with_sharding_constraint` that uses _logical_ axis +names instead of raw mesh axes. This allows us to avoid tightly coupling a +specific partitioning plan to the model code itself. Instead, we merely need +to annotate the axis names used in the model in a coherent scheme, and later +map these logical axes to the physical mesh axes using a small set of rules. +Example usage can be seen in `network.py`. + +## Scan over layers + +One challenge with giant models is the increasing amount of compilation time +required to handle extremely large layer stacks in XLA. At the size of a full +TPU pod this compile time cost can become quite extreme. To remedy this, +instead of handing the compiler a huge stack of unrolled layers, we can use +native XLA control flow constructs to simplify the computational graph given +from JAX. For giant models this can drop the compile time from hour(s) to +minutes, and even at base-scale can be roughly 5x faster. + +In this case, we want to use the [XLA While Op](xla-while) via JAX's +[scan](jax-scan) control flow construct to express the idea that we're looping +over identically-defined layers when using a deep transformer network. We do +this via a custom Flax version of scan called `scan_with_axes` that also handles +the parameter logical axis name metadata needed for partitioning. + +## Rematerialization / Checkpointing + +"Rematerialization" or "checkpointing" is a technique for trading off compute +time for lower peak memory utilization when performing reverse-mode automatic +differentiation. JAX offers several different default rematerialization +"policies" that dictate which kinds of intermediate values are preserved from +the forward-pass to the backwards-pass calculation, and which are discarded to +be recomputed anew in the backwards-pass. + + +[xla-while]: https://www.tensorflow.org/xla/operation_semantics#while +[jax-scan]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html diff --git a/t5x/examples/scalable_t5/__init__.py b/t5x/examples/scalable_t5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/scalable_t5/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/scalable_t5/layers.py b/t5x/examples/scalable_t5/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ba950821ec4d32db931465bad1d967653a63aa --- /dev/null +++ b/t5x/examples/scalable_t5/layers.py @@ -0,0 +1,931 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dense attention classes and mask/weighting functions.""" + +# pylint: disable=attribute-defined-outside-init,g-bare-generic + +import dataclasses +import functools +import operator +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union + +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +import jax +from jax import lax +from jax import random +import jax.numpy as jnp +import numpy as np + + +# from flax.linen.partitioning import param_with_axes, with_sharding_constraint +param_with_axes = nn_partitioning.param_with_axes +with_sharding_constraint = nn_partitioning.with_sharding_constraint + + +# Type annotations +Array = jnp.ndarray +DType = jnp.dtype +PRNGKey = jnp.ndarray +Shape = Iterable[int] +Activation = Callable[..., Array] +# Parameter initializers. +Initializer = Callable[[PRNGKey, Shape, DType], Array] +InitializerAxis = Union[int, Tuple[int, ...]] +NdInitializer = Callable[ + [PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array] + +default_embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0) + + +# ------------------------------------------------------------------------------ +# Temporary inlined JAX N-d initializer code +# TODO(levskaya): remove once new JAX release is out. +# ------------------------------------------------------------------------------ +def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1): + """Inlined JAX `nn.initializer._compute_fans`.""" + if isinstance(in_axis, int): + in_size = shape[in_axis] + else: + in_size = int(np.prod([shape[i] for i in in_axis])) + if isinstance(out_axis, int): + out_size = shape[out_axis] + else: + out_size = int(np.prod([shape[i] for i in out_axis])) + receptive_field_size = shape.total / in_size / out_size + fan_in = in_size * receptive_field_size + fan_out = out_size * receptive_field_size + return fan_in, fan_out + + +def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, + dtype=jnp.float_): + """Inlined JAX `nn.initializer.variance_scaling`.""" + + def init(key, shape, dtype=dtype): + dtype = jax.dtypes.canonicalize_dtype(dtype) + shape = jax.core.as_named_shape(shape) + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) + if mode == 'fan_in': + denominator = fan_in + elif mode == 'fan_out': + denominator = fan_out + elif mode == 'fan_avg': + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError( + 'invalid mode for variance scaling initializer: {}'.format(mode)) + variance = jnp.array(scale / denominator, dtype=dtype) + + if distribution == 'truncated_normal': + # constant is stddev of standard normal truncated to (-2, 2) + stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) + return random.truncated_normal(key, -2, 2, shape, dtype) * stddev + elif distribution == 'normal': + return random.normal(key, shape, dtype) * jnp.sqrt(variance) + elif distribution == 'uniform': + return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) + else: + raise ValueError('invalid distribution for variance scaling ' + 'initializer: {}'.format(distribution)) + return init +# ------------------------------------------------------------------------------ + + +def nd_dense_init(scale, mode, distribution): + """Initializer with in_axis, out_axis set at call time.""" + def init_fn(key, shape, dtype, in_axis, out_axis): + fn = variance_scaling( + scale, mode, distribution, in_axis, out_axis) + return fn(key, shape, dtype) + return init_fn + + +def dot_product_attention(query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0., + deterministic: bool = False, + dtype: DType = jnp.float32, + float32_logits: bool = False): + """Computes dot-product attention given query, key, and value. + + This is the core function for applying attention based on + https://arxiv.org/abs/1706.03762. It calculates the attention weights given + query and key and combines the values using the attention weights. + + Args: + query: queries for calculating attention with shape of `[batch, q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch, kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch, kv_length, + num_heads, v_depth_per_head]`. + bias: bias for the attention weights. This should be broadcastable to the + shape `[batch, num_heads, q_length, kv_length]` This can be used for + incorporating causal masks, padding masks, proximity bias, etc. + dropout_rng: JAX PRNGKey: to be used for dropout + dropout_rate: dropout rate + deterministic: bool, deterministic or not (to apply dropout) + dtype: the dtype of the computation (default: float32) + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + + Returns: + Output of shape `[batch, length, num_heads, v_depth_per_head]`. + """ + assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( + 'q, k, v batch dims must match.') + assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( + 'q, k, v num_heads must match.') + assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' + assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + + # Casting logits and softmax computation for float32 for model stability. + if float32_logits: + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) + + # `attn_weights`: [batch, num_heads, q_length, kv_length] + attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + + # Apply attention bias: masking, dropout, proximity bias, etc. + if bias is not None: + attn_weights = attn_weights + bias.astype(attn_weights.dtype) + + # Normalize the attention weights across `kv_length` dimension. + attn_weights = jax.nn.softmax(attn_weights).astype(dtype) + + # Apply attention dropout. + if not deterministic and dropout_rate > 0.: + keep_prob = 1.0 - dropout_rate + # T5 broadcasts along the "length" dim, but unclear which one that + # corresponds to in positional dimensions here, assuming query dim. + dropout_shape = list(attn_weights.shape) + dropout_shape[-2] = 1 + keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) + keep = jnp.broadcast_to(keep, attn_weights.shape) + multiplier = ( + keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) + attn_weights = attn_weights * multiplier + + # Take the linear combination of `value`. + return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) + + +dynamic_vector_slice_in_dim = jax.vmap( + lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) + + +class MultiHeadDotProductAttention(nn.Module): + """Multi-head dot-product attention. + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + """ + + num_heads: int + head_dim: int + dtype: DType = jnp.float32 + dropout_rate: float = 0. + kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') + float32_logits: bool = False # computes logits in float32 for stability. + + @nn.compact + def __call__(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + decode: bool = False, + deterministic: bool = False) -> Array: + """Applies multi-head dot product attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. + + There are two modes: decoding and non-decoding (e.g., training). The mode is + determined by `decode` argument. For decoding, this method is called twice, + first to initialize the cache and then for an actual decoding process. The + two calls are differentiated by the presence of 'cached_key' in the variable + dict. In the cache initialization stage, the cache variables are initialized + as zeros and will be filled in the subsequent decoding process. + + In the cache initialization call, `inputs_q` has a shape [batch, length, + q_features] and `inputs_kv`: [batch, length, kv_features]. During the + incremental decoding stage, query, key and value all have the shape [batch, + 1, qkv_features] corresponding to a single step. + + Args: + inputs_q: input queries of shape `[batch, q_length, q_features]`. + inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. + mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. + bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. + decode: Whether to prepare and use an autoregressive cache. + deterministic: Disables dropout if set to True. + + Returns: + output of shape `[batch, length, q_features]`. + """ + projection = functools.partial( + DenseGeneral, + axis=-1, + features=(self.num_heads, self.head_dim), + kernel_axes=('embed', 'heads', 'kv'), + dtype=self.dtype) + + # NOTE: T5 does not explicitly rescale the attention logits by + # 1/sqrt(depth_kq)! This is folded into the initializers of the + # linear transformations, which is equivalent under Adafactor. + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + query_init = lambda *args: self.kernel_init(*args) / depth_scaling + + # Project inputs_q to multi-headed q/k/v + # dimensions are then [batch, length, num_heads, head_dim] + query = projection(kernel_init=query_init, name='query')(inputs_q) + key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) + value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) + + query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) + key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) + value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) + + if decode: + # Detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable('cache', 'cached_key') + # The key and value have dimension [batch, length, num_heads, head_dim], + # but we cache them as [batch, num_heads, head_dim, length] as a TPU + # fusion optimization. This also enables the "scatter via one-hot + # broadcast" trick, which means we do a one-hot broadcast instead of a + # scatter/gather operations, resulting in a 3-4x speedup in practice. + swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) + cached_key = self.variable('cache', 'cached_key', jnp.zeros, + swap_dims(key.shape), key.dtype) + cached_value = self.variable('cache', 'cached_value', jnp.zeros, + swap_dims(value.shape), value.dtype) + cache_index = self.variable('cache', 'cache_index', + lambda: jnp.array(0, dtype=jnp.int32)) + if is_initialized: + batch, num_heads, head_dim, length = (cached_key.value.shape) + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + # Sanity shape check of cached key against input query. + expected_shape = (batch, 1, num_heads, head_dim) + if expected_shape != query.shape: + raise ValueError('Autoregressive cache shape error, ' + 'expected query shape %s instead got %s.' % + (expected_shape, query.shape)) + + # Create a OHE of the current index. NOTE: the index is increased below. + cur_index = cache_index.value + one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) + # In order to update the key, value caches with the current key and + # value, we move the length axis to the back, similar to what we did for + # the cached ones above. + # Note these are currently the key and value of a single position, since + # we feed one position at a time. + one_token_key = jnp.moveaxis(key, -3, -1) + one_token_value = jnp.moveaxis(value, -3, -1) + # Update key, value caches with our new 1d spatial slices. + # We implement an efficient scatter into the cache via one-hot + # broadcast and addition. + key = cached_key.value + one_token_key * one_hot_indices + value = cached_value.value + one_token_value * one_hot_indices + cached_key.value = key + cached_value.value = value + cache_index.value = cache_index.value + 1 + # Move the keys and values back to their original shapes. + key = jnp.moveaxis(key, -1, -3) + value = jnp.moveaxis(value, -1, -3) + + # Causal mask for cached decoder self-attention: our single query + # position should only attend to those key positions that have already + # been generated and cached, not the remaining zero elements. + mask = combine_masks( + mask, + jnp.broadcast_to( + jnp.arange(length) <= cur_index, + # (1, 1, length) represent (head dim, query length, key length) + # query length is 1 because during decoding we deal with one + # index. + # The same mask is applied to all batch elements and heads. + (batch, 1, 1, length))) + + # Grab the correct relative attention bias during decoding. This is + # only required during single step decoding. + if bias is not None: + # The bias is a full attention matrix, but during decoding we only + # have to take a slice of it. + # This is equivalent to bias[..., cur_index:cur_index+1, :]. + bias = dynamic_vector_slice_in_dim( + jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) + + # Convert the boolean attention mask to an attention bias. + if mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + mask > 0, + jnp.full(mask.shape, 0.).astype(self.dtype), + jnp.full(mask.shape, -1e10).astype(self.dtype)) + else: + attention_bias = None + + # Add provided bias term (e.g. relative position embedding). + if bias is not None: + attention_bias = combine_biases(attention_bias, bias) + + dropout_rng = None + if not deterministic and self.dropout_rate > 0.: + dropout_rng = self.make_rng('dropout') + + # Apply attention. + x = dot_product_attention( + query, + key, + value, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout_rate, + deterministic=deterministic, + dtype=self.dtype, + float32_logits=self.float32_logits) + + # Back to the original inputs dimensions. + out = DenseGeneral( + features=inputs_q.shape[-1], # output dim is set to the input dim. + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=('heads', 'kv', 'embed'), + dtype=self.dtype, + name='out')( + x) + return out + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +#------------------------------------------------------------------------------ +# DenseGeneral for attention layers. +#------------------------------------------------------------------------------ +class DenseGeneral(nn.Module): + """A linear transformation (without bias) with flexible axes. + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + """ + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + dtype: DType = jnp.float32 + kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + kernel_axes: Tuple[str, ...] = () + + @nn.compact + def __call__(self, inputs: Array) -> Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features + kernel_in_axis = np.arange(len(axis)) + kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = param_with_axes( + 'kernel', + self.kernel_init, + kernel_shape, + jnp.float32, + kernel_in_axis, + kernel_out_axis, + axes=self.kernel_axes) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) + + +def _convert_to_activation_function( + fn_or_string: Union[str, Callable]) -> Callable: + """Convert a string to an activation function.""" + if fn_or_string == 'linear': + return lambda x: x + elif isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + elif callable(fn_or_string): + return fn_or_string + else: + raise ValueError("don't know how to convert %s to an activation function" % + (fn_or_string,)) + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block. + + Attributes: + intermediate_dim: Shared dimension of hidden layers. + activations: Type of activations for each layer. Each element is either + 'linear', a string function name in flax.linen, or a function. + kernel_init: Kernel function, passed to the dense layers. + deterministic: Whether the dropout layers should be deterministic. + intermediate_dropout_rate: Dropout rate used after the intermediate layers. + dtype: Type for the dense layer. + """ + intermediate_dim: int = 2048 + activations: Sequence[Union[str, Callable]] = ('relu',) + kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + intermediate_dropout_rate: float = 0.1 + dtype: Any = jnp.float32 + + @nn.compact + def __call__(self, inputs, decode: bool = False, deterministic: bool = False): + """Applies Transformer MlpBlock module.""" + # Iterate over specified MLP input activation functions. + # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. + activations = [] + for idx, act_fn in enumerate(self.activations): + dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' + x = DenseGeneral( + self.intermediate_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=('embed', 'mlp'), + name=dense_name)( + inputs) + x = _convert_to_activation_function(act_fn)(x) + activations.append(x) + + # Take elementwise product of above intermediate activations. + x = functools.reduce(operator.mul, activations) + # Apply dropout and final dense output projection. + x = nn.Dropout( + rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) # Broadcast along length. + x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) + output = DenseGeneral( + inputs.shape[-1], + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=('mlp', 'embed'), + name='wo')( + x) + return output + + +class Embed(nn.Module): + """A parameterized function from integers [0, n) to d-dimensional vectors. + + Attributes: + num_embeddings: number of embeddings. + features: number of feature dimensions for each embedding. + dtype: the dtype of the embedding vectors (default: float32). + embedding_init: embedding initializer. + one_hot: performs the gather with a one-hot contraction rather than a true + gather. This is currently needed for SPMD partitioning. + """ + num_embeddings: int + features: int + cast_input_dtype: Optional[DType] = None + dtype: DType = jnp.float32 + attend_dtype: Optional[DType] = None + embedding_init: Initializer = default_embed_init + one_hot: bool = False + embedding: Array = dataclasses.field(init=False) + + def setup(self): + self.embedding = param_with_axes( + 'embedding', + self.embedding_init, (self.num_embeddings, self.features), + jnp.float32, + axes=('vocab', 'embed')) + + def __call__(self, inputs: Array) -> Array: + """Embeds the inputs along the last dimension. + + Args: + inputs: input data, all dimensions are considered batch dimensions. + + Returns: + Output which is embedded input data. The output shape follows the input, + with an additional `features` dimension appended. + """ + if self.cast_input_dtype: + inputs = inputs.astype(self.cast_input_dtype) + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError('Input type must be an integer or unsigned integer.') + if self.one_hot: + iota = lax.iota(jnp.int32, self.num_embeddings) + one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) + output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) + else: + output = jnp.asarray(self.embedding, self.dtype)[inputs] + output = with_sharding_constraint(output, ('batch', 'length', 'embed')) + return output + + def attend(self, query: Array) -> Array: + """Attend over the embedding using a query array. + + Args: + query: array with last dimension equal the feature depth `features` of the + embedding. + + Returns: + An array with final dim `num_embeddings` corresponding to the batched + inner-product of the array of query vectors against each embedding. + Commonly used for weight-sharing between embeddings and logit transform + in NLP models. + """ + dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype + return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) + + +class RelativePositionBiases(nn.Module): + """Adds T5-style relative positional embeddings to the attention logits. + + Attributes: + num_buckets: Number of buckets to bucket distances between key and query + positions into. + max_distance: Maximum distance before everything is lumped into the last + distance bucket. + num_heads: Number of heads in the attention layer. Each head will get a + different relative position weighting. + dtype: Type of arrays through this module. + embedding_init: initializer for relative embedding table. + """ + num_buckets: int + max_distance: int + num_heads: int + dtype: Any + embedding_init: Callable[..., Array] = nn.linear.default_embed_init + + @staticmethod + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """Translate relative position to a bucket number for relative attention. + + The relative position is defined as memory_position - query_position, i.e. + the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are + invalid. + We use smaller buckets for small absolute relative_position and larger + buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative + positions <=-max_distance map to the same bucket. This should allow for + more graceful generalization to longer sequences than the model has been + trained on. + + Args: + relative_position: an int32 array + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ + ret = 0 + n = -relative_position + if bidirectional: + num_buckets //= 2 + ret += (n < 0).astype(np.int32) * num_buckets + n = np.abs(n) + else: + n = np.maximum(n, 0) + # now n is in the range [0, inf) + max_exact = num_buckets // 2 + is_small = (n < max_exact) + val_if_large = max_exact + ( + np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / + np.log(max_distance / max_exact) * + (num_buckets - max_exact)).astype(np.int32) + val_if_large = np.minimum(val_if_large, num_buckets - 1) + ret += np.where(is_small, n, val_if_large) + return ret + + @nn.compact + def __call__(self, qlen, klen, bidirectional=True): + """Produce relative position embedding attention biases. + + Args: + qlen: attention query length. + klen: attention key length. + bidirectional: whether to allow positive memory-query relative position + embeddings. + + Returns: + output: `(1, len, q_len, k_len)` attention bias + """ + # TODO(levskaya): should we be computing this w. numpy as a program + # constant? + context_position = np.arange(qlen, dtype=jnp.int32)[:, None] + memory_position = np.arange(klen, dtype=jnp.int32)[None, :] + relative_position = memory_position - context_position # shape (qlen, klen) + rp_bucket = self._relative_position_bucket( + relative_position, + bidirectional=bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance) + relative_attention_bias = param_with_axes( + 'rel_embedding', + self.embedding_init, (self.num_heads, self.num_buckets), + jnp.float32, + axes=('heads', 'relpos_buckets')) + + relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) + # Instead of using a slow gather, we create a leading-dimension one-hot + # array from rp_bucket and use it to perform the gather-equivalent via a + # contraction, i.e.: + # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). + # This is equivalent to relative_attention_bias[:, rp_bucket] + bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) + rp_bucket_one_hot = jnp.array( + rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) + # --> shape (qlen, klen, num_heads) + values = lax.dot_general( + relative_attention_bias, + rp_bucket_one_hot, + ( + ((1,), (0,)), # rhs, lhs contracting dims + ((), ()))) # no batched dims + # Add a singleton batch dimension. + # --> shape (1, num_heads, qlen, klen) + return values[jnp.newaxis, ...] + + +#------------------------------------------------------------------------------ +# T5 Layernorm - no subtraction of mean or bias. +#------------------------------------------------------------------------------ +class LayerNorm(nn.Module): + """T5 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 + dtype: Any = jnp.float32 + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Applies layer normalization on the input.""" + x = jnp.asarray(x, jnp.float32) + features = x.shape[-1] + mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) + scale = param_with_axes( + 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',)) + + scale = jnp.asarray(scale, self.dtype) + return y * scale + + +#------------------------------------------------------------------------------ +# Mask-making utility functions. +#------------------------------------------------------------------------------ +def make_attention_mask(query_input: Array, + key_input: Array, + pairwise_fn: Callable = jnp.multiply, + extra_batch_dims: int = 0, + dtype: DType = jnp.float32) -> Array: + """Mask-making helper for attention weights. + + In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the + attention weights will be `[batch, heads, len_q, len_kv]` and this + function will produce `[batch, 1, len_q, len_kv]`. + + Args: + query_input: a batched, flat input of query_length size + key_input: a batched, flat input of key_length size + pairwise_fn: broadcasting elementwise comparison function + extra_batch_dims: number of extra batch dims to add singleton axes for, none + by default + dtype: mask return dtype + + Returns: + A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. + """ + # [batch, len_q, len_kv] + mask = pairwise_fn( + # [batch, len_q] -> [batch, len_q, 1] + jnp.expand_dims(query_input, axis=-1), + # [batch, len_q] -> [batch, 1, len_kv] + jnp.expand_dims(key_input, axis=-2)) + + # [batch, 1, len_q, len_kv]. This creates the head dim. + mask = jnp.expand_dims(mask, axis=-3) + mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) + return mask.astype(dtype) + + +def make_causal_mask(x: Array, + extra_batch_dims: int = 0, + dtype: DType = jnp.float32) -> Array: + """Make a causal mask for self-attention. + + In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights + will be `[batch, heads, len, len]` and this function will produce a + causal mask of shape `[batch, 1, len, len]`. + + Note that a causal mask does not depend on the values of x; it only depends on + the shape. If x has padding elements, they will not be treated in a special + manner. + + Args: + x: input array of shape `[batch, len]` + extra_batch_dims: number of batch dims to add singleton axes for, none by + default + dtype: mask return dtype + + Returns: + A `[batch, 1, len, len]` shaped causal mask for 1d attention. + """ + idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) + return make_attention_mask( + idxs, + idxs, + jnp.greater_equal, + extra_batch_dims=extra_batch_dims, + dtype=dtype) + + +def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): + """Combine attention masks. + + Args: + *masks: set of attention mask arguments to combine, some can be None. + dtype: final mask dtype + + Returns: + Combined mask, reduced by logical and, returns None if no masks given. + """ + masks = [m for m in masks if m is not None] + if not masks: + return None + assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( + f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + mask, *other_masks = masks + for other_mask in other_masks: + mask = jnp.logical_and(mask, other_mask) + return mask.astype(dtype) + + +def combine_biases(*masks: Optional[Array]): + """Combine attention biases. + + Args: + *masks: set of attention bias arguments to combine, some can be None. + + Returns: + Combined mask, reduced by summation, returns None if no masks given. + """ + masks = [m for m in masks if m is not None] + if not masks: + return None + assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( + f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + mask, *other_masks = masks + for other_mask in other_masks: + mask = mask + other_mask + return mask + + +def make_decoder_mask(decoder_target_tokens: Array, + dtype: DType, + decoder_causal_attention: Optional[Array] = None, + decoder_segment_ids: Optional[Array] = None) -> Array: + """Compute the self-attention mask for a decoder. + + Decoder mask is formed by combining a causal mask, a padding mask and an + optional packing mask. If decoder_causal_attention is passed, it makes the + masking non-causal for positions that have value of 1. + + A prefix LM is applied to a dataset which has a notion of "inputs" and + "targets", e.g., a machine translation task. The inputs and targets are + concatenated to form a new target. `decoder_target_tokens` is the concatenated + decoder output tokens. + + The "inputs" portion of the concatenated sequence can attend to other "inputs" + tokens even for those at a later time steps. In order to control this + behavior, `decoder_causal_attention` is necessary. This is a binary mask with + a value of 1 indicating that the position belonged to "inputs" portion of the + original dataset. + + Example: + + Suppose we have a dataset with two examples. + + ds = [{"inputs": [6, 7], "targets": [8]}, + {"inputs": [3, 4], "targets": [5]}] + + After the data preprocessing with packing, the two examples are packed into + one example with the following three fields (some fields are skipped for + simplicity). + + decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] + decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] + decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] + + where each array has [batch, length] shape with batch size being 1. Then, + this function computes the following mask. + + mask = [[[[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + + mask[b, 1, :, :] represents the mask for the example `b` in the batch. + Because mask is for a self-attention layer, the mask's shape is a square of + shape [query length, key length]. + + mask[b, 1, i, j] = 1 means that the query token at position i can attend to + the key token at position j. + + Args: + decoder_target_tokens: decoder output tokens. [batch, length] + dtype: dtype of the output mask. + decoder_causal_attention: a binary mask indicating which position should + only attend to earlier positions in the sequence. Others will attend + bidirectionally. [batch, length] + decoder_segment_ids: decoder segmentation info for packed examples. [batch, + length] + + Returns: + the combined decoder mask. + """ + masks = [] + # The same mask is applied to all attention heads. So the head dimension is 1, + # i.e., the mask will be broadcast along the heads dim. + # [batch, 1, length, length] + causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) + + # Positions with value 1 in `decoder_causal_attneition` can attend + # bidirectionally. + if decoder_causal_attention is not None: + # [batch, 1, length, length] + inputs_mask = make_attention_mask( + decoder_causal_attention, + decoder_causal_attention, + jnp.logical_and, + dtype=dtype) + masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) + else: + masks.append(causal_mask) + + # Padding mask. + masks.append( + make_attention_mask( + decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype)) + + # Packing mask + if decoder_segment_ids is not None: + masks.append( + make_attention_mask( + decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype)) + + return combine_masks(*masks, dtype=dtype) diff --git a/t5x/examples/scalable_t5/layers_test.py b/t5x/examples/scalable_t5/layers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7f80f2c782f39bb18bd7147f3f18c276c52ba176 --- /dev/null +++ b/t5x/examples/scalable_t5/layers_test.py @@ -0,0 +1,620 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for attention classes.""" + +import dataclasses +from typing import Optional +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from flax import linen as nn +from flax.core import freeze +from flax.linen import partitioning as nn_partitioning +import jax +from jax import random +from jax.nn import initializers +import jax.numpy as jnp +import numpy as np +from t5x.examples.scalable_t5 import layers + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +Array = jnp.ndarray +AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name + + +class SelfAttention(layers.MultiHeadDotProductAttention): + """Self-attention special case of multi-head dot-product attention.""" + + @nn.compact + def __call__(self, + inputs_q: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + deterministic: bool = False): + return super().__call__( + inputs_q, inputs_q, mask, bias, deterministic=deterministic) + + +@dataclasses.dataclass(frozen=True) +class SelfAttentionArgs: + num_heads: int = 1 + batch_size: int = 2 + # qkv_features: int = 3 + head_dim: int = 3 + # out_features: int = 4 + q_len: int = 5 + features: int = 6 + dropout_rate: float = 0.1 + deterministic: bool = False + decode: bool = False + float32_logits: bool = False + + def __post_init__(self): + # If we are doing decoding, the query length should be 1, because are doing + # autoregressive decoding where we feed one position at a time. + assert not self.decode or self.q_len == 1 + + def init_args(self): + return dict( + num_heads=self.num_heads, + head_dim=self.head_dim, + dropout_rate=self.dropout_rate, + float32_logits=self.float32_logits) + + def apply_args(self): + inputs_q = jnp.ones((self.batch_size, self.q_len, self.features)) + mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) + bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) + return { + 'inputs_q': inputs_q, + 'mask': mask, + 'bias': bias, + 'deterministic': self.deterministic + } + + +class AttentionTest(parameterized.TestCase): + + def test_dot_product_attention_shape(self): + # This test only checks for shape but tries to make sure all code paths are + # reached. + dropout_rng = random.PRNGKey(0) + batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6 + + query = jnp.ones((batch_size, q_len, num_heads, qk_depth)) + key = jnp.ones((batch_size, kv_len, num_heads, qk_depth)) + value = jnp.ones((batch_size, kv_len, num_heads, v_depth)) + bias = jnp.ones((batch_size, num_heads, q_len, kv_len)) + + args = dict( + query=query, + key=key, + value=value, + bias=bias, + dropout_rng=dropout_rng, + dropout_rate=0.5, + deterministic=False, + ) + + output = layers.dot_product_attention(**args) + self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth)) + + def test_make_attention_mask_multiply_pairwise_fn(self): + decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]]) + attention_mask = layers.make_attention_mask( + decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32) + expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]) + self.assertEqual(attention_mask.shape, (2, 1, 3, 3)) + np.testing.assert_array_equal(attention_mask[0, 0], expected0) + np.testing.assert_array_equal(attention_mask[1, 0], expected1) + + def test_make_attention_mask_equal_pairwise_fn(self): + segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]]) + attention_mask = layers.make_attention_mask( + segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32) + # Padding is not treated in a special way. So they need to be zeroed out + # separately. + expected0 = jnp.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]]) + expected1 = jnp.array([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1]]) + self.assertEqual(attention_mask.shape, (2, 1, 6, 6)) + np.testing.assert_array_equal(attention_mask[0, 0], expected0) + np.testing.assert_array_equal(attention_mask[1, 0], expected1) + + def test_make_causal_mask_with_padding(self): + x = jnp.array([[7, 0, 0], [8, 5, 0]]) + y = layers.make_causal_mask(x) + self.assertEqual(y.shape, (2, 1, 3, 3)) + # Padding is not treated in a special way. So they need to be zeroed out + # separately. + expected_y = jnp.array([[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]], + jnp.float32) + np.testing.assert_allclose(y[0], expected_y) + np.testing.assert_allclose(y[1], expected_y) + + def test_make_causal_mask_extra_batch_dims(self): + x = jnp.ones((3, 3, 5)) + y = layers.make_causal_mask(x, extra_batch_dims=2) + self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5)) + + def test_make_causal_mask(self): + x = jnp.ones((1, 3)) + y = layers.make_causal_mask(x) + self.assertEqual(y.shape, (1, 1, 3, 3)) + expected_y = jnp.array([[[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]]], + jnp.float32) + np.testing.assert_allclose(y, expected_y) + + def test_combine_masks(self): + masks = [ + jnp.array([0, 1, 0, 1], jnp.float32), None, + jnp.array([1, 1, 1, 1], jnp.float32), + jnp.array([1, 1, 1, 0], jnp.float32) + ] + y = layers.combine_masks(*masks) + np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32)) + + def test_combine_biases(self): + masks = [ + jnp.array([0, 1, 0, 1], jnp.float32), None, + jnp.array([0, 1, 1, 1], jnp.float32), + jnp.array([0, 1, 1, 0], jnp.float32) + ] + y = layers.combine_biases(*masks) + np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32)) + + def test_make_decoder_mask_lm_unpacked(self): + decoder_target_tokens = jnp.array([6, 7, 3, 0]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32) + expected_mask = jnp.array([[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], + [0, 0, 0, 0]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_lm_packed(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_segment_ids=decoder_segment_ids) + expected_mask = jnp.array([[[[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_unpacked(self): + decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask = jnp.array( + [[[[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]], + dtype=jnp.float32) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_packed(self): + decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + expected_mask = jnp.array([[[[1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask0 = jnp.array([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], + [0, 0, 0, 0]]) + expected_mask1 = jnp.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], + [0, 0, 0, 0]]) + self.assertEqual(mask.shape, (2, 1, 4, 4)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + np.testing.assert_array_equal(mask[1, 0], expected_mask1) + + def test_make_decoder_mask_composite_causal_attention(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]) + + self.assertEqual(mask.shape, (1, 1, 7, 7)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + + def test_make_decoder_mask_composite_causal_attention_packed(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1]]) + + self.assertEqual(mask.shape, (1, 1, 9, 9)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + + @parameterized.parameters({'f': 20}, {'f': 22}) + def test_multihead_dot_product_attention(self, f): + # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim + b, q, h, d, k = 2, 3, 4, 5, 6 + + base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + np.random.seed(0) + inputs_q = np.random.randn(b, q, f) + inputs_kv = np.random.randn(b, k, f) + + # Projection: [b, q, f] -> [b, q, h, d] + # So the kernels have to be [f, h, d] + query_kernel = np.random.randn(f, h, d) + key_kernel = np.random.randn(f, h, d) + value_kernel = np.random.randn(f, h, d) + # `out` calculation: [b, q, h, d] -> [b, q, f] + # So kernel has to be [h, d, f] + out_kernel = np.random.randn(h, d, f) + + params = { + 'query': { + 'kernel': query_kernel + }, + 'key': { + 'kernel': key_kernel + }, + 'value': { + 'kernel': value_kernel + }, + 'out': { + 'kernel': out_kernel + } + } + y = layers.MultiHeadDotProductAttention(**args).apply( + {'params': freeze(params)}, inputs_q, inputs_kv) + + query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel) + key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel) + value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel) + logits = np.einsum('bqhd,bkhd->bhqk', query, key) + weights = nn.softmax(logits, axis=-1) + combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value) + y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel) + np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5) + + def test_multihead_dot_product_attention_caching(self): + # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim + b, h, d, k = 2, 3, 4, 5 + f = h * d + + base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + cache = { + 'cached_key': np.zeros((b, h, d, k)), + 'cached_value': np.zeros((b, h, d, k)), + 'cache_index': np.array(0) + } + inputs_q = np.random.randn(b, 1, f) + inputs_kv = np.random.randn(b, 1, f) + + # Mock dense general such that q, k, v projections are replaced by simple + # reshaping. + def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument + return x.reshape(b, -1, h, d) + + with mock.patch.object( + layers.DenseGeneral, '__call__', new=mock_dense_general): + _, mutated = layers.MultiHeadDotProductAttention(**args).apply( + {'cache': freeze(cache)}, + inputs_q, + inputs_kv, + decode=True, + mutable=['cache']) + updated_cache = mutated['cache'] + + # Perform the same mocked projection to generate the expected cache. + # (key|value): [b, 1, h, d] + key = mock_dense_general(None, inputs_kv) + value = mock_dense_general(None, inputs_kv) + + # cached_(key|value): [b, h, d, k] + cache['cached_key'][:, :, :, 0] = key[:, 0, :, :] + cache['cached_value'][:, :, :, 0] = value[:, 0, :, :] + cache['cache_index'] = np.array(1) + for name, array in cache.items(): + np.testing.assert_allclose(array, updated_cache[name]) + + def test_dot_product_attention(self): + # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim + b, q, h, d, k = 2, 3, 4, 5, 6 + np.random.seed(0) + query = np.random.randn(b, q, h, d) + key = np.random.randn(b, k, h, d) + value = np.random.randn(b, k, h, d) + bias = np.random.randn(b, h, q, k) + attn_out = layers.dot_product_attention(query, key, value, bias=bias) + logits = np.einsum('bqhd,bkhd->bhqk', query, key) + weights = jax.nn.softmax(logits + bias, axis=-1) + expected = np.einsum('bhqk,bkhd->bqhd', weights, value) + np.testing.assert_allclose(attn_out, expected, atol=1e-6) + + +class EmbeddingTest(parameterized.TestCase): + + def test_embedder_raises_exception_for_incorrect_input_type(self): + """Tests that inputs are integers and that an exception is raised if not.""" + embed = layers.Embed(num_embeddings=10, features=5) + inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + bad_inputs = inputs.astype(np.float32) + with self.assertRaisesRegex( + ValueError, 'Input type must be an integer or unsigned integer.'): + _ = embed.apply(variables, bad_inputs) + + @parameterized.named_parameters( + { + 'testcase_name': 'with_ones', + 'init_fn': jax.nn.initializers.ones, + 'num_embeddings': 10, + 'features': 5, + 'matrix_sum': 5 * 10, + }, { + 'testcase_name': 'with_zeros', + 'init_fn': jax.nn.initializers.zeros, + 'num_embeddings': 10, + 'features': 5, + 'matrix_sum': 0, + }) + def test_embedding_initializes_correctly(self, init_fn, num_embeddings, + features, matrix_sum): + """Tests if the Embed class initializes with the requested initializer.""" + embed = layers.Embed( + num_embeddings=num_embeddings, + features=features, + embedding_init=init_fn) + inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + embedding_matrix = variables['params']['embedding'] + self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum) + + def test_embedding_matrix_shape(self): + """Tests that the embedding matrix has the right shape.""" + num_embeddings = 10 + features = 5 + embed = layers.Embed(num_embeddings=num_embeddings, features=features) + inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + embedding_matrix = variables['params']['embedding'] + self.assertEqual((num_embeddings, features), embedding_matrix.shape) + + def test_embedding_attend(self): + """Tests that attending with ones returns sum of embedding vectors.""" + features = 5 + embed = layers.Embed(num_embeddings=10, features=features) + inputs = np.array([[1]], dtype=np.int64) + variables = embed.init(jax.random.PRNGKey(0), inputs) + query = np.ones(features, dtype=np.float32) + result = embed.apply(variables, query, method=embed.attend) + expected = np.sum(variables['params']['embedding'], -1) + np.testing.assert_array_almost_equal(result, expected) + + +class DenseTest(parameterized.TestCase): + + def test_dense_general_no_bias(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 3)) + model = layers.DenseGeneral( + features=4, + kernel_init=lambda k, s, d, ai, ao: initializers.ones(k, s, d), + ) + y, _ = model.init_with_output(rng, x) + self.assertEqual(y.shape, (1, 4)) + np.testing.assert_allclose(y, np.full((1, 4), 3.)) + + def test_dense_general_two_features(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 3)) + model = layers.DenseGeneral( + features=(2, 2), + kernel_init=lambda k, s, d, ai, ao: initializers.ones(k, s, d), + ) + y, _ = model.init_with_output(rng, x) + # We transform the last input dimension to two output dimensions (2, 2). + np.testing.assert_allclose(y, np.full((1, 2, 2), 3.)) + + def test_dense_general_two_axes(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 2, 2)) + model = layers.DenseGeneral( + features=3, + axis=(-2, 2), # Note: this is the same as (1, 2). + kernel_init=lambda k, s, d, ai, ao: initializers.ones(k, s, d), + ) + y, _ = model.init_with_output(rng, x) + # We transform the last two input dimensions (2, 2) to one output dimension. + np.testing.assert_allclose(y, np.full((1, 3), 4.)) + + def test_mlp_same_out_dim(self): + module = layers.MlpBlock( + intermediate_dim=4, + activations=('relu',), + kernel_init=layers.nd_dense_init(1.0, 'fan_avg', 'uniform'), + dtype=jnp.float32, + ) + inputs = np.array( + [ + # Batch 1. + [[1, 1], [1, 1], [1, 2]], + # Batch 2. + [[2, 2], [3, 1], [2, 2]], + ], + dtype=np.float32) + params = module.init(random.PRNGKey(0), inputs, deterministic=True) + self.assertEqual( + jax.tree_map(lambda a: a.tolist(), params), { + 'params': { + 'wi': { + 'kernel': [[ + -0.8675811290740967, 0.08417510986328125, + 0.022586345672607422, -0.9124102592468262 + ], + [ + -0.19464373588562012, 0.49809837341308594, + 0.7808468341827393, 0.9267289638519287 + ]], + }, + 'wo': { + 'kernel': [[0.01154780387878418, 0.1397249698638916], + [0.974980354309082, 0.5903260707855225], + [-0.05997943878173828, 0.616570234298706], + [0.2934272289276123, 0.8181164264678955]], + }, + }, + 'params_axes': { + 'wi': { + 'kernel_axes': AxisMetadata(names=('embed', 'mlp')), + }, + 'wo': { + 'kernel_axes': AxisMetadata(names=('mlp', 'embed')), + }, + }, + }) + result = module.apply(params, inputs, deterministic=True) + np.testing.assert_allclose( + result.tolist(), + [[[0.5237172245979309, 0.8508185744285583], + [0.5237172245979309, 0.8508185744285583], + [1.2344461679458618, 2.3844780921936035]], + [[1.0474344491958618, 1.7016371488571167], + [0.6809444427490234, 0.9663378596305847], + [1.0474344491958618, 1.7016371488571167]]], + rtol=1e-6, + ) + + +class RelativePositionBiasesTest(absltest.TestCase): + + def setUp(self): + self.num_heads = 3 + self.query_len = 5 + self.key_len = 7 + self.relative_attention = layers.RelativePositionBiases( + num_buckets=12, + max_distance=10, + num_heads=3, + dtype=jnp.float32, + ) + super(RelativePositionBiasesTest, self).setUp() + + def test_relative_attention_bidirectional_params(self): + """Tests that bidirectional relative position biases have expected params.""" + params = self.relative_attention.init( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True) + param_shapes = jax.tree_map(lambda x: x.shape, params) + self.assertEqual( + param_shapes, { + 'params': { + 'rel_embedding': (3, 12), + }, + 'params_axes': { + 'rel_embedding_axes': + AxisMetadata(names=('heads', 'relpos_buckets')), + } + }) + + def test_regression_relative_attention_bidirectional_values(self): + """Tests that bidirectional relative position biases match expected values. + + See top docstring note on matching T5X behavior for these regression tests. + """ + outputs, unused_params = self.relative_attention.init_with_output( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True) + self.assertEqual(outputs.shape, + (1, self.num_heads, self.query_len, self.key_len)) + self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) + self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) + self.assertAlmostEqual(outputs[0, 1, 4, 6], 0.14510104, places=5) + self.assertAlmostEqual(outputs[0, 2, 4, 6], -0.36783996, places=5) + + def test_relative_attention_unidirectional_params(self): + """Tests that unidirectional relative position biases have expected params.""" + params = self.relative_attention.init( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False) + param_shapes = jax.tree_map(lambda x: x.shape, params) + self.assertEqual( + param_shapes, { + 'params': { + 'rel_embedding': (3, 12), + }, + 'params_axes': { + 'rel_embedding_axes': + AxisMetadata(names=('heads', 'relpos_buckets')), + } + }) + + def test_regression_relative_attention_unidirectional_values(self): + """Tests that unidirectional relative position biases match expected values. + + See top docstring note on matching T5X behavior for these regression tests. + """ + outputs, unused_params = self.relative_attention.init_with_output( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False) + self.assertEqual(outputs.shape, + (1, self.num_heads, self.query_len, self.key_len)) + self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) + self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) + self.assertAlmostEqual(outputs[0, 1, 4, 6], -0.13101986, places=5) + self.assertAlmostEqual(outputs[0, 2, 4, 6], 0.39296466, places=5) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/examples/scalable_t5/local_tiny.gin b/t5x/examples/scalable_t5/local_tiny.gin new file mode 100644 index 0000000000000000000000000000000000000000..3d7b28429a920a5a595afdf897f9525a8f9e1487 --- /dev/null +++ b/t5x/examples/scalable_t5/local_tiny.gin @@ -0,0 +1,70 @@ +# A gin file to make the Transformer models tiny for faster local testing. +# +# When testing locally with CPU, there are a few things that we need. +# - tiny model size +# - small enough batch size +# - small sequence length +# - determinstic dataset pipeline +# +# This gin file adds such configs. To use this gin file, add it on top of the +# existing full-scale gin files. The ordering of the gin file matters. So this +# should be added after all the other files are added to override the same +# configurables. + +from __gin__ import dynamic_registration + +from t5x import partitioning +from t5x import trainer +from t5x import utils +from t5x.examples.t5 import network + +import __main__ as train_script + +train_script.train.random_seed = 42 # dropout seed +train/utils.DatasetConfig.seed = 42 # dataset seed + +TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 7} +LABEL_SMOOTHING = 0.0 + +# Network specification overrides +network.Transformer.config = @network.T5Config() +network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 8 + num_heads = 4 + num_encoder_layers = 2 + num_decoder_layers = 2 + head_dim = 3 + mlp_dim = 16 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.0 + logits_via_embedding = False + scan_layers = True + remat_policy = 'minimal' + +TRAIN_STEPS = 3 + +train/utils.DatasetConfig: + batch_size = 8 + shuffle = False + +train_eval/utils.DatasetConfig.batch_size = 8 + +train_script.train: + eval_period = 3 + eval_steps = 3 + +trainer.Trainer.num_microbatches = 0 +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + +utils.CheckpointConfig: + restore = None + +infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS + + +# DISABLE INFERENCE EVAL +# train_script.train.infer_eval_dataset_cfg = None diff --git a/t5x/examples/scalable_t5/network.py b/t5x/examples/scalable_t5/network.py new file mode 100644 index 0000000000000000000000000000000000000000..a15cb0462e13234b7d5c44530589b1bd030774a8 --- /dev/null +++ b/t5x/examples/scalable_t5/network.py @@ -0,0 +1,520 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5.1.1 Transformer model.""" + +from typing import Any, Sequence + +from flax import linen as nn +from flax import struct +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +from t5x.examples.scalable_t5 import layers + + +with_sharding_constraint = nn_partitioning.with_sharding_constraint +scan_with_axes = nn_partitioning.scan_with_axes +remat = nn_partitioning.remat +ScanIn = nn_partitioning.ScanIn + + +@struct.dataclass +class T5Config: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int + # Activation dtypes. + dtype: Any = jnp.float32 + emb_dim: int = 512 + num_heads: int = 8 + num_encoder_layers: int = 6 + num_decoder_layers: int = 6 + head_dim: int = 64 + mlp_dim: int = 2048 + # Activation functions are retrieved from Flax. + mlp_activations: Sequence[str] = ('relu',) + dropout_rate: float = 0.1 + # If `True`, the embedding weights are used in the decoder output layer. + logits_via_embedding: bool = False + # minimal, full, or none + remat_policy: str = 'none' + scan_layers: bool = True + param_scan_axis: int = 1 + + +class EncoderLayer(nn.Module): + """Transformer encoder layer.""" + config: T5Config + + @nn.compact + def __call__(self, inputs, encoder_mask=None, deterministic=False): + cfg = self.config + + # Relative position embedding as attention biases. + encoder_bias = layers.RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_heads=cfg.num_heads, + dtype=cfg.dtype, + embedding_init=nn.initializers.variance_scaling( + 1.0, 'fan_avg', 'uniform'), + name='relative_posemb')(inputs.shape[-2], inputs.shape[-2], True) + + # Attention block. + assert inputs.ndim == 3 + inputs = with_sharding_constraint(inputs, ('batch', 'length', 'embed')) + x = layers.LayerNorm( + dtype=cfg.dtype, name='pre_attention_layer_norm')( + inputs) + x = with_sharding_constraint(x, ('batch', 'length', 'embed')) + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + name='attention')( + x, x, encoder_mask, encoder_bias, deterministic=deterministic) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x + inputs + x = with_sharding_constraint(x, ('batch', 'length', 'embed')) + + # MLP block. + y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) + y = with_sharding_constraint(y, ('batch', 'length', 'embed')) + # [batch, length, emb_dim] -> [batch, length, emb_dim] + y = layers.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + name='mlp', + )(y, deterministic=deterministic) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y + x + y = with_sharding_constraint(y, ('batch', 'length', 'embed')) + + if cfg.scan_layers: + return y, None + else: + return y + + +class DecoderLayer(nn.Module): + """Transformer decoder layer that attends to the encoder.""" + config: T5Config + + @nn.compact + def __call__(self, + inputs, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + deterministic=False, + decode=False, + max_decode_length=None): + cfg = self.config + + # Relative position embedding as attention biases. + l = max_decode_length if decode and max_decode_length else inputs.shape[-2] + decoder_bias = layers.RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_heads=cfg.num_heads, + dtype=cfg.dtype, + embedding_init=nn.initializers.variance_scaling( + 1.0, 'fan_avg', 'uniform'), + name='relative_posemb')(l, l, False) + + inputs = with_sharding_constraint(inputs, ('batch', 'length', 'embed')) + + # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] + x = layers.LayerNorm( + dtype=cfg.dtype, name='pre_self_attention_layer_norm')( + inputs) + x = with_sharding_constraint(x, ('batch', 'length', 'embed')) + + # Self-attention block + x = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + name='self_attention')( + x, + x, + decoder_mask, + decoder_bias, + deterministic=deterministic, + decode=decode) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x + inputs + x = with_sharding_constraint(x, ('batch', 'length', 'embed')) + + # Encoder-Decoder block. + y = layers.LayerNorm( + dtype=cfg.dtype, name='pre_cross_attention_layer_norm')( + x) + y = with_sharding_constraint(y, ('batch', 'length', 'embed')) + y = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + name='encoder_decoder_attention')( + y, encoded, encoder_decoder_mask, deterministic=deterministic) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y + x + y = with_sharding_constraint(y, ('batch', 'length', 'embed')) + + # MLP block. + z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y) + z = with_sharding_constraint(z, ('batch', 'length', 'embed')) + z = layers.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + name='mlp', + )(z, deterministic=deterministic) + z = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + z, deterministic=deterministic) + z = z + y + z = with_sharding_constraint(z, ('batch', 'length', 'embed')) + + if cfg.scan_layers: + return z, None + else: + return z + + +class Encoder(nn.Module): + """A stack of encoder layers.""" + config: T5Config + shared_embedding: nn.Module + + @nn.compact + def __call__(self, + encoder_input_tokens, + encoder_mask=None, + deterministic=False): + cfg = self.config + assert encoder_input_tokens.ndim == 2 # [batch, length] + + # [batch, length] -> [batch, length, emb_dim] + x = self.shared_embedding(encoder_input_tokens.astype('int32')) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x.astype(cfg.dtype) + + BlockLayer = EncoderLayer + + if cfg.remat_policy not in (None, 'none'): + if cfg.remat_policy == 'minimal': + policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + else: + policy = None + BlockLayer = remat( # pylint: disable=invalid-name + BlockLayer, + prevent_cse=not cfg.scan_layers, + policy=policy, + static_argnums=(2,)) + + if cfg.scan_layers: + initializing = self.is_mutable_collection('params') + params_spec = ( + cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)) + cache_spec = 0 + x, _ = scan_with_axes( + BlockLayer, + variable_axes={ + 'params': params_spec, + 'cache': cache_spec, + }, + split_rngs={ + 'params': True, + 'dropout': True + }, + in_axes=(nn.broadcast, nn.broadcast), + length=cfg.num_encoder_layers, + axis_name='layers')( + config=cfg, name='layers')(x, encoder_mask, deterministic) + else: + for lyr in range(cfg.num_encoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = BlockLayer( + config=cfg, name=f'layers_{lyr}')(x, encoder_mask, deterministic) + + x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) + return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) + + +class Decoder(nn.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture.""" + config: T5Config + shared_embedding: nn.Module + + @nn.compact + def __call__(self, + encoded, + decoder_input_tokens, + decoder_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + deterministic=False, + decode=False, + max_decode_length=None): + cfg = self.config + assert decoder_input_tokens.ndim == 2 # [batch, len] + + # [batch, length] -> [batch, length, emb_dim] + y = self.shared_embedding(decoder_input_tokens.astype('int32')) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + BlockLayer = DecoderLayer + + if cfg.remat_policy not in (None, 'none'): + if cfg.remat_policy == 'minimal': + policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + else: + policy = None + BlockLayer = remat( # pylint: disable=invalid-name + BlockLayer, + prevent_cse=not cfg.scan_layers, + policy=policy, + static_argnums=(4, 5, 6)) + if cfg.scan_layers: + initializing = self.is_mutable_collection('params') + params_spec = ( + cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)) + cache_spec = 0 + y, _ = scan_with_axes( + BlockLayer, + variable_axes={ + 'params': params_spec, + 'cache': cache_spec + }, + split_rngs={ + 'params': True, + 'dropout': True + }, + in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, + nn.broadcast, nn.broadcast), + length=cfg.num_decoder_layers, + axis_name='layers')( + config=cfg, name='layers')( + y, encoded, decoder_mask, encoder_decoder_mask, + deterministic, decode, max_decode_length) + else: + for lyr in range(cfg.num_decoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + y = BlockLayer( + config=cfg, name=f'layers_{lyr}')( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + deterministic=deterministic, + decode=decode, + max_decode_length=max_decode_length) + + y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + + # [batch, length, emb_dim] -> [batch, length, vocab_size] + if cfg.logits_via_embedding: + # Use the transpose of embedding matrix for logit transform. + logits = self.shared_embedding.attend(y) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + else: + logits = layers.DenseGeneral( + cfg.vocab_size, + dtype=jnp.float32, # Use float32 for stabiliity. + kernel_axes=('embed', 'vocab'), + name='logits_dense')( + y) + return logits + + +class Transformer(nn.Module): + """An encoder-decoder Transformer model.""" + config: T5Config + # needed only for janky models.py scan_layers detection. + scan_layers: bool = struct.field(init=False) + + def __post_init__(self): + super().__post_init__() + # needed only for janky models.py scan_layers detection. + object.__setattr__(self, 'scan_layers', + object.__getattribute__(self, 'config').scan_layers) + + def setup(self): + cfg = self.config + self.shared_embedding = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + attend_dtype=jnp.float32, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + one_hot=True, + name='token_embedder') + + self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) + self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) + + def encode(self, + encoder_input_tokens, + encoder_segment_ids=None, + enable_dropout=True): + """Applies Transformer encoder-branch on the inputs.""" + cfg = self.config + assert encoder_input_tokens.ndim == 2 # (batch, len) + + # Make padding attention mask. + encoder_mask = layers.make_attention_mask( + encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) + # Add segmentation block-diagonal attention mask if using segmented data. + if encoder_segment_ids is not None: + encoder_mask = layers.combine_masks( + encoder_mask, + layers.make_attention_mask( + encoder_segment_ids, + encoder_segment_ids, + jnp.equal, + dtype=cfg.dtype)) + + return self.encoder( + encoder_input_tokens, encoder_mask, deterministic=not enable_dropout) + + def decode( + self, + encoded, + encoder_input_tokens, # only needed for masks + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + decoder_positions=None, + enable_dropout=True, + decode=False, + max_decode_length=None): + """Applies Transformer decoder-branch on encoded-input and target.""" + cfg = self.config + + # Make padding attention masks. + if decode: + # Do not mask decoder attention based on targets padding at + # decoding/inference time. + decoder_mask = None + encoder_decoder_mask = layers.make_attention_mask( + jnp.ones_like(decoder_target_tokens), + encoder_input_tokens > 0, + dtype=cfg.dtype) + else: + decoder_mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=cfg.dtype, + decoder_segment_ids=decoder_segment_ids) + encoder_decoder_mask = layers.make_attention_mask( + decoder_target_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) + + # Add segmentation block-diagonal attention masks if using segmented data. + if encoder_segment_ids is not None: + if decode: + raise ValueError( + 'During decoding, packing should not be used but ' + '`encoder_segment_ids` was passed to `Transformer.decode`.') + + encoder_decoder_mask = layers.combine_masks( + encoder_decoder_mask, + layers.make_attention_mask( + decoder_segment_ids, + encoder_segment_ids, + jnp.equal, + dtype=cfg.dtype)) + + logits = self.decoder( + encoded, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + deterministic=not enable_dropout, + decode=decode, + max_decode_length=max_decode_length) + return logits + + def __call__(self, + encoder_input_tokens, + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + encoder_positions=None, + decoder_positions=None, + *, + enable_dropout: bool = True, + decode: bool = False): + """Applies Transformer model on the inputs. + + This method requires both decoder_target_tokens and decoder_input_tokens, + which is a shifted version of the former. For a packed dataset, it usually + has additional processing applied. For example, the first element of each + sequence has id 0 instead of the shifted EOS id from the previous sequence. + + Args: + encoder_input_tokens: input data to the encoder. + decoder_input_tokens: input token to the decoder. + decoder_target_tokens: target token to the decoder. + encoder_segment_ids: encoder segmentation info for packed examples. + decoder_segment_ids: decoder segmentation info for packed examples. + encoder_positions: encoder subsequence positions for packed examples. + decoder_positions: decoder subsequence positions for packed examples. + enable_dropout: Ensables dropout if set to True. + decode: Whether to prepare and use an autoregressive cache. + + Returns: + logits array from full transformer. + """ + encoded = self.encode( + encoder_input_tokens, + encoder_segment_ids=encoder_segment_ids, + enable_dropout=enable_dropout) + + return self.decode( + encoded, + encoder_input_tokens, # only used for masks + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=encoder_segment_ids, + decoder_segment_ids=decoder_segment_ids, + decoder_positions=decoder_positions, + enable_dropout=enable_dropout, + decode=decode) diff --git a/t5x/examples/scalable_t5/t5_1_1/__init__.py b/t5x/examples/scalable_t5/t5_1_1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/scalable_t5/t5_1_1/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/scalable_t5/t5_1_1/base.gin b/t5x/examples/scalable_t5/t5_1_1/base.gin new file mode 100644 index 0000000000000000000000000000000000000000..ebab93a6792375c3a58daff9cb0a27deff4ea1bb --- /dev/null +++ b/t5x/examples/scalable_t5/t5_1_1/base.gin @@ -0,0 +1,57 @@ +# T5.1.1 Base model. +from __gin__ import dynamic_registration + +import seqio +from t5x import adafactor +from t5x import models +from t5x.examples.scalable_t5 import network + +# ------------------- Loss HParam ---------------------------------------------- +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) +# the loss normalizing factor should be set to pretraining batch_size * +# target_token_length. +LOSS_NORMALIZING_FACTOR = None +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED + +# Vocabulary (shared by encoder and decoder) +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" + +# ------------------- Optimizer ------------------------------------------------ +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + logical_factor_rules = @adafactor.standard_logical_factor_rules() + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.EncoderDecoderModel() +models.EncoderDecoderModel: + module = @network.Transformer() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# ------------------- Network specification ------------------------------------ +network.Transformer.config = @network.T5Config() +network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 768 + num_heads = 12 + num_encoder_layers = 12 + num_decoder_layers = 12 + head_dim = 64 + mlp_dim = 2048 + mlp_activations = ('gelu', 'linear') + dropout_rate = %DROPOUT_RATE + logits_via_embedding = False + scan_layers = True + remat_policy = 'minimal' diff --git a/t5x/examples/scalable_t5/t5_1_1/examples/__init__.py b/t5x/examples/scalable_t5/t5_1_1/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/scalable_t5/t5_1_1/examples/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/scalable_t5/t5_1_1/examples/wmt19_ende_from_scratch.gin b/t5x/examples/scalable_t5/t5_1_1/examples/wmt19_ende_from_scratch.gin new file mode 100644 index 0000000000000000000000000000000000000000..1d75be863781c66a324b47c311ac8ac04c205da7 --- /dev/null +++ b/t5x/examples/scalable_t5/t5_1_1/examples/wmt19_ende_from_scratch.gin @@ -0,0 +1,62 @@ +from __gin__ import dynamic_registration + +import __main__ as train_script +from t5x import adafactor +from t5x import models +from t5x import partitioning +from t5x import trainer +from t5x import utils +from t5x.examples.scalable_t5 import network + +include "t5x/examples/scalable_t5/t5_1_1/base.gin" +include "t5x/configs/runs/finetune.gin" + +MIXTURE_OR_TASK_NAME = "wmt19_ende_v003" +MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" +TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} +TRAIN_STEPS = 25000 +LABEL_SMOOTHING = 0.1 +INITIAL_CHECKPOINT_PATH = None +# Note that `DROPOUT_RATE = 0.1` is specified in the finetune.gin but we just +# repeat to make it explicit. +DROPOUT_RATE = 0.1 + +train/utils.DatasetConfig: + batch_size = 128 + use_cached = False + pack = True + use_custom_packing_ops = False + seed = 0 + +train_eval/utils.DatasetConfig: + batch_size = 128 + use_cached = False + pack = False + use_custom_packing_ops = False + seed = 0 + +infer_eval/utils.DatasetConfig: + use_cached = False + +train_script.train: + eval_period = 250 + eval_steps = 20 + random_seed = 0 + use_hardware_rng = True + +utils.CheckpointConfig.restore = None +utils.SaveCheckpointConfig: + period = 500 # checkpoint frequency + keep = 1 + +# Decoder overrides +models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 + +trainer.Trainer.num_microbatches = 2 +utils.create_learning_rate_scheduler.warmup_steps = 1000 + +partitioning.PjitPartitioner: + model_parallel_submesh = (1, 1, 1, 2) + +adafactor.Adafactor: + logical_factor_rules = @adafactor.standard_logical_factor_rules() diff --git a/t5x/examples/scalable_t5/t5_1_1/large.gin b/t5x/examples/scalable_t5/t5_1_1/large.gin new file mode 100644 index 0000000000000000000000000000000000000000..b01f319967d7a1c39f58beaa45f012e2b65de9db --- /dev/null +++ b/t5x/examples/scalable_t5/t5_1_1/large.gin @@ -0,0 +1,13 @@ +# T5.1.1 Large model. + +include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 1024 + num_heads = 16 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 2816 diff --git a/t5x/examples/scalable_t5/t5_1_1/small.gin b/t5x/examples/scalable_t5/t5_1_1/small.gin new file mode 100644 index 0000000000000000000000000000000000000000..d1a8005c66994f2951c67fa65c91c1a77d86e576 --- /dev/null +++ b/t5x/examples/scalable_t5/t5_1_1/small.gin @@ -0,0 +1,13 @@ +# T5.1.1 Small model. + +include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 512 + num_heads = 6 + num_encoder_layers = 8 + num_decoder_layers = 8 + head_dim = 64 + mlp_dim = 1024 diff --git a/t5x/examples/scalable_t5/t5_1_1/xl.gin b/t5x/examples/scalable_t5/t5_1_1/xl.gin new file mode 100644 index 0000000000000000000000000000000000000000..d8d98b4d55eee083b17852042296233bf8c6bbc5 --- /dev/null +++ b/t5x/examples/scalable_t5/t5_1_1/xl.gin @@ -0,0 +1,13 @@ +# T5.1.1 XL model. + +include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 2048 + num_heads = 32 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 5120 diff --git a/t5x/examples/scalable_t5/t5_1_1/xxl.gin b/t5x/examples/scalable_t5/t5_1_1/xxl.gin new file mode 100644 index 0000000000000000000000000000000000000000..8ed37fe9209349dca2fe146675d05fb4a1f8eb8e --- /dev/null +++ b/t5x/examples/scalable_t5/t5_1_1/xxl.gin @@ -0,0 +1,13 @@ +# T5.1.1 XXL model. + +include 't5x/examples/scalable_t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 4096 + num_heads = 64 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 10240 diff --git a/t5x/examples/t5/README.md b/t5x/examples/t5/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bcabd31410b413909d05e5f0d8bd5f26d020e29c --- /dev/null +++ b/t5x/examples/t5/README.md @@ -0,0 +1,6 @@ +This directory contains model implementations for the T5-variants (T5.1.1, +T5.1.0, mT5, ByT5). All variants share the neural network implementation in +`network.py`, which has a minimal set of configurables in `TransformerConfig`. + +Refer to the [main +README](https://github.com/google-research/t5x/blob/main/README.md) for the example usages. diff --git a/t5x/examples/t5/__init__.py b/t5x/examples/t5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/t5/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/t5/byt5/__init__.py b/t5x/examples/t5/byt5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/t5/byt5/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/t5/byt5/base.gin b/t5x/examples/t5/byt5/base.gin new file mode 100644 index 0000000000000000000000000000000000000000..4e2122392c0c0aa57682bae062b1395ef451349d --- /dev/null +++ b/t5x/examples/t5/byt5/base.gin @@ -0,0 +1,54 @@ +# ByT5 Base model. +from __gin__ import dynamic_registration + +import seqio +from t5x import adafactor +from t5x import models +from t5x.examples.t5 import network + +# ------------------- Loss HParam ---------------------------------------------- +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) +# the loss normalizing factor should be set to pretraining batch_size * +# target_token_length. +LOSS_NORMALIZING_FACTOR = None +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED + +# Vocabulary (shared by encoder and decoder) +VOCABULARY = @seqio.ByteVocabulary() + +# ------------------- Optimizer ------------------------------------------------ +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + logical_factor_rules = @adafactor.standard_logical_factor_rules() + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.EncoderDecoderModel() +models.EncoderDecoderModel: + module = @network.Transformer() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# ------------------- Network specification ------------------------------------ +network.Transformer.config = @network.T5Config() +network.T5Config: + vocab_size = 384 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 1536 + num_heads = 12 + num_encoder_layers = 18 + num_decoder_layers = 6 + head_dim = 64 + mlp_dim = 3968 + mlp_activations = ('gelu', 'linear') + dropout_rate = %DROPOUT_RATE + logits_via_embedding = False diff --git a/t5x/examples/t5/byt5/large.gin b/t5x/examples/t5/byt5/large.gin new file mode 100644 index 0000000000000000000000000000000000000000..d4b8aaa3b42103877eb0bdb0fdb97d1b87c6f47a --- /dev/null +++ b/t5x/examples/t5/byt5/large.gin @@ -0,0 +1,13 @@ +# ByT5 Large model. + +include 't5x/examples/t5/byt5/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 1536 + num_heads = 16 + num_encoder_layers = 36 + num_decoder_layers = 12 + head_dim = 64 + mlp_dim = 3840 diff --git a/t5x/examples/t5/byt5/small.gin b/t5x/examples/t5/byt5/small.gin new file mode 100644 index 0000000000000000000000000000000000000000..11eeff1ab8a9caeca663432651ff2fdb9c8de7a4 --- /dev/null +++ b/t5x/examples/t5/byt5/small.gin @@ -0,0 +1,13 @@ +# ByT5 Small model. + +include 't5x/examples/t5/byt5/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 1472 + num_heads = 6 + num_encoder_layers = 12 + num_decoder_layers = 4 + head_dim = 64 + mlp_dim = 3584 diff --git a/t5x/examples/t5/byt5/tiny.gin b/t5x/examples/t5/byt5/tiny.gin new file mode 100644 index 0000000000000000000000000000000000000000..ed83eecd0b229ffd8b50561241e268d9cfc3ecfb --- /dev/null +++ b/t5x/examples/t5/byt5/tiny.gin @@ -0,0 +1,13 @@ +# T5.1.1 tiny model. + +include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 8 + num_heads = 4 + num_encoder_layers = 2 + num_decoder_layers = 2 + head_dim = 3 + mlp_dim = 16 diff --git a/t5x/examples/t5/byt5/xl.gin b/t5x/examples/t5/byt5/xl.gin new file mode 100644 index 0000000000000000000000000000000000000000..cbf38aaf51f525f0ac7e3870902fd43ed95a2574 --- /dev/null +++ b/t5x/examples/t5/byt5/xl.gin @@ -0,0 +1,13 @@ +# ByT5 XL model. + +include 't5x/examples/t5/byt5/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 2560 + num_heads = 32 + num_encoder_layers = 36 + num_decoder_layers = 12 + head_dim = 64 + mlp_dim = 6720 diff --git a/t5x/examples/t5/byt5/xxl.gin b/t5x/examples/t5/byt5/xxl.gin new file mode 100644 index 0000000000000000000000000000000000000000..24fa418f6664c84d3c27e376b68b384d9baace90 --- /dev/null +++ b/t5x/examples/t5/byt5/xxl.gin @@ -0,0 +1,13 @@ +# ByT5 XXL model. + +include 't5x/examples/t5/byt5/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 4672 + num_heads = 64 + num_encoder_layers = 36 + num_decoder_layers = 12 + head_dim = 64 + mlp_dim = 12352 diff --git a/t5x/examples/t5/layers.py b/t5x/examples/t5/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c202ad54a7e314c3ff6e48693c2bc646c1131e --- /dev/null +++ b/t5x/examples/t5/layers.py @@ -0,0 +1,867 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dense attention classes and mask/weighting functions.""" + +# pylint: disable=attribute-defined-outside-init,g-bare-generic + +import dataclasses +import functools +import operator +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union + +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +import jax +from jax import lax +from jax import random +import jax.numpy as jnp +import numpy as np + + +# from flax.linen.partitioning import param_with_axes, with_sharding_constraint +param_with_axes = nn_partitioning.param_with_axes +with_sharding_constraint = nn_partitioning.with_sharding_constraint + + +# Type annotations +Array = jnp.ndarray +DType = jnp.dtype +PRNGKey = jnp.ndarray +Shape = Iterable[int] +Activation = Callable[..., Array] +# Parameter initializers. +Initializer = Callable[[PRNGKey, Shape, DType], Array] + +default_embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0) + + +def dot_product_attention(query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0., + deterministic: bool = False, + dtype: DType = jnp.float32, + float32_logits: bool = False): + """Computes dot-product attention given query, key, and value. + + This is the core function for applying attention based on + https://arxiv.org/abs/1706.03762. It calculates the attention weights given + query and key and combines the values using the attention weights. + + Args: + query: queries for calculating attention with shape of `[batch, q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch, kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch, kv_length, + num_heads, v_depth_per_head]`. + bias: bias for the attention weights. This should be broadcastable to the + shape `[batch, num_heads, q_length, kv_length]` This can be used for + incorporating causal masks, padding masks, proximity bias, etc. + dropout_rng: JAX PRNGKey: to be used for dropout + dropout_rate: dropout rate + deterministic: bool, deterministic or not (to apply dropout) + dtype: the dtype of the computation (default: float32) + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + + Returns: + Output of shape `[batch, length, num_heads, v_depth_per_head]`. + """ + assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( + 'q, k, v batch dims must match.') + assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( + 'q, k, v num_heads must match.') + assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' + assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + + # Casting logits and softmax computation for float32 for model stability. + if float32_logits: + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) + + # `attn_weights`: [batch, num_heads, q_length, kv_length] + attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + + # Apply attention bias: masking, dropout, proximity bias, etc. + if bias is not None: + attn_weights = attn_weights + bias.astype(attn_weights.dtype) + + # Normalize the attention weights across `kv_length` dimension. + attn_weights = jax.nn.softmax(attn_weights).astype(dtype) + + # Apply attention dropout. + if not deterministic and dropout_rate > 0.: + keep_prob = 1.0 - dropout_rate + # T5 broadcasts along the "length" dim, but unclear which one that + # corresponds to in positional dimensions here, assuming query dim. + dropout_shape = list(attn_weights.shape) + dropout_shape[-2] = 1 + keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) + keep = jnp.broadcast_to(keep, attn_weights.shape) + multiplier = ( + keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) + attn_weights = attn_weights * multiplier + + # Take the linear combination of `value`. + return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) + + +dynamic_vector_slice_in_dim = jax.vmap( + lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) + + +class MultiHeadDotProductAttention(nn.Module): + """Multi-head dot-product attention. + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + """ + + num_heads: int + head_dim: int + dtype: DType = jnp.float32 + dropout_rate: float = 0. + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal') + float32_logits: bool = False # computes logits in float32 for stability. + + @nn.compact + def __call__(self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + decode: bool = False, + deterministic: bool = False) -> Array: + """Applies multi-head dot product attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. + + There are two modes: decoding and non-decoding (e.g., training). The mode is + determined by `decode` argument. For decoding, this method is called twice, + first to initialize the cache and then for an actual decoding process. The + two calls are differentiated by the presence of 'cached_key' in the variable + dict. In the cache initialization stage, the cache variables are initialized + as zeros and will be filled in the subsequent decoding process. + + In the cache initialization call, `inputs_q` has a shape [batch, length, + q_features] and `inputs_kv`: [batch, length, kv_features]. During the + incremental decoding stage, query, key and value all have the shape [batch, + 1, qkv_features] corresponding to a single step. + + Args: + inputs_q: input queries of shape `[batch, q_length, q_features]`. + inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. + mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. + bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. + decode: Whether to prepare and use an autoregressive cache. + deterministic: Disables dropout if set to True. + + Returns: + output of shape `[batch, length, q_features]`. + """ + projection = functools.partial( + DenseGeneral, + axis=-1, + features=(self.num_heads, self.head_dim), + kernel_axes=('embed', 'joined_kv'), + dtype=self.dtype) + + # NOTE: T5 does not explicitly rescale the attention logits by + # 1/sqrt(depth_kq)! This is folded into the initializers of the + # linear transformations, which is equivalent under Adafactor. + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + query_init = lambda *args: self.kernel_init(*args) / depth_scaling + + # Project inputs_q to multi-headed q/k/v + # dimensions are then [batch, length, num_heads, head_dim] + query = projection(kernel_init=query_init, name='query')(inputs_q) + key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) + value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) + + query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) + key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) + value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) + + if decode: + # Detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable('cache', 'cached_key') + # The key and value have dimension [batch, length, num_heads, head_dim], + # but we cache them as [batch, num_heads, head_dim, length] as a TPU + # fusion optimization. This also enables the "scatter via one-hot + # broadcast" trick, which means we do a one-hot broadcast instead of a + # scatter/gather operations, resulting in a 3-4x speedup in practice. + swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) + cached_key = self.variable('cache', 'cached_key', jnp.zeros, + swap_dims(key.shape), key.dtype) + cached_value = self.variable('cache', 'cached_value', jnp.zeros, + swap_dims(value.shape), value.dtype) + cache_index = self.variable('cache', 'cache_index', + lambda: jnp.array(0, dtype=jnp.int32)) + if is_initialized: + batch, num_heads, head_dim, length = (cached_key.value.shape) + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + # Sanity shape check of cached key against input query. + expected_shape = (batch, 1, num_heads, head_dim) + if expected_shape != query.shape: + raise ValueError('Autoregressive cache shape error, ' + 'expected query shape %s instead got %s.' % + (expected_shape, query.shape)) + + # Create a OHE of the current index. NOTE: the index is increased below. + cur_index = cache_index.value + one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) + # In order to update the key, value caches with the current key and + # value, we move the length axis to the back, similar to what we did for + # the cached ones above. + # Note these are currently the key and value of a single position, since + # we feed one position at a time. + one_token_key = jnp.moveaxis(key, -3, -1) + one_token_value = jnp.moveaxis(value, -3, -1) + # Update key, value caches with our new 1d spatial slices. + # We implement an efficient scatter into the cache via one-hot + # broadcast and addition. + key = cached_key.value + one_token_key * one_hot_indices + value = cached_value.value + one_token_value * one_hot_indices + cached_key.value = key + cached_value.value = value + cache_index.value = cache_index.value + 1 + # Move the keys and values back to their original shapes. + key = jnp.moveaxis(key, -1, -3) + value = jnp.moveaxis(value, -1, -3) + + # Causal mask for cached decoder self-attention: our single query + # position should only attend to those key positions that have already + # been generated and cached, not the remaining zero elements. + mask = combine_masks( + mask, + jnp.broadcast_to( + jnp.arange(length) <= cur_index, + # (1, 1, length) represent (head dim, query length, key length) + # query length is 1 because during decoding we deal with one + # index. + # The same mask is applied to all batch elements and heads. + (batch, 1, 1, length))) + + # Grab the correct relative attention bias during decoding. This is + # only required during single step decoding. + if bias is not None: + # The bias is a full attention matrix, but during decoding we only + # have to take a slice of it. + # This is equivalent to bias[..., cur_index:cur_index+1, :]. + bias = dynamic_vector_slice_in_dim( + jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) + + # Convert the boolean attention mask to an attention bias. + if mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + mask > 0, + jnp.full(mask.shape, 0.).astype(self.dtype), + jnp.full(mask.shape, -1e10).astype(self.dtype)) + else: + attention_bias = None + + # Add provided bias term (e.g. relative position embedding). + if bias is not None: + attention_bias = combine_biases(attention_bias, bias) + + dropout_rng = None + if not deterministic and self.dropout_rate > 0.: + dropout_rng = self.make_rng('dropout') + + # Apply attention. + x = dot_product_attention( + query, + key, + value, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout_rate, + deterministic=deterministic, + dtype=self.dtype, + float32_logits=self.float32_logits) + + # Back to the original inputs dimensions. + out = DenseGeneral( + features=inputs_q.shape[-1], # output dim is set to the input dim. + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=('joined_kv', 'embed'), + dtype=self.dtype, + name='out')( + x) + return out + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +#------------------------------------------------------------------------------ +# DenseGeneral for attention layers. +#------------------------------------------------------------------------------ +class DenseGeneral(nn.Module): + """A linear transformation (without bias) with flexible axes. + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + """ + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + dtype: DType = jnp.float32 + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'truncated_normal') + kernel_axes: Tuple[str, ...] = () + + @nn.compact + def __call__(self, inputs: Array) -> Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features + kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), + np.prod(features)) + kernel = param_with_axes( + 'kernel', + self.kernel_init, + kernel_param_shape, + jnp.float32, + axes=self.kernel_axes) + kernel = jnp.asarray(kernel, self.dtype) + kernel = jnp.reshape(kernel, kernel_shape) + + contract_ind = tuple(range(0, len(axis))) + return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) + + +def _convert_to_activation_function( + fn_or_string: Union[str, Callable]) -> Callable: + """Convert a string to an activation function.""" + if fn_or_string == 'linear': + return lambda x: x + elif isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + elif callable(fn_or_string): + return fn_or_string + else: + raise ValueError("don't know how to convert %s to an activation function" % + (fn_or_string,)) + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block. + + Attributes: + intermediate_dim: Shared dimension of hidden layers. + activations: Type of activations for each layer. Each element is either + 'linear', a string function name in flax.linen, or a function. + kernel_init: Kernel function, passed to the dense layers. + deterministic: Whether the dropout layers should be deterministic. + intermediate_dropout_rate: Dropout rate used after the intermediate layers. + dtype: Type for the dense layer. + """ + intermediate_dim: int = 2048 + activations: Sequence[Union[str, Callable]] = ('relu',) + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'truncated_normal') + intermediate_dropout_rate: float = 0.1 + dtype: Any = jnp.float32 + + @nn.compact + def __call__(self, inputs, decode: bool = False, deterministic: bool = False): + """Applies Transformer MlpBlock module.""" + # Iterate over specified MLP input activation functions. + # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. + activations = [] + for idx, act_fn in enumerate(self.activations): + dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' + x = DenseGeneral( + self.intermediate_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=('embed', 'mlp'), + name=dense_name)( + inputs) + x = _convert_to_activation_function(act_fn)(x) + activations.append(x) + + # Take elementwise product of above intermediate activations. + x = functools.reduce(operator.mul, activations) + # Apply dropout and final dense output projection. + x = nn.Dropout( + rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) # Broadcast along length. + x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) + output = DenseGeneral( + inputs.shape[-1], + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=('mlp', 'embed'), + name='wo')( + x) + return output + + +class Embed(nn.Module): + """A parameterized function from integers [0, n) to d-dimensional vectors. + + Attributes: + num_embeddings: number of embeddings. + features: number of feature dimensions for each embedding. + dtype: the dtype of the embedding vectors (default: float32). + embedding_init: embedding initializer. + one_hot: performs the gather with a one-hot contraction rather than a true + gather. This is currently needed for SPMD partitioning. + """ + num_embeddings: int + features: int + cast_input_dtype: Optional[DType] = None + dtype: DType = jnp.float32 + attend_dtype: Optional[DType] = None + embedding_init: Initializer = default_embed_init + one_hot: bool = False + embedding: Array = dataclasses.field(init=False) + + def setup(self): + self.embedding = param_with_axes( + 'embedding', + self.embedding_init, (self.num_embeddings, self.features), + jnp.float32, + axes=('vocab', 'embed')) + + def __call__(self, inputs: Array) -> Array: + """Embeds the inputs along the last dimension. + + Args: + inputs: input data, all dimensions are considered batch dimensions. + + Returns: + Output which is embedded input data. The output shape follows the input, + with an additional `features` dimension appended. + """ + if self.cast_input_dtype: + inputs = inputs.astype(self.cast_input_dtype) + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError('Input type must be an integer or unsigned integer.') + if self.one_hot: + iota = lax.iota(jnp.int32, self.num_embeddings) + one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) + output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) + else: + output = jnp.asarray(self.embedding, self.dtype)[inputs] + output = with_sharding_constraint(output, ('batch', 'length', 'embed')) + return output + + def attend(self, query: Array) -> Array: + """Attend over the embedding using a query array. + + Args: + query: array with last dimension equal the feature depth `features` of the + embedding. + + Returns: + An array with final dim `num_embeddings` corresponding to the batched + inner-product of the array of query vectors against each embedding. + Commonly used for weight-sharing between embeddings and logit transform + in NLP models. + """ + dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype + return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) + + +class RelativePositionBiases(nn.Module): + """Adds T5-style relative positional embeddings to the attention logits. + + Attributes: + num_buckets: Number of buckets to bucket distances between key and query + positions into. + max_distance: Maximum distance before everything is lumped into the last + distance bucket. + num_heads: Number of heads in the attention layer. Each head will get a + different relative position weighting. + dtype: Type of arrays through this module. + embedding_init: initializer for relative embedding table. + """ + num_buckets: int + max_distance: int + num_heads: int + dtype: Any + embedding_init: Callable[..., Array] = nn.linear.default_embed_init + + @staticmethod + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """Translate relative position to a bucket number for relative attention. + + The relative position is defined as memory_position - query_position, i.e. + the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are + invalid. + We use smaller buckets for small absolute relative_position and larger + buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative + positions <=-max_distance map to the same bucket. This should allow for + more graceful generalization to longer sequences than the model has been + trained on. + + Args: + relative_position: an int32 array + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ + ret = 0 + n = -relative_position + if bidirectional: + num_buckets //= 2 + ret += (n < 0).astype(np.int32) * num_buckets + n = np.abs(n) + else: + n = np.maximum(n, 0) + # now n is in the range [0, inf) + max_exact = num_buckets // 2 + is_small = (n < max_exact) + val_if_large = max_exact + ( + np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / + np.log(max_distance / max_exact) * + (num_buckets - max_exact)).astype(np.int32) + val_if_large = np.minimum(val_if_large, num_buckets - 1) + ret += np.where(is_small, n, val_if_large) + return ret + + @nn.compact + def __call__(self, qlen, klen, bidirectional=True): + """Produce relative position embedding attention biases. + + Args: + qlen: attention query length. + klen: attention key length. + bidirectional: whether to allow positive memory-query relative position + embeddings. + + Returns: + output: `(1, len, q_len, k_len)` attention bias + """ + # TODO(levskaya): should we be computing this w. numpy as a program + # constant? + context_position = np.arange(qlen, dtype=jnp.int32)[:, None] + memory_position = np.arange(klen, dtype=jnp.int32)[None, :] + relative_position = memory_position - context_position # shape (qlen, klen) + rp_bucket = self._relative_position_bucket( + relative_position, + bidirectional=bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance) + relative_attention_bias = param_with_axes( + 'rel_embedding', + self.embedding_init, (self.num_heads, self.num_buckets), + jnp.float32, + axes=('heads', 'relpos_buckets')) + + relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) + # Instead of using a slow gather, we create a leading-dimension one-hot + # array from rp_bucket and use it to perform the gather-equivalent via a + # contraction, i.e.: + # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). + # This is equivalent to relative_attention_bias[:, rp_bucket] + bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) + rp_bucket_one_hot = jnp.array( + rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) + # --> shape (qlen, klen, num_heads) + values = lax.dot_general( + relative_attention_bias, + rp_bucket_one_hot, + ( + ((1,), (0,)), # rhs, lhs contracting dims + ((), ()))) # no batched dims + # Add a singleton batch dimension. + # --> shape (1, num_heads, qlen, klen) + return values[jnp.newaxis, ...] + + +#------------------------------------------------------------------------------ +# T5 Layernorm - no subtraction of mean or bias. +#------------------------------------------------------------------------------ +class LayerNorm(nn.Module): + """T5 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 + dtype: Any = jnp.float32 + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Applies layer normalization on the input.""" + x = jnp.asarray(x, jnp.float32) + features = x.shape[-1] + mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) + scale = param_with_axes( + 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',)) + + scale = jnp.asarray(scale, self.dtype) + return y * scale + + +#------------------------------------------------------------------------------ +# Mask-making utility functions. +#------------------------------------------------------------------------------ +def make_attention_mask(query_input: Array, + key_input: Array, + pairwise_fn: Callable = jnp.multiply, + extra_batch_dims: int = 0, + dtype: DType = jnp.float32) -> Array: + """Mask-making helper for attention weights. + + In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the + attention weights will be `[batch, heads, len_q, len_kv]` and this + function will produce `[batch, 1, len_q, len_kv]`. + + Args: + query_input: a batched, flat input of query_length size + key_input: a batched, flat input of key_length size + pairwise_fn: broadcasting elementwise comparison function + extra_batch_dims: number of extra batch dims to add singleton axes for, none + by default + dtype: mask return dtype + + Returns: + A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. + """ + # [batch, len_q, len_kv] + mask = pairwise_fn( + # [batch, len_q] -> [batch, len_q, 1] + jnp.expand_dims(query_input, axis=-1), + # [batch, len_q] -> [batch, 1, len_kv] + jnp.expand_dims(key_input, axis=-2)) + + # [batch, 1, len_q, len_kv]. This creates the head dim. + mask = jnp.expand_dims(mask, axis=-3) + mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) + return mask.astype(dtype) + + +def make_causal_mask(x: Array, + extra_batch_dims: int = 0, + dtype: DType = jnp.float32) -> Array: + """Make a causal mask for self-attention. + + In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights + will be `[batch, heads, len, len]` and this function will produce a + causal mask of shape `[batch, 1, len, len]`. + + Note that a causal mask does not depend on the values of x; it only depends on + the shape. If x has padding elements, they will not be treated in a special + manner. + + Args: + x: input array of shape `[batch, len]` + extra_batch_dims: number of batch dims to add singleton axes for, none by + default + dtype: mask return dtype + + Returns: + A `[batch, 1, len, len]` shaped causal mask for 1d attention. + """ + idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) + return make_attention_mask( + idxs, + idxs, + jnp.greater_equal, + extra_batch_dims=extra_batch_dims, + dtype=dtype) + + +def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): + """Combine attention masks. + + Args: + *masks: set of attention mask arguments to combine, some can be None. + dtype: final mask dtype + + Returns: + Combined mask, reduced by logical and, returns None if no masks given. + """ + masks = [m for m in masks if m is not None] + if not masks: + return None + assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( + f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + mask, *other_masks = masks + for other_mask in other_masks: + mask = jnp.logical_and(mask, other_mask) + return mask.astype(dtype) + + +def combine_biases(*masks: Optional[Array]): + """Combine attention biases. + + Args: + *masks: set of attention bias arguments to combine, some can be None. + + Returns: + Combined mask, reduced by summation, returns None if no masks given. + """ + masks = [m for m in masks if m is not None] + if not masks: + return None + assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( + f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') + mask, *other_masks = masks + for other_mask in other_masks: + mask = mask + other_mask + return mask + + +def make_decoder_mask(decoder_target_tokens: Array, + dtype: DType, + decoder_causal_attention: Optional[Array] = None, + decoder_segment_ids: Optional[Array] = None) -> Array: + """Compute the self-attention mask for a decoder. + + Decoder mask is formed by combining a causal mask, a padding mask and an + optional packing mask. If decoder_causal_attention is passed, it makes the + masking non-causal for positions that have value of 1. + + A prefix LM is applied to a dataset which has a notion of "inputs" and + "targets", e.g., a machine translation task. The inputs and targets are + concatenated to form a new target. `decoder_target_tokens` is the concatenated + decoder output tokens. + + The "inputs" portion of the concatenated sequence can attend to other "inputs" + tokens even for those at a later time steps. In order to control this + behavior, `decoder_causal_attention` is necessary. This is a binary mask with + a value of 1 indicating that the position belonged to "inputs" portion of the + original dataset. + + Example: + + Suppose we have a dataset with two examples. + + ds = [{"inputs": [6, 7], "targets": [8]}, + {"inputs": [3, 4], "targets": [5]}] + + After the data preprocessing with packing, the two examples are packed into + one example with the following three fields (some fields are skipped for + simplicity). + + decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] + decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] + decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] + + where each array has [batch, length] shape with batch size being 1. Then, + this function computes the following mask. + + mask = [[[[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + + mask[b, 1, :, :] represents the mask for the example `b` in the batch. + Because mask is for a self-attention layer, the mask's shape is a square of + shape [query length, key length]. + + mask[b, 1, i, j] = 1 means that the query token at position i can attend to + the key token at position j. + + Args: + decoder_target_tokens: decoder output tokens. [batch, length] + dtype: dtype of the output mask. + decoder_causal_attention: a binary mask indicating which position should + only attend to earlier positions in the sequence. Others will attend + bidirectionally. [batch, length] + decoder_segment_ids: decoder segmentation info for packed examples. [batch, + length] + + Returns: + the combined decoder mask. + """ + masks = [] + # The same mask is applied to all attention heads. So the head dimension is 1, + # i.e., the mask will be broadcast along the heads dim. + # [batch, 1, length, length] + causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) + + # Positions with value 1 in `decoder_causal_attneition` can attend + # bidirectionally. + if decoder_causal_attention is not None: + # [batch, 1, length, length] + inputs_mask = make_attention_mask( + decoder_causal_attention, + decoder_causal_attention, + jnp.logical_and, + dtype=dtype) + masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) + else: + masks.append(causal_mask) + + # Padding mask. + masks.append( + make_attention_mask( + decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype)) + + # Packing mask + if decoder_segment_ids is not None: + masks.append( + make_attention_mask( + decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype)) + + return combine_masks(*masks, dtype=dtype) diff --git a/t5x/examples/t5/layers_test.py b/t5x/examples/t5/layers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..57f097ae53842982d5af6486a2bc2dab5dd21643 --- /dev/null +++ b/t5x/examples/t5/layers_test.py @@ -0,0 +1,620 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for attention classes.""" + +import dataclasses +from typing import Optional +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from flax import linen as nn +from flax.core import freeze +from flax.linen import partitioning as nn_partitioning +import jax +from jax import random +from jax.nn import initializers +import jax.numpy as jnp +import numpy as np +from t5x.examples.t5 import layers + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +Array = jnp.ndarray +AxisMetadata = nn_partitioning.AxisMetadata # pylint: disable=invalid-name + + +class SelfAttention(layers.MultiHeadDotProductAttention): + """Self-attention special case of multi-head dot-product attention.""" + + @nn.compact + def __call__(self, + inputs_q: Array, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + deterministic: bool = False): + return super().__call__( + inputs_q, inputs_q, mask, bias, deterministic=deterministic) + + +@dataclasses.dataclass(frozen=True) +class SelfAttentionArgs: + num_heads: int = 1 + batch_size: int = 2 + # qkv_features: int = 3 + head_dim: int = 3 + # out_features: int = 4 + q_len: int = 5 + features: int = 6 + dropout_rate: float = 0.1 + deterministic: bool = False + decode: bool = False + float32_logits: bool = False + + def __post_init__(self): + # If we are doing decoding, the query length should be 1, because are doing + # autoregressive decoding where we feed one position at a time. + assert not self.decode or self.q_len == 1 + + def init_args(self): + return dict( + num_heads=self.num_heads, + head_dim=self.head_dim, + dropout_rate=self.dropout_rate, + float32_logits=self.float32_logits) + + def apply_args(self): + inputs_q = jnp.ones((self.batch_size, self.q_len, self.features)) + mask = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) + bias = jnp.ones((self.batch_size, self.num_heads, self.q_len, self.q_len)) + return { + 'inputs_q': inputs_q, + 'mask': mask, + 'bias': bias, + 'deterministic': self.deterministic + } + + +class AttentionTest(parameterized.TestCase): + + def test_dot_product_attention_shape(self): + # This test only checks for shape but tries to make sure all code paths are + # reached. + dropout_rng = random.PRNGKey(0) + batch_size, num_heads, q_len, kv_len, qk_depth, v_depth = 1, 2, 3, 4, 5, 6 + + query = jnp.ones((batch_size, q_len, num_heads, qk_depth)) + key = jnp.ones((batch_size, kv_len, num_heads, qk_depth)) + value = jnp.ones((batch_size, kv_len, num_heads, v_depth)) + bias = jnp.ones((batch_size, num_heads, q_len, kv_len)) + + args = dict( + query=query, + key=key, + value=value, + bias=bias, + dropout_rng=dropout_rng, + dropout_rate=0.5, + deterministic=False, + ) + + output = layers.dot_product_attention(**args) + self.assertEqual(output.shape, (batch_size, q_len, num_heads, v_depth)) + + def test_make_attention_mask_multiply_pairwise_fn(self): + decoder_target_tokens = jnp.array([[7, 0, 0], [8, 5, 0]]) + attention_mask = layers.make_attention_mask( + decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=jnp.int32) + expected0 = jnp.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + expected1 = jnp.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]]) + self.assertEqual(attention_mask.shape, (2, 1, 3, 3)) + np.testing.assert_array_equal(attention_mask[0, 0], expected0) + np.testing.assert_array_equal(attention_mask[1, 0], expected1) + + def test_make_attention_mask_equal_pairwise_fn(self): + segment_ids = jnp.array([[1, 1, 2, 2, 2, 0], [1, 1, 1, 2, 0, 0]]) + attention_mask = layers.make_attention_mask( + segment_ids, segment_ids, pairwise_fn=jnp.equal, dtype=jnp.int32) + # Padding is not treated in a special way. So they need to be zeroed out + # separately. + expected0 = jnp.array([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0], [0, 0, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]]) + expected1 = jnp.array([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1]]) + self.assertEqual(attention_mask.shape, (2, 1, 6, 6)) + np.testing.assert_array_equal(attention_mask[0, 0], expected0) + np.testing.assert_array_equal(attention_mask[1, 0], expected1) + + def test_make_causal_mask_with_padding(self): + x = jnp.array([[7, 0, 0], [8, 5, 0]]) + y = layers.make_causal_mask(x) + self.assertEqual(y.shape, (2, 1, 3, 3)) + # Padding is not treated in a special way. So they need to be zeroed out + # separately. + expected_y = jnp.array([[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]], + jnp.float32) + np.testing.assert_allclose(y[0], expected_y) + np.testing.assert_allclose(y[1], expected_y) + + def test_make_causal_mask_extra_batch_dims(self): + x = jnp.ones((3, 3, 5)) + y = layers.make_causal_mask(x, extra_batch_dims=2) + self.assertEqual(y.shape, (1, 1, 3, 3, 1, 5, 5)) + + def test_make_causal_mask(self): + x = jnp.ones((1, 3)) + y = layers.make_causal_mask(x) + self.assertEqual(y.shape, (1, 1, 3, 3)) + expected_y = jnp.array([[[[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]]], + jnp.float32) + np.testing.assert_allclose(y, expected_y) + + def test_combine_masks(self): + masks = [ + jnp.array([0, 1, 0, 1], jnp.float32), None, + jnp.array([1, 1, 1, 1], jnp.float32), + jnp.array([1, 1, 1, 0], jnp.float32) + ] + y = layers.combine_masks(*masks) + np.testing.assert_allclose(y, jnp.array([0, 1, 0, 0], jnp.float32)) + + def test_combine_biases(self): + masks = [ + jnp.array([0, 1, 0, 1], jnp.float32), None, + jnp.array([0, 1, 1, 1], jnp.float32), + jnp.array([0, 1, 1, 0], jnp.float32) + ] + y = layers.combine_biases(*masks) + np.testing.assert_allclose(y, jnp.array([0, 3, 2, 2], jnp.float32)) + + def test_make_decoder_mask_lm_unpacked(self): + decoder_target_tokens = jnp.array([6, 7, 3, 0]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, dtype=jnp.float32) + expected_mask = jnp.array([[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], + [0, 0, 0, 0]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_lm_packed(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 5, 0]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_segment_ids=decoder_segment_ids) + expected_mask = jnp.array([[[[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_unpacked(self): + decoder_target_tokens = jnp.array([[5, 6, 7, 3, 4, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 1, 0, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask = jnp.array( + [[[[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0]]]], + dtype=jnp.float32) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_packed(self): + decoder_target_tokens = jnp.array([[5, 6, 7, 8, 3, 4, 0]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 2, 2, 2, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 1, 1, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + expected_mask = jnp.array([[[[1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]]) + np.testing.assert_array_equal(mask, expected_mask) + + def test_make_decoder_mask_prefix_lm_unpacked_multiple_elements(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 0], [4, 5, 0, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0], [1, 0, 0, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask0 = jnp.array([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], + [0, 0, 0, 0]]) + expected_mask1 = jnp.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], + [0, 0, 0, 0]]) + self.assertEqual(mask.shape, (2, 1, 4, 4)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + np.testing.assert_array_equal(mask[1, 0], expected_mask1) + + def test_make_decoder_mask_composite_causal_attention(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 0]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention) + expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0], [1, 1, 0, 0, 1, 1, 0], + [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]) + + self.assertEqual(mask.shape, (1, 1, 7, 7)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + + def test_make_decoder_mask_composite_causal_attention_packed(self): + decoder_target_tokens = jnp.array([[6, 7, 3, 4, 8, 9, 2, 3, 4]]) + decoder_segment_ids = jnp.array([[1, 1, 1, 1, 1, 1, 2, 2, 2]]) + decoder_causal_attention = jnp.array([[1, 1, 0, 0, 1, 1, 1, 1, 0]]) + mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=jnp.float32, + decoder_causal_attention=decoder_causal_attention, + decoder_segment_ids=decoder_segment_ids) + expected_mask0 = jnp.array([[1, 1, 0, 0, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1]]) + + self.assertEqual(mask.shape, (1, 1, 9, 9)) + np.testing.assert_array_equal(mask[0, 0], expected_mask0) + + @parameterized.parameters({'f': 20}, {'f': 22}) + def test_multihead_dot_product_attention(self, f): + # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim + b, q, h, d, k = 2, 3, 4, 5, 6 + + base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + np.random.seed(0) + inputs_q = np.random.randn(b, q, f) + inputs_kv = np.random.randn(b, k, f) + + # Projection: [b, q, f] -> [b, q, h, d] + # So the kernels have to be [f, h, d] + query_kernel = np.random.randn(f, h, d) + key_kernel = np.random.randn(f, h, d) + value_kernel = np.random.randn(f, h, d) + # `out` calculation: [b, q, h, d] -> [b, q, f] + # So kernel has to be [h, d, f] + out_kernel = np.random.randn(h, d, f) + + params = { + 'query': { + 'kernel': query_kernel.reshape(f, -1) + }, + 'key': { + 'kernel': key_kernel.reshape(f, -1) + }, + 'value': { + 'kernel': value_kernel.reshape(f, -1) + }, + 'out': { + 'kernel': out_kernel.reshape(-1, f) + } + } + y = layers.MultiHeadDotProductAttention(**args).apply( + {'params': freeze(params)}, inputs_q, inputs_kv) + + query = np.einsum('bqf,fhd->bqhd', inputs_q, query_kernel) + key = np.einsum('bkf,fhd->bkhd', inputs_kv, key_kernel) + value = np.einsum('bkf,fhd->bkhd', inputs_kv, value_kernel) + logits = np.einsum('bqhd,bkhd->bhqk', query, key) + weights = nn.softmax(logits, axis=-1) + combined_value = np.einsum('bhqk,bkhd->bqhd', weights, value) + y_expected = np.einsum('bqhd,hdf->bqf', combined_value, out_kernel) + np.testing.assert_allclose(y, y_expected, rtol=1e-5, atol=1e-5) + + def test_multihead_dot_product_attention_caching(self): + # b: batch, f: qkv_features, k: kv_len, h: num_head, d: head_dim + b, h, d, k = 2, 3, 4, 5 + f = h * d + + base_args = SelfAttentionArgs(num_heads=h, head_dim=d, dropout_rate=0) + args = base_args.init_args() + + cache = { + 'cached_key': np.zeros((b, h, d, k)), + 'cached_value': np.zeros((b, h, d, k)), + 'cache_index': np.array(0) + } + inputs_q = np.random.randn(b, 1, f) + inputs_kv = np.random.randn(b, 1, f) + + # Mock dense general such that q, k, v projections are replaced by simple + # reshaping. + def mock_dense_general(self, x, **kwargs): # pylint: disable=unused-argument + return x.reshape(b, -1, h, d) + + with mock.patch.object( + layers.DenseGeneral, '__call__', new=mock_dense_general): + _, mutated = layers.MultiHeadDotProductAttention(**args).apply( + {'cache': freeze(cache)}, + inputs_q, + inputs_kv, + decode=True, + mutable=['cache']) + updated_cache = mutated['cache'] + + # Perform the same mocked projection to generate the expected cache. + # (key|value): [b, 1, h, d] + key = mock_dense_general(None, inputs_kv) + value = mock_dense_general(None, inputs_kv) + + # cached_(key|value): [b, h, d, k] + cache['cached_key'][:, :, :, 0] = key[:, 0, :, :] + cache['cached_value'][:, :, :, 0] = value[:, 0, :, :] + cache['cache_index'] = np.array(1) + for name, array in cache.items(): + np.testing.assert_allclose(array, updated_cache[name]) + + def test_dot_product_attention(self): + # b: batch, f: emb_dim, q: q_len, k: kv_len, h: num_head, d: head_dim + b, q, h, d, k = 2, 3, 4, 5, 6 + np.random.seed(0) + query = np.random.randn(b, q, h, d) + key = np.random.randn(b, k, h, d) + value = np.random.randn(b, k, h, d) + bias = np.random.randn(b, h, q, k) + attn_out = layers.dot_product_attention(query, key, value, bias=bias) + logits = np.einsum('bqhd,bkhd->bhqk', query, key) + weights = jax.nn.softmax(logits + bias, axis=-1) + expected = np.einsum('bhqk,bkhd->bqhd', weights, value) + np.testing.assert_allclose(attn_out, expected, atol=1e-6) + + +class EmbeddingTest(parameterized.TestCase): + + def test_embedder_raises_exception_for_incorrect_input_type(self): + """Tests that inputs are integers and that an exception is raised if not.""" + embed = layers.Embed(num_embeddings=10, features=5) + inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + bad_inputs = inputs.astype(np.float32) + with self.assertRaisesRegex( + ValueError, 'Input type must be an integer or unsigned integer.'): + _ = embed.apply(variables, bad_inputs) + + @parameterized.named_parameters( + { + 'testcase_name': 'with_ones', + 'init_fn': jax.nn.initializers.ones, + 'num_embeddings': 10, + 'features': 5, + 'matrix_sum': 5 * 10, + }, { + 'testcase_name': 'with_zeros', + 'init_fn': jax.nn.initializers.zeros, + 'num_embeddings': 10, + 'features': 5, + 'matrix_sum': 0, + }) + def test_embedding_initializes_correctly(self, init_fn, num_embeddings, + features, matrix_sum): + """Tests if the Embed class initializes with the requested initializer.""" + embed = layers.Embed( + num_embeddings=num_embeddings, + features=features, + embedding_init=init_fn) + inputs = np.expand_dims(np.arange(5, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + embedding_matrix = variables['params']['embedding'] + self.assertEqual(int(np.sum(embedding_matrix)), matrix_sum) + + def test_embedding_matrix_shape(self): + """Tests that the embedding matrix has the right shape.""" + num_embeddings = 10 + features = 5 + embed = layers.Embed(num_embeddings=num_embeddings, features=features) + inputs = np.expand_dims(np.arange(features, dtype=np.int64), 1) + variables = embed.init(jax.random.PRNGKey(0), inputs) + embedding_matrix = variables['params']['embedding'] + self.assertEqual((num_embeddings, features), embedding_matrix.shape) + + def test_embedding_attend(self): + """Tests that attending with ones returns sum of embedding vectors.""" + features = 5 + embed = layers.Embed(num_embeddings=10, features=features) + inputs = np.array([[1]], dtype=np.int64) + variables = embed.init(jax.random.PRNGKey(0), inputs) + query = np.ones(features, dtype=np.float32) + result = embed.apply(variables, query, method=embed.attend) + expected = np.sum(variables['params']['embedding'], -1) + np.testing.assert_array_almost_equal(result, expected) + + +class DenseTest(parameterized.TestCase): + + def test_dense_general_no_bias(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 3)) + model = layers.DenseGeneral( + features=4, + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + self.assertEqual(y.shape, (1, 4)) + np.testing.assert_allclose(y, np.full((1, 4), 3.)) + + def test_dense_general_two_features(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 3)) + model = layers.DenseGeneral( + features=(2, 2), + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + # We transform the last input dimension to two output dimensions (2, 2). + np.testing.assert_allclose(y, np.full((1, 2, 2), 3.)) + + def test_dense_general_two_axes(self): + rng = random.PRNGKey(0) + x = jnp.ones((1, 2, 2)) + model = layers.DenseGeneral( + features=3, + axis=(-2, 2), # Note: this is the same as (1, 2). + kernel_init=initializers.ones, + ) + y, _ = model.init_with_output(rng, x) + # We transform the last two input dimensions (2, 2) to one output dimension. + np.testing.assert_allclose(y, np.full((1, 3), 4.)) + + def test_mlp_same_out_dim(self): + module = layers.MlpBlock( + intermediate_dim=4, + activations=('relu',), + kernel_init=nn.initializers.xavier_uniform(), + dtype=jnp.float32, + ) + inputs = np.array( + [ + # Batch 1. + [[1, 1], [1, 1], [1, 2]], + # Batch 2. + [[2, 2], [3, 1], [2, 2]], + ], + dtype=np.float32) + params = module.init(random.PRNGKey(0), inputs, deterministic=True) + self.assertEqual( + jax.tree_map(lambda a: a.tolist(), params), { + 'params': { + 'wi': { + 'kernel': [[ + -0.8675811290740967, 0.08417510986328125, + 0.022586345672607422, -0.9124102592468262 + ], + [ + -0.19464373588562012, 0.49809837341308594, + 0.7808468341827393, 0.9267289638519287 + ]], + }, + 'wo': { + 'kernel': [[0.01154780387878418, 0.1397249698638916], + [0.974980354309082, 0.5903260707855225], + [-0.05997943878173828, 0.616570234298706], + [0.2934272289276123, 0.8181164264678955]], + }, + }, + 'params_axes': { + 'wi': { + 'kernel_axes': AxisMetadata(names=('embed', 'mlp')), + }, + 'wo': { + 'kernel_axes': AxisMetadata(names=('mlp', 'embed')), + }, + }, + }) + result = module.apply(params, inputs, deterministic=True) + np.testing.assert_allclose( + result.tolist(), + [[[0.5237172245979309, 0.8508185744285583], + [0.5237172245979309, 0.8508185744285583], + [1.2344461679458618, 2.3844780921936035]], + [[1.0474344491958618, 1.7016371488571167], + [0.6809444427490234, 0.9663378596305847], + [1.0474344491958618, 1.7016371488571167]]], + rtol=1e-6, + ) + + +class RelativePositionBiasesTest(absltest.TestCase): + + def setUp(self): + self.num_heads = 3 + self.query_len = 5 + self.key_len = 7 + self.relative_attention = layers.RelativePositionBiases( + num_buckets=12, + max_distance=10, + num_heads=3, + dtype=jnp.float32, + ) + super(RelativePositionBiasesTest, self).setUp() + + def test_relative_attention_bidirectional_params(self): + """Tests that bidirectional relative position biases have expected params.""" + params = self.relative_attention.init( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True) + param_shapes = jax.tree_map(lambda x: x.shape, params) + self.assertEqual( + param_shapes, { + 'params': { + 'rel_embedding': (3, 12), + }, + 'params_axes': { + 'rel_embedding_axes': + AxisMetadata(names=('heads', 'relpos_buckets')), + } + }) + + def test_regression_relative_attention_bidirectional_values(self): + """Tests that bidirectional relative position biases match expected values. + + See top docstring note on matching T5X behavior for these regression tests. + """ + outputs, unused_params = self.relative_attention.init_with_output( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=True) + self.assertEqual(outputs.shape, + (1, self.num_heads, self.query_len, self.key_len)) + self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) + self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) + self.assertAlmostEqual(outputs[0, 1, 4, 6], 0.14510104, places=5) + self.assertAlmostEqual(outputs[0, 2, 4, 6], -0.36783996, places=5) + + def test_relative_attention_unidirectional_params(self): + """Tests that unidirectional relative position biases have expected params.""" + params = self.relative_attention.init( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False) + param_shapes = jax.tree_map(lambda x: x.shape, params) + self.assertEqual( + param_shapes, { + 'params': { + 'rel_embedding': (3, 12), + }, + 'params_axes': { + 'rel_embedding_axes': + AxisMetadata(names=('heads', 'relpos_buckets')), + } + }) + + def test_regression_relative_attention_unidirectional_values(self): + """Tests that unidirectional relative position biases match expected values. + + See top docstring note on matching T5X behavior for these regression tests. + """ + outputs, unused_params = self.relative_attention.init_with_output( + random.PRNGKey(0), self.query_len, self.key_len, bidirectional=False) + self.assertEqual(outputs.shape, + (1, self.num_heads, self.query_len, self.key_len)) + self.assertAlmostEqual(outputs[0, 0, 0, 0], 0.55764728, places=5) + self.assertAlmostEqual(outputs[0, 1, 2, 1], -0.10935841, places=5) + self.assertAlmostEqual(outputs[0, 1, 4, 6], -0.13101986, places=5) + self.assertAlmostEqual(outputs[0, 2, 4, 6], 0.39296466, places=5) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/examples/t5/local_tiny.gin b/t5x/examples/t5/local_tiny.gin new file mode 100644 index 0000000000000000000000000000000000000000..20844cfb39b220df80d936d147da7ba51cfde55e --- /dev/null +++ b/t5x/examples/t5/local_tiny.gin @@ -0,0 +1,68 @@ +# A gin file to make the Transformer models tiny for faster local testing. +# +# When testing locally with CPU, there are a few things that we need. +# - tiny model size +# - small enough batch size +# - small sequence length +# - determinstic dataset pipeline +# +# This gin file adds such configs. To use this gin file, add it on top of the +# existing full-scale gin files. The ordering of the gin file matters. So this +# should be added after all the other files are added to override the same +# configurables. + +from __gin__ import dynamic_registration + +from t5x import partitioning +from t5x import trainer +from t5x import utils +from t5x.examples.t5 import network + +import __main__ as train_script + +train_script.train.random_seed = 42 # dropout seed +train/utils.DatasetConfig.seed = 42 # dataset seed + +TASK_FEATURE_LENGTHS = {"inputs": 8, "targets": 7} +LABEL_SMOOTHING = 0.0 + +# Network specification overrides +network.Transformer.config = @network.T5Config() +network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 8 + num_heads = 4 + num_encoder_layers = 2 + num_decoder_layers = 2 + head_dim = 3 + mlp_dim = 16 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.0 + logits_via_embedding = False + +TRAIN_STEPS = 3 + +train/utils.DatasetConfig: + batch_size = 8 + shuffle = False + +train_eval/utils.DatasetConfig.batch_size = 8 + +train_script.train: + eval_period = 3 + eval_steps = 3 + +trainer.Trainer.num_microbatches = 0 +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + +utils.CheckpointConfig: + restore = None + +infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS + + +# DISABLE INFERENCE EVAL +# train_script.train.infer_eval_dataset_cfg = None diff --git a/t5x/examples/t5/mt5/__init__.py b/t5x/examples/t5/mt5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/t5/mt5/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/t5/mt5/base.gin b/t5x/examples/t5/mt5/base.gin new file mode 100644 index 0000000000000000000000000000000000000000..6dabd2c05d7a5256c7837c0bcd0a73581d01d2e9 --- /dev/null +++ b/t5x/examples/t5/mt5/base.gin @@ -0,0 +1,55 @@ +# mT5 Base model. +from __gin__ import dynamic_registration + +import seqio +from t5x import adafactor +from t5x import models +from t5x.examples.t5 import network + +# ------------------- Loss HParam ---------------------------------------------- +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) +# the loss normalizing factor should be set to pretraining batch_size * +# target_token_length. +LOSS_NORMALIZING_FACTOR = None +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED + +# Vocabulary (shared by encoder and decoder) +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model" + +# ------------------- Optimizer ------------------------------------------------ +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + logical_factor_rules = @adafactor.standard_logical_factor_rules() + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.EncoderDecoderModel() +models.EncoderDecoderModel: + module = @network.Transformer() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# ------------------- Network specification ------------------------------------ +network.Transformer.config = @network.T5Config() +network.T5Config: + vocab_size = 250112 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 768 + num_heads = 12 + num_encoder_layers = 12 + num_decoder_layers = 12 + head_dim = 64 + mlp_dim = 2048 + mlp_activations = ('gelu', 'linear') + dropout_rate = %DROPOUT_RATE + logits_via_embedding = False diff --git a/t5x/examples/t5/mt5/large.gin b/t5x/examples/t5/mt5/large.gin new file mode 100644 index 0000000000000000000000000000000000000000..5b0ea1cd9243f3e9b072267a4530f501e2c3c06f --- /dev/null +++ b/t5x/examples/t5/mt5/large.gin @@ -0,0 +1,13 @@ +# mT5 Large model. + +include 't5x/examples/t5/mt5/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 1024 + num_heads = 16 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 2816 diff --git a/t5x/examples/t5/mt5/small.gin b/t5x/examples/t5/mt5/small.gin new file mode 100644 index 0000000000000000000000000000000000000000..e3f8192cab2016b1cd41b8c87ca8a78cb8b4cb64 --- /dev/null +++ b/t5x/examples/t5/mt5/small.gin @@ -0,0 +1,13 @@ +# mT5 Small model. + +include 't5x/examples/t5/mt5/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 512 + num_heads = 6 + num_encoder_layers = 8 + num_decoder_layers = 8 + head_dim = 64 + mlp_dim = 1024 diff --git a/t5x/examples/t5/mt5/tiny.gin b/t5x/examples/t5/mt5/tiny.gin new file mode 100644 index 0000000000000000000000000000000000000000..ed83eecd0b229ffd8b50561241e268d9cfc3ecfb --- /dev/null +++ b/t5x/examples/t5/mt5/tiny.gin @@ -0,0 +1,13 @@ +# T5.1.1 tiny model. + +include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 8 + num_heads = 4 + num_encoder_layers = 2 + num_decoder_layers = 2 + head_dim = 3 + mlp_dim = 16 diff --git a/t5x/examples/t5/mt5/xl.gin b/t5x/examples/t5/mt5/xl.gin new file mode 100644 index 0000000000000000000000000000000000000000..63178f5fc804346ac206797be6b6c5b0bf9a53c8 --- /dev/null +++ b/t5x/examples/t5/mt5/xl.gin @@ -0,0 +1,13 @@ +# mT5 XL model. + +include 't5x/examples/t5/mt5/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 2048 + num_heads = 32 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 5120 diff --git a/t5x/examples/t5/mt5/xxl.gin b/t5x/examples/t5/mt5/xxl.gin new file mode 100644 index 0000000000000000000000000000000000000000..e61a443d60cb92c6eb897f64cd5af1669a925129 --- /dev/null +++ b/t5x/examples/t5/mt5/xxl.gin @@ -0,0 +1,13 @@ +# mT5 XXL model. + +include 't5x/examples/t5/mt5/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 4096 + num_heads = 64 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 10240 diff --git a/t5x/examples/t5/network.py b/t5x/examples/t5/network.py new file mode 100644 index 0000000000000000000000000000000000000000..28bcbf17b55912b43f19443712a7aa83cccaaf11 --- /dev/null +++ b/t5x/examples/t5/network.py @@ -0,0 +1,424 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5.1.1 Transformer model.""" + +from typing import Any, Sequence + +from flax import linen as nn +from flax import struct +import jax.numpy as jnp +from t5x.examples.t5 import layers + + +@struct.dataclass +class T5Config: + """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int + # Activation dtypes. + dtype: Any = jnp.float32 + emb_dim: int = 512 + num_heads: int = 8 + num_encoder_layers: int = 6 + num_decoder_layers: int = 6 + head_dim: int = 64 + mlp_dim: int = 2048 + # Activation functions are retrieved from Flax. + mlp_activations: Sequence[str] = ('relu',) + dropout_rate: float = 0.1 + # If `True`, the embedding weights are used in the decoder output layer. + logits_via_embedding: bool = False + # Whether to accumulate attention logits in float32 regardless of dtype. + float32_attention_logits: bool = False + + +class EncoderLayer(nn.Module): + """Transformer encoder layer.""" + config: T5Config + relative_embedding: nn.Module + + @nn.compact + def __call__(self, inputs, encoder_mask=None, deterministic=False): + cfg = self.config + + # Relative position embedding as attention biases. + encoder_bias = self.relative_embedding(inputs.shape[-2], inputs.shape[-2], + True) + + # Attention block. + assert inputs.ndim == 3 + x = layers.LayerNorm( + dtype=cfg.dtype, name='pre_attention_layer_norm')( + inputs) + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + float32_logits=cfg.float32_attention_logits, + name='attention')( + x, x, encoder_mask, encoder_bias, deterministic=deterministic) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x + inputs + + # MLP block. + y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) + # [batch, length, emb_dim] -> [batch, length, emb_dim] + y = layers.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + name='mlp', + )(y, deterministic=deterministic) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y + x + + return y + + +class DecoderLayer(nn.Module): + """Transformer decoder layer that attends to the encoder.""" + config: T5Config + relative_embedding: nn.Module + + @nn.compact + def __call__(self, + inputs, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + deterministic=False, + decode=False, + max_decode_length=None): + cfg = self.config + + # Relative position embedding as attention biases. + l = max_decode_length if decode and max_decode_length else inputs.shape[-2] + decoder_bias = self.relative_embedding(l, l, False) + + # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] + x = layers.LayerNorm( + dtype=cfg.dtype, name='pre_self_attention_layer_norm')( + inputs) + + # Self-attention block + x = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + float32_logits=cfg.float32_attention_logits, + name='self_attention')( + x, + x, + decoder_mask, + decoder_bias, + deterministic=deterministic, + decode=decode) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x + inputs + + # Encoder-Decoder block. + y = layers.LayerNorm( + dtype=cfg.dtype, name='pre_cross_attention_layer_norm')( + x) + y = layers.MultiHeadDotProductAttention( + num_heads=cfg.num_heads, + dtype=cfg.dtype, + head_dim=cfg.head_dim, + dropout_rate=cfg.dropout_rate, + float32_logits=cfg.float32_attention_logits, + name='encoder_decoder_attention')( + y, encoded, encoder_decoder_mask, deterministic=deterministic) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y + x + + # MLP block. + z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y) + z = layers.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + name='mlp', + )(z, deterministic=deterministic) + z = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + z, deterministic=deterministic) + z = z + y + + return z + + +class Encoder(nn.Module): + """A stack of encoder layers.""" + config: T5Config + shared_embedding: nn.Module + + @nn.compact + def __call__(self, + encoder_input_tokens, + encoder_mask=None, + deterministic=False): + cfg = self.config + assert encoder_input_tokens.ndim == 2 # [batch, length] + rel_emb = layers.RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_heads=cfg.num_heads, + dtype=cfg.dtype, + embedding_init=nn.initializers.variance_scaling(1.0, 'fan_avg', + 'uniform'), + name='relpos_bias') + + # [batch, length] -> [batch, length, emb_dim] + x = self.shared_embedding(encoder_input_tokens.astype('int32')) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x.astype(cfg.dtype) + + for lyr in range(cfg.num_encoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = EncoderLayer( + config=cfg, relative_embedding=rel_emb, + name=f'layers_{lyr}')(x, encoder_mask, deterministic) + + x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) + return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) + + +class Decoder(nn.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture.""" + config: T5Config + shared_embedding: nn.Module + + @nn.compact + def __call__(self, + encoded, + decoder_input_tokens, + decoder_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + deterministic=False, + decode=False, + max_decode_length=None): + cfg = self.config + assert decoder_input_tokens.ndim == 2 # [batch, len] + rel_emb = layers.RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_heads=cfg.num_heads, + dtype=cfg.dtype, + embedding_init=nn.initializers.variance_scaling(1.0, 'fan_avg', + 'uniform'), + name='relpos_bias') + + # [batch, length] -> [batch, length, emb_dim] + y = self.shared_embedding(decoder_input_tokens.astype('int32')) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + for lyr in range(cfg.num_decoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + y = DecoderLayer( + config=cfg, relative_embedding=rel_emb, name=f'layers_{lyr}')( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + deterministic=deterministic, + decode=decode, + max_decode_length=max_decode_length) + + y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) + + # [batch, length, emb_dim] -> [batch, length, vocab_size] + if cfg.logits_via_embedding: + # Use the transpose of embedding matrix for logit transform. + logits = self.shared_embedding.attend(y) + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + else: + logits = layers.DenseGeneral( + cfg.vocab_size, + dtype=jnp.float32, # Use float32 for stabiliity. + kernel_axes=('embed', 'vocab'), + name='logits_dense')( + y) + return logits + + +class Transformer(nn.Module): + """An encoder-decoder Transformer model.""" + config: T5Config + + def setup(self): + cfg = self.config + self.shared_embedding = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + attend_dtype=jnp.float32, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + one_hot=True, + name='token_embedder') + + self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) + self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) + + def encode(self, + encoder_input_tokens, + encoder_segment_ids=None, + enable_dropout=True): + """Applies Transformer encoder-branch on the inputs.""" + cfg = self.config + assert encoder_input_tokens.ndim == 2 # (batch, len) + + # Make padding attention mask. + encoder_mask = layers.make_attention_mask( + encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) + # Add segmentation block-diagonal attention mask if using segmented data. + if encoder_segment_ids is not None: + encoder_mask = layers.combine_masks( + encoder_mask, + layers.make_attention_mask( + encoder_segment_ids, + encoder_segment_ids, + jnp.equal, + dtype=cfg.dtype)) + + return self.encoder( + encoder_input_tokens, encoder_mask, deterministic=not enable_dropout) + + def decode( + self, + encoded, + encoder_input_tokens, # only needed for masks + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + decoder_positions=None, + enable_dropout=True, + decode=False, + max_decode_length=None): + """Applies Transformer decoder-branch on encoded-input and target.""" + cfg = self.config + + # Make padding attention masks. + if decode: + # Do not mask decoder attention based on targets padding at + # decoding/inference time. + decoder_mask = None + encoder_decoder_mask = layers.make_attention_mask( + jnp.ones_like(decoder_target_tokens), + encoder_input_tokens > 0, + dtype=cfg.dtype) + else: + decoder_mask = layers.make_decoder_mask( + decoder_target_tokens=decoder_target_tokens, + dtype=cfg.dtype, + decoder_segment_ids=decoder_segment_ids) + encoder_decoder_mask = layers.make_attention_mask( + decoder_target_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) + + # Add segmentation block-diagonal attention masks if using segmented data. + if encoder_segment_ids is not None: + if decode: + raise ValueError( + 'During decoding, packing should not be used but ' + '`encoder_segment_ids` was passed to `Transformer.decode`.') + + encoder_decoder_mask = layers.combine_masks( + encoder_decoder_mask, + layers.make_attention_mask( + decoder_segment_ids, + encoder_segment_ids, + jnp.equal, + dtype=cfg.dtype)) + + logits = self.decoder( + encoded, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + deterministic=not enable_dropout, + decode=decode, + max_decode_length=max_decode_length) + return logits + + def __call__(self, + encoder_input_tokens, + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + encoder_positions=None, + decoder_positions=None, + *, + enable_dropout: bool = True, + decode: bool = False): + """Applies Transformer model on the inputs. + + This method requires both decoder_target_tokens and decoder_input_tokens, + which is a shifted version of the former. For a packed dataset, it usually + has additional processing applied. For example, the first element of each + sequence has id 0 instead of the shifted EOS id from the previous sequence. + + Args: + encoder_input_tokens: input data to the encoder. + decoder_input_tokens: input token to the decoder. + decoder_target_tokens: target token to the decoder. + encoder_segment_ids: encoder segmentation info for packed examples. + decoder_segment_ids: decoder segmentation info for packed examples. + encoder_positions: encoder subsequence positions for packed examples. + decoder_positions: decoder subsequence positions for packed examples. + enable_dropout: Ensables dropout if set to True. + decode: Whether to prepare and use an autoregressive cache. + + Returns: + logits array from full transformer. + """ + encoded = self.encode( + encoder_input_tokens, + encoder_segment_ids=encoder_segment_ids, + enable_dropout=enable_dropout) + + return self.decode( + encoded, + encoder_input_tokens, # only used for masks + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=encoder_segment_ids, + decoder_segment_ids=decoder_segment_ids, + decoder_positions=decoder_positions, + enable_dropout=enable_dropout, + decode=decode) diff --git a/t5x/examples/t5/network_test.py b/t5x/examples/t5/network_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7a25a4189db5c180f3834e62f6884a218ad919 --- /dev/null +++ b/t5x/examples/t5/network_test.py @@ -0,0 +1,111 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for network.""" + +import os + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +import jax +import numpy as np +import seqio +from t5x import adafactor +from t5x import models +from t5x import test_utils +from t5x.examples.t5 import network + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +FLAGS = flags.FLAGS + + +def get_test_model(emb_dim, + head_dim, + num_heads, + mlp_dim, + dtype='float32', + vocab_size=32128, + num_encoder_layers=2, + num_decoder_layers=2): + config = network.T5Config( + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + vocab_size=vocab_size, + dropout_rate=0, + emb_dim=emb_dim, + num_heads=num_heads, + head_dim=head_dim, + mlp_dim=mlp_dim, + dtype=dtype, + mlp_activations=('gelu', 'linear')) + module = network.Transformer(config=config) + vocab = seqio.test_utils.sentencepiece_vocab() + optimizer_def = adafactor.Adafactor() + return models.EncoderDecoderModel( + module, vocab, vocab, optimizer_def=optimizer_def) + + +class NetworkTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + batch_size, max_decode_len, input_len = 2, 3, 4 + self.input_shapes = { + 'encoder_input_tokens': (batch_size, input_len), + 'decoder_input_tokens': (batch_size, max_decode_len) + } + np.random.seed(42) + self.batch = { + 'encoder_input_tokens': + np.random.randint(3, 10, size=(batch_size, input_len)), + 'decoder_input_tokens': + np.random.randint(3, 10, size=(batch_size, max_decode_len)), + 'decoder_target_tokens': + np.random.randint(3, 10, size=(batch_size, max_decode_len)) + } + + def test_t5_1_1_regression(self): + np.random.seed(0) + batch_size, max_decode_len, input_len = 2, 3, 4 + batch = { + 'encoder_input_tokens': + np.random.randint(3, 10, size=(batch_size, input_len)), + 'decoder_input_tokens': + np.random.randint(3, 10, size=(batch_size, max_decode_len)), + 'decoder_target_tokens': + np.random.randint(3, 10, size=(batch_size, max_decode_len)) + } + model = get_test_model( + emb_dim=13, + head_dim=64, + num_heads=8, + mlp_dim=2048, + vocab_size=10, + num_encoder_layers=3) + params = model.get_initial_variables( + jax.random.PRNGKey(42), self.input_shapes)['params'] + loss, _ = jax.jit(model.loss_fn)(params, batch, jax.random.PRNGKey(1)) + self.assertAlmostEqual(loss, 18.088945, delta=0.05) + + predicted, scores = model.predict_batch_with_aux(params, batch) + np.testing.assert_array_equal(predicted, [[7, 1, 0], [1, 0, 0]]) + np.testing.assert_allclose( + scores['scores'], [-3.0401115, -1.9265753], rtol=1e-3) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/examples/t5/t5_1_0/11B.gin b/t5x/examples/t5/t5_1_0/11B.gin new file mode 100644 index 0000000000000000000000000000000000000000..003f659429befe3334ada2736a1f872f2fa440e7 --- /dev/null +++ b/t5x/examples/t5/t5_1_0/11B.gin @@ -0,0 +1,13 @@ +# T5.1.0 11B model. + +include 't5x/examples/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 1024 + num_heads = 128 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 128 + mlp_dim = 65536 diff --git a/t5x/examples/t5/t5_1_0/3B.gin b/t5x/examples/t5/t5_1_0/3B.gin new file mode 100644 index 0000000000000000000000000000000000000000..ccfcbd88e227fdf20143a80404523b0c15337417 --- /dev/null +++ b/t5x/examples/t5/t5_1_0/3B.gin @@ -0,0 +1,13 @@ +# T5.1.0 3B model. + +include 't5x/examples/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 1024 + num_heads = 32 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 128 + mlp_dim = 16384 diff --git a/t5x/examples/t5/t5_1_0/__init__.py b/t5x/examples/t5/t5_1_0/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/t5/t5_1_0/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/t5/t5_1_0/base.gin b/t5x/examples/t5/t5_1_0/base.gin new file mode 100644 index 0000000000000000000000000000000000000000..5b7d1e34481004753ad21483df6106358ff67f06 --- /dev/null +++ b/t5x/examples/t5/t5_1_0/base.gin @@ -0,0 +1,55 @@ +# T5.1.0 Base model. +from __gin__ import dynamic_registration + +import seqio +from t5x import adafactor +from t5x import models +from t5x.examples.t5 import network + +# ------------------- Loss HParam ---------------------------------------------- +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) +# the loss normalizing factor should be set to pretraining batch_size * +# target_token_length. +LOSS_NORMALIZING_FACTOR = None +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED + +# Vocabulary (shared by encoder and decoder) +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" + +# ------------------- Optimizer ------------------------------------------------ +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + logical_factor_rules = @adafactor.standard_logical_factor_rules() + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.EncoderDecoderModel() +models.EncoderDecoderModel: + module = @network.Transformer() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# ------------------- Network specification ------------------------------------ +network.Transformer.config = @network.T5Config() +network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 768 + num_heads = 12 + num_encoder_layers = 12 + num_decoder_layers = 12 + head_dim = 64 + mlp_dim = 3072 + mlp_activations = ('relu',) + dropout_rate = %DROPOUT_RATE + logits_via_embedding = True diff --git a/t5x/examples/t5/t5_1_0/large.gin b/t5x/examples/t5/t5_1_0/large.gin new file mode 100644 index 0000000000000000000000000000000000000000..07d1b8eeb32f6948cda159c9a0233ef1fdaa5303 --- /dev/null +++ b/t5x/examples/t5/t5_1_0/large.gin @@ -0,0 +1,13 @@ +# T5.1.0 Large model. + +include 't5x/examples/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 1024 + num_heads = 16 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 4096 diff --git a/t5x/examples/t5/t5_1_0/small.gin b/t5x/examples/t5/t5_1_0/small.gin new file mode 100644 index 0000000000000000000000000000000000000000..3c86b02a2dcfcfe18da1ee78abf763b3209dfdaf --- /dev/null +++ b/t5x/examples/t5/t5_1_0/small.gin @@ -0,0 +1,13 @@ +# T5.1.1 Small model. + +include 't5x/examples/t5/t5_1_0/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 512 + num_heads = 8 + num_encoder_layers = 6 + num_decoder_layers = 6 + head_dim = 64 + mlp_dim = 2048 diff --git a/t5x/examples/t5/t5_1_0/tiny.gin b/t5x/examples/t5/t5_1_0/tiny.gin new file mode 100644 index 0000000000000000000000000000000000000000..ed83eecd0b229ffd8b50561241e268d9cfc3ecfb --- /dev/null +++ b/t5x/examples/t5/t5_1_0/tiny.gin @@ -0,0 +1,13 @@ +# T5.1.1 tiny model. + +include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 8 + num_heads = 4 + num_encoder_layers = 2 + num_decoder_layers = 2 + head_dim = 3 + mlp_dim = 16 diff --git a/t5x/examples/t5/t5_1_1/__init__.py b/t5x/examples/t5/t5_1_1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/t5/t5_1_1/base.gin b/t5x/examples/t5/t5_1_1/base.gin new file mode 100644 index 0000000000000000000000000000000000000000..0dbc43566b75815a991c5f2d9351cac257c56e66 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/base.gin @@ -0,0 +1,55 @@ +# T5.1.1 Base model. +from __gin__ import dynamic_registration + +import seqio +from t5x import adafactor +from t5x import models +from t5x.examples.t5 import network + +# ------------------- Loss HParam ---------------------------------------------- +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +# NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF) +# the loss normalizing factor should be set to pretraining batch_size * +# target_token_length. +LOSS_NORMALIZING_FACTOR = None +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED + +# Vocabulary (shared by encoder and decoder) +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" + +# ------------------- Optimizer ------------------------------------------------ +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + logical_factor_rules = @adafactor.standard_logical_factor_rules() + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.EncoderDecoderModel() +models.EncoderDecoderModel: + module = @network.Transformer() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# ------------------- Network specification ------------------------------------ +network.Transformer.config = @network.T5Config() +network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 768 + num_heads = 12 + num_encoder_layers = 12 + num_decoder_layers = 12 + head_dim = 64 + mlp_dim = 2048 + mlp_activations = ('gelu', 'linear') + dropout_rate = %DROPOUT_RATE + logits_via_embedding = False diff --git a/t5x/examples/t5/t5_1_1/examples/__init__.py b/t5x/examples/t5/t5_1_1/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da022c16301721a096a208e8bdb2a71bb87f9788 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This empty file is needed for loading the gin files in this directory. diff --git a/t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin b/t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin new file mode 100644 index 0000000000000000000000000000000000000000..8f211f918750f84156202c1a66e00e4c40a476b6 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_c4_pretrain.gin @@ -0,0 +1,19 @@ +# Register necessary SeqIO Tasks/Mixtures. +from __gin__ import dynamic_registration +import t5.data.mixtures +import __main__ as train_script + + +include 't5x/examples/t5/t5_1_1/base.gin' +include 't5x/configs/runs/pretrain.gin' + + +MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption" +TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} +TRAIN_STEPS = 100000 +DROPOUT_RATE = 0.0 +BATCH_SIZE = 256 + + +train_script.train: + eval_period = 2000 diff --git a/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin b/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin new file mode 100644 index 0000000000000000000000000000000000000000..ca551ab2ed0348ce3024a2ef20f1464cd15d73ae --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_eval.gin @@ -0,0 +1,46 @@ +from __gin__ import dynamic_registration + +import __main__ as eval_script +import seqio +from t5.data import mixtures +from t5x import partitioning +from t5x import utils +from t5x import models + +include "t5x/examples/t5/t5_1_1/base.gin" # defines %MODEL. + +INITIAL_CHECKPOINT_PATH = %gin.REQUIRED +EVAL_OUTPUT_DIR = %gin.REQUIRED + +DROPOUT_RATE = 0.0 # unused boilerplate + + +eval_script.evaluate: + model = %MODEL # imported from separate gin file + dataset_cfg = @utils.DatasetConfig() + partitioner = @partitioning.PjitPartitioner() + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + output_dir = %EVAL_OUTPUT_DIR + inference_evaluator_cls = @seqio.Evaluator + + +seqio.Evaluator: + logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] + num_examples = None # Use all examples in the dataset. + use_memory_cache = True + + +utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = None # Auto-computes the max feature lengths. + split = 'test' + batch_size = 32 + shuffle = False + seed = 42 + +partitioning.PjitPartitioner.num_partitions = 1 +models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 + +utils.RestoreCheckpointConfig: + path = %INITIAL_CHECKPOINT_PATH + mode = 'specific' diff --git a/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin b/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin new file mode 100644 index 0000000000000000000000000000000000000000..588b264b57c8644723a39a0af70a215f96fac2bc --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_finetune.gin @@ -0,0 +1,51 @@ +from __gin__ import dynamic_registration + +import __main__ as train_script +import seqio +import t5.data.mixtures +from t5x import utils +from t5x import models + + +include 't5x/configs/runs/finetune.gin' +include 't5x/examples/t5/t5_1_1/base.gin' + +BATCH_SIZE = 128 +MIXTURE_OR_TASK_NAME = "wmt14_enfr_v003" +TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 256} +DROPOUT_RATE = 0.1 +TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. + +INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000" + +# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained +# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be +# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: +# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. +LOSS_NORMALIZING_FACTOR = 233472 + +train_script.train: + eval_period = 100 + +train_script.train: + train_dataset_cfg = @train/utils.DatasetConfig() + train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() + infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() + +models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 + +infer_eval/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = %TASK_FEATURE_LENGTHS + split = 'validation' + batch_size = 64 + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS + pack = False + module = %MIXTURE_OR_TASK_MODULE + +seqio.Evaluator: + logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] + num_examples = None # Use all examples in the dataset. + use_memory_cache = True diff --git a/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin b/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin new file mode 100644 index 0000000000000000000000000000000000000000..82094888d9640d43ae2fcbfc1cfd2947c45321d6 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_wmt14enfr_train.gin @@ -0,0 +1,18 @@ +from __gin__ import dynamic_registration +import t5.data.mixtures +import __main__ as train_script + + +include 't5x/configs/runs/pretrain.gin' +include 't5x/examples/t5/t5_1_1/base.gin' + + +TRAIN_STEPS = 100000 +BATCH_SIZE = 128 +MIXTURE_OR_TASK_NAME = "wmt14_enfr_v003" +TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 256} +DROPOUT_RATE = 0.1 +INITIAL_CHECKPOINT_PATH = %gin.REQUIRED + +train_script.train: + eval_period = 2000 diff --git a/t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin b/t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin new file mode 100644 index 0000000000000000000000000000000000000000..bd583c75bf22bda578478634a4c65231ca41fe84 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin @@ -0,0 +1,62 @@ +from __gin__ import dynamic_registration + +import __main__ as train_script +from t5x import adafactor +from t5x import models +from t5x import partitioning +from t5x import trainer +from t5x import utils +from t5x.examples.t5 import network + +include "t5x/examples/t5/t5_1_1/base.gin" +include "t5x/configs/runs/finetune.gin" + +MIXTURE_OR_TASK_NAME = "wmt19_ende_v003" +MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" +TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512} +TRAIN_STEPS = 5000 +LABEL_SMOOTHING = 0.1 +INITIAL_CHECKPOINT_PATH = None +# Note that `DROPOUT_RATE = 0.1` is specified in the finetune.gin but we just +# repeat to make it explicit. +DROPOUT_RATE = 0.1 + +train/utils.DatasetConfig: + batch_size = 128 + use_cached = False + pack = True + use_custom_packing_ops = False + seed = 0 + +train_eval/utils.DatasetConfig: + batch_size = 128 + use_cached = False + pack = False + use_custom_packing_ops = False + seed = 0 + +infer_eval/utils.DatasetConfig: + use_cached = False + +train_script.train: + eval_period = 250 + eval_steps = 20 + random_seed = 0 + use_hardware_rng = True + +utils.CheckpointConfig.restore = None +utils.SaveCheckpointConfig: + period = 500 # checkpoint frequency + keep = 1 + +# Decoder overrides +models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 + +trainer.Trainer.num_microbatches = 2 +utils.create_learning_rate_scheduler.warmup_steps = 1000 + +partitioning.PjitPartitioner: + model_parallel_submesh = (1, 1, 1, 2) + +adafactor.Adafactor: + logical_factor_rules = @adafactor.standard_logical_factor_rules() diff --git a/t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin b/t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin new file mode 100644 index 0000000000000000000000000000000000000000..9f8bf0ab7d4d6f0ca78ff47df8371d365602f5c0 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin @@ -0,0 +1,34 @@ +from __gin__ import dynamic_registration + +import __main__ as eval_script +from t5.data import mixtures +from t5x import partitioning +from t5x import utils + +include "t5x/examples/t5/t5_1_1/base.gin" # defines %MODEL. + +CHECKPOINT_PATH = %gin.REQUIRED # passed via commandline +EVAL_OUTPUT_DIR = %gin.REQUIRED # passed via commandline + +DROPOUT_RATE = 0.0 # unused boilerplate +MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" + +eval_script.evaluate: + model = %MODEL # imported from separate gin file + dataset_cfg = @utils.DatasetConfig() + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + output_dir = %EVAL_OUTPUT_DIR + +utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = None # Auto-computes the max feature lengths. + split = 'test' + batch_size = 32 + shuffle = False + seed = 42 + +partitioning.PjitPartitioner.num_partitions = 2 + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' diff --git a/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin b/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin new file mode 100644 index 0000000000000000000000000000000000000000..a981a46092ec8e74db7869092003c3ba9dada953 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin @@ -0,0 +1,63 @@ +from __gin__ import dynamic_registration + +import __main__ as train_script +import seqio +from t5.data import mixtures +from t5x import models +from t5x import partitioning +from t5x import utils + +include "t5x/examples/t5/t5_1_1/base.gin" +include "t5x/configs/runs/pretrain.gin" + +MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" +TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} +TRAIN_STEPS = 50000 +DROPOUT_RATE = 0.0 + +train/utils.DatasetConfig: + batch_size = 128 + use_cached = False + pack = True + seed = 0 + +train_eval/utils.DatasetConfig: + batch_size = 128 + use_cached = False + pack = True + seed = 0 + +infer_eval/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + task_feature_lengths = None # compute max + split = "validation" + seed = 0 + batch_size = 128 + shuffle = False + use_cached = False + +train_script.train: + eval_period = 500 + eval_steps = 20 + random_seed = 0 + use_hardware_rng = True + infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() + inference_evaluator_cls = @seqio.Evaluator + +seqio.Evaluator: + logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] + num_examples = None # Use all examples in the infer_eval dataset. + use_memory_cache = True + +utils.SaveCheckpointConfig: + period = 5000 # checkpoint frequency + +# `num_decodes` is equivalent to a beam size in a beam search decoding. +models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 + +partitioning.PjitPartitioner.num_partitions = 2 + +utils.create_learning_rate_scheduler: + factors = 'constant * rsqrt_decay' + base_learning_rate = 1.0 + warmup_steps = 10000 diff --git a/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin b/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin new file mode 100644 index 0000000000000000000000000000000000000000..5c4ae4468349439a9fabde65c25255af66c94756 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch_adamw.gin @@ -0,0 +1,51 @@ +# This gin file is to show how to switch to an optimizer other than +# Adafactor. Gin configuration makes it easy by simply importing any available +# optimizer in t5x/optimizers module. Note the optimizers in t5x/optimizers are +# wrapped version of optimizers implemented in optax. + +from __gin__ import dynamic_registration + +from t5x import optimizers +from t5x import utils +import optax + +include "t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" + +# In this case, we choose to switch to the AdamW optimizer with gradient clip. +OPTIMIZER = @optimizers.chain() + +optimizers.chain: + transformations = [@optax.clip(), @optax.adamw()] + +optax.clip: + max_delta = 1.0 + +optax.adamw: + # Unlike Adafactor, most optimizers require to specify + # `learning_rate`. `learning_rate` accepts a float number (e.g., 1e-4) or + # a schedule function, which should take an argument `step` and output + # a learning rate for that step. + # As for choices of schedule functions, we can either use T5x + # learning rate scheduler, i.e., utils.create_learning_rate_scheduler, or + # optax's native schedule functions, e.g., warmup_cosine_decay_schedule. + learning_rate = @optax.warmup_cosine_decay_schedule() + +optax.warmup_cosine_decay_schedule: + init_value = 0.0 + peak_value = 1e-4 + warmup_steps = 1000 + decay_steps = %TRAIN_STEPS + end_value = 0.0 + + +# Below is an example of using the T5X's schedule functions. +# Feel free to uncomment to try. +# optax.adamw: +# learning_rate = @utils.create_learning_rate_scheduler() + +# utils.create_learning_rate_scheduler: +# factors = 'constant * linear_warmup * rsqrt_decay' +# base_learning_rate = 0.01 +# warmup_steps = 10000 + + diff --git a/t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin b/t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin new file mode 100644 index 0000000000000000000000000000000000000000..73898092b78ae06ba1fec42634b3fe8e3092a6d7 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin @@ -0,0 +1,19 @@ +from __gin__ import dynamic_registration + +import __main__ as infer_script +from t5.data import mixtures +from t5x import partitioning +from t5x import utils + +include "t5x/examples/t5/t5_1_1/base.gin" +include "t5x/configs/runs/infer.gin" + +DROPOUT_RATE = 0.0 # unused but needs to be specified +MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" +TASK_FEATURE_LENGTHS = {"inputs": 64, "targets": 64} + +partitioning.PjitPartitioner.num_partitions = 1 + +utils.DatasetConfig: + split = "test" + batch_size = 32 diff --git a/t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin b/t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin new file mode 100644 index 0000000000000000000000000000000000000000..4a4ccbeadbd5f118d6cfe2c0d0385a3f5d233d8e --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin @@ -0,0 +1,11 @@ +include 't5x/examples/t5/t5_1_1/small.gin' +include 't5x/configs/runs/pretrain.gin' + +# Register necessary SeqIO Tasks/Mixtures. +import t5.data.mixtures + +MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption" +TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} +TRAIN_STEPS = 10000 +DROPOUT_RATE = 0.0 +BATCH_SIZE = 256 diff --git a/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin b/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin new file mode 100644 index 0000000000000000000000000000000000000000..f948cd9f18883516f00f9939c2d0aac069b8ee7c --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin @@ -0,0 +1,21 @@ +from __gin__ import dynamic_registration + +import __main__ as train_script +from t5.data import mixtures +from t5x import models +from t5x import partitioning +from t5x import utils + +include "t5x/examples/t5/t5_1_1/small.gin" +include "t5x/configs/runs/finetune.gin" + +MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003" +TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} +TRAIN_STEPS = 1_020_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. +DROPOUT_RATE = 0.0 +INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" +# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained +# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be +# set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1: +# `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`. +LOSS_NORMALIZING_FACTOR = 233472 diff --git a/t5x/examples/t5/t5_1_1/examples/test_train_t5_tiny.gin b/t5x/examples/t5/t5_1_1/examples/test_train_t5_tiny.gin new file mode 100644 index 0000000000000000000000000000000000000000..9006ad4dcf107a0a9ae12c7c7113bb772d532cfc --- /dev/null +++ b/t5x/examples/t5/t5_1_1/examples/test_train_t5_tiny.gin @@ -0,0 +1,56 @@ +# Test config to exercise train.py with model-based pjit partitioning. + +from __gin__ import dynamic_registration + +import __main__ as train_script +from t5x import adafactor +from t5x import models +from t5x import partitioning +from t5x import trainer +from t5x import utils + +include 't5x/configs/runs/pretrain.gin' +include 't5x/examples/t5/t5_1_1/tiny.gin' + +MODEL_DIR = "/tmp" # Will be overridden in test. + +TRAIN_STEPS = 3 +MIXTURE_OR_TASK_MODULE = "t5.data.mixtures" +MIXTURE_OR_TASK_NAME = "wmt19_ende_v003" +TASK_FEATURE_LENGTHS = {"inputs": 32, "targets": 32} +DROPOUT_RATE = 0.0 + +models.EncoderDecoderModel: + z_loss = 0.0 + label_smoothing = 0.0 + loss_normalizing_factor = None + + +train/utils.DatasetConfig: + pack = False + seed = 0 + shuffle = False + use_cached = False + batch_size = 8 + +train_eval/utils.DatasetConfig: + pack = False + seed = 0 + shuffle = False + use_cached = False + batch_size = 8 + +train_script.train: + random_seed = 0 + eval_steps = 2 + actions={'TRAIN_EVAL': [@trainer.TerminateOnNanAction()]} + +trainer.TerminateOnNanAction: + task = %MIXTURE_OR_TASK_NAME + +partitioning.PjitPartitioner.num_partitions = 2 +utils.SaveCheckpointConfig.period = 4 + +# Overriding from pretrain.gin to keep magic constants in tests. +utils.create_learning_rate_scheduler: + warmup_steps = 1000 diff --git a/t5x/examples/t5/t5_1_1/large.gin b/t5x/examples/t5/t5_1_1/large.gin new file mode 100644 index 0000000000000000000000000000000000000000..6d92ef41984399ff6cc87b869e45b55fb1860a42 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/large.gin @@ -0,0 +1,13 @@ +# T5.1.1 Large model. + +include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 1024 + num_heads = 16 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 2816 diff --git a/t5x/examples/t5/t5_1_1/small.gin b/t5x/examples/t5/t5_1_1/small.gin new file mode 100644 index 0000000000000000000000000000000000000000..1c4f9d0dc6f89fc0bbd88d7116fedf508de8cc03 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/small.gin @@ -0,0 +1,13 @@ +# T5.1.1 Small model. + +include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 512 + num_heads = 6 + num_encoder_layers = 8 + num_decoder_layers = 8 + head_dim = 64 + mlp_dim = 1024 diff --git a/t5x/examples/t5/t5_1_1/tiny.gin b/t5x/examples/t5/t5_1_1/tiny.gin new file mode 100644 index 0000000000000000000000000000000000000000..ed83eecd0b229ffd8b50561241e268d9cfc3ecfb --- /dev/null +++ b/t5x/examples/t5/t5_1_1/tiny.gin @@ -0,0 +1,13 @@ +# T5.1.1 tiny model. + +include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 8 + num_heads = 4 + num_encoder_layers = 2 + num_decoder_layers = 2 + head_dim = 3 + mlp_dim = 16 diff --git a/t5x/examples/t5/t5_1_1/xl.gin b/t5x/examples/t5/t5_1_1/xl.gin new file mode 100644 index 0000000000000000000000000000000000000000..34f8cd6f312729454480a83822c1d8ff8920c242 --- /dev/null +++ b/t5x/examples/t5/t5_1_1/xl.gin @@ -0,0 +1,13 @@ +# T5.1.1 XL model. + +include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 2048 + num_heads = 32 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 5120 diff --git a/t5x/examples/t5/t5_1_1/xxl.gin b/t5x/examples/t5/t5_1_1/xxl.gin new file mode 100644 index 0000000000000000000000000000000000000000..1d4828687bfd79c78e47977bd2ff520efe3f9d1a --- /dev/null +++ b/t5x/examples/t5/t5_1_1/xxl.gin @@ -0,0 +1,13 @@ +# T5.1.1 XXL model. + +include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model. + +# ------------------- Network specification overrides -------------------------- +network.Transformer.config = @network.T5Config() +network.T5Config: + emb_dim = 4096 + num_heads = 64 + num_encoder_layers = 24 + num_decoder_layers = 24 + head_dim = 64 + mlp_dim = 10240 diff --git a/t5x/gin_utils.py b/t5x/gin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9b98c7cc0839e47b34071cde6114e5c7912f7b --- /dev/null +++ b/t5x/gin_utils.py @@ -0,0 +1,122 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for using gin configurations with T5X binaries.""" +import os +from typing import Optional, Sequence + +from absl import app +from absl import logging +from clu import metric_writers +import gin +import jax +import tensorflow as tf + + + +def parse_gin_flags(gin_search_paths: Sequence[str], + gin_files: Sequence[str], + gin_bindings: Sequence[str], + skip_unknown: bool = False, + finalize_config: bool = True): + """Parses provided gin files override params. + + Args: + gin_search_paths: paths that will be searched for gin files. + gin_files: paths to gin config files to be parsed. Files will be parsed in + order with conflicting settings being overriden by later files. Paths may + be relative to paths in `gin_search_paths`. + gin_bindings: individual gin bindings to be applied after the gin files are + parsed. Will be applied in order with conflicting settings being overriden + by later oens. + skip_unknown: whether to ignore unknown bindings or raise an error (default + behavior). + finalize_config: whether to finalize the config so that it cannot be + modified (default behavior). + """ + # We import t5.data here since it includes gin configurable functions commonly + # used by task modules. + # TODO(adarob): Strip gin from t5.data and remove this import. + import t5.data # pylint:disable=unused-import,g-import-not-at-top + # Register .gin file search paths with gin + for gin_file_path in gin_search_paths: + gin.add_config_file_search_path(gin_file_path) + + + # Parse config files and bindings passed via flag. + gin.parse_config_files_and_bindings( + gin_files, + gin_bindings, + skip_unknown=skip_unknown, + finalize_config=finalize_config) + logging.info('Gin Configuration:\n%s', gin.config_str()) + + +def rewrite_gin_args(args: Sequence[str]) -> Sequence[str]: + """Rewrite `--gin.NAME=VALUE` flags to `--gin_bindings=NAME=VALUE`.""" + + def _rewrite_gin_arg(arg): + if not arg.startswith('--gin.'): + return arg + if '=' not in arg: + raise ValueError( + "Gin bindings must be of the form '--gin.=', got: " + + arg) + # Strip '--gin.' + arg = arg[6:] + name, value = arg.split('=', maxsplit=1) + r_arg = f'--gin_bindings={name} = {value}' + print(f'Rewritten gin arg: {r_arg}') + return r_arg + + return [_rewrite_gin_arg(arg) for arg in args] + + +@gin.register +def summarize_gin_config(model_dir: str, + summary_writer: Optional[metric_writers.MetricWriter], + step: int): + """Writes gin config to the model dir and TensorBoard summary.""" + if jax.process_index() == 0: + config_str = gin.config_str() + tf.io.gfile.makedirs(model_dir) + # Write the config as JSON. + with tf.io.gfile.GFile(os.path.join(model_dir, 'config.gin'), 'w') as f: + f.write(config_str) + # Include a raw dump of the json as a text summary. + if summary_writer is not None: + summary_writer.write_texts(step, {'config': gin.markdown(config_str)}) + summary_writer.flush() + + +def run(main): + """Wrapper for app.run that rewrites gin args before parsing.""" + app.run( + main, + flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a))) + + +# ====================== Configurable Utility Functions ====================== + + +@gin.configurable +def sum_fn(var1=gin.REQUIRED, var2=gin.REQUIRED): + """sum function to use inside gin files.""" + return var1 + var2 + + +@gin.configurable +def bool_fn(var1=gin.REQUIRED): + """bool function to use inside gin files.""" + return bool(var1) diff --git a/t5x/gin_utils_test.py b/t5x/gin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e69b73991c64a4e450199476946244422b62d94e --- /dev/null +++ b/t5x/gin_utils_test.py @@ -0,0 +1,59 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for gin_utils.""" + +from absl.testing import absltest +from t5x import gin_utils + + +class GinUtilsTest(absltest.TestCase): + + def test_rewrite_gin_args(self): + test_args = [ + '--gin_file=path/to/file', + 'gin.value=3', + '--gin.value=3', + '--gin.value="3"', + '--gin.value=\'3\'', + '--gin.tricky="key = value"', + '--gin.dict={"foo": 4, "bar": "four"}', + '--gin.gin=bar', + '--gin.scope/foo=bar', + ] + expected_args = [ + '--gin_file=path/to/file', + 'gin.value=3', + '--gin_bindings=value = 3', + '--gin_bindings=value = "3"', + '--gin_bindings=value = \'3\'', + '--gin_bindings=tricky = "key = value"', + '--gin_bindings=dict = {"foo": 4, "bar": "four"}', + '--gin_bindings=gin = bar', + '--gin_bindings=scope/foo = bar', + ] + self.assertSequenceEqual( + gin_utils.rewrite_gin_args(test_args), expected_args) + + def test_rewrite_gin_args_malformed(self): + test_args = ['--gin.value=3', '--gin.test'] + with self.assertRaisesWithLiteralMatch( + ValueError, + "Gin bindings must be of the form '--gin.=', got: " + '--gin.test'): + gin_utils.rewrite_gin_args(test_args) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/infer.py b/t5x/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce18f7e1e818db4482d829265d26cb30a6c8974 --- /dev/null +++ b/t5x/infer.py @@ -0,0 +1,732 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=line-too-long +# pyformat: disable +r"""This script runs inference on a T5X-compatible model. + +""" +# pyformat: enable +# pylint:enable=line-too-long + +import concurrent.futures +import functools +import hashlib +import json +import os +import re +import shutil +import time +from typing import Any, Callable, Iterator, List, Mapping, Optional, Sequence, Tuple, Type + +# TODO(adarob): Re-enable once users are notified and tests are updated. +# Must be set before flax imports. +# pylint:disable=g-import-not-at-top +os.environ['FLAX_LAZY_RNG'] = 'no' +from absl import logging +from clu import metric_writers +import jax +from jax.experimental import multihost_utils +import jax.numpy as jnp +import numpy as np +import seqio +from t5x import gin_utils +from t5x import models +from t5x import partitioning +from t5x import utils +import tensorflow as tf +from tensorflow.io import gfile +from typing_extensions import Protocol + +# Automatically search for gin files relative to the T5X package. +_DEFAULT_GIN_SEARCH_PATHS = [ + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +] + +AUTOTUNE = tf.data.experimental.AUTOTUNE + + +class SummarizeConfigFn(Protocol): + + def __call__(self, model_dir: str, + summary_writer: Optional[metric_writers.SummaryWriter], + step: int) -> None: + ... + + +class FailFastThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): + """Wrapper for ThreadPoolExecutor that crashes main thread on exceptions. + + NOTE: this class should be used only from the main thread. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._incomplete_futures: List[concurrent.futures.Future] = [] + + def check_for_exceptions(self, wait: bool = False): + """Raises any exceptions from complete futures on the main thread.""" + still_incomplete_futures = [] + for future in self._incomplete_futures: + try: + exception = future.exception(timeout=0 if wait else None) + except concurrent.futures.TimeoutError: + still_incomplete_futures.append(future) + if exception is not None: + raise exception + + self._incomplete_futures = still_incomplete_futures + + def submit(self, *args, **kwargs) -> concurrent.futures.Future: + """Submit function to threadpool, capturing the returned future.""" + future = super().submit(*args, **kwargs) + self._incomplete_futures.append(future) + self.check_for_exceptions(wait=False) + return future + + def shutdown(self, *args, wait: bool = False, **kwargs): + self.check_for_exceptions(wait=wait) + super().shutdown(*args, **kwargs) + + +def create_task_from_tfexample_file(paths: Sequence[str], + file_type: str, + inputs_key: str, + targets_key: Optional[str], + features: Mapping[str, seqio.Feature], + task_id: Optional[str] = None) -> str: + """Registers ad-hoc Task for file-based dataset of TFExamples. + + Args: + paths: Input file paths; all files should have type `file_type` and contain + binary-serialized TFExample protos. + file_type: Input file type; e.g., 'tfrecord', 'recordio', 'sstable'. For + keyed formats like 'sstable', we ignore the keys and use only the values. + inputs_key: Name of TFExample feature containing the input text for T5X. The + value of this feature should be a UTF8-encoded string. + targets_key: Optional name of a TFExample feature containing the target text + (relevant only in scoring mode). The value of this feature should be a + UTF8-encoded string. + features: Should have entries for keys 'inputs' and (if targets_key is not + None) 'targets', mapping to `seqio.Feature` objects that specify + attributes like vocabulary, add_eos, etc. These attributes are used for + preprocessing and featurizing the input text. + task_id: Task name identifier. By default, it is set to a unique and + deterministic hash id. Overrideable via this argument. + + Returns: + Name of the newly-registered Task. This Task has a split named 'infer' that + contains the preprocessed and featurized input dataset. + """ + # tf.io.gfile.glob supports lists, in contrast to gfile.glob. + files = tf.io.gfile.glob(paths) + if files: + logging.info('Using tfexample files %s', files) + else: + # Fail early if there's something wrong with the input file pattern. + raise ValueError('Missing or invalid paths: %s' % paths) + reader = { + 'tfrecord': + tf.data.TFRecordDataset, + }[file_type] + + feature_description = {inputs_key: tf.io.FixedLenFeature([], tf.string)} + if targets_key: + feature_description[targets_key] = tf.io.FixedLenFeature([], tf.string) + + # Create a unique, deterministic task name. + if task_id is None: + task_id = hashlib.md5( + ':'.join(list(paths) + + [inputs_key, targets_key or '']).encode()).hexdigest()[:10] + + task = seqio.TaskRegistry.add( + name=f'infer_{task_id}', + source=seqio.TFExampleDataSource({'infer': paths}, + feature_description=feature_description, + reader_cls=reader), + preprocessors=[ + functools.partial( + seqio.preprocessors.rekey, + key_map={ + 'inputs': inputs_key, + 'targets': targets_key + }), seqio.preprocessors.tokenize_and_append_eos + ], + output_features=features) + + return task.name + + +def merge_chunks_to_file( + output_dir: str, + output_fname: str, + tmp_dir: str, + step: Optional[int], +) -> None: + """Merge the predictions from different chunks into a unified file.""" + logging.info('Merging chunk results.') + # Merge chunks into single file. + chunk_paths = sorted( + gfile.glob(os.path.join(tmp_dir, f'{output_fname}-chunk?????'))) + + if not chunk_paths: + raise FileNotFoundError( + 'No chunk results found! One possible explanation is that your ' + 'input did not contain any examples') + + assert int(chunk_paths[-1][-5:]) + 1 == len(chunk_paths), ( + f'Expecting {int(chunk_paths[-1][-5:])} chunk paths, found ' + f'{len(chunk_paths)}') + output_path = os.path.join(output_dir, output_fname) + del step + with gfile.GFile(output_path, 'wb') as merged: + for chunk_path in chunk_paths: + with gfile.GFile(chunk_path, 'rb') as ef: + shutil.copyfileobj(ef, merged) + logging.info('Results written to %s.', output_path) + + +_Inferences = Tuple[Sequence[Any], Mapping[str, Any]] + + +def write_inferences_to_file( + path: str, + inferences: _Inferences, + task_ds: tf.data.Dataset, + mode: str, + vocabulary: Optional[seqio.Vocabulary] = None, + json_encoder_cls: Type[json.JSONEncoder] = seqio.TensorAndNumpyEncoder, + include_all_inputs: bool = False, + input_fields_to_include: Optional[Sequence[str]] = None, + output_ids: bool = False) -> None: + """Write model predictions, along with pretokenized inputs, to JSONL file. + + Args: + path: File path to write to. + inferences: A tuple containing (predictions, aux_values). If mode is + 'predict' then the `predictions` will be token IDs. If it's + 'scores' then it'll be a collection of scores. `aux_values` will be an + empty dictionary unless mode is 'predict_with_aux', in which case it'll + contain the model's auxiliary outputs. + task_ds: Original task dataset. Features from task with suffix + `_pretokenized` are added to the outputs. + mode: Prediction mode, either 'predict', 'score' or 'predict_with_aux'. + vocabulary: Task output vocabulary. Only used in `predict` mode in order to + decode predicted outputs into string. + json_encoder_cls: a JSON encoder class used to customize JSON serialization + via json.dumps. + include_all_inputs: if True, will include all model inputs in the output + JSONL file (including raw tokens) in addition to the pretokenized inputs. + input_fields_to_include: List of input fields to include in the output JSONL + file. This list should be None if `include_all_inputs` is set to True. + output_ids: if True, will output the token ID sequence for the output, in + addition to the decoded text. + """ + all_predictions, all_aux_values = inferences + + if mode in ('predict', 'predict_with_aux') and vocabulary is None: + raise ValueError('The `vocabulary` parameter is required in `predict` and ' + '`predict_with_aux` modes') + + def _json_compat(value): + if isinstance(value, bytes): + return value.decode('utf-8') + elif isinstance(value, (jnp.bfloat16, jnp.floating)): + return float(value) + elif isinstance(value, jnp.integer): + return float(value) + elif isinstance(value, (jnp.ndarray, np.ndarray)): + # Flatten array features. + return value.tolist() + else: + return value + + if include_all_inputs and input_fields_to_include is not None: + raise ValueError( + 'include_all_inputs and input_fields_to_include should not be set' + ' simultaneously.') + with gfile.GFile(path, 'w') as f: + for i, inp in task_ds.enumerate().as_numpy_iterator(): + predictions = all_predictions[i] + aux_values = {aux_field: v[i] for aux_field, v in all_aux_values.items()} + + if include_all_inputs: + inputs = inp + elif input_fields_to_include is not None: + inputs = { + k: v for k, v in inp.items() if k in input_fields_to_include or + (k.endswith('_pretokenized') and + k[:-len('_pretokenized')] in input_fields_to_include) + } + else: + inputs = {k: v for k, v in inp.items() if k.endswith('_pretokenized')} + + json_dict = {} + json_dict['inputs'] = {k: _json_compat(v) for k, v in inputs.items()} + + if mode == 'predict': + assert vocabulary is not None + json_dict['prediction'] = _json_compat( + vocabulary.decode_tf(tf.constant(predictions)).numpy()) + if output_ids: + pred = _json_compat(tf.constant(predictions).numpy()) + # Truncate padding tokens. + assert isinstance(pred, list) + pred = pred[:pred.index(0)] if 0 in pred else pred + json_dict['prediction_tokens'] = pred + elif mode == 'score': + json_dict['score'] = _json_compat(predictions) + elif mode == 'predict_with_aux': + assert vocabulary is not None + json_dict['prediction'] = _json_compat( + vocabulary.decode_tf(tf.constant(predictions)).numpy()) + if output_ids: + pred = _json_compat(tf.constant(predictions).numpy()) + # Truncate padding tokens. + pred = pred[:pred.index(0)] if 0 in pred else pred + json_dict['prediction_tokens'] = pred + json_dict['aux'] = jax.tree_map(_json_compat, aux_values) + else: + raise ValueError(f'Invalid mode: {mode}') + json_str = json.dumps(json_dict, cls=json_encoder_cls) + f.write(json_str + '\n') + + +WriteFn = Callable[[ + str, + _Inferences, + tf.data.Dataset, + str, + Optional[seqio.Vocabulary], +], None] + +MergeFn = Callable[[str, str, str, Optional[int]], None] + + +def _extract_tokens_and_aux_values(inference_fn_outputs) -> _Inferences: + """Extracts tokens and aux scores from a cached dataset.""" + all_aux_values = {} + if isinstance(inference_fn_outputs, tuple): + indices_and_tokens, all_aux_values = inference_fn_outputs + indices, tokens = zip(*indices_and_tokens) + + permutation = np.argsort(indices) + + tokens = [tokens[permutation[i]] for i in range(len(permutation))] + for aux_keys, aux_values in all_aux_values.items(): + all_aux_values[aux_keys] = [ + aux_values[permutation[i]] for i in range(len(permutation)) + ] + + else: + indices_and_tokens = inference_fn_outputs + _, tokens = zip(*sorted(indices_and_tokens, key=lambda x: x[0])) + + return tokens, all_aux_values + + +def infer( + *, + mode: str, + model: models.BaseTransformerModel, + dataset_cfg: utils.DatasetConfig, + restore_checkpoint_cfg: utils.RestoreCheckpointConfig, + partitioner: partitioning.BasePartitioner, + output_dir: str, + checkpoint_period: int, + shard_id: int = 0, + num_shards: int = 1, + merge_chunked_results: bool = True, + write_fn: WriteFn = write_inferences_to_file, + checkpoint_ds_iter: bool = True, + fallback_init_rng: Optional[int] = None, + merge_fn: MergeFn = merge_chunks_to_file, + summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, +): + """Infer function. + + Args: + mode: Either 'predict' to decode targets, 'score' to compute the log + likelihood of given targets, or 'predict_with_aux' for both. + model: The model object to use for inference. + dataset_cfg: Specification for the dataset to infer based on. + restore_checkpoint_cfg: Specification for the model parameter checkpoint to + load. + partitioner: Partitioner for model parameters and data across devices. + output_dir: Path to directory to write temporary files and final results. + checkpoint_period: The intermediate results and dataset iterator will be + checkpointed on each multiple of this number of batches to enable + continuation after a failure. + shard_id: Index of dataset shard for this instance to use if splitting the + work across multiple jobs. + num_shards: Total number of dataset shards to split dataset across. + merge_chunked_results: Whether to merge results of all chunks into a single + json file. + write_fn: Callable function used to serialized and write inferences out to + files. + checkpoint_ds_iter: if True, will checkpoint the dataset iterator every + `checkpoint_period` to enable faster restore. This must be disabled for + certain datasets, for example since stateful iterators (e.g. from + seqio.FunctionTask) cannot be checkpointed. + fallback_init_rng: A random seed used for parameter initialization during + model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is + set to True. If None, parameter initialization is not allowed during model + loading and having fallback_to_scratch enabled will result in an error. + merge_fn: Callable function used to merge inferences from multiple files. + summarize_config_fn: A function that takes in the model directory, an + optional SummaryWriter, and the step number, and writes a summary of the + configuration. SummaryWriter will be None in most cases. + """ + logging.info('Process ID: %d', jax.process_index()) + + summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) + + if mode not in ('predict', 'score', 'predict_with_aux'): + raise ValueError( + "`mode` must be one of 'predict', 'score' or 'predict_with_aux'. " + f"Got '{mode}'") + + # Remove double-slashes in directory path to avoid inconsistencies. + output_dir = re.sub(r'(? 1: + raise app.UsageError('Too many command-line arguments.') + + if FLAGS.tfds_data_dir: + seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) + + + # Create gin-configurable version of `infer`. + infer_using_gin = gin.configurable(infer) + + gin_utils.parse_gin_flags( + # User-provided gin paths take precedence if relative paths conflict. + FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, + FLAGS.gin_file, + FLAGS.gin_bindings) + + # See http://yaqs/7882016229479677952 for further gin-config discussion. + def _get_gin_parameter(key: str) -> Any: + value = gin.query_parameter(key) + if isinstance(value, gin.config.ConfigurableReference): + if value.evaluate: + return value.scoped_configurable_fn() + return value.scoped_configurable_fn + return value + + shard_id = ( + FLAGS.shard_id + if FLAGS.shard_id is not None else _get_gin_parameter('infer.shard_id')) + if shard_id == 0: + gin_utils.summarize_gin_config( + model_dir=_get_gin_parameter('infer.output_dir'), + summary_writer=None, + step=0) + if FLAGS.shard_id is not None: + # We fall back to this flag since XM does not support sweeps over flags + # with '.' in them (it treats them like nested dictionaries). + # TODO(adarob): Figure out a workaround so we can deprecate this flag. + infer_using_gin(shard_id=FLAGS.shard_id) + else: + infer_using_gin() + + + gin_utils.run(main) diff --git a/t5x/losses.py b/t5x/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..5f421162811f8a708b191872bba05440e74b78a6 --- /dev/null +++ b/t5x/losses.py @@ -0,0 +1,264 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loss functions.""" +import enum +from typing import Tuple, Mapping, Optional, Union + +from flax.training import common_utils +import jax +import jax.numpy as jnp +import numpy as np + + +@jax.custom_vjp +def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, + z_loss: float) -> jnp.ndarray: + """Computes cross entropy loss with stable custom gradient. + + Computes a stabilized-gradient version of: + -jnp.sum(targets * nn.log_softmax(logits), axis=-1) + + If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 + will be added to the cross entropy loss (z = softmax normalization constant). + The two uses of z_loss are: + 1. To keep the logits from drifting too far from zero, which can cause + unacceptable roundoff errors in bfloat16. + 2. To encourage the logits to be normalized log-probabilities. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical one-hot targets [batch, length, num_classes] float + array. + z_loss: coefficient for auxilliary z-loss loss term. + + Returns: + tuple with the total loss and the z_loss, both + float arrays with shape [batch, length]. + """ + logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) + log_softmax = logits - logits_sum + loss = -jnp.sum(targets * log_softmax, axis=-1) + # Add auxilliary z-loss term. + log_z = jnp.squeeze(logits_sum, axis=-1) + total_z_loss = z_loss * jax.lax.square(log_z) + loss += total_z_loss + return loss, total_z_loss + + +def _cross_entropy_with_logits_fwd( + logits: jnp.ndarray, + targets: jnp.ndarray, + z_loss: float = 0.0 +) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp + .ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]: + """Forward-mode of `cross_entropy_with_logits`.""" + max_logit = logits.max(axis=-1, keepdims=True) + shifted = logits - max_logit + exp_shifted = jnp.exp(shifted) + sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True) + log_softmax = shifted - jnp.log(sum_exp) + loss = -jnp.sum(targets * log_softmax, axis=-1) + # Add auxilliary z-loss term. + log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) + total_z_loss = z_loss * jax.lax.square(log_z) + loss += total_z_loss + return (loss, total_z_loss), (logits, targets, z_loss, exp_shifted, sum_exp, + log_softmax, log_z) + + +def _cross_entropy_with_logits_bwd( + res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, jnp.ndarray], g: Tuple[jnp.ndarray, jnp.ndarray] +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Backward-mode of `cross_entropy_with_logits`.""" + g = g[0] # Ignore z_loss component as that is only used for logging. + logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res + # z-loss term adds the (2 * z_loss * log_z) factor. + deriv = ( + jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - + targets) + g_logits = jnp.expand_dims(g, axis=-1) * deriv + g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax + return (jnp.asarray(g_logits, + logits.dtype), jnp.asarray(g_targets, targets.dtype), + jnp.array(0.0)) # sets z-loss coeff gradient to 0 + + +cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, + _cross_entropy_with_logits_bwd) + + +def compute_weighted_cross_entropy( + logits: jnp.ndarray, + targets: jnp.ndarray, + weights: Optional[jnp.ndarray] = None, + label_smoothing: float = 0.0, + z_loss: float = 0.0, + loss_normalizing_factor: Optional[float] = None +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Compute weighted cross entropy and entropy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + weights: None or array of shape [batch, length]. + label_smoothing: label smoothing constant, used to determine the on and off + values. + z_loss: coefficient for auxiliary z-loss loss term. + loss_normalizing_factor: Constant to divide loss by. If not specified, loss + will not be normalized. Intended for backward compatibility with T5-MTF + training. Should not normally be used. + + Returns: + Tuple of scalar loss, z_loss, and weight sum. + """ + if logits.ndim != targets.ndim + 1: + raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % + (str(logits.shape), str(targets.shape))) + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) + soft_targets = common_utils.onehot( + targets, vocab_size, on_value=confidence, off_value=low_confidence) + total_loss, total_z_loss = cross_entropy_with_logits( + logits, soft_targets, z_loss=z_loss) + total_loss = total_loss - normalizing_constant + + weight_sum = np.prod(targets.shape) + if weights is not None: + total_loss = total_loss * weights + total_z_loss = total_z_loss * weights + weight_sum = jnp.sum(weights) + + # By default, we do not normalize loss based on anything. + # We don't normalize based on batch size because the optimizers we use are + # pretty much scale invariant, so this simplifies things. + # We don't normalize based on number of non-padding tokens in order to treat + # each token as equally important regardless of sequence length. + if loss_normalizing_factor is not None: + total_loss /= loss_normalizing_factor + total_z_loss /= loss_normalizing_factor + return jnp.sum(total_loss), jnp.sum(total_z_loss), weight_sum + + +@enum.unique +class SpecialLossNormalizingFactor(enum.Enum): + """Specially calcualted loss_normalizing_factors, that are not a constant. + + Attributes: + NUM_REAL_TARGET_TOKENS: Whether to divide the loss by the number of real + (non-padding) tokens in the current target batch. If + 'decoder_loss_weights' are specified, it will be the sum of the weights. + Otherwise it will be the number of non-zero 'decoder_target_tokens'. + NUM_TOTAL_TARGET_TOKENS: Whether to divide the loss by the total number of + target tokens, i.e., batch_size * target_seq_length (including padding). + AVERAGE_PER_SEQUENCE: This will first compute the per-sequence loss + (averaged over the number of real target tokens in the sequence), and then + compute the average of that over the sequences. This can be preferable to + NUM_REAL_TARGET_TOKENS for finetuning, because it will weigh all examples + equally, regardless of sequence length (which can be especially important + for multi-task finetuning). + """ + NUM_REAL_TARGET_TOKENS = 1 + NUM_TOTAL_TARGET_TOKENS = 2 + AVERAGE_PER_SEQUENCE = 3 + + +def convert_special_loss_normalizing_factor_to_enum( + x: str) -> SpecialLossNormalizingFactor: + """Converts stringified version of LNF to an enum. + + This is useful because gin dynamic registration does not (currently) + have support for enum. + + Args: + x: stringified version of SpecialLossNormalizingFactor enum. + + Returns: + SpecialLossNormalizingFactor enum instance. + """ + x = x.upper() + if x == 'NUM_REAL_TARGET_TOKENS': + return SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS + if x == 'NUM_TOTAL_TARGET_TOKENS': + return SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS + if x == 'AVERAGE_PER_SEQUENCE': + return SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE + raise ValueError( + 'Could not convert string \"%s\" to SpecialLossNormalizingFactor' % x) + + +def get_loss_normalizing_factor_and_weights( + loss_normalizing_factor: Optional[Union[float, int, str, + SpecialLossNormalizingFactor]], + batch: Mapping[str, jnp.ndarray]): + """Get the float loss_normalizing_factor and loss weights. + + If loss_normalizing_factor is float or None, this will simply return the + input loss_normalizing_factor and batch. + + If loss_normalizing_factor is a SpecialLossNormalizingFactor, it will + return a float loss_normalizing_factor and loss weights corresponding to + the special LNF. See SpecialLossNormalizingFactor for more details. + + Args: + loss_normalizing_factor: The input LNF, which may be a float, None, or + SpecialLossNormalizingFactor (or a stringified SLNF). + batch: Input data batch. + + Returns: + Tuple of (output_loss_normalizing_factor, loss_weights). + 'output_loss_normalizing_factor' is a scalar float (Python float + or jnp float). + 'loss_weights' is the per token loss weight JNP array. + """ + + loss_weights = batch.get('decoder_loss_weights', None) + if (loss_normalizing_factor is None or + not isinstance(loss_normalizing_factor, + (str, SpecialLossNormalizingFactor))): + return (loss_normalizing_factor, loss_weights) + + if isinstance(loss_normalizing_factor, str): + loss_normalizing_factor = convert_special_loss_normalizing_factor_to_enum( + loss_normalizing_factor) + + # If `loss_weights` are not provided, we assume that the padding id is 0 and + # that non-padding tokens in the decoder all correspond to the positions + # where loss should be taken. If more fine-grained behavior (e.g., taking + # loss on subset of 'decoder_target_tokens') is desired, provide + # `loss_weights` that account for this. + if loss_weights is None: + loss_weights = jnp.asarray(batch['decoder_target_tokens'] > 0, jnp.float32) + + output_normalizing_factor = None + if (loss_normalizing_factor == + SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS): + output_normalizing_factor = jnp.sum(loss_weights) + elif (loss_normalizing_factor == + SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS): + output_normalizing_factor = np.prod(batch['decoder_target_tokens'].shape) + elif (loss_normalizing_factor == + SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE): + loss_weights /= jnp.sum(loss_weights, axis=-1, keepdims=True) + 1e-3 + output_normalizing_factor = jnp.sum(loss_weights) + else: + raise ValueError('Unsupported value of loss_normalizing_factor: %s' % + str(loss_normalizing_factor)) + + return (output_normalizing_factor, loss_weights) diff --git a/t5x/losses_test.py b/t5x/losses_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f7287cbbee5f2c74133a1acb51112b32a015c354 --- /dev/null +++ b/t5x/losses_test.py @@ -0,0 +1,136 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.losses.""" + +from absl.testing import absltest +import jax +import jax.numpy as jnp +import numpy as np +from t5x import losses + + +class LossTest(absltest.TestCase): + + def test_xent(self): + + def lossfn(logits, targets, weights): + loss, z_loss, weight_sum = losses.compute_weighted_cross_entropy( + logits, + targets, + weights, + label_smoothing=0.1, + z_loss=0.1, + loss_normalizing_factor=0.1) + return loss, (z_loss, weight_sum) + + batch_size = 2 + length = 4 + vocab_size = 8 + logits = np.random.normal(size=(batch_size, length, + vocab_size)).astype(np.float32) + targets = np.random.randint(0, vocab_size, size=(batch_size, length)) + weights = np.ones_like(targets) + out = jax.jit(jax.value_and_grad(lossfn, has_aux=True))(logits, targets, + weights) + (loss, (z_loss, weight_sum)), dlogits = out + # Just a smoke test for now + # TODO(t5x): Expand test + print(jax.device_get(((loss, (z_loss, weight_sum)), dlogits))) + + +class SpecialLossNormalizingFactorTest(absltest.TestCase): + + def test_num_real_target_tokens(self): + batch = { + 'decoder_target_tokens': + jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32) + } + + (output_lnf, + output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( + loss_normalizing_factor=losses.SpecialLossNormalizingFactor + .NUM_REAL_TARGET_TOKENS, + batch=batch) + + np.testing.assert_allclose(output_lnf, 6.0, rtol=1e-3) + np.testing.assert_allclose( + output_loss_weights, + np.array([[1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], + dtype=np.float32), + rtol=1e-3) + + def test_num_total_target_tokens(self): + batch = { + 'decoder_target_tokens': + jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32) + } + + (output_lnf, + output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( + loss_normalizing_factor=losses.SpecialLossNormalizingFactor + .NUM_TOTAL_TARGET_TOKENS, + batch=batch) + + np.testing.assert_allclose(output_lnf, 10.0, rtol=1e-3) + np.testing.assert_allclose( + output_loss_weights, + np.array([[1.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], + dtype=np.float32), + rtol=1e-3) + + def test_average_per_sequence(self): + batch = { + 'decoder_target_tokens': + jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32) + } + + (output_lnf, + output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( + loss_normalizing_factor=losses.SpecialLossNormalizingFactor + .AVERAGE_PER_SEQUENCE, + batch=batch) + + np.testing.assert_allclose(output_lnf, 2.0, rtol=1e-3) + np.testing.assert_allclose( + output_loss_weights, + jnp.asarray([[0.25, 0.25, 0.25, 0.25, 0.0], [0.5, 0.5, 0.0, 0.0, 0.0]], + jnp.float32), + rtol=1e-3) + + def test_average_per_sequence_with_weights(self): + batch = { + 'decoder_target_tokens': + jnp.asarray([[1, 2, 3, 4, 0], [5, 6, 0, 0, 0]], jnp.int32), + 'decoder_loss_weights': + jnp.asarray([[0.5, 1.0, 0.25, 2.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0]], + jnp.float32) + } + + (output_lnf, + output_loss_weights) = losses.get_loss_normalizing_factor_and_weights( + loss_normalizing_factor=losses.SpecialLossNormalizingFactor + .AVERAGE_PER_SEQUENCE, + batch=batch) + + np.testing.assert_allclose(output_lnf, 2.0, rtol=1e-3) + np.testing.assert_allclose( + output_loss_weights, + jnp.asarray( + [[0.1333, 0.2666, 0.0666, 0.5333, 0.0], [0.5, 0.5, 0.0, 0.0, 0.0]], + jnp.float32), + rtol=1e-3) + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/main.py b/t5x/main.py new file mode 100644 index 0000000000000000000000000000000000000000..47d8441297c58a5162e1eef51a013a88dbcf831e --- /dev/null +++ b/t5x/main.py @@ -0,0 +1,165 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""The main entrance for running any of the T5X supported binaries. + +Currently this includes train/infer/eval/precompile. + +Example Local (CPU) Pretrain Gin usage + +python -m t5x.main \ + --gin_file=t5x/examples/t5/t5_1_1/tiny.gin \ + --gin_file=t5x/configs/runs/pretrain.gin \ + --gin.MODEL_DIR=\"/tmp/t5x_pretrain\" \ + --gin.TRAIN_STEPS=10 \ + --gin.MIXTURE_OR_TASK_NAME=\"c4_v220_span_corruption\" \ + --gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \ + --gin.TASK_FEATURE_LENGTHS="{'inputs': 128, 'targets': 30}" \ + --gin.DROPOUT_RATE=0.1 \ + --run_mode=train \ + --logtostderr +""" +import concurrent.futures # pylint:disable=unused-import +import enum +import os +from typing import Optional, Sequence + +from absl import app +from absl import flags +from absl import logging + +import gin +import jax +import seqio + +from t5x import eval as eval_lib +from t5x import gin_utils +from t5x import infer as infer_lib +from t5x import precompile as precompile_lib +from t5x import train as train_lib +from t5x import utils + + +@enum.unique +class RunMode(enum.Enum): + """All the running mode possible in T5X.""" + TRAIN = 'train' + EVAL = 'eval' + INFER = 'infer' + PRECOMPILE = 'precompile' + + +_GIN_FILE = flags.DEFINE_multi_string( + 'gin_file', + default=None, + help='Path to gin configuration file. Multiple paths may be passed and ' + 'will be imported in the given order, with later configurations ' + 'overriding earlier ones.') + +_GIN_BINDINGS = flags.DEFINE_multi_string( + 'gin_bindings', default=[], help='Individual gin bindings.') + +_GIN_SEARCH_PATHS = flags.DEFINE_list( + 'gin_search_paths', + default=['.'], + help='Comma-separated list of gin config path prefixes to be prepended ' + 'to suffixes given via `--gin_file`. If a file appears in. Only the ' + 'first prefix that produces a valid path for each suffix will be ' + 'used.') + +_RUN_MODE = flags.DEFINE_enum_class( + 'run_mode', + default=None, + enum_class=RunMode, + help='The mode to run T5X under') + +_TFDS_DATA_DIR = flags.DEFINE_string( + 'tfds_data_dir', None, + 'If set, this directory will be used to store datasets prepared by ' + 'TensorFlow Datasets that are not available in the public TFDS GCS ' + 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' + 'all `Task`s.') + +_DRY_RUN = flags.DEFINE_bool( + 'dry_run', False, + 'If set, does not start the function but stil loads and logs the config.') + + +FLAGS = flags.FLAGS + +# Automatically search for gin files relative to the T5X package. +_DEFAULT_GIN_SEARCH_PATHS = [ + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +] + +train = train_lib.train +evaluate = eval_lib.evaluate +infer = infer_lib.infer +precompile = precompile_lib.precompile + +_FUNC_MAP = { + RunMode.TRAIN: train, + RunMode.EVAL: evaluate, + RunMode.INFER: infer, + RunMode.PRECOMPILE: precompile, +} + + +def main(argv: Sequence[str]): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + + if _TFDS_DATA_DIR.value is not None: + seqio.set_tfds_data_dir_override(_TFDS_DATA_DIR.value) + + + # Register function explicitly under __main__ module, to maintain backward + # compatability of existing '__main__' module references. + gin.register(_FUNC_MAP[_RUN_MODE.value], '__main__') + if _GIN_SEARCH_PATHS.value != ['.']: + logging.warning( + 'Using absolute paths for the gin files is strongly recommended.') + + # User-provided gin paths take precedence if relative paths conflict. + gin_utils.parse_gin_flags(_GIN_SEARCH_PATHS.value + _DEFAULT_GIN_SEARCH_PATHS, + _GIN_FILE.value, _GIN_BINDINGS.value) + + if _DRY_RUN.value: + return + + run_with_gin = gin.get_configurable(_FUNC_MAP[_RUN_MODE.value]) + + run_with_gin() + + + +def _flags_parser(args: Sequence[str]) -> Sequence[str]: + """Flag parser. + + See absl.app.parse_flags_with_usage and absl.app.main(..., flags_parser). + + Args: + args: All command line arguments. + + Returns: + [str], a non-empty list of remaining command line arguments after parsing + flags, including program name. + """ + return app.parse_flags_with_usage(list(gin_utils.rewrite_gin_args(args))) + + +if __name__ == '__main__': + jax.config.parse_flags_with_absl() + app.run(main, flags_parser=_flags_parser) diff --git a/t5x/metrics.py b/t5x/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..93d9a50f9f7d58d66b4a882b1219d3adb39392ac --- /dev/null +++ b/t5x/metrics.py @@ -0,0 +1,323 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5X Metrics. + +Defines Metric objects and collections used by T5X models. These objects use the +CLU metrics library +""" + +import dataclasses +from typing import MutableMapping, Optional, Union + +from clu import metrics as clu_metrics +import flax # Only used for flax.struct.dataclass. +import jax +from jax.experimental.global_device_array import GlobalDeviceArray +import jax.numpy as jnp +import numpy as np + +MetricsMap = MutableMapping[str, clu_metrics.Metric] +Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray] + + +def _check_param(value, *, ndim=None, dtype=jnp.float32): + """Raises a `ValueError` if `value` does not match ndim/dtype. + + Args: + value: Value to be tested. + ndim: Expected dimensions. + dtype: Expected dtype. + + Raises: + A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value` + is not an instance of `jnp.ndarray`. + """ + if ndim is not None and value.ndim != ndim: + raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}") + if dtype is not None and value.dtype != dtype: + raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}") + + +@flax.struct.dataclass +class Sum(clu_metrics.Metric): + """Computes the sum of a scalar or a batch of tensors. + + See also documentation of `Metric`. + """ + + total: Scalar + + @classmethod + def from_model_output(cls, values: Scalar, **_) -> clu_metrics.Metric: + """Initializes a Sum Metric from array (or singular) values. + + Args: + values: array of values to sum (or a single value). + + Returns: + A Sum object. + """ + values = jnp.asarray(values) + if values.ndim == 0: + values = values[None] + return cls(total=values.sum()) + + def merge(self, other: "Sum") -> "Sum": + return type(self)(total=self.total + other.total) + + def compute(self) -> Scalar: + return self.total + + +@flax.struct.dataclass +class Step(clu_metrics.Metric): + """Abstract class representing a per-step or step-per metric. + + Tracks number of steps. Must be set manually using replace_steps, since the + use of microbatches may otherwise cause the computation to be incorrect. + + See also documentation of `Metric`. + """ + steps: Optional[int] = 1 + + def replace_steps(self, steps) -> "Step": + return self.replace(steps=steps) + + def compute(self) -> Scalar: + if self.steps is None: + raise ValueError( + "`steps` must be set by calling `replace_steps` before computing metric." + ) + return self.steps + + +@flax.struct.dataclass +class AveragePerStep(Step): + """Represents per-step average (total divided by number of steps). + + See also documentation of `Step`. + """ + total: Optional[Scalar] = None + + @classmethod + def from_model_output(cls, + values: Scalar, + steps: Optional[int] = 1, + **_) -> clu_metrics.Metric: + """Initializes an AveragePerStep Metric from array (or singular) values. + + Args: + values: array of values to sum (or a single value). + steps: number of steps, defaults to 1. + + Returns: + AveragePerStep object. + """ + values = jnp.asarray(values) + if values.ndim == 0: + values = values[None] + return cls(total=values.sum(), steps=steps) + + def merge(self, other: "AveragePerStep") -> "AveragePerStep": + assert type(self) is type(other) + return type(self)( + total=self.total + other.total, steps=self.steps + other.steps) + + def compute(self) -> Scalar: + steps = super().compute() + if self.total is None: + raise ValueError("`AveragePerStep` `total` cannot be None.") + return self.total / steps + + +@flax.struct.dataclass +class Time(clu_metrics.Metric): + """Computes the sum of a float-valued metric over a period of time. + + Duration (the denominator) must be set manually. This is because JAX does not + properly support time functions inside compiled functions. Calling time.time() + inside a compiled function results in the stored time being the compilation + time, not the run time. + + See also documentation of `Metric`. + """ + duration: Optional[Scalar] = None + + def merge(self, other: "Time") -> "Time": + return self + + def compute(self) -> Scalar: + if self.duration is None: + raise ValueError( + "`Time` `duration` must be set by calling `replace_duration` before computing." + ) + return self.duration + + def replace_duration(self, duration: Scalar) -> "Time": + """Replaces duration with the given value. + + Should be used outside a compiled function to set the duration of the + metric. + + Args: + duration: metric duration + + Returns: + A new Time object. + """ + return self.replace(duration=duration) + + +@flax.struct.dataclass +class TimeRate(Time): + """Computes the sum of a float-valued metric over a period of time. + + Duration (the denominator) must be set using replace_duration. This is because + JAX does not properly support time functions inside compiled functions. + Calling time.time() inside a compiled function results in the stored time + being the compilation time, not the run time. + + See also documentation of `Time` and `Metric`. + """ + + numerator: Optional[jnp.ndarray] = None + + @classmethod + def from_model_output(cls, numerator: float, **_) -> clu_metrics.Metric: + """Initializes a TimeRate Metric from a float value (the numerator). + + Args: + numerator: a float (numerator of the metric) + + Returns: + A TimeRate object. + """ + return cls(numerator=numerator) + + def merge(self, other: "TimeRate") -> "TimeRate": + assert_msg = "Merging with non-None durations is currently not supported." + assert self.duration is None and other.duration is None, assert_msg + return type(self)(numerator=self.numerator + other.numerator) + + def compute(self) -> Scalar: + duration = super().compute() + return self.numerator / duration + + def replace_duration(self, duration: Scalar) -> "Time": + if not (isinstance(self.numerator, np.ndarray) or + isinstance(self.numerator, GlobalDeviceArray)): + raise ValueError( + "Expected numerator to be of type np.ndarray or GlobalDeviceArray " + "since method should be called outside of a compiled function. " + "Got ", type(self.numerator)) + return super().replace_duration(duration) + + +@flax.struct.dataclass +class StepsPerTime(Step, Time): + """Represents a metric computed as number of steps per time. + + See also documentation of `Step`. + """ + + @classmethod + def from_model_output(cls, + steps: Optional[int] = 1, + **_) -> clu_metrics.Metric: + """Initializes an StepsPerTime Metric. + + Args: + steps: number of steps, defaults to 1. + + Returns: + StepsPerTime object. + """ + return cls(steps=steps) + + def merge(self, other: "StepsPerTime") -> "StepsPerTime": + assert type(self) is type(other) + return type(self)(steps=self.steps + other.steps) + + def compute(self) -> Scalar: + steps = Step.compute(self) + duration = Time.compute(self) + return steps / duration + + +def is_metric_obj(obj): + return isinstance(obj, clu_metrics.Metric) + + +def is_time_metric(obj): + return isinstance(obj, Time) + + +def create_metrics_dict(float_metrics_dict): + """Input: dict{str: float} | Output: dict{str: Metric}.""" + return {k: Sum(v) for k, v in float_metrics_dict.items()} + + +def shape_obj_to_defined_obj(obj: clu_metrics.Metric): + """Converts shapes in Metric to zero arrays. + + obj should be a Metric object subclass where each member variable is a + ShapeDtypeStruct (from jax.eval_shape). A new object of the same class where + each member variable is an array of zeros with the same shape and type as + the corresponding variable defined by ShapeDtypeStruct. + + Args: + obj: a clu.metrics.Metric object where each member variable is a + ShapeDtypeStruct (from jax.eval_shape) + + Returns: + A Metric object with class variables initialized as zero arrays. + """ + + def class_attr_shape(a): + attr = getattr(obj, a.name) + if isinstance(attr, clu_metrics.Metric): + return shape_obj_to_defined_obj(attr) + else: + if hasattr(attr, "shape"): + return jnp.zeros(shape=attr.shape, dtype=attr.dtype) + else: + return attr + + return obj.__class__( + **{a.name: class_attr_shape(a) for a in dataclasses.fields(obj)}) + + +def set_time_metrics_duration(metrics, duration): + """Sets duration for TimeRate objects in metrics pytree.""" + + def fn(o): + if isinstance(o, Time): + return o.replace_duration(duration) + else: + return o + + return jax.tree_map(fn, metrics, is_leaf=lambda obj: isinstance(obj, Time)) + + +def set_step_metrics_num_steps(metrics, num_steps): + """Sets steps for Step objects in metrics pytree.""" + + def fn(o): + if isinstance(o, Step): + return o.replace_steps(num_steps) + else: + return o + + return jax.tree_map(fn, metrics, is_leaf=is_metric_obj) diff --git a/t5x/metrics_test.py b/t5x/metrics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd66c5a441650abfdde699e2c8a173fa8fec6cc --- /dev/null +++ b/t5x/metrics_test.py @@ -0,0 +1,96 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for clu.metrics.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np +from t5x import metrics + + +class MetricsTest(parameterized.TestCase): + + @parameterized.named_parameters( + ("0d_values", 2., 2.), ("1d_values", [1, 2, 3], 6.), + ("2d_values", [[1, 2], [2, 3], [3, 4]], 15.), + ("3d_values", [[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]], 27.) + ) + def test_sum(self, values, expected_result): + self.assertAlmostEqual( + metrics.Sum.from_model_output(values).compute(), expected_result) + + def test_time_rate(self): + value = np.array([3.]) + duration = 2. + metric = metrics.TimeRate.from_model_output(value).replace_duration( + duration) + self.assertAlmostEqual(metric.compute(), value / duration) + + def test_time_rate_unset_duration(self): + value = jnp.array([3.]) + metric = metrics.TimeRate.from_model_output(value) + with self.assertRaises(ValueError): + metric.compute() + + def test_time_rate_sets_duration_inside_jitted_fn(self): + + @jax.jit + def fn(): + value = jnp.array([3.]) + duration = 2. + metric = metrics.TimeRate.from_model_output(value).replace_duration( + duration) + return metric + + with self.assertRaises(ValueError): + fn() + + def test_time(self): + duration = 2. + metric = metrics.Time().replace_duration(duration) + self.assertAlmostEqual(metric.compute(), duration) + + def test_time_unset_duration(self): + metric = metrics.Time() + with self.assertRaises(ValueError): + metric.compute() + + @parameterized.named_parameters( + ("0d_values", 2., 2.), + ("1d_values", [1, 2, 3], 6.), + ) + def test_average_per_step(self, values, expected_result): + a = metrics.AveragePerStep.from_model_output(values) + m = metrics.set_step_metrics_num_steps({"a": a}, 1) + self.assertAlmostEqual(m["a"].compute(), expected_result) + + steps = 5 + b = metrics.AveragePerStep.from_model_output(values, steps=steps) + m = metrics.set_step_metrics_num_steps({"b": b}, steps) + self.assertAlmostEqual(m["b"].compute(), expected_result / steps) + + def test_steps_per_time(self): + steps = 8. + duration = 2. + metric = metrics.StepsPerTime.from_model_output( + steps=steps).replace_duration(duration) + metrics_dict = metrics.set_step_metrics_num_steps({"metric": metric}, steps) + self.assertAlmostEqual(metrics_dict["metric"].compute(), steps / duration) + + +if __name__ == "__main__": + absltest.main() diff --git a/t5x/models.py b/t5x/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ec78aee68c57612576b2282a41f8eddf3bd28d --- /dev/null +++ b/t5x/models.py @@ -0,0 +1,1178 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5X Models. + +This module uses layers.py to build a higher-level model structure and define +methods for the loss computation as well as a train, prediction, and evaluation +steps. +""" + +import abc +import functools +from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union + +import clu.metrics as clu_metrics +from flax import core as flax_core +from flax import linen as nn +from flax.core import scope as flax_scope +from flax.training import common_utils +import jax +import jax.numpy as jnp +import numpy as np +import seqio +from t5x import decoding +from t5x import losses +from t5x import metrics as metrics_lib +from t5x import optimizers +import tensorflow as tf +import typing_extensions + +Array = Union[np.ndarray, jnp.ndarray, jax.pxla.ShardedDeviceArray, tf.Tensor] +MetricsMap = metrics_lib.MetricsMap +PyTreeDef = type(jax.tree_structure(None)) + + +class TokensIdsToLogitsCallable(typing_extensions.Protocol): + """Token ids to logits mapping call signature.""" + + def __call__( + self, token_ids: jnp.ndarray, cache: Mapping[str, jnp.ndarray] + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Performs forward pass to convert token ids to logits. + + Args: + token_ids: [batch_size, 1] int32 tokens for single position used during + incremental decoding. Non-0 prefix tokens to be used as a forced prompt. + cache: flax attention cache. + + Returns: + a tuple of logits with a shape [batch_size, vocab_size] and an updated + cache. + """ + ... + + +class DecodeFnCallable(typing_extensions.Protocol): + """Decoding function call signature.""" + + def __call__(self, *, inputs: jnp.ndarray, cache: Mapping[str, jnp.ndarray], + tokens_to_logits: TokensIdsToLogitsCallable, eos_id: int, + num_decodes: int, decode_rng: Optional[jax.random.KeyArray], + cache_offset: int, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Decoding function interface. + + Args: + inputs: [batch_size, max_decode_len] int32 sequence of tokens, with non-0 + prefix tokens to be used as a forced prompt. + cache: flax attention cache. + tokens_to_logits: fast autoregressive decoder function taking single token + slices and cache and returning next-token logits and updated cache. + eos_id: end-of-sentence token for target vocabulary. + num_decodes: number of decoded sequences to be returned. + decode_rng: an optional JAX PRNG Key for stochastic sampling routines. + cache_offset: axis offset for cache, arising from scanned layers. + **kwargs: an optional kwargs. One common usecase of this is passing + decoding parameters at the callsite. + + Returns: + decodes: Array of sequences: [batch_size, num_decodes, max_decode_len]. + The `num_decodes` dimension is expected to be sorted by the `scores`, + i.e., `decodes[:, -1, :] has the highest scores among `num_decodes` + decoded sequences. + scores: Array of log likelihood scores: [batch_size, num_decodes] + """ + ... + + +class BaseModel(abc.ABC): + """Abstract base class for models. + + Wraps a flax module to provide a basic interface for computing loss, + evaluation metrics, prediction, and scoring. + + Subclasses must implement the abstract methods. Any additional arguments added + to these methods must have defaults or be bound at run time to fit the + interface expected by the standard training, inference, and evaluation + functions. + """ + + FEATURE_CONVERTER_CLS: Callable[..., seqio.FeatureConverter] + + def __init__(self, optimizer_def: optimizers.OptimizerDefType): + # TODO(jbulian): Move the optimizer out of the model and make it a training + # parameter. + self.optimizer_def = optimizer_def + + @abc.abstractmethod + def loss_fn( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.random.KeyArray], + ) -> Tuple[jnp.ndarray, MetricsMap]: + """Computes loss and metrics. + + Args: + params: model parameters. + batch: a batch of inputs. + dropout_rng: rng to use for dropout, or None for deterministic mode. + + Returns: + loss: the loss computed for the given inputs and parameters. + aux: + weight_sum: sum of the per-token weights applied to the loss. + metrics: a mapping of metrics computed for this batch. + """ + pass + + def eval_fn( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + ) -> Tuple[jnp.ndarray, MetricsMap]: + """Computes loss and metrics during the evaluation. + + Args: + params: model parameters. + batch: a batch of inputs. + + Returns: + loss: the loss computed for the given inputs and parameters. + aux: + weight_sum: sum of the per-token weights applied to the loss. + metrics: a mapping of metrics computed for this batch. + """ + return self.loss_fn( + params=params, + batch=batch, + dropout_rng=None, + ) + + def predict_batch(self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rng: Optional[jax.random.KeyArray] = None) -> jnp.ndarray: + """Predicts a batch of outputs from the model. + + Args: + params: model parameters. + batch: a batch of inputs. + rng: an optional RNG to use during prediction (e.g., for decoding). + + Returns: + The model predictions. + """ + return self.predict_batch_with_aux(params=params, batch=batch, rng=rng)[0] + + @abc.abstractmethod + def predict_batch_with_aux( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rng: Optional[jax.random.KeyArray] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Predict a batch from the modelwith auxiliary outputs. + + Args: + params: model parameters. + batch: a batch of inputs. + rng: an optional RNG key to use during prediction (e.g., for decoding). + + Returns: + predictions: the model predictions + aux: auxiliary data + """ + pass + + @abc.abstractmethod + def score_batch(self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + return_intermediates: bool = False) -> jnp.ndarray: + """Computes scores for batch.""" + pass + + @abc.abstractmethod + def get_initial_variables( + self, + rng: jax.random.KeyArray, + input_shapes: Mapping[str, Array], + input_types: Optional[Mapping[str, jnp.dtype]] = None + ) -> flax_scope.FrozenVariableDict: + """Returns the initial variables of the model.""" + pass + + +class BaseTransformerModel(BaseModel): + """Abstract base class for Transformer models. + + Subclasses must implement `predict_batch_with_aux`, `score_batch`, + `get_initial_variables` from `BaseModel` as well as `_compute_logits`. + """ + + def __init__( + self, + module: nn.Module, + input_vocabulary: seqio.Vocabulary, + output_vocabulary: seqio.Vocabulary, + optimizer_def: optimizers.OptimizerDefType, + decode_fn: Optional[DecodeFnCallable] = None, + label_smoothing: float = 0.0, + z_loss: float = 0.0, + loss_normalizing_factor: Optional[Union[ + float, int, str, losses.SpecialLossNormalizingFactor]] = None, + ): + self.module = module + self._input_vocabulary = input_vocabulary + self._output_vocabulary = output_vocabulary + self._decode_fn = decode_fn + self._label_smoothing = label_smoothing + self._z_loss = z_loss + self._loss_normalizing_factor = loss_normalizing_factor + + super().__init__(optimizer_def=optimizer_def) + + @property + def input_vocabulary(self): + return self._input_vocabulary + + @property + def output_vocabulary(self): + return self._output_vocabulary + + @property + def decode_fn(self): + return self._decode_fn + + @abc.abstractmethod + def _compute_logits( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.random.KeyArray] = None) -> jnp.ndarray: + """Computes logits via a forward pass of the model.""" + pass + + def loss_fn( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.random.KeyArray], + ) -> Tuple[jnp.ndarray, MetricsMap]: + """Loss function used for training with a cross-entropy loss.""" + logits = self._compute_logits(params, batch, dropout_rng) + + loss_normalizing_factor: Optional[Union[ + float, int, str, losses.SpecialLossNormalizingFactor]] + (loss_normalizing_factor, + weights) = losses.get_loss_normalizing_factor_and_weights( + self._loss_normalizing_factor, batch) + + loss, z_loss, _ = losses.compute_weighted_cross_entropy( + logits, + targets=batch['decoder_target_tokens'], + weights=weights, + label_smoothing=self._label_smoothing, + z_loss=self._z_loss, + loss_normalizing_factor=loss_normalizing_factor) + metrics = self._compute_metrics( + logits=logits, + targets=batch['decoder_target_tokens'], + mask=weights, + loss=loss, + z_loss=z_loss) + return loss, metrics + + def _compute_metrics( + self, + logits: jnp.ndarray, + targets: jnp.ndarray, + mask: jnp.ndarray, + loss: jnp.ndarray, + z_loss: Optional[jnp.ndarray] = None, + ) -> MetricsMap: + return compute_base_metrics( + logits=logits, targets=targets, mask=mask, loss=loss, z_loss=z_loss) + + +class EncoderDecoderModel(BaseTransformerModel): + """Wrapper class for the models.Transformer nn.module.""" + + FEATURE_CONVERTER_CLS = seqio.EncDecFeatureConverter + + def __init__( + self, + module: nn.Module, + input_vocabulary: seqio.Vocabulary, + output_vocabulary: seqio.Vocabulary, + optimizer_def: optimizers.OptimizerDefType, + decode_fn: DecodeFnCallable = decoding.beam_search, + feature_converter_cls: Optional[Callable[..., + seqio.FeatureConverter]] = None, + label_smoothing: float = 0.0, + z_loss: float = 0.0, + loss_normalizing_factor: Optional[float] = None, + ): + if feature_converter_cls is not None: + self.FEATURE_CONVERTER_CLS = feature_converter_cls # pylint: disable=invalid-name + super().__init__( + module=module, + input_vocabulary=input_vocabulary, + output_vocabulary=output_vocabulary, + optimizer_def=optimizer_def, + decode_fn=decode_fn, + label_smoothing=label_smoothing, + z_loss=z_loss, + loss_normalizing_factor=loss_normalizing_factor, + ) + + def get_initial_variables( + self, + rng: jax.random.KeyArray, + input_shapes: Mapping[str, Array], + input_types: Optional[Mapping[str, jnp.dtype]] = None + ) -> flax_scope.FrozenVariableDict: + """Get the initial variables for an encoder-decoder model.""" + input_types = {} if input_types is None else input_types + encoder_shape = input_shapes['encoder_input_tokens'] + encoder_type = input_types.get('encoder_input_tokens', jnp.float32) + decoder_shape = input_shapes['decoder_input_tokens'] + decoder_type = input_types.get('decoder_input_tokens', jnp.float32) + if 'encoder_positions' in input_shapes: + encoder_positions = jnp.ones( + input_shapes['encoder_positions'], + input_types.get('encoder_positions', jnp.int32)) + else: + encoder_positions = None + if 'decoder_positions' in input_shapes: + decoder_positions = jnp.ones( + input_shapes['decoder_positions'], + input_types.get('decoder_positions', jnp.int32)) + else: + decoder_positions = None + if 'encoder_segment_ids' in input_shapes: + encoder_segment_ids = jnp.ones( + input_shapes['encoder_segment_ids'], + input_types.get('encoder_segment_ids', jnp.int32)) + else: + encoder_segment_ids = None + if 'decoder_segment_ids' in input_shapes: + decoder_segment_ids = jnp.ones( + input_shapes['decoder_segment_ids'], + input_types.get('decoder_segment_ids', jnp.int32)) + else: + decoder_segment_ids = None + initial_variables = self.module.init( + rng, + jnp.ones(encoder_shape, encoder_type), + jnp.ones(decoder_shape, decoder_type), + jnp.ones(decoder_shape, decoder_type), + encoder_positions=encoder_positions, + decoder_positions=decoder_positions, + encoder_segment_ids=encoder_segment_ids, + decoder_segment_ids=decoder_segment_ids, + decode=False, + enable_dropout=False) + return initial_variables + + def _compute_logits( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """Computes logits via a forward pass of `self.module_cls`.""" + # Dropout is provided only for the training mode. + rngs = {'dropout': dropout_rng} if dropout_rng is not None else None + if other_variables is None: + other_variables = {} + return self.module.apply( + { + 'params': params, + **other_variables + }, + batch['encoder_input_tokens'], + batch['decoder_input_tokens'], + batch['decoder_target_tokens'], + encoder_segment_ids=batch.get('encoder_segment_ids', None), + decoder_segment_ids=batch.get('decoder_segment_ids', None), + encoder_positions=batch.get('encoder_positions', None), + decoder_positions=batch.get('decoder_positions', None), + decode=False, + enable_dropout=rngs is not None, + rngs=rngs, + mutable=mutable) + + def _compute_logits_from_slice( + self, flat_ids: jnp.ndarray, flat_cache: Mapping[str, jnp.ndarray], + params: PyTreeDef, encoded_inputs: jnp.ndarray, raw_inputs: jnp.ndarray, + max_decode_length: int) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Token slice to logits from decoder model.""" + # flat_ids: [batch * beam, seq_len=1] + # cache is expanded inside beam_search to become flat_cache + # flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len] + # flat_logits: [batch * beam, seq_len=1, vocab] + flat_logits, new_vars = self.module.apply( + { + 'params': params, + 'cache': flat_cache + }, + encoded_inputs, + raw_inputs, # only needed for encoder padding mask + flat_ids, + flat_ids, + enable_dropout=False, + decode=True, + max_decode_length=max_decode_length, + mutable=['cache'], + method=self.module.decode) + # Remove sequence length dimension since it's always 1 during decoding. + flat_logits = jnp.squeeze(flat_logits, axis=1) + new_flat_cache = new_vars['cache'] + return flat_logits, new_flat_cache + + def predict_batch_with_aux( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rng: Optional[jax.random.KeyArray] = None, + decoder_params: Optional[MutableMapping[str, Any]] = None, + return_all_decodes: bool = False, + num_decodes: int = 1, + prompt_with_targets: bool = False + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Predict with fast decoding beam search on a batch. + + Here we refer to "parameters" for values that can be compiled into the + model dynamically, as opposed to static configuration settings that require + a recompile. For example, the model weights and the decoder brevity-penalty + are parameters and can be modified without requiring a recompile. The number + of layers, the batch size and the decoder beam size are configuration + options that require recompilation if changed. + + This method can be used with a customizable decoding function as long as it + follows the signature of `DecodeFnCallable`. In order to provide a unified + interface for the decoding functions, we use a generic names. For example, a + beam size is a concept unique to beam search. Conceptually, it corresponds + to the number of sequences returned by the beam search. Therefore, the + generic argument `num_decodes` corresponds to the beam size if + `self._decode_fn` is a beam search. For temperature sampling, `num_decodes` + corresponds to the number of independent sequences to be sampled. Typically + `num_decodes = 1` is used for temperature sampling. + + If `return_all_decodes = True`, the return tuple contains the predictions + with a shape [batch, num_decodes, max_decode_len] and the scores (i.e., log + probability of the generated sequence) with a shape [batch, num_decodes]. + + If `return_all_decodes = False`, the return tuple contains the predictions + with a shape [batch, max_decode_len] and the scores with a shape [batch]. + + `decoder_params` can be used to pass dynamic configurations to + `self.decode_fn`. An example usage is to pass different random seed (i.e., + `jax.random.PRNGKey(seed)` with different `seed` value). This can be done by + setting `decoder_params['decode_rng'] = jax.random.PRNGKey(seed)`. + + If `prompt_with_targets = True`, then `decoder_prompt_inputs` is initialized + from the batch's `decoder_input_tokens`. The EOS is stripped to avoid + decoding to stop after the prompt by matching to `output_vocabulary.eos_id`. + + Args: + params: model parameters. + batch: a batch of inputs. + rng: an optional RNG key to use during prediction, which is passed as + 'decode_rng' to the decoding function. + decoder_params: additional (model-independent) parameters for the decoder. + return_all_decodes: whether to return the entire beam or just the top-1. + num_decodes: the number of beams to use in beam search. + prompt_with_targets: Whether the force decode decoder_inputs. + + Returns: + A tuple containing: + the batch of predictions, with the entire beam if requested + an auxiliary dictionary of decoder scores + """ + # Prepare zeroed-out autoregressive cache. + # [batch, input_len] + inputs = batch['encoder_input_tokens'] + # [batch, target_len] + target_shape = batch['decoder_input_tokens'].shape + target_type = batch['decoder_input_tokens'].dtype + _, variables_with_cache = self.module.apply( + {'params': params}, + jnp.ones(inputs.shape, inputs.dtype), + jnp.ones(target_shape, target_type), + jnp.ones(target_shape, target_type), + decode=True, + enable_dropout=False, + mutable=['cache']) + + cache = variables_with_cache['cache'] + + # Prepare transformer fast-decoder call for beam search: for beam search, we + # need to set up our decoder model to handle a batch size equal to + # batch_size * num_decodes, where each batch item's data is expanded + # in-place rather than tiled. + # i.e. if we denote each batch element subtensor as el[n]: + # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] + # [batch * num_decodes, input_len, emb_dim] + encoded_inputs = decoding.flat_batch_beam_expand( + self.module.apply({'params': params}, + inputs, + enable_dropout=False, + method=self.module.encode), num_decodes) + + # [batch * num_decodes, input_len] + raw_inputs = decoding.flat_batch_beam_expand(inputs, num_decodes) + + tokens_ids_to_logits = functools.partial( + self._compute_logits_from_slice, + params=params, + encoded_inputs=encoded_inputs, + raw_inputs=raw_inputs, + max_decode_length=target_shape[1]) + + if decoder_params is None: + decoder_params = {} + if rng is not None: + if decoder_params.get('decode_rng') is not None: + raise ValueError( + f'Got RNG both from the `rng` argument ({rng}) and ' + f"`decoder_params['decode_rng']` ({decoder_params['decode_rng']}). " + 'Please specify one or the other.') + decoder_params['decode_rng'] = rng + + # `decoder_prompt_inputs` is initialized from the batch's + # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop + # after the prompt by matching to `output_vocabulary.eos_id`. + # These inputs are ignored by the beam search decode fn. + if prompt_with_targets: + decoder_prompt_inputs = batch['decoder_input_tokens'] + decoder_prompt_inputs = decoder_prompt_inputs * ( + decoder_prompt_inputs != self.output_vocabulary.eos_id) + else: + decoder_prompt_inputs = jnp.zeros_like(batch['decoder_input_tokens']) + + # TODO(hwchung): rename the returned value names to more generic ones. + # Using the above-defined single-step decoder function, run a + # beam search over possible sequences given input encoding. + # decodes: [batch, num_decodes, max_decode_len + 1] + # scores: [batch, num_decodes] + scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers + decodes, scores = self._decode_fn( + inputs=decoder_prompt_inputs, + cache=cache, + tokens_to_logits=tokens_ids_to_logits, + eos_id=self.output_vocabulary.eos_id, + num_decodes=num_decodes, + cache_offset=1 if scanned else 0, + **decoder_params) + + # Beam search returns [n_batch, n_beam, n_length] with beam dimension sorted + # in increasing order of log-probability. + # Return the highest scoring beam sequence. + if return_all_decodes: + return decodes, {'scores': scores} + else: + return decodes[:, -1, :], {'scores': scores[:, -1]} + + def score_batch( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + return_intermediates: bool = False, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]: + """Compute log likelihood score on a batch.""" + weights = batch['decoder_loss_weights'] + target_tokens = batch['decoder_target_tokens'] + + if return_intermediates: + logits, modified_variables = self._compute_logits( + params=params, batch=batch, mutable=['intermediates']) + + # Inside self.module, we called nn.Module.sow to track various + # intermediate values. We extract them here. + intermediates = flax_core.unfreeze( + modified_variables.get('intermediates', {})) + + # Track per-token labels and loss weights as well. These are not + # intermediate values of logit computation, so we manually add them here. + intermediates.setdefault('decoder', {}) + intermediates['decoder']['target_tokens'] = (target_tokens,) + intermediates['decoder']['loss_weights'] = (weights,) + # Note that the values are singleton tuples. This is because values inside + # `intermediates` should be tuples tracking all instantiations of a value. + # These values each have just one instantiation, hence singletons. + else: + logits = self._compute_logits(params, batch) # type: jnp.ndarray + + # Purposefully don't use config.z_loss because that term is for training + # stability and shouldn't affect our reported scores. + token_scores = -losses.cross_entropy_with_logits( + logits, + common_utils.onehot( + target_tokens, logits.shape[-1], on_value=1, off_value=0), + z_loss=0.0)[0] * weights + + sequence_scores = token_scores.sum(-1) + + if return_intermediates: + return sequence_scores, intermediates + + return sequence_scores + + +class DecoderOnlyModel(BaseTransformerModel): + """Model class for the decoder-only modules. + + It accepts inputs made out of only 'targets' or both 'inputs' + and 'targets'. If both 'inputs' and 'targets' are present, the loss will + be computed only on 'targets'. + + By default the self-attention is fully causal and a given position only + attends to the time steps before and itself. If + `inputs_bidirectional_attention = True`, the attention in the "inputs" region + is bidirectional. This architecture was referred to as "Prefix LM" in Raffel + et al. 2019 (https://arxiv.org/abs/1910.10683). + """ + + FEATURE_CONVERTER_CLS = seqio.DecoderFeatureConverter + + def __init__( + self, + module: nn.Module, + vocabulary: seqio.Vocabulary, + optimizer_def: optimizers.OptimizerDefType, + decode_fn: DecodeFnCallable = decoding.temperature_sample, + inputs_bidirectional_attention: bool = False, + feature_converter_cls: Optional[Callable[..., + seqio.FeatureConverter]] = None, + label_smoothing: float = 0.0, + z_loss: float = 0.0, + loss_normalizing_factor: Optional[float] = None, + ): + if feature_converter_cls is not None: + self.FEATURE_CONVERTER_CLS = feature_converter_cls # pylint: disable=invalid-name + self._inputs_bidirectional_attention = inputs_bidirectional_attention + super().__init__( + module, + input_vocabulary=vocabulary, + output_vocabulary=vocabulary, + optimizer_def=optimizer_def, + decode_fn=decode_fn, + label_smoothing=label_smoothing, + z_loss=z_loss, + loss_normalizing_factor=loss_normalizing_factor, + ) + + def get_initial_variables( + self, + rng: jax.random.KeyArray, + input_shapes: Mapping[str, Array], + input_types: Optional[Mapping[str, jnp.dtype]] = None + ) -> flax_scope.FrozenVariableDict: + """Get the initial variables.""" + input_types = {} if input_types is None else input_types + decoder_shape = input_shapes['decoder_input_tokens'] + decoder_type = input_types.get('decoder_input_tokens', jnp.float32) + initial_variables = self.module.init( + rng, + jnp.ones(decoder_shape, decoder_type), + jnp.ones(decoder_shape, decoder_type), + enable_dropout=False) + return initial_variables + + def _get_decoder_causal_attention(self, batch): + """Returns decoder causal attention from the batch or None.""" + if self._inputs_bidirectional_attention: + if 'decoder_causal_attention' not in batch: + raise ValueError('`inputs_bidirectional_attention` mode requires ' + '"decoder_causal_attention" feature in the batch') + decoder_causal_attention = batch['decoder_causal_attention'] + else: + decoder_causal_attention = None + + return decoder_causal_attention + + def _compute_logits( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False) -> jnp.ndarray: + """Computes logits via a forward pass of `self.module`.""" + rngs = {'dropout': dropout_rng} if dropout_rng is not None else None + decoder_causal_attention = self._get_decoder_causal_attention(batch) + + return self.module.apply( + {'params': params}, + batch['decoder_input_tokens'], + batch['decoder_target_tokens'], + decoder_segment_ids=batch.get('decoder_segment_ids', None), + decoder_positions=batch.get('decoder_positions', None), + decoder_causal_attention=decoder_causal_attention, + rngs=rngs, + decode=False, + enable_dropout=rngs is not None, + mutable=mutable) + + def _compute_logits_from_slice( + self, + flat_ids: jnp.ndarray, + flat_cache: Mapping[str, jnp.ndarray], + params: PyTreeDef, + max_decode_length: int, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Token slice to logits from decoder model.""" + # flat_ids: [batch, seq_len=1] + # flat_cache['cached_(keys|values)']: + # [batch, num_heads, depth_per_head, max_decode_length] + # flat_cache['cache_index']: [batch] + # flat_logits: [batch, seq_len=1, vocab] + flat_logits, new_vars = self.module.apply( + { + 'params': params, + 'cache': flat_cache + }, + flat_ids, + flat_ids, + enable_dropout=False, + decode=True, + max_decode_length=max_decode_length, + mutable=['cache']) + # Remove sequence length dimension since it's always 1 during decoding. + flat_logits = jnp.squeeze(flat_logits, axis=1) + new_flat_cache = new_vars['cache'] + return flat_logits, new_flat_cache + + def score_batch(self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + return_intermediates: bool = False) -> jnp.ndarray: + """Compute log likelihood score on a batch.""" + + decoder_target_tokens = batch['decoder_target_tokens'] + weights = batch['decoder_loss_weights'] + + if return_intermediates: + logits, modified_variables = self._compute_logits( + params=params, + batch=batch, + dropout_rng=None, + mutable=['intermediates']) + + # Inside self.module, we called nn.Module.sow to track various + # intermediate values. We extract them here. + intermediates = flax_core.unfreeze( + modified_variables.get('intermediates', {})) + + # Track per-token labels and loss weights as well. These are not + # intermediate values of logit computation, so we manually add them here. + intermediates.setdefault('decoder', {}) + intermediates['decoder']['target_tokens'] = (decoder_target_tokens,) + intermediates['decoder']['loss_weights'] = (weights,) + # Note that the values are singleton tuples. This is because values inside + # `intermediates` should be tuples tracking all instantiations of a value. + # These values each have just one instantiation, hence singletons. + else: + logits = self._compute_logits( + params=params, batch=batch, dropout_rng=None) + + token_scores = -losses.cross_entropy_with_logits( + logits, + common_utils.onehot( + decoder_target_tokens, logits.shape[-1], on_value=1, off_value=0), + z_loss=0.0)[0] * weights + sequence_scores = token_scores.sum(-1) + + if return_intermediates: + return sequence_scores, intermediates + + return sequence_scores + + def _compute_kv_cache( + self, + params: PyTreeDef, + inputs: jnp.ndarray, + inputs_lengths: jnp.ndarray, + decoder_causal_attention: jnp.ndarray, + ) -> PyTreeDef: + """Compute the key/value cache on the input prefix.""" + _, variables_with_cache = self.module.apply({'params': params}, + jnp.ones_like(inputs), + jnp.ones_like(inputs), + enable_dropout=False, + decode=True, + mutable=['cache']) + cache = variables_with_cache['cache'] + + # Prefill our cache with all the inputs. `inputs_lengths` is the index of + # the last input token. The cache will be filled for all the input + # positions, save the last input token. The cache index will point to the + # index of this last input token which is considered during prefilling but + # not cached. This re-computation is required as the logits for this + # position are required for selecting the first output token. + # + # The cache is still `[B, ..., max_decode_len]` but any position less than + # the `inputs_length` will be non-zero, that is + # `cached_key[b, ..., i < inputs_lengths[b]] != 0`. + # + # The cache index is now a vector of size [B] = input_lengths + + # If `self._inputs_bidirectional_attention = False`, we should not pass + # batch['decoder_causal_attention'] to `module.apply` during cache prefill + # and pass None instead. + maybe_decoder_causal_attention = self._get_decoder_causal_attention( + {'decoder_causal_attention': decoder_causal_attention}) + + _, variables_with_cache = self.module.apply( + { + 'params': params, + 'cache': cache + }, + decoder_input_tokens=inputs, + # Use the `decoder_causal_attention`, which has 1 for all input + # positions, including the BOS token, as the targets so when the + # decoder attention mask is built, it will correctly cover the whole + # input, Using something like the inputs will cause the first input + # token (the 0 for BOS) will not be included in the mask. This also + # restricts the mask to not include any target positions like it would + # if you used `decoder_target_tokens`. + decoder_target_tokens=decoder_causal_attention, + decoder_causal_attention=maybe_decoder_causal_attention, + mutable=['cache'], + enable_dropout=False, + prefill=True, + prefill_lengths=inputs_lengths) + return variables_with_cache['cache'] + + def predict_batch_with_aux( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rng: Optional[jax.random.KeyArray] = None, + *, + return_all_decodes: bool = False, + num_decodes: int = 1, + decoder_params: Optional[MutableMapping[str, Any]] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Predict with prefix. + + `decoder_params` can be used to pass dynamic configurations to + `self.decode_fn`. An example usage is to pass different random seed (i.e., + `jax.random.PRNGKey(seed)` with different `seed` value). This can be done by + setting `decoder_params['decode_rng'] = jax.random.PRNGKey(seed)`. + + Although this method is short, there are a few subtle points that. We use a + running example to make these points clear. + + ``` + Example + inputs = [9, 4, 6, 1] + targets = [3, 9, 1] + + seqio.DecoderFeatureConverter will generate these set of features + + decoder_target_tokens = [9, 4, 6, 1, 3, 9, 1, 0, 0] + decoder_input_tokens = [0, 9, 4, 6, 1, 3, 9, 1, 0] + decoder_causal_attention = [1, 1, 1, 1, 1, 0, 0, 0, 0] + + The output of this function is (a` through `e` are the sampled token ids): + + sampled_sequences = [9, 4, 6, 1, a, b, c, d, e]. + ``` + + Given these set of features, we make a few important observation. + + 1) When a decoder-only model is used for a supervised learning with "inputs" + and "targets", one way to handle this is to concatenate the "inputs" and + "targets". For training, we use teacher forcing for the entire + concatenated sequence. For inference, on the other hand, we don't have + the targets. This requires that we use teacher forcing on the "inputs" + portion while using the generated token as the input token for the next + decoding step. For evaluation, we do have "targets" but we only want to + use them for computing metrics, i.e., by comparing to the sequence + generated by the model. + + This function is currently used for evaluation mode, but by ignoring + "targets", it can be extended for the inference mode. + + 2) During evaluation mode, the targets portion is zeroed out and they are + filled with the sampled token ids. The inputs portion is kept intact. + + 3) Note that `decoder_causal_attention` has an additional 1 after the final + "inputs" token. This is because the position where the last "inputs" + token (in this case 1) is input and the output is the first "target" + token (in this case 3) can be included in the non-causal attention + region. + + This results in an alignment between `decoder_input_tokens` and + `decoder_causal_attention` because the former is shifted to the right by + one position. So we use `decoder_causal_attention` as a binary mask to + zero out the target tokens in `decoder_input_tokens`. + + Note: + In order to use a custom self._decode_fn with this model it must support: + + 1) Decoding from a partially decoded state by accepting a vector of + `initial_indices` that specify where in the input to start decoding + from. + 2) Using a vector as the loop counter to support different examples being + a different number of steps into their decoding loop. + 3) Be able to handle one batch element reaching `max_decode_length` + before the others without it causing the model to prematurely stop + decoding. + + Args: + params: model parameters. + batch: batch element with the model features specified in + seqio.DecoderFeatureConverter. + rng: an optional RNG key to use during prediction, which is passed as + 'decode_rng' to the decoding function. + return_all_decodes: if True, will return all batch_size * num_decodes + samples from the model as an array of shape [batch_size, num_decodes, + sequence_length]. Otherwise returns only the most likely samples as an + array of shape [batch_size, sequence_length]. + num_decodes: number of decoded sequences to be returned. + decoder_params: additional (model-independent) parameters for the decoder. + + Returns: + sampled_sequences: an array of shape [batch, max_decode_length]. + """ + if 'decoder_causal_attention' not in batch: + raise ValueError( + 'Batch does not have the right format for text generation: probably ' + 'because `task_feature_lengths` passed to the feature converter does ' + 'not have both `inputs` and `targets`.') + # We can use the decoder causal attention mask to tell how long the inputs + # are. The causal mask has a 1 for all the input tokens (and one more to + # cover the original BOS token, created by shifting the inputs one to the + # right) so we need to delete one. + inputs_lengths = jnp.sum(batch['decoder_causal_attention'], axis=1) - 1 + + # since decoder_input_tokens is shifted to the right and + # `decoder_causal_attention` has one more 1 than the number of inputs + # tokens, this masks out targets portion of the decoder_input_tokens. + inputs = batch['decoder_input_tokens'] * batch['decoder_causal_attention'] + + prefilled_cache = self._compute_kv_cache(params, inputs, inputs_lengths, + batch['decoder_causal_attention']) + + target_shape = batch['decoder_input_tokens'].shape + max_decode_length = target_shape[1] + + tokens_ids_to_logits = functools.partial( + self._compute_logits_from_slice, + params=params, + max_decode_length=max_decode_length) + + if decoder_params is None: + decoder_params = {} + if rng is not None: + if decoder_params.get('decode_rng') is not None: + raise ValueError( + f'Got RNG both from the `rng` argument ({rng}) and ' + f"`decoder_params['decode_rng']` ({decoder_params['decode_rng']}). " + 'Please specify one or the other.') + decoder_params['decode_rng'] = rng + + # Using the above-defined single-step decoder function, run temperature + # sampling with the prefix. + # [batch, max_decode_length] + scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers + decoded_sequences, scores = self._decode_fn( + inputs=inputs, + cache=prefilled_cache, + tokens_to_logits=tokens_ids_to_logits, + eos_id=self.output_vocabulary.eos_id, + num_decodes=num_decodes, + initial_index=inputs_lengths, + cache_offset=1 if scanned else 0, + **decoder_params) + + if not return_all_decodes: + # Search returns [n_batch, n_beam/decodes, n_length] with the beam/decode + # dimension sorted in increasing order of log-probability. + # `scores` is [batch, beam/decode_size] + # We take the highest scoring sequence (-1) and its score + decoded_sequences = decoded_sequences[:, -1, :] + # Beam search returns [] + aux = {'scores': scores[:, -1]} + else: + # We return all samples and scores, rather than just the top ones. + aux = {'scores': scores} + + return remove_prefix(decoded_sequences, inputs_lengths), aux + + +@jax.vmap +def remove_prefix(sequence: jnp.ndarray, + prefix_length: jnp.ndarray) -> jnp.ndarray: + """Remove the prefix portion and shift to the left by the prefix length. + + The example below uses non-decorated function definition, i.e., arrays do not + have batch dimension. `jax.vmap` internally inserts the batch dimension at + axis=0. The shape annotations do not include the batch dimension either. + + Example: + ```python + sequence = [1, 2, 3, 4, 5, 6, 7, 0] + prefix_length = 2 + remove_prefix(sequence, prefix_length) = [3, 4, 5, 6, 7, 0, 0, 0] + ``` + + Note that this function assumes that the padding token has an id of 0. + + Args: + sequence: [length] array. + prefix_length: scalar, i.e., rank 0 array. + + Returns: + [length] array with the prefix removed and the suffix shifted. + """ + length = sequence.shape[-1] + # A binary mask with 1 at inputs. + inputs_mask = (jnp.arange(length) < prefix_length) + # A binary mask with 1 at the targets and padding positions. + targets_and_padding_mask = jnp.logical_not(inputs_mask).astype(sequence.dtype) + # Since padding id = 0, the padding mask is zeroed out. + targets = sequence * targets_and_padding_mask + # Shift to the left by prefix length. Wrapped elements are already zeroed. + return jnp.roll(targets, -prefix_length, axis=-1) + + +# TODO(cpgaffney) Remove this method when dependencies no longer use - rely on +# WeightedAccuracy Metric instead. +def compute_weighted_accuracy( + logits: jnp.ndarray, + targets: jnp.ndarray, + weights: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Compute weighted accuracy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array of categories. + weights: None or array of shape [batch, length] + + Returns: + Scalar accuracy. + """ + if logits.ndim != targets.ndim + 1: + raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % + (str(logits.shape), str(targets.shape))) + accuracy = jnp.equal(jnp.argmax(logits, axis=-1), targets) + if weights is not None: + accuracy = accuracy * weights + + return jnp.sum(accuracy) + + +# TODO(cpgaffney) remove when users rely on compute_base_metrics +def compute_metrics(logits: jnp.ndarray, targets: jnp.ndarray, + weights: jnp.ndarray, loss: jnp.ndarray, + weight_sum: jnp.ndarray, + additional_metrics: MetricsMap) -> MetricsMap: + """Compute summary metrics.""" + accuracy = compute_weighted_accuracy(logits, targets, weights) + metrics = { + 'loss': loss, + 'accuracy': accuracy, + 'weight_sum': weight_sum, + 'num_examples': targets.shape[0], + 'num_tokens': targets.size + } + metrics = metrics_lib.create_metrics_dict(metrics) + metrics.update(additional_metrics) + return metrics + + +def compute_base_metrics( + logits: jnp.ndarray, + targets: jnp.ndarray, + mask: jnp.ndarray, + loss: jnp.ndarray, + z_loss: Optional[jnp.ndarray] = None, +) -> MetricsMap: + """Compute summary metrics. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array of categories. + mask: None or array of shape [batch, length]. Note: must consist of boolean + values (float-valued weights not supported). + loss: loss (float) + z_loss: z_loss (float) + + Returns: + Dict of metrics. + """ + num_examples = targets.shape[0] + num_tokens = targets.size + num_devices = jax.device_count() + assert num_devices, 'JAX is reporting no devices, but it should.' + # Note: apply mask again even though mask has already been applied to loss. + # This is needed to divide by mask sum, but should not affect correctness of + # the numerator. + nonpadding_tokens = jnp.sum(mask) if mask is not None else targets.size + metrics = { + 'accuracy': + clu_metrics.Accuracy.from_model_output( + logits=logits, labels=targets.astype(jnp.int32), mask=mask), + 'loss': + metrics_lib.AveragePerStep(total=loss), + 'loss_per_nonpadding_target_token': + clu_metrics.Average(total=loss, count=nonpadding_tokens), + 'loss_per_all_target_tokens': + clu_metrics.Average(total=loss, count=num_tokens), + 'timing/seqs_per_second': + metrics_lib.TimeRate.from_model_output(numerator=num_examples), + 'timing/steps_per_second': + metrics_lib.StepsPerTime.from_model_output(), + 'timing/seconds': + metrics_lib.Time(), + 'timing/seqs': + metrics_lib.Sum(num_examples), + 'timing/seqs_per_second_per_core': + metrics_lib.TimeRate.from_model_output(numerator=num_examples / + num_devices), + 'timing/target_tokens_per_second': + metrics_lib.TimeRate.from_model_output(numerator=num_tokens), + 'timing/target_tokens_per_second_per_core': + metrics_lib.TimeRate.from_model_output(numerator=num_tokens / + num_devices), + 'nonpadding_fraction': + clu_metrics.Average(total=nonpadding_tokens, count=num_tokens), + } + if z_loss is not None: + metrics.update({ + 'z_loss': + metrics_lib.AveragePerStep(total=z_loss), + 'z_loss_per_all_target_tokens': + clu_metrics.Average(total=z_loss, count=num_tokens), + 'cross_ent_loss': + metrics_lib.AveragePerStep(total=loss - z_loss), + 'cross_ent_loss_per_all_target_tokens': + clu_metrics.Average(total=jnp.sum(loss - z_loss), count=num_tokens) + }) + return metrics + + +def get_input_vocabulary(model: BaseTransformerModel) -> seqio.Vocabulary: + return model.input_vocabulary + + +def get_output_vocabulary(model: BaseTransformerModel) -> seqio.Vocabulary: + return model.output_vocabulary diff --git a/t5x/models_test.py b/t5x/models_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f3d77bce55dd463852962369e62f50b1acb846 --- /dev/null +++ b/t5x/models_test.py @@ -0,0 +1,985 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.models.""" + +import functools +from unittest import mock + +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import flax +from flax import traverse_util +import jax +import jax.numpy as jnp +import numpy as np +import t5.data.tasks # pylint:disable=unused-import +from t5x import decoding +from t5x import models +from t5x import partitioning +from t5x import test_utils +from t5x import trainer as trainer_lib +from t5x import utils +import tensorflow as tf + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +PartitionSpec = partitioning.PartitionSpec + + +class ModelsTest(parameterized.TestCase): + + def test_remove_prefix(self): + sequences = np.array([[1, 2, 3, 4, 5, 6, 7, 0], [6, 7, 8, 9, 10, 11, 0, 0]]) + prefix_lengths = np.array([2, 4]) + expected = [[3, 4, 5, 6, 7, 0, 0, 0], [10, 11, 0, 0, 0, 0, 0, 0]] + remove_prefix = jax.jit(models.remove_prefix) + actual = remove_prefix(sequences, prefix_lengths) + np.testing.assert_array_equal(actual, expected) + + def test_remove_prefix_zero_len_prefix(self): + sequences = np.array([[1, 2, 3, 4, 5, 6, 7, 0], [6, 7, 8, 9, 10, 11, 0, 0]]) + prefix_lengths = np.array([0, 0]) + remove_prefix = jax.jit(models.remove_prefix) + actual = remove_prefix(sequences, prefix_lengths) + # The expected output is the original sequences. + np.testing.assert_array_equal(actual, sequences) + + +BATCH_SIZE, ENCODER_LEN, MAX_DECODE_LEN, EMBED_DIM = 2, 3, 4, 5 + + +class EncoderDecoderModelTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='no_types', + shapes={ + 'encoder_input_tokens': [1, 512], + 'decoder_input_tokens': [1, 62] + }, + types=None), + dict( + testcase_name='int32', + shapes={ + 'encoder_input_tokens': [1, 512], + 'decoder_input_tokens': [1, 62] + }, + types={ + 'encoder_input_tokens': jnp.int32, + 'decoder_input_tokens': jnp.int32 + }), + dict( + testcase_name='float32', + shapes={ + 'encoder_input_tokens': [1, 512], + 'decoder_input_tokens': [1, 62], + 'encoder_positions': [1, 512], + 'decoder_positions': [1, 62], + }, + types={ + 'encoder_input_tokens': jnp.int32, + 'decoder_input_tokens': jnp.int32, + 'encoder_positions': jnp.int32, + 'decoder_positions': jnp.int32 + }), + dict( + testcase_name='float32_segment_ids', + shapes={ + 'encoder_input_tokens': [1, 512], + 'decoder_input_tokens': [1, 62], + 'encoder_segment_ids': [1, 512], + 'decoder_segment_ids': [1, 62], + }, + types={ + 'encoder_input_tokens': jnp.int32, + 'decoder_input_tokens': jnp.int32, + 'encoder_segment_ids': jnp.int32, + 'decoder_segment_ids': jnp.int32 + }), + ) + def test_get_initial_variables_shapes_and_types(self, shapes, types): + mock_transformer = mock.Mock() + mock_transformer.init.return_value = {'params': {}} + mock_optimizer_def = mock.Mock() + rng = mock.Mock() + + def mock_init(self): + self.module = mock_transformer + self.optimizer_def = mock_optimizer_def + + with mock.patch.object( + models.EncoderDecoderModel, '__init__', new=mock_init): + model = models.EncoderDecoderModel() + model.get_initial_variables(rng, shapes, types) + + if types is None: + encoder_input = jnp.ones( + shapes['encoder_input_tokens'], dtype=jnp.float32) + decoder_input = jnp.ones( + shapes['decoder_input_tokens'], dtype=jnp.float32) + else: + encoder_input = jnp.ones( + shapes['encoder_input_tokens'], dtype=types['encoder_input_tokens']) + decoder_input = jnp.ones( + shapes['decoder_input_tokens'], dtype=types['decoder_input_tokens']) + + # Using `.assert_called_once_with` doesn't work because the simple + # comparison it does for the array arguments fail (truth value of an array + # is ambiguous). + called_with = mock_transformer.init.call_args + self.assertEqual(called_with[0][0], rng) + np.testing.assert_allclose(called_with[0][1], encoder_input) + np.testing.assert_allclose(called_with[0][2], decoder_input) + np.testing.assert_allclose(called_with[0][3], decoder_input) + + if 'encoder_positions' in shapes: + encoder_positions = jnp.ones( + shapes['encoder_positions'], dtype=types['encoder_positions']) + np.testing.assert_allclose(called_with[1]['encoder_positions'], + encoder_positions) + else: + self.assertIsNone(called_with[1]['encoder_positions']) + if 'decoder_positions' in shapes: + decoder_positions = jnp.ones( + shapes['decoder_positions'], dtype=types['decoder_positions']) + np.testing.assert_allclose(called_with[1]['decoder_positions'], + decoder_positions) + else: + self.assertIsNone(called_with[1]['decoder_positions']) + + if 'encoder_segment_ids' in shapes: + encoder_positions = jnp.ones( + shapes['encoder_segment_ids'], dtype=types['encoder_segment_ids']) + np.testing.assert_allclose(called_with[1]['encoder_segment_ids'], + encoder_positions) + else: + self.assertIsNone(called_with[1]['encoder_segment_ids']) + if 'decoder_segment_ids' in shapes: + decoder_segment_ids = jnp.ones( + shapes['decoder_segment_ids'], dtype=types['decoder_segment_ids']) + np.testing.assert_allclose(called_with[1]['decoder_segment_ids'], + decoder_segment_ids) + else: + self.assertIsNone(called_with[1]['decoder_segment_ids']) + + self.assertFalse(called_with[1]['decode']) + self.assertFalse(called_with[1]['enable_dropout']) + + @parameterized.named_parameters( + dict(testcase_name='no_force_decoding', prompt_with_targets=False), + dict(testcase_name='force_decoding', prompt_with_targets=True), + ) + def test_prompt_with_targets(self, prompt_with_targets): + batch_size, encoder_len, max_decode_len, emb_dim = 2, 3, 4, 5 + batch = { + 'encoder_input_tokens': + np.zeros((batch_size, encoder_len), dtype=np.int32), + 'decoder_input_tokens': + np.full([batch_size, max_decode_len], 2, dtype=np.int32) + } + + # These dummy logits represent the probability distribution where all the + # probability mass is in one item (i.e., degenerate distribution). For + # batch element 0, it is vocabulary index 3. + # We test `_predict_step` to avoid having to define a task and its + # vocabulary. + dummy_logits = jnp.expand_dims( + jnp.array([[-1e7, -1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, -1e7, 0]]), + axis=1) + + mock_decode_fn = mock.Mock() + mock_decode_fn.return_value = (np.full([batch_size, max_decode_len, 1], + 3, + dtype=np.int32), + np.full([batch_size, 1], + 1.0, + dtype=np.float32)) + + class MockModule: + + def __init__(self): + self.dtype = jnp.float32 + + def apply(self, *args, method=None, **kwargs): + del args, kwargs + if method is None: # use for module.`__call__` + return (dummy_logits, {'cache': {}}) + else: + return method() + + def encode(self): + return jnp.zeros((batch_size, encoder_len, emb_dim)) + + def decode(self): + return (dummy_logits, {'cache': {}}) + + def mock_init(self): + self.module = MockModule() + self.module.scan_layers = False + self._input_vocabulary = mock.Mock(eos_id=1) + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = mock_decode_fn + + with mock.patch.object( + models.EncoderDecoderModel, '__init__', new=mock_init): + model = models.EncoderDecoderModel() + + model.predict_batch_with_aux({}, + batch, + prompt_with_targets=prompt_with_targets) + + if prompt_with_targets: + expected_inputs = batch['decoder_input_tokens'] + else: + expected_inputs = np.zeros([batch_size, max_decode_len], dtype=np.int32) + + assert mock_decode_fn.call_count == 1 + # Look at the kwargs call list for inputs, assert_called_with doesn't + # work well with np.array comparison. + np.testing.assert_array_equal(mock_decode_fn.mock_calls[0][2]['inputs'], + expected_inputs) + + def test_predict_batch_loop_and_caches_are_equal(self): + vocab_size = 50 + lengths = np.array([[2], [3]]) + batch_size, beam_size, encoder_len, max_decode_len = 2, 2, 3, 7 + batch = { + 'encoder_input_tokens': + np.zeros((batch_size, encoder_len), dtype=np.int32), + 'decoder_target_tokens': + np.zeros((batch_size, encoder_len), dtype=np.int32), + 'decoder_input_tokens': + np.concatenate( + [ + np.expand_dims( + np.concatenate( + [[0], + np.arange(9, 9 + lengths[0][0], dtype=np.int32), + np.zeros((max_decode_len - lengths[0][0] - 1), + dtype=np.int32)]), + axis=0), # First element + np.expand_dims( + np.concatenate( + [[0], + np.arange(3, 3 + lengths[1][0], dtype=np.int32), + np.zeros((max_decode_len - lengths[1][0] - 1), + dtype=np.int32)]), + axis=0) # Second element + ], + axis=0), + } + + model = test_utils.get_t5_test_model(vocab_size=50) + module = model.module + params = module.init( + jax.random.PRNGKey(0), + jnp.ones((batch_size, encoder_len)), + jnp.ones((batch_size, max_decode_len)), + jnp.ones((batch_size, max_decode_len)), + enable_dropout=False)['params'] + + def mock_init(self): + self.module = module + # Set the EOS token to be larger then the vocabulary size. This forces the + # model to decode all the way to `max_decode_length`, allowing us to test + # behavior when one element reaches the end before the others. + self._output_vocabulary = mock.Mock(eos_id=vocab_size + 12) + self._decode_fn = decoding.beam_search + + with mock.patch.object( + models.EncoderDecoderModel, '__init__', new=mock_init): + model = models.EncoderDecoderModel() + + with mock.patch.object( + model, '_compute_logits_from_slice', + autospec=True) as tokens_to_logits_mock: + # Make the side effect of the mock, call the method on the class, with the + # instance partialed in as `self`. This lets us call the actual code, + # while recording the inputs, without an infinite loop you would get + # calling `instance.method` + tokens_to_logits_mock.side_effect = functools.partial( + models.EncoderDecoderModel._compute_logits_from_slice, model) + # Disable jit, so that the `lax.while_loop` isn't traced, as the + # collection of tracers in the mock call_args would generally trigger a + # tracer leak error. + with jax.disable_jit(): + _ = model.predict_batch_with_aux( + params, batch, prompt_with_targets=True, num_decodes=2) + + # Collect all the input tokens to our tokens_to_logits function + all_inputs = [] + all_cache_keys = [] # Collect all the cache keys + all_cache_values = [] # Collect all the cache values + # Currently force decoding generates logits at every step. We should have + # `max_decode_length` calls to our tokens -> logits func. + self.assertLen(tokens_to_logits_mock.call_args_list, max_decode_len) + for tokens_call in tokens_to_logits_mock.call_args_list: + # Inputs: [B * Be, 1] + inputs, cache = tokens_call[0] + cache = flax.core.unfreeze(cache) + # Cache: [B * Be, 1] * #Layers + cache_keys = [ + v for k, v in traverse_util.flatten_dict(cache).items() + if k[-1] == 'cached_key' + ] + cache_values = [ + v for k, v in traverse_util.flatten_dict(cache).items() + if k[-1] == 'cached_value' + ] + all_inputs.append(inputs) + all_cache_keys.append(cache_keys) + all_cache_values.append(cache_values) + # Convert inputs to a single block [B, DL, Be] + all_inputs = np.concatenate(all_inputs, axis=1) + # Convert caches into a single block per layer [B * Be, DL] * L + all_cache_keys = [np.stack(c, axis=1) for c in zip(*all_cache_keys)] + all_cache_values = [np.stack(c, axis=1) for c in zip(*all_cache_values)] + + # Make sure that for each batch, the cache for each beam is identical when + # prompt is being forced. + for b in range(batch_size): + for i, input_token in enumerate(all_inputs[b * beam_size]): + if i < lengths[b]: + self.assertEqual(input_token, batch['decoder_input_tokens'][b][i]) + # For all layers. + for cache_keys in all_cache_keys: + np.testing.assert_array_equal(cache_keys[b * beam_size][i], + cache_keys[b * beam_size + 1][i]) + for cache_values in all_cache_values: + np.testing.assert_array_equal(cache_values[b * beam_size][i], + cache_values[b * beam_size + 1][i]) + + def test_score_batch(self): + encoder_input_tokens = jnp.ones((2, 3)) + # For this test, decoder input and target tokens are dummy values. + decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) + decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) + decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) + logits = jnp.arange(0, 24).reshape((2, 4, 3)) + params = {'foo': jnp.zeros(3)} + + mock_transformer = mock.Mock() + mock_transformer.apply.return_value = logits + mock_transformer.dtype = jnp.float32 + + batch = { + 'encoder_input_tokens': encoder_input_tokens, + 'decoder_input_tokens': decoder_input_tokens, + 'decoder_target_tokens': decoder_target_tokens, + 'decoder_loss_weights': decoder_loss_weights + } + + def mock_init(self): + self.module = mock_transformer + + with mock.patch.object( + models.EncoderDecoderModel, '__init__', new=mock_init): + model = models.EncoderDecoderModel() + res = model.score_batch(params, batch) + + mock_transformer.apply.assert_called_with({'params': params}, + encoder_input_tokens, + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + encoder_positions=None, + decoder_positions=None, + decode=False, + enable_dropout=False, + rngs=None, + mutable=False) + np.testing.assert_allclose(res, [-3.222973, -1.815315], rtol=1e-4) + + def test_score_batch_can_return_intermediates(self): + encoder_input_tokens = jnp.ones((2, 3)) + # For this test, decoder input and target tokens are dummy values. + decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) + decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) + decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) + logits = jnp.arange(0, 24).reshape((2, 4, 3)) + modified_variables = {'intermediates': {'bar': jnp.ones(5)}} + params = {'foo': jnp.zeros(3)} + + mock_transformer = mock.Mock() + mock_transformer.apply.return_value = (logits, modified_variables) + mock_transformer.dtype = jnp.float32 + + batch = { + 'encoder_input_tokens': encoder_input_tokens, + 'decoder_input_tokens': decoder_input_tokens, + 'decoder_target_tokens': decoder_target_tokens, + 'decoder_loss_weights': decoder_loss_weights + } + + def mock_init(self): + self.module = mock_transformer + + with mock.patch.object( + models.EncoderDecoderModel, '__init__', new=mock_init): + model = models.EncoderDecoderModel() + scores, intermediates = model.score_batch( + params, batch, return_intermediates=True) + + mock_transformer.apply.assert_called_with({'params': params}, + encoder_input_tokens, + decoder_input_tokens, + decoder_target_tokens, + encoder_segment_ids=None, + decoder_segment_ids=None, + encoder_positions=None, + decoder_positions=None, + decode=False, + enable_dropout=False, + rngs=None, + mutable=['intermediates']) + np.testing.assert_allclose(scores, [-3.222973, -1.815315], rtol=1e-4) + # Incumbent intermediates are passed out unchanged. + np.testing.assert_allclose(intermediates['bar'], jnp.ones(5)) + # A new collection of decoder intermediates are inserted by score_batch() + np.testing.assert_allclose(intermediates['decoder']['loss_weights'][0], + decoder_loss_weights) + np.testing.assert_allclose(intermediates['decoder']['target_tokens'][0], + decoder_target_tokens) + + def test_train_transformer_wmt(self): + # Dummy input data + input_shape = (16, 8) + encoder_input_tokens = np.ones(shape=input_shape, dtype=np.float32) + decoder_input_tokens = 5 * np.ones(shape=input_shape, dtype=np.float32) + decoder_target_tokens = 5 * np.ones(input_shape, dtype=np.float32) + # input_data = {'inputs': inputs, 'targets': targets} + input_data = { + 'encoder_input_tokens': encoder_input_tokens, + 'decoder_input_tokens': decoder_input_tokens, + 'decoder_target_tokens': decoder_target_tokens + } + + partitioner = partitioning.PjitPartitioner(num_partitions=1) + + model = test_utils.get_t5_test_model() + + ds_iter = tf.data.Dataset.from_tensors(input_data).as_numpy_iterator() + input_shapes = {k: input_shape for k in input_data} + + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=model.optimizer_def, + init_fn=model.get_initial_variables, + input_shapes=input_shapes, + partitioner=partitioner) + train_state_axes = train_state_initializer.train_state_axes + train_state = train_state_initializer.from_scratch(jax.random.PRNGKey(0)) + + trainer = trainer_lib.Trainer( + model, + train_state=train_state, + partitioner=partitioner, + eval_names=[], + summary_dir=None, + train_state_axes=train_state_axes, + rng=jax.random.PRNGKey(0), + learning_rate_fn=lambda x: 0.001, + num_microbatches=1) + + trainer.train(ds_iter, 1) + logging.info('optimizer after first step %s', train_state.params) + + + @parameterized.parameters( + {'decode_fn': decoding.beam_search}, + {'decode_fn': functools.partial(decoding.temperature_sample, topk=4)}) + def test_predict_batch(self, decode_fn): + batch_size, encoder_len, max_decode_len, emb_dim = 2, 3, 4, 5 + batch = { + 'encoder_input_tokens': + np.zeros((batch_size, encoder_len), dtype=np.int32), + 'decoder_input_tokens': + np.zeros((batch_size, max_decode_len), dtype=np.int32) + } + + # These dummy logits represent the probability distribution where all the + # probability mass is in one item (i.e., degenerate distribution). For + # batch element 0, it is vocabulary index 2. + # We test `_predict_step` to avoid having to define a task and its + # vocabulary. + dummy_logits = jnp.expand_dims( + jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1) + + class MockModule: + + def __init__(self): + self.dtype = jnp.float32 + + def apply(self, *args, method=None, **kwargs): + del args, kwargs + if method is None: # use for module.`__call__` + return (dummy_logits, {'cache': {}}) + else: + return method() + + def encode(self): + return jnp.zeros((batch_size, encoder_len, emb_dim)) + + def decode(self): + return (dummy_logits, {'cache': {}}) + + def mock_init(self): + self.module = MockModule() + self.module.scan_layers = False + self._input_vocabulary = mock.Mock(eos_id=1) + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = decode_fn + + with mock.patch.object( + models.EncoderDecoderModel, '__init__', new=mock_init): + model = models.EncoderDecoderModel() + + actual = model.predict_batch({}, batch) + # The predicted token for the first batch element is always 2 and it is 3 + # for the second batch element. + expected = [[2] * max_decode_len, [3] * max_decode_len] + np.testing.assert_array_equal(actual, expected) + + def test_predict_batch_rng(self): + batch = { + 'encoder_input_tokens': np.zeros((2, 1), dtype=np.int32), + 'decoder_input_tokens': np.zeros((2, 2), dtype=np.int32) + } + + decode_fn_mock = mock.Mock( + return_value=(np.zeros((2, 2, 3)), np.zeros((2, 2)))) + + def mock_init(self): + self.module = mock.Mock( + apply=mock.Mock(side_effect=lambda *_, **kwargs: ( # pylint:disable=g-long-lambda,g-long-ternary + np.zeros((2, 2)), { + 'cache': None + }) if 'mutable' in kwargs else np.zeros((2, 2)))) + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = decode_fn_mock + + with mock.patch.object( + models.EncoderDecoderModel, '__init__', new=mock_init): + model = models.EncoderDecoderModel() + + # No RNG + model.predict_batch({}, batch) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertNotIn('decode_rng', decode_fn_kwargs) + + # No RNG (w/ aux) + model.predict_batch_with_aux({}, batch) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertNotIn('decode_rng', decode_fn_kwargs) + + # decoder_params RNG + model.predict_batch_with_aux({}, batch, decoder_params={'decode_rng': 3}) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertEqual(decode_fn_kwargs['decode_rng'], 3) + + # rng RNG + model.predict_batch({}, batch, rng=4) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertEqual(decode_fn_kwargs['decode_rng'], 4) + + # rng RNG (w/ aux) + model.predict_batch_with_aux({}, batch, rng=4) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertEqual(decode_fn_kwargs['decode_rng'], 4) + + # Both + with self.assertRaisesWithLiteralMatch( + ValueError, 'Got RNG both from the `rng` argument (4) and ' + "`decoder_params['decode_rng']` (3). Please specify one or the other."): + model.predict_batch_with_aux({}, + batch, + rng=4, + decoder_params={'decode_rng': 3}) + + @parameterized.named_parameters( + dict( + testcase_name='int32', + batch={ + 'encoder_input_tokens': + np.zeros((BATCH_SIZE, ENCODER_LEN), dtype=np.int32), + 'decoder_input_tokens': + np.zeros((BATCH_SIZE, MAX_DECODE_LEN), dtype=np.int32) + }), + dict( + testcase_name='float32', + batch={ + 'encoder_input_tokens': + np.zeros((BATCH_SIZE, ENCODER_LEN), dtype=np.float32), + 'decoder_input_tokens': + np.zeros((BATCH_SIZE, MAX_DECODE_LEN), dtype=np.float32) + })) + def test_predict_batch_fake_input_shapes_and_types(self, batch): + + # These dummy logits represent the probability distribution where all the + # probability mass is in one item (i.e., degenerate distribution). For + # batch element 0, it is vocabulary index 2. + # We test `_predict_step` to avoid having to define a task and its + # vocabulary. + dummy_logits = jnp.ones((2, 1, 4), jnp.float32) + + class MockModule: + + def __init__(self): + self.dtype = jnp.float32 + self.call_args_list = [] + + def apply(self, *args, method=None, **kwargs): + # Not sure why this isn't a real Mock so just record the args/kwargs + self.call_args_list.append({'args': args, 'kwargs': kwargs}) + del args, kwargs + if method is None: # use for module.`__call__` + return (dummy_logits, {'cache': {}}) + else: + return method() + + def encode(self): + return jnp.zeros((BATCH_SIZE, ENCODER_LEN, EMBED_DIM)) + + def decode(self): + return (dummy_logits, {'cache': {}}) + + def mock_init(self): + self.module = MockModule() + self.module.scan_layers = False + self._input_vocabulary = mock.Mock(eos_id=1) + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = decoding.beam_search + self._inputs_bidirectional_attention = False + + with mock.patch.object( + models.EncoderDecoderModel, '__init__', new=mock_init): + model = models.EncoderDecoderModel() + model.predict_batch({}, batch) + + fake_inputs = jnp.ones_like(batch['encoder_input_tokens']) + fake_target = jnp.ones_like(batch['decoder_input_tokens']) + + cache_init_call = model.module.call_args_list[0] + self.assertEqual(cache_init_call['args'][0], {'params': {}}) + np.testing.assert_allclose(cache_init_call['args'][1], fake_inputs) + np.testing.assert_allclose(cache_init_call['args'][2], fake_target) + np.testing.assert_allclose(cache_init_call['args'][3], fake_target) + self.assertEqual(cache_init_call['kwargs'], { + 'decode': True, + 'enable_dropout': False, + 'mutable': ['cache'] + }) + + +class DecoderOnlyModelTest(parameterized.TestCase): + + + + def test_predict_batch_visible_in_prefill(self): + batch_size = 2 + seq_len = 10 + lengths = np.array([[6], [3]]) + batch = { + 'decoder_input_tokens': + np.tile( + np.expand_dims(np.arange(seq_len, dtype=np.int32), axis=0), + (batch_size, 1)), + 'decoder_causal_attention': + (lengths > np.arange(seq_len)).astype(np.int32) + } + + dummy_logits = jnp.expand_dims( + jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1) + + mock_module = mock.Mock() + mock_module.apply.return_value = (dummy_logits, {'cache': {}}) + mock_module.dtype = jnp.float32 + + def mock_init(self): + self.module = mock_module + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) + self._inputs_bidirectional_attention = False + + with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): + model = models.DecoderOnlyModel() + + model.predict_batch({}, batch) + prefill_call = mock_module.apply.call_args_list[1] + kwargs = prefill_call[1] + inputs = prefill_call[1]['decoder_input_tokens'] + # Note that, for the prefill call, we use 'decoder_causal_attention' as + # 'decoder_target_tokens'. + targets = prefill_call[1]['decoder_target_tokens'] + self.assertTrue(kwargs['prefill']) + np.testing.assert_array_equal(kwargs['prefill_lengths'], + np.squeeze(lengths - 1, axis=-1)) + # Test that the non padding values of the "targets" cover all of the input, + # you it will all be considered in the attention mask. + np.testing.assert_array_equal(inputs * targets, inputs) + # Check that the first value of the target is 1, the first value of the + # inputs is always 0 so the masking check wouldn't catch it if the target + # had a 0 in the first location. + np.testing.assert_array_equal(targets[:, 0], np.ones_like(targets[:, 0])) + # Test that the targets are properly removed. Our input is a sequence from 0 + # onward, so our largest value (the last input) should be equal by it's + # position (which is 1 - length). If we didn't mask the target correctly, + # we would expect a larger value in the max. + np.testing.assert_array_equal( + np.max(inputs, axis=1), np.squeeze(lengths - 1, axis=-1)) + + + def test_predict_batch(self): + batch = { + 'decoder_input_tokens': + np.array([[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]]), + 'decoder_causal_attention': + np.array([[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]]) + } + + # These dummy logits represent the probability distribution where all the + # probability mass is in one item (i.e., degenerate distribution). For + # batch element 0, it is vocabulary index 2. + # We test `_predict_step` to avoid having to define a task and its + # vocabulary. + dummy_logits = jnp.expand_dims( + jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1) + + mock_module = mock.Mock() + mock_module.apply.return_value = (dummy_logits, {'cache': {}}) + mock_module.dtype = jnp.float32 + + def mock_init(self): + self.module = mock_module + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) + self._inputs_bidirectional_attention = False + + with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): + model = models.DecoderOnlyModel() + + actual = model.predict_batch({}, batch) + + expected = [[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 3, 0]] + + # The expected progression of the first element of 'decoder_input_tokens': + # [0, 3, 4, 5, 6, 0, 0] -> [0, 3, 4, 0, 0, 0, 0] -> + # [3, 4, 2, 2, 2, 2, 2] -> [2, 2, 2, 2, 2, 0, 0] + + # The expected progression of the second element of 'decoder_input_tokens': + # [0, 7, 8, 9, 0, 0, 0] -> [0, 7, 0, 0, 0, 0, 0] -> + # [7, 3, 3, 3, 3, 3, 3] -> [3, 3, 3, 3, 3, 3, 0] + + np.testing.assert_array_equal(actual, expected) + + def test_predict_batch_rng(self): + batch = { + 'decoder_input_tokens': np.zeros((2, 2), dtype=np.int32), + 'decoder_causal_attention': np.zeros((2, 2), dtype=np.int32) + } + + decode_fn_mock = mock.Mock( + return_value=(np.zeros((2, 2, 3)), np.zeros((2, 2)))) + + def mock_init(self): + self.module = mock.Mock( + apply=mock.Mock(side_effect=lambda *_, **kwargs: ( # pylint:disable=g-long-lambda,g-long-ternary + np.zeros((2, 2)), { + 'cache': None + }) if 'mutable' in kwargs else np.zeros((2, 2)))) + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = decode_fn_mock + self._inputs_bidirectional_attention = False + + with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): + model = models.DecoderOnlyModel() + + # No RNG + model.predict_batch({}, batch) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertNotIn('decode_rng', decode_fn_kwargs) + + # No RNG (w/ aux) + model.predict_batch_with_aux({}, batch) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertNotIn('decode_rng', decode_fn_kwargs) + + # decoder_params RNG + model.predict_batch_with_aux({}, batch, decoder_params={'decode_rng': 3}) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertEqual(decode_fn_kwargs['decode_rng'], 3) + + # rng RNG + model.predict_batch({}, batch, rng=4) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertEqual(decode_fn_kwargs['decode_rng'], 4) + + # rng RNG (w/ aux) + model.predict_batch_with_aux({}, batch, rng=4) + _, decode_fn_kwargs = decode_fn_mock.call_args + self.assertEqual(decode_fn_kwargs['decode_rng'], 4) + + # Both + with self.assertRaisesWithLiteralMatch( + ValueError, 'Got RNG both from the `rng` argument (4) and ' + "`decoder_params['decode_rng']` (3). Please specify one or the other."): + model.predict_batch_with_aux({}, + batch, + rng=4, + decoder_params={'decode_rng': 3}) + + def test_predict_batch_num_decodes_temperature_sample(self): + batch = { + 'decoder_input_tokens': np.array([ + [0, 3, 4, 5, 6, 0, 0], + ]), + 'decoder_causal_attention': np.array([ + [1, 1, 1, 0, 0, 0, 0], + ]) + } + + # These dummy logits represent the probability distribution where all the + # probability mass is in one item (i.e., degenerate distribution). For + # batch element 0, it is vocabulary index 2. We have two samples. + # Technically these should be identical since the prompts are the same, but + # this makes testing easier. + dummy_logits = jnp.expand_dims( + jnp.array([[-1e7, -1e7, 0, -1e7], [-1e7, -1e7, -1e7, 0]]), axis=1) + + mock_module = mock.Mock() + mock_module.apply.return_value = (dummy_logits, {'cache': {}}) + mock_module.dtype = jnp.float32 + + def mock_init(self): + self.module = mock_module + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) + self._inputs_bidirectional_attention = False + + with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): + model = models.DecoderOnlyModel() + + actual_output, aux = model.predict_batch_with_aux({}, + batch, + num_decodes=2, + return_all_decodes=True) + + expected_output = [[[2, 2, 2, 2, 2, 0, 0], [3, 3, 3, 3, 3, 0, 0]]] + expected_scores = [[0., 0.]] + + # The expected progression of the first element of 'decoder_input_tokens': + # [0, 3, 4, 5, 6, 0, 0] -> [0, 3, 4, 0, 0, 0, 0] -> + # [3, 4, 2, 2, 2, 2, 2] -> [2, 2, 2, 2, 2, 0, 0] + + # The expected progression of the second element of 'decoder_input_tokens': + # [0, 7, 8, 9, 0, 0, 0] -> [0, 7, 0, 0, 0, 0, 0] -> + # [7, 3, 3, 3, 3, 3, 3] -> [3, 3, 3, 3, 3, 3, 0] + + np.testing.assert_array_equal(actual_output, expected_output) + np.testing.assert_array_equal(aux['scores'], expected_scores) + + def test_predict_batch_fake_input_shapes_and_types(self): + # The input and causal attention actually have to be int32 for this test, + # even though the cache init should work with any types the `inputs` that + # is created from multiplying the causal attention and the input tokens + # needs to be an int or the decoding will fail. + batch = { + 'decoder_input_tokens': + np.array([[0, 3, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0, 0]], + dtype=np.int32), + 'decoder_causal_attention': + np.array([[1, 1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0]], + dtype=np.int32) + } + + dummy_logits = jnp.ones((2, 1, 5), jnp.float32) + + mock_module = mock.Mock() + mock_module.apply.return_value = (dummy_logits, {'cache': {}}) + mock_module.dtype = jnp.float32 + + def mock_init(self): + self.module = mock_module + self._output_vocabulary = mock.Mock(eos_id=1) + self._decode_fn = functools.partial(decoding.temperature_sample, topk=4) + self._inputs_bidirectional_attention = False + + with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): + model = models.DecoderOnlyModel() + + model.predict_batch({}, batch) + + fake_target = jnp.ones_like(batch['decoder_input_tokens']) + + cache_init_call = mock_module.apply.call_args_list[0] + + self.assertEqual(cache_init_call[0][0], {'params': {}}) + np.testing.assert_allclose(cache_init_call[0][1], fake_target) + np.testing.assert_allclose(cache_init_call[0][2], fake_target) + self.assertEqual(cache_init_call[1], { + 'decode': True, + 'enable_dropout': False, + 'mutable': ['cache'] + }) + + @parameterized.named_parameters( + dict( + testcase_name='no_types', + shapes={'decoder_input_tokens': [1, 62]}, + types=None), + dict( + testcase_name='int32', + shapes={'decoder_input_tokens': [1, 62]}, + types={'decoder_input_tokens': jnp.int32}), + dict( + testcase_name='float32', + shapes={'decoder_input_tokens': [1, 62]}, + types={'decoder_input_tokens': jnp.int32}), + ) + def test_get_initial_variables_shapes_and_types(self, shapes, types): + mock_lm = mock.Mock() + mock_lm.init.return_value = {'params': {}} + mock_optimizer_def = mock.Mock() + rng = mock.Mock() + + def mock_init(self): + self.module = mock_lm + self.optimizer_def = mock_optimizer_def + + with mock.patch.object(models.DecoderOnlyModel, '__init__', new=mock_init): + model = models.DecoderOnlyModel() + model.get_initial_variables(rng, shapes, types) + + if types is None: + decoder_input = jnp.ones( + shapes['decoder_input_tokens'], dtype=jnp.float32) + else: + decoder_input = jnp.ones( + shapes['decoder_input_tokens'], dtype=types['decoder_input_tokens']) + + # Using `.assert_called_once_with` doesn't work because the simple + # comparison it does for the array arguments fail (truth value of an array + # is ambiguous). + called_with = mock_lm.init.call_args + self.assertEqual(called_with[0][0], rng) + np.testing.assert_allclose(called_with[0][1], decoder_input) + np.testing.assert_allclose(called_with[0][2], decoder_input) + self.assertEqual(mock_lm.init.call_args[1], {'enable_dropout': False}) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/optimizers.py b/t5x/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec7778e346495a07f413667f8dff00a02725ecf --- /dev/null +++ b/t5x/optimizers.py @@ -0,0 +1,706 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5X Optimizer Support. + +Tools for wrapping Optax optimizers and handling SPMD annotations for use with +pjit. + +Additional support for the legacy Adafactor implementation. +""" + +import functools +from typing import Any, Optional, Union, Sequence, Tuple + +import flax +from flax import optim # just used for transitional type definitions +from flax import serialization +from flax import struct +from flax import traverse_util +from flax.core import frozen_dict +from flax.serialization import from_state_dict +from flax.serialization import to_state_dict +import jax +import jax.numpy as jnp +import optax + +freeze = flax.core.frozen_dict.freeze +unfreeze = flax.core.frozen_dict.unfreeze + +Dtype = Any + + +@struct.dataclass +class OptimizerState: + step: jnp.ndarray + param_states: Any + + +class OptimizerDef: + """Base class for an optimizer definition.""" + + def __init__(self, hyper_params): + self.hyper_params = hyper_params + + def apply_gradient(self, hyper_params, params, state, grads): + """Applies a gradient for a set of parameters.""" + raise NotImplementedError() + + def init_state(self, params): + raise NotImplementedError() + + def update_hyper_params(self, **hyper_param_overrides): + """Updates the hyper parameters with a set of overrides. + + Args: + **hyper_param_overrides: the hyper parameters updates will override the + defaults specified in the `OptimizerDef`. Pass `hyper_params=...` to + replace all hyper parameters. + + Returns: + The new hyper parameters. + """ + hp = hyper_param_overrides.pop('hyper_params', self.hyper_params) + if hyper_param_overrides: + hp = hp.replace(**hyper_param_overrides) + return hp + + def create(self, target): + """Creates a new optimizer for the given target. + + Args: + target: the object to be optimized. This is typically a variable dict + returned by `flax.linen.Module.init()`, but it can also be a container + of variables dicts, e.g. `(v1, v2)` and `('var1': v1, 'var2': v2)` are + valid inputs as well. + + Returns: + An instance of `Optimizer`. + """ + opt_def = self + state = opt_def.init_state(target) + return Optimizer(opt_def, state, target) + + def state_dict(self, target, state): + return to_state_dict({ + 'target': to_state_dict(target), + 'state': to_state_dict(state) + }) + + def restore_state(self, opt_target, opt_state, state_dict): + """Restore the optimizer target and state from the state dict. + + Args: + opt_target: the optimizer target. + opt_state: the optimizer state. + state_dict: the state dict containing the desired new state of the + optimizer. + + Returns: + a tuple of the optimizer target and state with the restored values from + the state dict. + """ + + opt_target = from_state_dict(opt_target, state_dict['target']) + opt_state = from_state_dict(opt_state, state_dict['state']) + return opt_target, opt_state + + +class Optimizer(struct.PyTreeNode): + """Legacy flax optimizer class. + + Optimizer carries the target and optimizer state. The optimizer is updated + using the method apply_gradient. + + Attributes: + optimizer_def: The optimizer definition. + state: The initial state of the optimizer. + target: The target to optimizer. + """ + + optimizer_def: OptimizerDef = struct.field(pytree_node=False) + state: Any = struct.field(pytree_node=True) + target: Any = struct.field(pytree_node=True) + + def apply_gradient(self, grads, **hyper_param_overrides): + """Applies a pytree of gradients to the target. + + Args: + grads: A pytree of gradients. + **hyper_param_overrides: the hyper parameters passed to apply_gradient + will override the defaults specified in the `OptimizerDef`. Pass + `hyper_params=...` to replace all hyper parameters. + + Returns: + A new optimizer with the updated target and state. + """ + hyper_params = self.optimizer_def.update_hyper_params( + **hyper_param_overrides) + new_target, new_state = self.optimizer_def.apply_gradient( + hyper_params, self.target, self.state, grads) + return self.replace(target=new_target, state=new_state) + + def state_dict(self): + return self.optimizer_def.state_dict(self.target, self.state) + + def restore_state(self, state): + target, state = self.optimizer_def.restore_state(self.target, self.state, + state) + return self.replace(target=target, state=state) + + +# Transitional Type Definitions + +OptimizerType = Union[optim.Optimizer, Optimizer] +OptimizerStateType = Union[optim.OptimizerState, OptimizerState] +OptimizerDefType = Union[optim.OptimizerDef, OptimizerDef] + +# Optax Elementwise Wrapper + + +class OptaxStatePartitionRules: + """Collection of rules to partition optax states. + + These rules work for optimizers whose states are simply replications of + params, e.g., Adam. Optimizers that aim to save memory by factoring states, + e.g., Adafactor, SM3, are not supported currently. + """ + + # Rules mapping a particular optax state to a callable returning the state + # with arrays replaced by t5x PartitionSpec or None. + # + # NOTE(levskaya): This is not an entirely exhaustive list, add to this list + # to support additional optimizers / transformations. + # + # pylint: disable=g-long-lambda + + _RULES = { + + # Leaf Optax States: + optax.AddNoiseState: + lambda state, params_axes: optax.AddNoiseState( + count=None, rng_key=None), + optax.DifferentiallyPrivateAggregateState: + lambda state, params_axes: optax.DifferentiallyPrivateAggregateState( + rng_key=None), + optax.EmaState: + lambda state, params_axes: optax.EmaState( + count=None, ema=params_axes), + optax.EmptyState: + lambda state, params_axes: optax.EmptyState(), + optax.TraceState: + lambda state, params_axes: optax.TraceState(trace=params_axes), + optax.ScaleByAdamState: + lambda state, params_axes: optax.ScaleByAdamState( + count=None, mu=params_axes, nu=params_axes), + optax.ScaleByBeliefState: + lambda state, params_axes: optax.ScaleByBeliefState( + count=None, mu=params_axes, nu=params_axes), + optax.ScaleByRssState: + lambda state, params_axes: optax.ScaleByRssState( + sum_of_squares=params_axes), + optax.ScaleByRmsState: + lambda state, params_axes: optax.ScaleByRmsState(nu=params_axes), + optax.ScaleByRStdDevState: + lambda state, params_axes: optax.ScaleByRStdDevState( + mu=params_axes, nu=params_axes), + optax.ScaleBySM3State: + lambda state, params_axes: optax.ScaleBySM3State( + mu=params_axes, nu=params_axes), + optax.ScaleByTrustRatioState: + lambda state, params_axes: optax.ScaleByTrustRatioState(), + optax.ScaleByScheduleState: + lambda state, params_axes: optax.ScaleByScheduleState(count=None), + optax.ScaleByFromageState: + lambda state, params_axes: optax.ScaleByFromageState(count=None), + optax.ZeroNansState: + lambda state, params_axes: optax.ZeroNansState(found_nan=None), + # FactoredState + + # Recursive, Combinator Optax States: + + # MaskedState + optax.MaskedState: + lambda state, params_axes: optax.MaskedState( + inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( + state.inner_state, params_axes)), + optax.InjectHyperparamsState: + lambda state, params_axes: optax.InjectHyperparamsState( + count=None, + hyperparams=jax.tree_map(lambda x: None, state.hyperparams), + inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( + state.inner_state, params_axes)), + optax.MultiStepsState: + lambda state, params_axes: optax.MultiStepsState( + mini_step=None, + gradient_step=None, + inner_opt_state=OptaxStatePartitionRules. + derive_optax_logical_axes( # pylint: disable=line-too-long + state.inner_opt_state, params_axes), + acc_grads=params_axes), + optax.ApplyIfFiniteState: + lambda state, params_axes: optax.ApplyIfFiniteState( + notfinite_count=None, + last_finite=None, + total_notfinite=None, + inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( + state.inner_state, params_axes)), + optax.MaybeUpdateState: + lambda state, params_axes: optax.MaybeUpdateState( + inner_state=OptaxStatePartitionRules.derive_optax_logical_axes( + state.inner_state, params_axes), + step=None), + optax.MultiTransformState: + lambda state, params_axes: optax.MultiTransformState( + inner_states=OptaxStatePartitionRules.derive_optax_logical_axes( + state.inner_states, params_axes)), + # LookaheadState + # SplitRealAndImaginaryState + } + # pylint: enable=g-long-lambda + + @classmethod + def _is_optax_state(cls, x): + """Returns true if an object is an optax state. + + Note that in optax states are simply derived from NamedTuple, so we have to + do some hacky name matching. + + Args: + x: object. + + Returns: + True if x is an optax state. + """ + # A solution from stack overflow. Note that isinstance(x, NamedTuple) would + # not work. + is_named_tuple = ( + isinstance(x, tuple) and hasattr(x, '_asdict') and + hasattr(x, '_fields')) + result = is_named_tuple and type(x).__name__.endswith('State') + return result + + @classmethod + def derive_optax_logical_axes(cls, optax_state, params_axes): + """Derived logical axes for optax state.""" + # Flatten the optax state but do not go into the registered states. + flattened_state, tree_def = jax.tree_flatten( + optax_state, is_leaf=cls._is_optax_state) + + def derive_fn(x): + if type(x) not in cls._RULES: + if cls._is_optax_state(x): + raise ValueError( + f'Encountered unregistered optax state type {type(x).__name__}') + return None + return cls._RULES[type(x)](x, params_axes) + + flattened_axes = [derive_fn(x) for x in flattened_state] + derived_axes = jax.tree_unflatten(tree_def, flattened_axes) + return derived_axes + + +@struct.dataclass +class _OptaxWrapperHyperParams: + """Dummy hyper params struct, not used.""" + # Required by t5x trainer. Unused as learning rate scheduling is done using + # optax.Schedule. + learning_rate: Optional[float] = None + + +class OptaxWrapper(OptimizerDef): + """Wrapper to make optax optimizer compatible with T5X.""" + + def __init__(self, optax_optimizer: optax.GradientTransformation): + """Initializer. + + Args: + optax_optimizer: An optax optimizer. + """ + self.optax_optimizer = optax_optimizer + super().__init__(hyper_params=_OptaxWrapperHyperParams()) + + def init_state(self, params): + """Create initial state based on the params to optimize. + + Args: + params: PyTree of parameters to optimize. + + Returns: + Initial optimizer state. + """ + state = OptimizerState( + step=0, param_states=self.optax_optimizer.init(params)) + return state + + def apply_gradient(self, hyper_params, params, state, grads): + """Applies gradient. + + Args: + hyper_params: Unused hyper parameters. + params: PyTree of the parameters. + state: A named tuple containing the state of the optimizer. + grads: PyTree of the gradients for the parameters. + + Returns: + A tuple containing the new parameters and the new optimizer state. + """ + del hyper_params + + updates, new_optax_state = self.optax_optimizer.update( + grads, state.param_states, params) + new_params = optax.apply_updates(params, updates) + return new_params, OptimizerState( + step=state.step + 1, param_states=new_optax_state) + + def derive_logical_axes(self, optimizer, param_logical_axes): + """Derives optimizer state logical axes from params logical axes. + + Args: + optimizer: `optimizers.Optimizer` instance. + param_logical_axes: A PyTree where each leaf is a t5x PartitionSpec. + + Returns: + An `optimizers.Optimizer` instance, with all the leafs replaced by t5x + PartitionSpec or None (no partition). + """ + optimizer_logical_axes = jax.tree_map(lambda x: None, + optimizer.state_dict()) + optimizer_logical_axes['target'] = param_logical_axes + + optax_state_axes = OptaxStatePartitionRules.derive_optax_logical_axes( + optimizer.state.param_states, param_logical_axes) + + optimizer_logical_axes['state']['param_states'] = ( + serialization.to_state_dict(optax_state_axes)) + + return optimizer.restore_state(frozen_dict.unfreeze(optimizer_logical_axes)) + + def state_dict(self, target, state): + """Override state dict function. + + We need to override this function because many optax transformations use + `optax.EmptyState`, which produces empty dict in the state dict. This causes + the T5 training loop to fail in multiple places. As a remedy, we will + filter out the generated state dict so that there are no empty dict in the + output. + + The restore_state function is also overridden to reconstruct those empty + dict. + + Args: + target: Pytree of target variables. + state: Pytree of optimizer state. + + Returns: + A nested state. + """ + state_dict = to_state_dict(state) + + # This step removes any empty dict (recursively) in the state dict. + state_dict = traverse_util.unflatten_dict( + traverse_util.flatten_dict(state_dict, sep='/'), sep='/') + + return to_state_dict({ + 'target': to_state_dict(target), + 'state': state_dict, + }) + + def restore_state(self, opt_target, opt_state, state_dict): + """Override to restore empty dicts corresponding to `optax.EmptyState`. + + Args: + opt_target: the optimizer target. + opt_state: the optimizer state. + state_dict: the state dict containing the desired new state of the + optimizer. + + Returns: + a tuple of the optimizer target and state with the restored values from + the state dict. + """ + opt_target = from_state_dict(opt_target, state_dict['target']) + + # Get all the possible keys in the reference optimizer state. + flat_ref_opt_state_dict = traverse_util.flatten_dict( + to_state_dict(opt_state), keep_empty_nodes=True, sep='/') + + flat_src_opt_state_dict = dict( + traverse_util.flatten_dict(state_dict['state'], sep='/')) + # Adding the empty paths back to flat_src_opt_state_dict. + for k, v in flat_ref_opt_state_dict.items(): + if k in flat_src_opt_state_dict: + continue + # The key is not in the input state dict, presumably because it + # corresponds to an empty dict. + if v != traverse_util.empty_node: + raise ValueError( + f'Failed to restore optimizer state, path {k} is not present ' + 'in the input optimizer state dict.') + flat_src_opt_state_dict[k] = v + + # Restore state from the enhanced state dict. + opt_state = from_state_dict( + opt_state, + traverse_util.unflatten_dict(flat_src_opt_state_dict, sep='/')) + return opt_target, opt_state + + +# Optax wrapper and elementary wrapped optax optimizers. + + +def wrap_optax_optimizer(optax_optimizer): + """Converts optax optimizer constructor to a wrapped T5X-compatible optimizer. + + Args: + optax_optimizer: an optax optimizer creation function that returns an optax + GradientTransformation. + + Returns: + A function that takes the same arguments as the original optax creation + function but instead returns a wrapped OptimizerDef-compatible interface for + using the optimizer with T5X. + """ + + @functools.wraps(optax_optimizer) + def wrapped_optimizer(*args, **kwargs) -> OptimizerDef: + return OptaxWrapper(optax_optimizer(*args, **kwargs)) + + return wrapped_optimizer + + +def chain( + transformations: Sequence[optax.GradientTransformation] +) -> optax.GradientTransformation: + return optax.chain(*transformations) + + +chain = wrap_optax_optimizer(chain) +adabelief = wrap_optax_optimizer(optax.adabelief) +adagrad = wrap_optax_optimizer(optax.adagrad) +adam = wrap_optax_optimizer(optax.adam) +adamw = wrap_optax_optimizer(optax.adamw) +fromage = wrap_optax_optimizer(optax.fromage) +lars = wrap_optax_optimizer(optax.lars) +lamb = wrap_optax_optimizer(optax.lamb) +noisy_sgd = wrap_optax_optimizer(optax.noisy_sgd) +radam = wrap_optax_optimizer(optax.radam) +rmsprop = wrap_optax_optimizer(optax.rmsprop) +sgd = wrap_optax_optimizer(optax.sgd) +yogi = wrap_optax_optimizer(optax.yogi) +dpsgd = wrap_optax_optimizer(optax.dpsgd) + +# Excluded optimizers: +# TODO(levskaya): add shampoo, sm3 +# We use our own generalized adafactor implementations. +# adafactor = wrap_optax_optimizer(optax.adafactor) +# We may use a more complete quantized implementation of SM3 +# sm3 = wrap_optax_optimizer(optax.sm3) + +# Inlined Legacy Generalized Multioptimizer + + +class _Marker: + """Used to mark unoptimized leaves.""" + + def __init__(self): + self._indices = [] + + +def _tree_of_paths(tree): + """Converts a (frozen) nested dictionary into a (frozen) dict of paths.""" + is_frozen = isinstance(tree, flax.core.frozen_dict.FrozenDict) + flat_tree = traverse_util.flatten_dict(unfreeze(tree)) + path_tree = traverse_util.unflatten_dict( + {k: '/'.join(k) for k in flat_tree.keys()}) + if is_frozen: + path_tree = freeze(path_tree) + return path_tree + + +def _subtree_from_traversal(traversal, tree): + """Creates a (frozen) tree subset given a traversal.""" + is_frozen = isinstance(tree, flax.core.frozen_dict.FrozenDict) + flat_tree = {} + for path, leaf in zip( + traversal.iterate(_tree_of_paths(tree)), traversal.iterate(tree)): + flat_tree[path] = leaf + new_tree = traverse_util.unflatten_dict( + {tuple(k.split('/')): v for k, v in flat_tree.items()}) + if is_frozen: + new_tree = freeze(new_tree) + return new_tree + + +def _update_subtree_of_traversal(traversal, tree, update): + """Updates a (frozen) tree's subset given a traversal and update subtree.""" + is_frozen = isinstance(tree, flax.core.frozen_dict.FrozenDict) + flat_tree = traverse_util.flatten_dict(unfreeze(tree)) + flat_tree = {'/'.join(k): v for k, v in flat_tree.items()} + for path, leaf in zip( + traversal.iterate(_tree_of_paths(update)), traversal.iterate(update)): + flat_tree[path] = leaf + nested_d = traverse_util.unflatten_dict( + {tuple(k.split('/')): v for k, v in flat_tree.items()}) + if is_frozen: + nested_d = freeze(nested_d) + return nested_d + + +class MultiOptimizer(OptimizerDef): + """Generalized Multioptimizer. + + NB: Although this is provided for legacy support, it is still quite general + and should work fine with wrapped optax optimizers. But do note that the more + canonical way of mixing multiple optimizers inside optax uses optax.masked or + optax.multi_transform instead. + + A MultiOptimizer is subclass of :class:`OptimizerDef` and useful for applying + separate optimizer algorithms to various subsets of the model parameters. + + The example below creates two optimizers using + :class:`flax.traverse_util.ModelParamTraversal`: + one to optimize ``kernel`` parameters and to optimize ``bias`` parameters. + Note each optimizer is created with a different learning rate:: + + kernels = traverse_util.ModelParamTraversal( + lambda path, _: 'kernel' in path) + biases = traverse_util.ModelParamTraversal(lambda path, _: 'bias' in path) + kernel_opt = optimizers.adam(learning_rate=0.01) + bias_opt = optimizers.adam(learning_rate=0.1) + opt_def = MultiOptimizer((kernels, kernel_opt), (biases, bias_opt)) + optimizer = opt_def.create(model) + + In order to train only a subset of the parameters, you can simply use a single + :class:`flax.traverse_util.ModelParamTraversal` instance. + + If you want to update the learning rates of both optimizers online with + different learning rate schedules, you should update the learning rates when + applying the gradient. In the following example, the second optimizer is not + doing any optimization during the first 1000 steps:: + + hparams = optimizer.optimizer_def.hyper_params + new_optimizer = optimizer.apply_gradient( + grads, + hyper_params=[ + hparams[0].replace(learning_rate=0.2), + hparams[1].replace(learning_rate=jnp.where(step < 1000, 0., lr)), + ]) + """ + + def __init__( + self, traversals_and_optimizers: Sequence[Tuple[traverse_util.Traversal, + OptimizerDef]]): + """Create a new MultiOptimizer. + + See docstring of :class:`MultiOptimizer` for more details. + + Args: + traversals_and_optimizers: pairs of flax.traverse_util.Traversal and + `optimizers.OptimizerDef` instances. + """ + traversals, sub_optimizers = zip(*traversals_and_optimizers) + hyper_params = [opt.hyper_params for opt in sub_optimizers] + super().__init__(hyper_params) + self.traversals = traversals + self.sub_optimizers = sub_optimizers + + def init_state(self, params): + param_states = jax.tree_map(lambda x: _Marker(), params) + overlap = False + for idx, traversal in enumerate(self.traversals): + for match in traversal.iterate(param_states): + match._indices.append(idx) # pylint: disable=protected-access + overlap |= len(match._indices) > 1 # pylint: disable=protected-access + if overlap: + raise ValueError( + 'Multiple optimizers match the same leaves : ' + + str(jax.tree_map(lambda match: match._indices, param_states))) # pylint: disable=protected-access + + param_states = jax.tree_map(lambda x: _Marker(), params) + for focus, opt_def in zip(self.traversals, self.sub_optimizers): + ps = _subtree_from_traversal(focus, params) + ss = opt_def.init_state(ps) + param_states = _update_subtree_of_traversal(focus, param_states, + ss.param_states) + # Update state to None when param is not optimized by any sub optimizer. + param_states = jax.tree_map( + lambda x: (None if isinstance(x, _Marker) else x), param_states) + return OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states) + + def apply_gradient(self, hyper_params, params, state, grads): + new_params = params + it = zip(self.traversals, self.sub_optimizers, hyper_params) + new_param_states = jax.tree_map(lambda x: _Marker(), params) + for focus, opt_def, hp in it: + ps = _subtree_from_traversal(focus, params) + gs = _subtree_from_traversal(focus, grads) + ss = _subtree_from_traversal(focus, state.param_states) + prev_ss = OptimizerState(state.step, ss) + new_ps, new_ss = opt_def.apply_gradient(hp, ps, prev_ss, gs) + new_params = _update_subtree_of_traversal(focus, new_params, new_ps) + new_param_states = _update_subtree_of_traversal(focus, new_param_states, + new_ss.param_states) + # Update state to None when param is not optimized by any sub optimizer. + new_param_states = jax.tree_map( + lambda x: (None if isinstance(x, _Marker) else x), new_param_states) + return new_params, OptimizerState(state.step + 1, new_param_states) + + def update_hyper_params(self, **hyper_param_overrides): + """Updates the hyper parameters with a set of overrides. + + This method is called from :meth:`Optimizer.apply_gradient` to create the + hyper parameters for a specific optimization step. + MultiOptimizer will apply the overrides for each sub optimizer. + + Args: + **hyper_param_overrides: the hyper parameters updates will override the + defaults specified in the `OptimizerDef`. Pass `hyper_params=...` to + replace all hyper parameters. + + Returns: + The new hyper parameters. + """ + hps = hyper_param_overrides.pop('hyper_params', self.hyper_params) + if hyper_param_overrides: + hps = [hp.replace(**hyper_param_overrides) for hp in hps] + return hps + + def set_param_axes(self, param_logical_axes): + """Derives factorization rules from model parameter logical axes.""" + for focus, opt_def in zip(self.traversals, self.sub_optimizers): + pla_subtree = _subtree_from_traversal(focus, param_logical_axes) + if hasattr(opt_def, 'set_param_axes'): + opt_def.set_param_axes(pla_subtree) + + def derive_logical_axes(self, optimizer, param_logical_axes): + """Derives optimizer logical partitioning from model logical partitions.""" + param_states = jax.tree_map(lambda x: _Marker(), + optimizer.state.param_states) + for focus, opt_def in zip(self.traversals, self.sub_optimizers): + if hasattr(opt_def, 'derive_logical_axes'): + ps = _subtree_from_traversal(focus, param_logical_axes) + ss = _subtree_from_traversal(focus, optimizer.state.param_states) + new_opt = opt_def.derive_logical_axes( + Optimizer(opt_def, OptimizerState(None, ss), ps), ps) + param_states = _update_subtree_of_traversal(focus, param_states, + new_opt.state.param_states) + # Update axes to None when param is not optimized by any sub optimizer. + param_states = jax.tree_map( + lambda x: (None if isinstance(x, _Marker) else x), param_states) + return Optimizer(optimizer.optimizer_def, + OptimizerState(None, param_states), param_logical_axes) + + # TODO(levskaya): add traversal handling for state_dict / restore_state + # this is required to make this work w. optax optimizers... diff --git a/t5x/optimizers_test.py b/t5x/optimizers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e7559b6e19536025cf2fead7b68f44ccff903ab2 --- /dev/null +++ b/t5x/optimizers_test.py @@ -0,0 +1,317 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.optimizers.""" + +import dataclasses +import functools +import operator + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import flax +from flax.core import frozen_dict +import jax +import jax.numpy as jnp +import numpy as np +import optax +import seqio +from t5x import models +from t5x import optimizers +from t5x import partitioning +from t5x import test_utils +from t5x import trainer +from t5x import utils +from t5x.examples.t5 import network + + +def _assert_numpy_allclose(a, b, atol=None, rtol=None): + a, b = jnp.array(a), jnp.array(b) + a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a + b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b + kw = {} + if atol: + kw['atol'] = atol + if rtol: + kw['rtol'] = rtol + np.testing.assert_allclose(a, b, **kw) + + +def check_eq(xs, ys, atol=None, rtol=None): + xs_leaves, xs_tree = jax.tree_flatten(xs) + ys_leaves, ys_tree = jax.tree_flatten(ys) + assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}" + assert jax.tree_util.tree_all( + jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape, + xs_leaves, ys_leaves)), "Leaves' shapes don't match." + assert jax.tree_multimap( + functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol), + xs_leaves, ys_leaves) + + +def flattened_state_dict(x): + s = flax.serialization.to_state_dict(x) + return flax.traverse_util.flatten_dict(s, sep='/') + + +def tree_shape(x): + return jax.tree_map(jnp.shape, x) + + +def tree_equals(x, y): + return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y)) + + +def get_fake_tokenized_dataset_no_pretokenized(*_, split='validation', **__): + return test_utils.get_fake_tokenized_dataset(split=split).map( + lambda x: {k: v for k, v in x.items() if not k.endswith('_pretokenized')}) + + +def get_t5_test_model(optimizer_def, + **config_overrides) -> models.EncoderDecoderModel: + """Returns a tiny T5 1.1 model to use for testing.""" + tiny_config = network.T5Config( + vocab_size=128, + dtype='bfloat16', + emb_dim=8, + num_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + head_dim=3, + mlp_dim=16, + mlp_activations=('gelu', 'linear'), + dropout_rate=0.0, + logits_via_embedding=False, + ) + tiny_config = dataclasses.replace(tiny_config, **config_overrides) + vocabulary = test_utils.get_fake_vocab() + return models.EncoderDecoderModel( + module=network.Transformer(tiny_config), + input_vocabulary=vocabulary, + output_vocabulary=vocabulary, + optimizer_def=optimizer_def) + + +class BasicTest(chex.TestCase): + + @classmethod + def get_params(cls): + return frozen_dict.FrozenDict({ + 'forward': { + 'input_layer': { + 'embedding': jnp.zeros([16, 8], dtype=jnp.float32), + }, + 'output_layer': { + 'layer_norm': { + 'scale': jnp.zeros([8], dtype=jnp.float32), + }, + 'proj': { + 'bias': jnp.zeros([1], dtype=jnp.float32), + 'kernel': jnp.zeros([8, 1], dtype=jnp.float32), + }, + }, + }, + 'loss': { + 'loss_fn': { + 'loss_biases': jnp.zeros([2], dtype=jnp.float32), + }, + }, + }) + + @classmethod + def get_params_shapes(cls): + return jax.tree_map(jnp.shape, cls.get_params()) + + @classmethod + def get_param_logical_axes(cls): + return frozen_dict.FrozenDict({ + 'forward': { + 'input_layer': { + 'embedding': partitioning.PartitionSpec('vocab', 'embed'), + }, + 'output_layer': { + 'layer_norm': { + 'scale': partitioning.PartitionSpec('embed',), + }, + 'proj': { + 'bias': + partitioning.PartitionSpec('output_head',), + 'kernel': + partitioning.PartitionSpec('embed', 'output_head'), + }, + }, + }, + 'loss': { + 'loss_fn': { + 'loss_biases': partitioning.PartitionSpec('unmodeled',), + }, + }, + }) + + def test_logical_axes_adamw(self): + opt = optax.adamw(0.001, weight_decay=0.001) + wrapper = optimizers.OptaxWrapper(opt) + optimizer = wrapper.create(self.get_params()) + got = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes()) + want = optimizers.Optimizer( + optimizer_def=wrapper, + state=optimizers.OptimizerState( + step=None, + param_states=( + optax.ScaleByAdamState( + count=None, + mu=self.get_param_logical_axes(), + nu=self.get_param_logical_axes()), + optax.EmptyState(), + optax.EmptyState(), + )), + target=self.get_param_logical_axes()) + chex.assert_trees_all_equal(got, want) + + @parameterized.parameters( + ('sgd', lambda: optax.sgd(1e-2, 0.0)), + ('adam', lambda: optax.adam(1e-1)), + ('adamw', lambda: optax.adamw(1e-1)), + ('lamb', lambda: optax.adamw(1e-1)), + ('rmsprop', lambda: optax.rmsprop(1e-1)), + ('rmsprop_momentum', lambda: optax.rmsprop(5e-2, momentum=0.9)), + ('fromage', lambda: optax.fromage(1e-2)), + ('adabelief', lambda: optax.adabelief(1e-1)), + ('radam', lambda: optax.radam(1e-1)), + ('yogi', lambda: optax.yogi(1.0)), + ) + def test_sanity_check_logical_axes(self, opt_name, opt_fn): + opt = opt_fn() + + wrapper = optimizers.OptaxWrapper(opt) + optimizer = wrapper.create(self.get_params()) + _ = wrapper.derive_logical_axes(optimizer, self.get_param_logical_axes()) + + # TODO(rosun): basic sanity check, we just want to make sure if a param + # name, e.g., `loss_biases` appear in the tree, the corresponding value is + # always a PartitionSpec. + + def test_adamw_state_serialization(self): + opt = optax.adamw(0.001, weight_decay=0.001) + wrapper = optimizers.OptaxWrapper(opt) + optimizer = wrapper.create(self.get_params()) + + state_dict = optimizer.state_dict() + + chex.assert_trees_all_equal( + frozen_dict.FrozenDict(jax.tree_map(jnp.shape, state_dict)), + frozen_dict.FrozenDict({ + 'target': self.get_params_shapes(), + 'state': { + 'step': (), + 'param_states': { + '0': { + 'count': (), + 'mu': self.get_params_shapes(), + 'nu': self.get_params_shapes(), + }, + # NB: We eliminate empty tuple leaves from EmptyState() in + # OptaxWrapper to avoid having the rest of T5X have to + # correctly handle this detail. e.g. we omit these: + # '1': {}, + # '2': {}, + }, + } + })) + + new_optimizer = optimizer.restore_state(state_dict) + + chex.assert_trees_all_equal(optimizer, new_optimizer) + + +class OptaxWrapperTest(chex.TestCase): + + def run_train_loop(self, optimizer_def): + # Construct input data. + + ds = get_fake_tokenized_dataset_no_pretokenized(split='validation') + ds = seqio.EncDecFeatureConverter()( + ds, task_feature_lengths={ + 'inputs': 8, + 'targets': 8 + }) + ds = ds.repeat().batch(8) + ds_iter = ds.as_numpy_iterator() + first_batch = next(ds_iter) + + model = get_t5_test_model(optimizer_def, vocab_size=128) + + learning_rate_fn = utils.create_learning_rate_scheduler() + + input_shapes = jax.tree_map(jnp.shape, first_batch) + input_types = jax.tree_map(lambda x: jnp.dtype(x.dtype), first_batch) + + partitioner = partitioning.PjitPartitioner( + num_partitions=2, + logical_axis_rules=partitioning.standard_logical_axis_rules()) + + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=model.optimizer_def, + init_fn=model.get_initial_variables, + input_shapes=input_shapes, + input_types=input_types, + partitioner=partitioner) + + train_state_axes = train_state_initializer.train_state_axes + train_state = train_state_initializer.from_scratch(jax.random.PRNGKey(0)) + + trainer_instance = trainer.Trainer( + model, + train_state=train_state, + partitioner=partitioner, + eval_names=[], + summary_dir=None, + train_state_axes=train_state_axes, + rng=jax.random.PRNGKey(0), + learning_rate_fn=learning_rate_fn, + num_microbatches=1) + + chex.assert_tree_all_finite(train_state.params) + for _ in range(2): + trainer_instance.train(ds_iter, 1) + chex.assert_tree_all_finite(train_state.params) + + # check save/restore structural equality + restored_instance = trainer_instance.train_state.restore_state( + trainer_instance.train_state.state_dict()) + chex.assert_tree_all_equal_structs(trainer_instance.train_state, + restored_instance) + + # NOTE(levskaya): these are surprisingly slow tests on CPU. + @parameterized.parameters( + ('sgd', lambda: optax.sgd(1e-2, 0.0)), + ('adam', lambda: optax.adam(1e-1)), + ('adamw', lambda: optax.adamw(1e-1)), + ('lamb', lambda: optax.adamw(1e-1)), + # ('rmsprop', lambda: optax.rmsprop(1e-1)), + # ('rmsprop_momentum', lambda: optax.rmsprop(5e-2, momentum=0.9)), + # ('fromage', lambda: optax.fromage(1e-2)), + ('adabelief', lambda: optax.adabelief(1e-1)), + # ('radam', lambda: optax.radam(1e-1)), + ('yogi', lambda: optax.yogi(1.0)), + ) + def test_optimizer(self, opt_name, opt_fn): + opt = opt_fn() + optimizer_def = optimizers.OptaxWrapper(opt) + self.run_train_loop(optimizer_def) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/partitioning.py b/t5x/partitioning.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e9c3d46c9c1ef4142b554eb577d3821fa89e1d --- /dev/null +++ b/t5x/partitioning.py @@ -0,0 +1,902 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for partitioning.""" + +import abc +import collections +import dataclasses +import typing +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +from absl import logging +import cached_property +from flax import traverse_util +from flax.linen import partitioning as flax_partitioning +import jax +from jax import numpy as jnp +from jax import random +from jax.experimental import PartitionSpec +from jax.experimental.maps import Mesh +from jax.experimental.pjit import pjit as jax_pjit +import numpy as np +from t5x import train_state as train_state_lib + +JaxDevice = jax.lib.xla_client.Device +TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores). +OtherMesh = Tuple[int, int] +HardwareMesh = Union[TpuMesh, OtherMesh] +PyTreeDef = type(jax.tree_structure(None)) +TrainState = train_state_lib.TrainState +LogicalAxisRules = Sequence[Tuple[str, Optional[str]]] + +if typing.TYPE_CHECKING: # See b/163639353 + cached_property = property # pylint: disable=invalid-name +else: + cached_property = cached_property.cached_property + + +class AxisNames(tuple): + """Tuple of strings specifying name for each axis. + + We create a separate class for this so JAX's pytree utilities can distinguish + it from a tuple that should be treated as a pytree, instead treating it as a + leaf. + """ + + def __new__(cls, *names): + return tuple.__new__(AxisNames, names) + + def __repr__(self): + return 'AxisNames%s' % tuple.__repr__(self) + + +# pjit wrappers for cpu fallback. +# ----------------------------------------------------------------------------- +# TODO(levskaya): upstream this fallback behavior to jax pjit. +def pjit( + fun: Callable, # pylint: disable=g-bare-generic + in_axis_resources, + out_axis_resources, + static_argnums: Union[int, Sequence[int]] = (), + donate_argnums: Union[int, Sequence[int]] = (), + backend: Optional[str] = None): + """Wrapper for pjit that calls normal jit on cpu.""" + if jax.devices(backend)[0].platform == 'cpu': + return jax.jit( + fun, static_argnums=static_argnums, donate_argnums=donate_argnums) + else: + return jax_pjit( + fun, + in_axis_resources, + out_axis_resources, + static_argnums=static_argnums, + donate_argnums=donate_argnums) + + +def with_sharding_constraint(x, axis_resources): + """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit.""" + if jax.devices()[0].platform == 'cpu' or not global_mesh_defined(): + return x + else: + return jax.experimental.pjit.with_sharding_constraint(x, axis_resources) + + +# pjit Mesh creation functions. +# ----------------------------------------------------------------------------- +def bounds_from_last_device( + last_device: jax.lib.xla_client.Device) -> HardwareMesh: + """Get the bound from the given last device.""" + # Must be passed the device at the highest-coordinate corner of the + # relevant mesh, which is a requirement we know is satisfied by the last + # device in jax.devices(). + if hasattr(last_device, 'coords'): + x, y, z = last_device.coords + return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 + else: + # On non-TPU platforms, the "mesh" is hosts x devices per host in order + # to take advantage of faster within-host interconnect. + return jax.host_count(), jax.local_device_count() + + +def get_coords(device: jax.lib.xla_client.Device) -> HardwareMesh: + """Returns the coordinates of the given device.""" + if hasattr(device, 'coords'): + return (*device.coords, device.core_on_chip) + return (device.process_index, device.id % jax.local_device_count()) + + +def global_mesh_defined(): + """Checks if global xmap/pjit mesh resource environment is defined.""" + maps_env = jax.experimental.maps.thread_resources.env + return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison + + +def get_mesh(model_parallel_submesh: HardwareMesh, + input_devices: Sequence[JaxDevice] = (), + input_local_devices: Sequence[JaxDevice] = (), + tile_by_host_if_needed: bool = True, + backend: Optional[str] = None) -> Mesh: + """Construct an xmap/pjit Mesh for the given model-parallel submesh. + + The resulting mesh has two resource axes: 'model', with the provided submesh + shape, and 'data', which covers the rest of the mesh. + + Args: + model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for + a single model-parallel replica's "tile" in the physical device mesh. The + first three elements (`x`, `y`, and `z`) should be factors of the pod + slice; e.g., if you are using df_4x8, then `x` should be a factor of 4 + (one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z` + must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4 + (and maybe later TPUs) that allow 3D slices. `core` is the number of cores + to use from each TPU node. As communication is usually fastest inside the + same node, if you need a tile of more than 1 core, then + you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better + than (2,1,1,1). To pick a good spec, try a few possible values until you + get high TPU utilization. + input_devices: the devices to use, will use jax.devices() if this is not + set. + input_local_devices: the local devices to use, will use jax.local_devices() + if this is not set. + tile_by_host_if_needed: JAX currently requires that the parts of any sharded + array that are located on one host's local devices form a single + contiguous slice. A best effort will be made to achieve this without + "tiling" the device assignment over hosts (which can reduce XLA collective + performance). If this flag is True, then the device assignment will be + tiled over hosts if necessary to satisfy this constraint and create a + buildable mesh; if false, mesh construction will fail instead. + backend: get devices from the pinned backend, if specified. This is + useful for explicitly specifying the devices other than relying on + jax_platform_name. + + Returns: + A xmap / pjit Mesh containing the virtual device mesh with data, model axes. + """ + input_devices = input_devices or jax.devices(backend) + input_local_devices = input_local_devices or jax.local_devices(0, backend) + last_device = input_devices[-1] + global_hardware_mesh = bounds_from_last_device(last_device) + mesh_ndim = len(global_hardware_mesh) + local_hardware_mesh = bounds_from_last_device(input_local_devices[-1]) + mesh_err = ( + f'each dimension of the model parallel submesh {model_parallel_submesh} ' + 'must be a factor of the corresponding dimension of the global device ' + f'mesh {global_hardware_mesh}') + assert not any( + g % m + for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err + assert not any( + g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh)) + devices = np.empty(global_hardware_mesh, dtype=np.object) + for device in input_devices: + device_coords = get_coords(device) + devices[device_coords] = device + tile_by_host = tile_by_host_if_needed + if len(global_hardware_mesh) == 4: + # enable contiguous local chunks without host tiling by making Z major + global_hardware_mesh = typing.cast(Tuple[int, int, int, int], + global_hardware_mesh) + model_parallel_submesh = typing.cast(Tuple[int, int, int, int], + model_parallel_submesh) + gx, gy, gz, gc = global_hardware_mesh + mx, my, mz, mc = model_parallel_submesh + if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and + mz == gz > 1): + logging.info('ensuring YZ plane has a Z-major device order') + # YZ should be ZY + assert mc == gc, (mc, gc) + global_hardware_mesh = gx, gz, gy, gc + model_parallel_submesh = mx, mz, my, mc + devices = devices.swapaxes(1, 2) + tile_by_host = False + if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and + mz == gz > 1): + logging.info('ensuring XZ plane has a Z-major device order') + # XZ should be ZX + assert mc == gc, (mc, gc) + global_hardware_mesh = gz, gy, gx, gc + model_parallel_submesh = mz, my, mx, mc + devices = devices.swapaxes(0, 2) + tile_by_host = False + if tile_by_host: + logging.warning( + 'Tiling device assignment mesh by hosts, which may lead to ' + 'reduced XLA collective performance. To avoid this, modify ' + 'the model parallel submesh or run with more tasks per host.') + tile_err = ( + 'to tile the mesh by hosts, each dimension of the model parallel ' + 'submesh must be either a factor or a multiple of the corresponding ' + 'dimension of the per-host submesh') + + def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]: + """Split a global mesh dimension into four tiling components. + + Args: + g: global mesh bounds dimension size + m: model-parallel submesh bounds dimension size + l: local submesh bounds dimension size + + Returns: + The resulting tuple divides the dimension into the hosts component of + the data-parallel submesh, the devices component of the data-parallel + submesh, the hosts component of the model-parallel submesh, and the + devices component of the model-parallel submesh. + """ + d = g // m + if m >= l: + assert not m % l, tile_err + return (d, 1, m // l, l) + else: + assert not l % m, tile_err + return (d // (l // m), l // m, 1, m) + + # e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...] + dh_dd_mh_md_tups = map(dh_dd_mh_md, global_hardware_mesh, + model_parallel_submesh, local_hardware_mesh) + # reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...) + devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension + # TODO(jekbradbury): reorder local subgroups for ring locality + # Transpose to [data_host], [data_device], [model_host], [model_device] + # block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...) + devices = devices.transpose(*(4 * i for i in range(mesh_ndim)), + *(4 * i + 1 for i in range(mesh_ndim)), + *(4 * i + 2 for i in range(mesh_ndim)), + *(4 * i + 3 for i in range(mesh_ndim))) + else: + # e.g. [(x_data, x_model), (y_data, y_model), ...] + model_data_tups = [ + (g // m, m) + for g, m in zip(global_hardware_mesh, model_parallel_submesh) + ] + # reshape to e.g. (x_data, x_model, y_data, y_model...) + devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension + # TODO(jekbradbury): reorder small subgroups for ring locality + # transpose to e.g. (x_data, y_data, ..., x_model, ...) + devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), + *(2 * i + 1 for i in range(mesh_ndim))) + # reshape to (data, model) + devices = devices.reshape(-1, np.prod(model_parallel_submesh)) + global_mesh = Mesh(devices, ['data', 'model']) + logging.info('global_mesh axes_names: %s', global_mesh.axis_names) + logging.info('global_mesh devices: %s', global_mesh.devices) + return global_mesh + + +def get_cpu_mesh() -> Mesh: + """Trivial mesh for CPU Testing.""" + devices = np.empty((jax.host_count(), jax.local_device_count()), + dtype=np.object) + for device in jax.devices(): + devices[device.process_index, device.id % jax.local_device_count()] = device + return Mesh(devices, ['data', 'model']) + + +def get_gpu_mesh() -> Mesh: + """Simple mesh for GPUs.""" + devices = np.empty((jax.host_count(), jax.local_device_count()), + dtype=np.object) + for device in jax.devices(): + devices[device.process_index, device.id % jax.local_device_count()] = device + return Mesh(devices, ['data', 'model']) + + +def default_mesh(num_partitions: int, + model_parallel_submesh: Optional[HardwareMesh] = None, + backend: Optional[str] = None) -> Mesh: + """Attempt to return a default mesh for simple cases. + + Args: + num_partitions: number of partitions to use, will be ignored if + model_parallel_submesh is provided. + model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as + the model-parallel device tile. + backend: get devices from the pinned backend, if specified. This is useful + for explicitly specifying the devices other than relying on + jax_platform_name. + + Returns: + xmap/pjit 2D Mesh with 'data', 'model' mesh axes. + """ + last_device = jax.devices(backend)[-1] + platform = last_device.platform + device_kind = last_device.device_kind + bounds = bounds_from_last_device(last_device) + + if model_parallel_submesh: + return get_mesh(model_parallel_submesh, backend=backend) + + if platform == 'cpu': + return get_cpu_mesh() + elif platform == 'gpu': + return get_gpu_mesh() + + mps = None + if device_kind in ('TPU v2', 'TPU v3'): + if num_partitions == 1: + mps = (1, 1, 1, 1) + elif num_partitions == 2: + mps = (1, 1, 1, 2) + elif num_partitions == 4: + mps = (2, 1, 1, 2) + elif num_partitions == 8: + mps = (2, 2, 1, 2) + elif num_partitions == 16: + mps = (4, 2, 1, 2) + # assume the use of megacore on TPU v4 + elif device_kind == 'TPU v4' and bounds[3] == 1: + if num_partitions == 1: + mps = (1, 1, 1, 1) + elif num_partitions == 2: + mps = (1, 2, 1, 1) + elif num_partitions == 4: + if bounds[0] >= 4: + mps = (4, 1, 1, 1) + else: + mps = (2, 2, 1, 1) + elif num_partitions == 8: + if bounds[2] >= 8: + mps = (1, 1, 8, 1) + else: + mps = (4, 2, 1, 1) + elif num_partitions == 16: + if bounds[2] >= 16: + mps = (1, 1, 16, 1) + elif bounds[0] >= 8: + mps = (8, 2, 1, 1) + else: + mps = (4, 4, 1, 1) + + if mps is None: + raise ValueError('No default mesh for this configuration: specify ' + 'config.model_parallel_submesh explicitly.') + return get_mesh(mps, backend=backend) + + +# Data chunking helper. +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class LocalChunkInfo: + # The logical slice of an array located on this host's local devices. + slice: Tuple[slice, ...] + # A unique index for this host/local chunk among chunks with the same slice. + replica_id: int + + +class LocalChunker: + """Utility class to aid chunking of sharded arrays in multihost settings.""" + + def __init__(self, global_mesh: Mesh): + self.global_mesh = global_mesh + local_mesh = global_mesh.local_mesh + first_local_device = local_mesh.devices.reshape(-1)[0] + host_location = collections.OrderedDict( + zip( + global_mesh.shape.keys(), + list(zip(*np.nonzero( + global_mesh.devices == first_local_device)))[0])) + self.num_chunks = collections.OrderedDict() + self.chunk_ids = collections.OrderedDict() + self.mesh_axes = list(global_mesh.shape.keys()) + for mesh_axis in self.mesh_axes: + num_devices_per_chunk = local_mesh.shape[mesh_axis] + self.num_chunks[mesh_axis] = ( + global_mesh.shape[mesh_axis] // num_devices_per_chunk) + self.chunk_ids[mesh_axis] = ( + host_location[mesh_axis] // num_devices_per_chunk) + + def get_local_chunk_info( + self, global_shape: Tuple[int, ...], + mesh_axes: Sequence[Optional[str]]) -> LocalChunkInfo: + """Get the local chunk info for a given array shape and sharded axes. + + Args: + global_shape: the global, unsharded shape of the array to chunk. + mesh_axes: a sequence of names (or None) of equal rank to `global_shape` + that specifies which mesh dimensions the array is sharded along. + + Returns: + LocalChunkInfo containing the logical slices of the array found on this + host's local devices, as well as the replica index for this chunk among + chunks with the same slice. The latter is used to determine which + host should write this chunk during checkpointing. + """ + local_slice = [slice(None) for dim in global_shape] + sharded_mesh_axes = set() + for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)): + if not mesh_axis: + continue + sharded_mesh_axes.add(mesh_axis) + if not isinstance(mesh_axis, str): + raise NotImplementedError('TODO(jekbradbury)') + chunk_id = self.chunk_ids[mesh_axis] + chunk_size = size // self.num_chunks[mesh_axis] + local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size) + + replicated_mesh_axes = [ + mesh_axis for mesh_axis in self.mesh_axes + if mesh_axis not in sharded_mesh_axes + ] + replica_id = 0 + for mesh_axis in replicated_mesh_axes: + chunk_id = self.chunk_ids[mesh_axis] + replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id + + return LocalChunkInfo(tuple(local_slice), replica_id) + + +def standard_logical_axis_rules( + activation_partitioning_dims: int = 1, + parameter_partitioning_dims: int = 1, + additional_rules: Optional[LogicalAxisRules] = None) -> LogicalAxisRules: + """Default sharding rules for T5X model in terms of logical axis names. + + Args: + activation_partitioning_dims: enables 2-D activation sharding when set to 2. + parameter_partitioning_dims: enables 2-D parameter sharding when set to 2. + additional_rules: additional rules (a sequence of tuples) that will be + appended to the standard rules. + + Returns: + Sequence of logical axis rules + """ + logging.info( + '`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d', + activation_partitioning_dims, parameter_partitioning_dims) + + if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1: + rules = [ + ('batch', 'data'), + ('vocab', 'model'), + ('embed', None), + ('mlp', 'model'), + ('heads', 'model'), + ('kv', None), + ('joined_kv', 'model'), # joined heads+kv dim in 2D attn param layouts + ] + elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1: + rules = [ + ('batch', 'data'), + ('vocab', 'model'), + ('mlp', 'model'), + ('heads', 'model'), + ('kv', None), + ('joined_kv', 'model'), + ('embed', 'model'), + ] + elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2: + rules = [ + ('batch', 'data'), + ('vocab', 'model'), + ('mlp', 'model'), + ('heads', 'model'), + ('kv', None), + ('joined_kv', 'model'), + ('embed', 'data'), + ] + elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2: + rules = [ + ('batch', 'data'), + ('vocab', 'model'), + ('mlp', 'model'), + ('heads', 'model'), + ('kv', None), + ('joined_kv', 'model'), + ('embed', 'model'), + ('embed', 'data'), + ] + else: + raise ValueError( + f'`activation_partitioning_dims` = {activation_partitioning_dims} ' + f'`parameter_partitioning_dims` = {parameter_partitioning_dims} ' + 'is not supported.') + + # Add the common rules for the replicated logical axes names. + replicated_rules = [ + ('relpos_buckets', None), + ('abspos_buckets', None), + ('length', None), + ('layers', None), + ('stack', None), + ('mlp_activations', None), + ] + rules.extend(replicated_rules) + + if additional_rules: + rules.extend(additional_rules) + + return rules + + +# NB: This needs to be top-level for the jax compilation cache. +def _id_fn(x, ix): + """Identity function for copying parameters to the devices, sharded.""" + # A pure identity such as `lambda x, *: x` can get optimized away, so we + # include a random.split as a cheap function that cannot be optimized away. + return x, random.split(jnp.array([ix, ix], dtype=jnp.uint32)) + + +@dataclasses.dataclass +class DataLayout: + """Represents data layout for the partitioned model.""" + batch_size: int + shard_id: int + num_shards: int + is_first_host_in_replica_set: bool + + +PartitionedCallable = Callable[..., Any] +CompiledPartitionedCallable = Callable[..., Any] + + +class BasePartitioner(metaclass=abc.ABCMeta): + """Interface for partitioning computations across hardware devices.""" + + def __init__(self, + num_partitions: Optional[int] = None, + model_parallel_submesh: Optional[HardwareMesh] = None, + params_on_devices: bool = True, + backend: Optional[str] = None): + """Configures the partitioner. + + Args: + num_partitions: the number of partitions to use. Ignored if + `model_parallel_submesh` is provided. + model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use + as the model-parallel device tile. This submesh is used for the larger + of the two parameter dimensions, and, if 2-D activation sharding is + enabled, for the model dimension of activations. The rest of the mesh is + used for data parallelism and, if 2-D parameter sharding is enabled, the + other parameter dimension. + params_on_devices: whether to keep the params on devices, if False - + params stay in the host memory. Note that some partitioners might ignore + this setting, for example if they don't support storing all params on + device memory. + backend: get devices from the pinned backend, if specified. This is useful + for explicitly specifying the devices other than relying on + jax_platform_name. + """ + + if not num_partitions and not model_parallel_submesh: + raise ValueError('At least one of `num_partitions` or ' + '`model_parallel_submesh` must be set.') + + if model_parallel_submesh is not None and len(model_parallel_submesh) != 4: + logging.error( + '`model_parallel_submesh` must be either None or a 4-tuple. Got ' + 'Got `num_partitions=%s`. A ValueError will be raised beginning ' + 'March 1, 2022.', model_parallel_submesh) + + if bool(num_partitions) and bool(model_parallel_submesh): + logging.error( + 'At most one of `num_partitions` or `model_parallel_submesh` can be ' + 'set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A ' + 'ValueError will be raised beginning March 21, 2022.', num_partitions, + model_parallel_submesh) + + self._num_partitions = num_partitions + self._model_parallel_submesh = model_parallel_submesh + self._params_on_devices = params_on_devices + self._data_axis = 'data' + self._backend = backend + + @property + def mesh(self) -> Mesh: + raise NotImplementedError + + @property + def data_partition_spec(self) -> PartitionSpec: + return PartitionSpec(self._data_axis) + + def get_data_layout(self, + batch_size: Optional[int] = None, + host_index: Optional[int] = None) -> DataLayout: + """Returns filled `DataLayout` based on the partitioned model layout. + + Args: + batch_size: if set, indicates the requested batch size. The exception will + be raised if this batch size is not compatible with the layout. If not + set, the batch size is inferred from the layout. + host_index: indicates the host index to use for the calculations, if not + set - use JAX-provided one. Should be in [0, num_hosts) interval and the + order should match the order of corresponding CPU devices in + `jax.devices()`. + + Returns: + Filled `DataLayout` structure. + """ + if host_index is not None: + raise NotImplementedError('Explicit host_index is not yet implemented.') + if self._data_axis is None: + return DataLayout( + batch_size=batch_size, + shard_id=0, + num_shards=1, + is_first_host_in_replica_set=(jax.process_index() == 0)) + mesh_size = self._local_chunker.global_mesh.shape[self._data_axis] + batch_size = batch_size or mesh_size + if batch_size % mesh_size: + raise ValueError( + f'Batch size ({batch_size}) must be divisible by corresponding ' + f'mesh size ({mesh_size}).') + num_shards = self._local_chunker.num_chunks[self._data_axis] + if batch_size % num_shards: + raise ValueError( + f'Batch size ({batch_size}) must be divisible by number of ' + f'replicas ({num_shards}).') + replica_id = self._local_chunker.get_local_chunk_info( + (batch_size,), [self._data_axis]).replica_id + return DataLayout( + batch_size=batch_size, + shard_id=self._local_chunker.chunk_ids[self._data_axis], + num_shards=num_shards, + is_first_host_in_replica_set=(replica_id == 0)) + + def get_local_chunk_info( + self, global_shape: Tuple[int, ...], + mesh_axes: Sequence[Optional[str]]) -> LocalChunkInfo: + """Returns the local chunk info for a given array shape and sharded axes.""" + return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes) + + @property + def params_on_devices(self): + return self._params_on_devices + + def move_params_to_devices(self, train_state: TrainState, + train_state_axes: TrainState) -> TrainState: + """Moves the optimizer parameters to devices.""" + p_id_fn = self.partition( + _id_fn, + in_axis_resources=(train_state_axes, None), + out_axis_resources=(train_state_axes, None), + donate_argnums=(0,)) + train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32)) + return train_state + + @property + @abc.abstractmethod + def _local_chunker(self): + """Returns the chunker that matches the parameters of this partitioner.""" + raise NotImplementedError + + def get_logical_axes(self, train_state: TrainState) -> TrainState: + """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" + # By default, return None for the logical axes. + return train_state.restore_state( + jax.tree_map(lambda x: None, train_state.state_dict())) + + def get_mesh_axes(self, train_state: TrainState) -> TrainState: + """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" + raise NotImplementedError + + @abc.abstractmethod + def partition( + self, + fn: Callable, # pylint: disable=g-bare-generic + in_axis_resources, + out_axis_resources, + static_argnums: Union[int, Sequence[int]] = (), + donate_argnums: Union[int, Sequence[int]] = () + ) -> PartitionedCallable: + """Partitions the computation using partitioner-specific implementation. + + Args: + fn: the function to partition. + in_axis_resources: Pytree of structure matching that of arguments to `fn`, + with all actual arguments replaced by resource assignment + specifications. It is also valid to specify a pytree prefix (e.g. one + value in place of a whole subtree), in which case the leaves get + broadcast to all values in that subtree. + The valid resource assignment specifications are: + `None`: in which case the value will be replicated on all devices + `PartitionSpec`: a tuple of length at most equal to the rank of the + partitioned value. Each element can be a `None`, a mesh axis or a + tuple of mesh axes, and specifies the set of resources assigned to + partition the value's dimension matching its position in the spec. + out_axis_resources: Like `in_axis_resources`, but specifies resource + assignment for function outputs. + static_argnums: an optional int or collection of ints that specify which + positional arguments to treat as static (compile-time constant) in the + partitioned function. + donate_argnums: an optional int or collection of ints that specify which + argument buffers are "donated" to the computation. It is safe to donate + argument buffers if you no longer need them once the computation has + finished. + + Returns: + A partitioned version of the input function. + """ + raise NotImplementedError + + @abc.abstractmethod + def compile(self, partitioned_fn: PartitionedCallable, + *args) -> CompiledPartitionedCallable: + """Compiles and returns the partitioned function, or the original. + + Args: + partitioned_fn: The partitioned function. + *args: Sample arguments to the partitioned function matching the input + shapes that will be passed to the compiled function. + + Returns: + The compiled function, or the original if this partitioner does not + support compilation. + """ + raise NotImplementedError + + +class PjittedFnWithContext(PartitionedCallable): + """Wraps pjitted function to apply the appropriate contexts.""" + + def __init__(self, + pjitted_fn, + partition_mesh: Mesh, + logical_axis_rules: flax_partitioning.LogicalRules = ()): + self._pjitted_fn = pjitted_fn + self._mesh = partition_mesh + self._logical_axis_rules = logical_axis_rules + + def __call__(self, *args): + with Mesh(self._mesh.devices, + self._mesh.axis_names), flax_partitioning.axis_rules( + self._logical_axis_rules): + return self._pjitted_fn(*args) + + def lower(self, *args): + with Mesh(self._mesh.devices, + self._mesh.axis_names), flax_partitioning.axis_rules( + self._logical_axis_rules): + return self._pjitted_fn.lower(*args) + + +class BasePjitPartitioner(BasePartitioner): + """Partitioner that uses T5X version of jax.pjit.""" + + @cached_property + def _local_chunker(self) -> LocalChunker: + return LocalChunker(self.mesh) + + @cached_property + def mesh(self) -> Mesh: + return default_mesh(self._num_partitions, self._model_parallel_submesh, + self._backend) + + def partition( + self, + fn: Callable, # pylint: disable=g-bare-generic + in_axis_resources, + out_axis_resources, + static_argnums: Union[int, Sequence[int]] = (), + donate_argnums: Union[int, Sequence[int]] = () + ) -> PjittedFnWithContext: + pjitted = pjit( + fn, + in_axis_resources=in_axis_resources, + out_axis_resources=out_axis_resources, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + backend=self._backend) + + return PjittedFnWithContext(pjitted, self.mesh) + + def compile(self, partitioned_fn: PjittedFnWithContext, + *args) -> CompiledPartitionedCallable: + return partitioned_fn.lower(*args).compile() + + +class PjitPartitioner(BasePjitPartitioner): + """Partitioner that uses named axes and jax.pjit.""" + + def __init__(self, + num_partitions: Optional[int] = None, + model_parallel_submesh: Optional[HardwareMesh] = None, + params_on_devices: bool = True, + backend: Optional[str] = None, + logical_axis_rules: Optional[LogicalAxisRules] = None): + """PjitPartitioner constructor. + + See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details. + + Args: + num_partitions: an integer that specifies the size of the model parallel + submesh to be automatically selected for the current topology. See + `model_parallel_submesh` for details on how this submesh is used. + Mutually exlusive with `model_parallel_submesh`. + model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)` + submesh model-parallel device tile, an axis of accelerator parallelism + orthogonal to data parallelism. Array axes in a model's parameters or + activations can be sharded over this submesh using axis rules (see + `logical_axis_rules`) that map them to 'model'. The effective number of + model sub-partitions is equal to `np.prod(model_parallel_submesh)` and + must evenly divide the total number of devices (i.e., + `jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest + of the TPU mesh is the data parallel submesh, providing + `jax.device_count() // np.prod(model_parallel_submesh)` partitions. It + is used for data (batch) parallelism and to shard other array axes that + are mapped to 'data'. This argument is mutually exclusive with + `num_partitions`. + params_on_devices: whether to keep the params on devices, if False - + params stay in the host memory. Note that some partitioners might ignore + this setting, for example if they don't support storing all params on + device memory. + backend: get devices from the pinned backend, if specified. This is + useful for explicitly specifying the devices other than relying on + jax_platform_name. + logical_axis_rules: a priority-ordered sequence of KV tuples that maps + logical axis names to either `None` (not sharded), 'model' (to shard + across the model-parallel submesh), or 'data' (to shard across the + data-parallel submesh). + """ + super().__init__( + num_partitions=num_partitions, + model_parallel_submesh=model_parallel_submesh, + params_on_devices=params_on_devices, + backend=backend) + if logical_axis_rules is None: + logical_axis_rules = standard_logical_axis_rules() + self._logical_axis_rules = tuple(logical_axis_rules) + self._data_axis, = flax_partitioning.logical_to_mesh_axes( + ['batch'], logical_axis_rules) + + def partition( + self, + fn: Callable, # pylint: disable=g-bare-generic + in_axis_resources, + out_axis_resources, + static_argnums: Union[int, Sequence[int]] = (), + donate_argnums: Union[int, Sequence[int]] = () + ) -> PjittedFnWithContext: + """Partitions the function using jax.pjit.""" + pjitted = pjit( + fn, + in_axis_resources=in_axis_resources, + out_axis_resources=out_axis_resources, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + backend=self._backend) + + return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules) + + @property + def logical_axis_rules(self): + """Returns the logical axis rules.""" + return self._logical_axis_rules + + def get_logical_axes(self, train_state: TrainState) -> TrainState: + """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" + return train_state.as_logical_axes() + + def get_mesh_axes(self, train_state: TrainState) -> TrainState: + """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" + logical_axes = self.get_logical_axes(train_state) + + def _logical_to_mesh_axes(param_name, logical_axes): + if logical_axes is None: + return None + elif logical_axes is traverse_util.empty_node: + return traverse_util.empty_node + try: + return flax_partitioning.logical_to_mesh_axes(logical_axes, + self._logical_axis_rules) + except ValueError as e: + raise ValueError(f'Failed to map logical axes for {param_name}') from e + + flat_logical_axes = traverse_util.flatten_dict( + logical_axes.state_dict(), keep_empty_nodes=True, sep='/') + flat_mesh_axes = { + k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() + } + + return logical_axes.restore_state( + traverse_util.unflatten_dict(flat_mesh_axes, sep='/')) diff --git a/t5x/partitioning_test.py b/t5x/partitioning_test.py new file mode 100644 index 0000000000000000000000000000000000000000..68594d1ec842341c074d7d87534d8bb46ee25237 --- /dev/null +++ b/t5x/partitioning_test.py @@ -0,0 +1,272 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.partitioning.""" + +import collections + +from absl.testing import absltest +from absl.testing import parameterized +import flax.core +from flax.linen import partitioning as nn_partitioning +import jax +import numpy as np +from t5x import adafactor +from t5x import optimizers +from t5x import partitioning +from t5x import test_utils as ptu +from t5x import train_state + +jax.config.parse_flags_with_absl() + +mock = absltest.mock +TpuDevice = ptu.TpuDevice +TPUV3_32 = ptu.make_devices(4, 4, 1, 2, kind='TPU v3') +AxisMetadata = nn_partitioning.AxisMetadata +PartitionSpec = partitioning.PartitionSpec + + +class PartitioningTest(absltest.TestCase): + + @mock.patch('jax.host_count') + @mock.patch('jax.local_device_count') + def test_bounds_from_last_device(self, local_device_count, host_count): + last_device = mock.Mock(coords=(3, 3, 3), core_on_chip=1) + tpu_bounds = partitioning.bounds_from_last_device(last_device) + self.assertEqual(tpu_bounds, (4, 4, 4, 2)) + + last_device = mock.Mock(spec=[]) + host_count.return_value = 1 + local_device_count.return_value = 4 + non_tpu_bounds = partitioning.bounds_from_last_device(last_device) + self.assertEqual(non_tpu_bounds, (1, 4)) + + @mock.patch('jax.local_device_count') + def test_get_coords(self, local_device_count): + device = mock.Mock(coords=(1, 0, 1), core_on_chip=1) + coords = partitioning.get_coords(device) + self.assertEqual(coords, (1, 0, 1, 1)) + + device = mock.Mock(spec=['process_index', 'id']) + device.process_index = 1 + device.id = 9 + local_device_count.return_value = 8 + coords = partitioning.get_coords(device) + self.assertEqual(coords, (1, 1)) + + @mock.patch('jax.local_devices') + @mock.patch('jax.devices') + @mock.patch('jax._src.lib.xla_bridge.process_index') + def test_default_mesh(self, process_index_fn, devices_fn, local_devices_fn): + devices_fn.return_value = TPUV3_32 + local_devices_fn.return_value = [ + d for d in TPUV3_32 if d.process_index == 0 + ] + process_index_fn.return_value = 0 + + global_mesh = partitioning.default_mesh(4) + self.assertEqual(global_mesh.axis_names, ('data', 'model')) + self.assertEqual(global_mesh.shape, + collections.OrderedDict((('data', 8), ('model', 4)))) + self.assertEqual(global_mesh.size, 32) + + for process_index in (0, 1, 2, 3): + process_index_fn.return_value = process_index + local_mesh = global_mesh.local_mesh + self.assertEqual(local_mesh.axis_names, ('data', 'model')) + self.assertEqual(local_mesh.shape, + collections.OrderedDict((('data', 2), ('model', 4)))) + self.assertEqual(local_mesh.size, 8) + + process_index_fn.return_value = 0 + local_mesh = global_mesh.local_mesh + lds = np.array([ + [ + TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0), + TpuDevice(id=1, process_index=0, coords=(0, 0, 0), core_on_chip=1), + TpuDevice(id=2, process_index=0, coords=(1, 0, 0), core_on_chip=0), + TpuDevice(id=3, process_index=0, coords=(1, 0, 0), core_on_chip=1) + ], + [ + TpuDevice(id=8, process_index=0, coords=(0, 1, 0), core_on_chip=0), + TpuDevice(id=9, process_index=0, coords=(0, 1, 0), core_on_chip=1), + TpuDevice(id=10, process_index=0, coords=(1, 1, 0), core_on_chip=0), + TpuDevice(id=11, process_index=0, coords=(1, 1, 0), core_on_chip=1) + ] + ], + dtype=object) + np.testing.assert_array_equal(local_mesh.devices, lds) + + @mock.patch('jax.local_devices') + @mock.patch('jax.devices') + @mock.patch('jax._src.lib.xla_bridge.process_index') + def test_local_chunker(self, process_index_fn, devices_fn, local_devices_fn): + devices_fn.return_value = TPUV3_32 + local_devices_fn.return_value = [ + d for d in TPUV3_32 if d.process_index == 0 + ] + process_index_fn.return_value = 0 + global_mesh = partitioning.default_mesh(4) + local_chunker = partitioning.LocalChunker(global_mesh) + self.assertEqual(local_chunker.num_chunks['data'], 4) + self.assertEqual(local_chunker.num_chunks['model'], 1) + + # Derive the chunk order along the first 'data' dim for testing. + host_ordering = [] + for d in global_mesh.devices[:, 0]: + if d.process_index not in host_ordering: + host_ordering.append(d.process_index) + process_index_to_data_pos = { + process_index: idx for idx, process_index in enumerate(host_ordering) + } + + for process_indexx in (0, 1, 2, 3): + process_index_fn.return_value = process_indexx + global_mesh = partitioning.default_mesh(4) + local_chunker = partitioning.LocalChunker(global_mesh) + # get expected chunk for 'data' axis. + expected_chunk = process_index_to_data_pos[process_indexx] + self.assertEqual(local_chunker.chunk_ids['data'], expected_chunk) + self.assertEqual(local_chunker.chunk_ids['model'], 0) + # Sharded along both axes. + local_chunk_info = local_chunker.get_local_chunk_info((128, 16), + ['data', 'model']) + self.assertEqual(local_chunk_info.replica_id, 0) + self.assertEqual(local_chunk_info.slice, + (slice(32 * expected_chunk, 32 * + (expected_chunk + 1)), slice(0, 16))) + # Replicated across first axis. + local_chunk_info = local_chunker.get_local_chunk_info((128, 16), + [None, 'model']) + self.assertEqual(local_chunk_info.replica_id, expected_chunk) + self.assertEqual(local_chunk_info.slice, (slice(None), slice(0, 16))) + + +class ModelBasedPartitionerTest(parameterized.TestCase): + + def get_axes_spec(self, partitioner, factored, momentum): + opt_def = adafactor.Adafactor( + learning_rate=0.1, + factored=factored, + min_dim_size_to_factor=8, + beta1=0.1 if momentum else None, + logical_factor_rules={ + 'batch': adafactor.FactorDim.NONE, + 'embed': adafactor.FactorDim.ROW, + 'vocab': adafactor.FactorDim.COLUMN, + 'mlp': adafactor.FactorDim.COLUMN, + }) + state = train_state.FlaxOptimTrainState.create( + opt_def, + flax.core.freeze({ + 'params': { + 'logits_dense': np.ones((16, 16), np.float32), + 'mlp': { + 'wo': { + 'kernel': np.ones((32, 16), np.float32) + } + } + }, + 'params_axes': { + 'logits_dense_axes': AxisMetadata(names=('vocab', 'embed')), + 'mlp': { + 'wo': { + 'kernel_axes': AxisMetadata(names=('embed', 'mlp')) + } + } + } + })) + return partitioner.get_mesh_axes(state).state_dict() + + def get_expected_axes_spec(self, + spec_0, + spec_1, + kernel_spec=PartitionSpec(None, 'model')): + return train_state.FlaxOptimTrainState( + optimizers.Optimizer( + # opt_def, + adafactor.Adafactor(0.1), # opt_def not compared. + state=optimizers.OptimizerState( + step=None, + param_states={ + 'logits_dense': spec_0, + 'mlp': { + 'wo': { + 'kernel': spec_1 + } + } + }), + target={ + 'logits_dense': PartitionSpec('model', None), + 'mlp': { + 'wo': { + 'kernel': kernel_spec + } + } + })).state_dict() + + def test_get_mesh_axes(self): + partitioner = partitioning.PjitPartitioner( + num_partitions=1, + logical_axis_rules=(('batch', 'data'), ('embed', None), + ('vocab', 'model'), ('mlp', 'model'))) + + p0_spec = PartitionSpec('model', None) + p1_spec = PartitionSpec(None, 'model') + + # Test quadrant of conditions: factored or not / momentum or not. + axes_spec = self.get_axes_spec(partitioner, factored=True, momentum=False) + expected_axes_spec = self.get_expected_axes_spec( + adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None), + adafactor._AdafactorParamState(m=None, v=None, v_col=None, v_row=None)) + jax.tree_multimap(self.assertEqual, axes_spec, expected_axes_spec) + + axes_spec = self.get_axes_spec(partitioner, factored=True, momentum=True) + expected_axes_spec = self.get_expected_axes_spec( + adafactor._AdafactorParamState( + m=p0_spec, v=None, v_col=None, v_row=None), + adafactor._AdafactorParamState( + m=p1_spec, v=None, v_col=None, v_row=None)) + jax.tree_multimap(self.assertEqual, axes_spec, expected_axes_spec) + + axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=True) + expected_axes_spec = self.get_expected_axes_spec( + adafactor._AdafactorParamState( + m=p0_spec, v=p0_spec, v_col=None, v_row=None), + adafactor._AdafactorParamState( + m=p1_spec, v=p1_spec, v_col=None, v_row=None)) + jax.tree_multimap(self.assertEqual, axes_spec, expected_axes_spec) + + axes_spec = self.get_axes_spec(partitioner, factored=False, momentum=False) + expected_axes_spec = self.get_expected_axes_spec( + adafactor._AdafactorParamState( + m=None, v=p0_spec, v_col=None, v_row=None), + adafactor._AdafactorParamState( + m=None, v=p1_spec, v_col=None, v_row=None)) + jax.tree_multimap(self.assertEqual, axes_spec, expected_axes_spec) + + @parameterized.product(activation_dims=(1, 2), param_dims=(1, 2)) + def test_standard_logical_axis_rules(self, activation_dims, param_dims): + default_rules = partitioning.standard_logical_axis_rules( + activation_dims, param_dims, additional_rules=None) + custom_rules = (('my-new-axis', 'data'), ('another-axis', None), + ('another-one', 'model')) + new_rules = partitioning.standard_logical_axis_rules( + activation_dims, param_dims, additional_rules=custom_rules) + self.assertEqual(new_rules[:len(default_rules)], default_rules) + self.assertEqual(new_rules[len(default_rules):], list(custom_rules)) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/precompile.py b/t5x/precompile.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8cf3ce677627baf71bfb818dc7b2db4492fc28 --- /dev/null +++ b/t5x/precompile.py @@ -0,0 +1,132 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Precompile and generates HLO from TPU metadata backend. + +TPU Metadata backend is a TPU backend without real TPU devices while supporting +any TPU topologies, to allow work that doesn't require real TPUs to run as if +it is, e.g., compiling/lowering a HLO graph with the backend. + +Ideally, the precompile defaults to cpu backend for default device array +placement since metadata backend does not have memory allocation. + +The pjit function is pinned to use available TPU Metadata backend, for getting +a proper lowering under TPU mesh. + +""" +import os + +from typing import Iterator, Optional + +import jax +from jax import random +import numpy as np +import t5.data.mixtures # pylint:disable=unused-import +from t5x import models +from t5x import partitioning +from t5x import trainer as trainer_lib +from t5x import utils + +import tensorflow as tf + + +def precompile(*, + model: models.BaseTransformerModel, + train_dataset_cfg: utils.DatasetConfig, + partitioner: partitioning.BasePartitioner, + model_dir: str, + random_seed: Optional[int], + get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset): + """Compiles and dump the HLO to model dir, with HLO text dumps.""" + rng = random.PRNGKey(random_seed or 42) + _, trainer_rng = random.split(rng, 2) + + # TODO(hthu): Find a better way of getting dataset shapes instead of actually + # reading database and iterate on it. + data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) + ds_shard_id = data_layout.shard_id + num_ds_shards = data_layout.num_shards + + def _verify_matching_vocabs(cfg: utils.DatasetConfig): + ds_vocabs = utils.get_vocabulary(cfg) + if (ds_vocabs[0] != model.input_vocabulary or + ds_vocabs[1] != model.output_vocabulary): + raise ValueError(f'Model and Task vocabularies do not match:\n' + f' task={cfg.mixture_or_task_name}\n' + f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' + f' model.input_vocabulary={model.input_vocabulary}\n' + f' model.output_vocabulary={model.output_vocabulary}\n') + + _verify_matching_vocabs(train_dataset_cfg) + + train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards, + model.FEATURE_CONVERTER_CLS) + + # Need to use full batch size. + input_shapes = { + k: (data_layout.batch_size, *v.shape[1:]) + for k, v in train_ds.element_spec.items() + } + input_types = { + k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() + } + + checkpointable_train_iter = iter(train_ds) + train_iter: Iterator[trainer_lib.BatchType] = map( + lambda x: jax.tree_map(np.array, x), checkpointable_train_iter) + batch = next(train_iter) + + # Compiling does not care about loading real weights. + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=model.optimizer_def, + init_fn=model.get_initial_variables, + input_shapes=input_shapes, + input_types=input_types, + partitioner=partitioner) + train_state_shape = train_state_initializer.global_train_state_shape + train_state_axes = train_state_initializer.train_state_axes + + def train_step(train_state, batch): + return trainer_lib.train_with_lr( + train_state, + batch, + learning_rate=1e-3, + dropout_rng=trainer_rng, + model=model, + num_microbatches=None, + weight_metrics_computer=None) + + partitioned_step = partitioner.partition( + train_step, + in_axis_resources=(train_state_axes, partitioning.PartitionSpec('data',)), + out_axis_resources=(train_state_axes, None), + donate_argnums=(0,)) + + # PartitionedTrainCallable has lower() defined but isn't exposed in pytype. + # TODO(hthu): Explicitly expose the lower() interface. + # pytype: disable=attribute-error + lowered = partitioned_step.lower(train_state_shape, batch) + # pytype: enable=attribute-error + + # TODO(hthu): Make this a proper library without writing files by default. + tf.io.gfile.makedirs(model_dir) + with tf.io.gfile.GFile( + os.path.join(model_dir, 'lowered_hlo_pre_optimization'), 'w') as f: + f.write(lowered.compiler_ir(dialect='hlo').as_serialized_hlo_module_proto()) + compiled = lowered.compile() + output_path = os.path.join(model_dir, 'lowered_hlo_post_optimization') + with tf.io.gfile.GFile(output_path, 'w') as f: + f.write(compiled.compiler_ir()[0].as_serialized_hlo_module_proto()) + with tf.io.gfile.GFile(os.path.join(model_dir, 'assignment'), 'wb') as f: + np.save(f, partitioner.mesh.device_ids) diff --git a/t5x/scripts/__init__.py b/t5x/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac5693550488d38623ec8e5b56e3fc3de148d40 --- /dev/null +++ b/t5x/scripts/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This empty file is needed to be recognized as a package by the setuptools.""" diff --git a/t5x/scripts/convert_tf_checkpoint.py b/t5x/scripts/convert_tf_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..64f23941d3275cfc9ad4175cdb9ec9e570784368 --- /dev/null +++ b/t5x/scripts/convert_tf_checkpoint.py @@ -0,0 +1,115 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Tool to convert a T5/MeshTF checkpoint to T5X. + +While T5X can be load these checkpoints on-the-fly, the process can be slow +for very large checkpoints. For frequently used checkpoints, it's recommended to +convert them once to a T5X checkpoint. + +Example usage: + +CUDA_VISIBLE_DEVICES="" +python -m t5x.scripts.convert_tf_checkpoint \ + --gin_file=t5x/examples/t5/t5_1_0/small.gin\ + --gin.convert_checkpoint.model=%MODEL\ + --gin.convert_checkpoint.tf_checkpoint_path=\ +\"gs://t5-data/pretrained_models/small/model.ckpt-1000000\"\ + --gin.convert_checkpoint.output_dir=\"/tmp/t5x_checkpoints/t5_small\"\ + --logtostderr +""" +import jax +import jax.numpy as jnp +from t5x import checkpoints +from t5x import models +from t5x import partitioning +from t5x import train_state as train_state_lib + + +def convert_checkpoint(model: models.BaseModel, + tf_checkpoint_path: str, + output_dir: str, + save_dtype: jnp.dtype = jnp.float32, + concurrent_gb: int = 16): + """Converts a TensorFlow checkpoint to a P5X checkpoint. + + Args: + model: + tf_checkpoint_path: Path to a TensorFlow checkpoint to convert. + output_dir: Path to a directory to write the converted checkpoint. + save_dtype: What dtype to store the target parameters as. + concurrent_gb: Number of gigabtes of parameters to convert in parallel. + Actual RAM usage may be 4X this number. + """ + + def initialize_train_state(rng): + initial_variables = model.get_initial_variables( + rng=rng, + input_shapes={ + 'encoder_input_tokens': (1, 1), + 'decoder_input_tokens': (1, 1) + }) + return train_state_lib.FlaxOptimTrainState.create(model.optimizer_def, + initial_variables) + + train_state = jax.eval_shape(initialize_train_state, jax.random.PRNGKey(0)) + + partitioner = partitioning.PjitPartitioner(1) + + checkpointer = checkpoints.Checkpointer( + train_state, partitioner, output_dir, save_dtype=jnp.dtype(save_dtype)) + + checkpointer.convert_from_tf_checkpoint( + tf_checkpoint_path, concurrent_gb=concurrent_gb) + + +if __name__ == '__main__': + # pylint:disable=g-import-not-at-top + from absl import flags + import gin + from t5x import gin_utils + # pylint:disable=g-import-not-at-top + + FLAGS = flags.FLAGS + + jax.config.parse_flags_with_absl() + + flags.DEFINE_multi_string( + 'gin_file', + default=None, + help='Path to gin configuration file. Multiple paths may be passed and ' + 'will be imported in the given order, with later configurations ' + 'overriding earlier ones.') + + flags.DEFINE_multi_string( + 'gin_bindings', default=[], help='Individual gin bindings') + + flags.DEFINE_list( + 'gin_search_paths', + default=['t5x/configs'], + help='Comma-separated list of gin config path prefixes to be prepended ' + 'to suffixes given via `--gin_file`. If a file appears in. Only the ' + 'first prefix that produces a valid path for each suffix will be ' + 'used.') + + def main(_): + """True main function.""" + convert_checkpoint_using_gin = gin.configurable(convert_checkpoint) + + gin_utils.parse_gin_flags(FLAGS.gin_search_paths, FLAGS.gin_file, + FLAGS.gin_bindings) + # Get gin-configured version of `convert_checkpoint`. + convert_checkpoint_using_gin() + + gin_utils.run(main) diff --git a/t5x/scripts/xm_launch.py b/t5x/scripts/xm_launch.py new file mode 100644 index 0000000000000000000000000000000000000000..00867f33254bd2f5a2653ab0ac2271f67eaf8239 --- /dev/null +++ b/t5x/scripts/xm_launch.py @@ -0,0 +1,213 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""XManager launcher for t5x. + +Read about XManager: +https://github.com/deepmind/xmanager + +Usage: +xmanager xm_launch.py -- \ + --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \ + --model_dir=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/$(date +%Y%m%d) \ + --tfds_data_dir=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data +""" + +import collections +import os +import shutil +import sys +import tempfile +from typing import Any, Dict + +from absl import app +from absl import flags +from xmanager import xm +from xmanager import xm_local +from xmanager.contrib import copybara + +_NAME = flags.DEFINE_string( + 'name', + 't5x', + 'Name of the experiment.', +) +_RUN_MODE = flags.DEFINE_enum( + 'run_mode', + 'train', + ['train', 'eval', 'infer'], + 'The mode to run T5X under', +) +_CLONE_GITHUB = flags.DEFINE_bool( + 'clone_github', + False, + 'If True, clone t5x/ from GitHub. Otherwise, use the local version.', +) +_COPYBARA_CONFIG = flags.DEFINE_string( + 'copybara_config', + None, + 'Copybara config to use. See https://github.com/google/copybara ' + 'If None, the local t5x directory will be copied with no modifications.', +) +_COPYBARA_WORKFLOW = flags.DEFINE_string( + 'copybara_workflow', + 'local', + 'Copybara workflow to apply with --copybara_config', +) +_COPYBARA_ORIGIN = flags.DEFINE_string( + 'copybara_origin', + '..', + 'Copybara origin folder to apply with --copybara_config', +) + +_TPU_CORES = flags.DEFINE_integer( + 'tpu_cores', + 8, + 'Number of TPU cores to run. There will be a new worker every 8 cores. ' + 'TPU types: https://cloud.google.com/tpu/docs/types-zones#types', +) +_MODEL_DIR = flags.DEFINE_string( + 'model_dir', + None, + 'Model dir to save logs, ckpts, etc. in "gs://model_dir" format.', +) +_TFDS_DATA_DIR = flags.DEFINE_string( + 'tfds_data_dir', + None, + 'Data dir to save the processed dataset in "gs://data_dir" format.', +) +_SEQIO_CACHE_DIRS = flags.DEFINE_list( + 'seqio_additional_cache_dirs', + [], + 'Comma separated directories in "gs://cache_dir" format to search for cached Tasks in addition to defaults.', +) +_PROJECT_DIRS = flags.DEFINE_list( + 'project_dirs', + None, + 'Project dir with custom components.', +) +_PIP_INSTALL = flags.DEFINE_list( + 'pip_install', + None, + 'Extra pip packages to install.', +) + + +@xm.run_in_asyncio_loop +async def main(_, gin_args: Dict[str, Any]): + name = 't5x' + async with xm_local.create_experiment(experiment_title=name) as experiment: + # TODO(chenandrew) Vertex Tensorboard is not supported for TPUs. + # https://github.com/deepmind/xmanager/issues/11 + # vertex = xm_local.vertex_client() + # tensorboard_name = await vertex.get_or_create_tensorboard(name) + # tensorboard = xm_local.TensorboardCapability( + # name=tensorboard_name, + # base_output_directory=_MODEL_DIR.value) + tensorboard = None + executor = xm_local.Vertex( + requirements=xm.JobRequirements(tpu_v2=_TPU_CORES.value), + tensorboard=tensorboard, + ) + + staging = os.path.join(tempfile.mkdtemp(), _NAME.value) + # The t5x/ root directory. + t5x_path = os.path.abspath(os.path.join(__file__, '..', '..', '..')) + t5x_destination = os.path.join(staging, 't5x') + if _COPYBARA_CONFIG.value: + t5x_path = copybara.run_workflow(_COPYBARA_CONFIG.value, + _COPYBARA_WORKFLOW.value, + _COPYBARA_ORIGIN.value, t5x_destination) + + if _CLONE_GITHUB.value: + copy_t5x = [ + 'RUN git clone --branch=main https://github.com/google-research/t5x', + ] + else: + if t5x_path != t5x_destination: + shutil.copytree(t5x_path, t5x_destination) + staging_t5x_path = os.path.join(os.path.basename(staging), 't5x') + copy_t5x = [f'COPY {staging_t5x_path}/ t5x'] + + copy_projects = [] + if _PROJECT_DIRS.value: + for project_dir in _PROJECT_DIRS.value: + project_name = os.path.basename(project_dir) + shutil.copytree(project_dir, os.path.join(staging, project_name)) + staging_project_dir = os.path.join( + os.path.basename(staging), project_name) + copy_projects.append(f'COPY {staging_project_dir}/ {project_name}') + + pip_install = [] + if _PIP_INSTALL.value: + pip_install = [ + 'RUN python3 -m pip install ' + ' '.join(_PIP_INSTALL.value) + ] + + [executable] = experiment.package([ + xm.python_container( + executor.Spec(), + path=staging, + base_image='gcr.io/deeplearning-platform-release/base-cpu', + docker_instructions=[ + *copy_t5x, + 'WORKDIR t5x', + 'RUN python3 -m pip install -e ".[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html', + # TODO(chenandrew): Remove the below command. + # TFDS 4.5.2 is missing SplitInfo fields. + 'RUN python3 -m pip install --force-reinstall tfds-nightly', + *pip_install, + *copy_projects, + ], + entrypoint=xm.CommandList([ + f'export MODEL_DIR=\'"{_MODEL_DIR.value}/logs"\'', + f'export TFDS_DATA_DIR={_TFDS_DATA_DIR.value}', + 'export SEQIO_CACHE_DIRS={}'.format(','.join( + _SEQIO_CACHE_DIRS.value)), + 'export T5X_DIR=.', + ('python3 ${T5X_DIR}/t5x/main.py ' + f'--run_mode={_RUN_MODE.value} ' + '--gin.MODEL_DIR=${MODEL_DIR} ' + '--tfds_data_dir=${TFDS_DATA_DIR} ' + '--undefok=seqio_additional_cache_dirs ' + '--seqio_additional_cache_dirs=${SEQIO_CACHE_DIRS} '), + ]), + ), + ]) + args = [] + for k, l in gin_args.items(): + for v in l: + if '\'' or '"' in v: + args.append(xm.ShellSafeArg(f'--{k}={v}')) + else: + args.append(f'--{k}={v}') + + experiment.add(xm.Job(executable=executable, executor=executor, args=args)) + + +def _split_gin_args(argv, prefix='--gin'): + """Separates absl and gin args into separate lists.""" + other_args = [argv[0]] + gin_args = collections.defaultdict(list) + for arg in argv[1:]: + if arg.startswith(prefix): + k, v = arg[len('--'):].split('=', maxsplit=1) + gin_args[k].append(v) + else: + other_args.append(arg) + return other_args, gin_args + + +if __name__ == '__main__': + _other_args, _gin_args = _split_gin_args(sys.argv) + app.run(lambda argv: main(argv, _gin_args), _other_args) diff --git a/t5x/state_utils.py b/t5x/state_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..33c6f73bdfa647f4de3af93d6d12a25d61d4907e --- /dev/null +++ b/t5x/state_utils.py @@ -0,0 +1,215 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for processing optimizer states.""" + +import re +from typing import Any, Mapping, Optional, Sequence, Tuple + +from absl import logging +from flax import traverse_util + + +def tensorstore_leaf(_, value): + """Detect if the node is a serialized tensorstore spec. + + Args: + _: The unused name of the current item. + value: The value of the possible leaf. + + Returns: + True if the value represents a tensorstore spec, False otherwise. + """ + # It is a tensorstore leaf if it at least has `driver`, `kvstore` and + # `metadata` in its keys, sometime they have additional ones like `dtype` or + # `transform`. + return set(value.keys()) >= {"driver", "kvstore", "metadata"} + + +def flatten_state_dict(state_dict, keep_empty_nodes: bool = False): + """Flatten a dictionary until an array or tensorstore is reached. + + Args: + state_dict: Optimizer state as nested dictionary. + keep_empty_nodes: Whether to keep empty node, for example, empty param + states from simple optimizers or non-touched parameter states in a + multioptimizer. + + Returns: + Flattened dictionary, though keeping tensor store state unflattened. + """ + return traverse_util.flatten_dict( + state_dict, + is_leaf=tensorstore_leaf, + keep_empty_nodes=keep_empty_nodes, + sep="/") + + +def get_name_tree(state_dict, keep_empty_nodes: bool = False): + """Returns new state_dict with leaves as joined path keys separated by "/".""" + return traverse_util.unflatten_dict({ + k: "/".join(k) for k in traverse_util.flatten_dict( + state_dict, keep_empty_nodes=keep_empty_nodes) + }) + + +def intersect_state( + state_dict: Mapping[str, Any], + intersect_state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """Drops non-matching entries from `state_dict`. + + Args: + state_dict: nested dict of optimizer state + intersect_state_dict: nested dict of entries to keep + + Returns: + nested dict like `state_dict` but with extra keys removed + """ + state_dict_flat = flatten_state_dict(state_dict) + intersect_state_dict_flat = flatten_state_dict(intersect_state_dict) + + for k in list(state_dict_flat): + if k not in intersect_state_dict_flat: + state_dict_flat.pop(k) + logging.warning("Ignoring param=%s from checkpoint", k) + + state_dict = traverse_util.unflatten_dict(state_dict_flat, sep="/") + + return state_dict + + +def merge_state(state_dict: Mapping[str, Any], + from_scratch_state: Mapping[str, Any]) -> Mapping[str, Any]: + """Inserts new entries into `state_dict`. + + Args: + state_dict: nested dict of optimizer state + from_scratch_state: nested dict of entries to insert + + Returns: + a nested dict like `state_dict` but with extra entries from + `from_scratch_state` inserted + """ + state_dict_flat = flatten_state_dict(state_dict) + from_scratch_state_flat = flatten_state_dict(from_scratch_state) + + for k in from_scratch_state_flat: + if k not in state_dict_flat: + logging.warning("Initializing param=%s from scratch", k) + state_dict_flat[k] = from_scratch_state_flat[k] + + state_dict = traverse_util.unflatten_dict(state_dict_flat, sep="/") + + return state_dict + + +def apply_assignment_map(ckpt_optimizer_state, + optimizer_state, + assignment_map: Sequence[Tuple[str, Optional[str]]], + require_all_rules_match: bool = True, + *, + is_resuming: bool = False): + """Applies an assignment map to a checkpoint optimizer state. + + In contrast to previous implementations, this has a switch whether to require + that all rules match, and has somewhat-custom-but-sensible replacement rules: + + 1. old keys that are matched are removed. + 2. old keys that don't match are retained. + 3. if two new keys map to the same old key, they both get assigned to its + value. + 4. if a new key isn't mapped but is in the checkpoint, it is copied over. + 5. new keys with None-valued replacement patterns are removed. + + Args: + ckpt_optimizer_state: Optimizer state in the checkpoint (usually, previous + model). + optimizer_state: optimizer state in the current model. + assignment_map: List of tuples (matcher, replacement) where matcher is a + regex, and replacement is a string replacement (possibly with + regex-compatible group match codes) or None if the matching state should + be dropped. + require_all_rules_match: Whether to require that all rules match. + is_resuming: Whether we are resuming a training run (True) or initializing a + new one (False). + + Returns: + New, remapped optimizer state. + """ + if is_resuming: + # Do not apply the transformation when resuming after a temporary stop. + # This ensures that the transformation will only happen once. + return ckpt_optimizer_state + + flat_ckpt = flatten_state_dict(ckpt_optimizer_state) + unmapped_old_keys = flat_ckpt.copy() + result = {} + explicitly_skipped_keys = set() + flat_opt = flatten_state_dict(optimizer_state) + + used_patterns = set() + for k in flat_opt: + for pattern, repl in assignment_map: + p_match = re.fullmatch(pattern, k) + if p_match: + # Skip initialization if the replacement pattern for this key is None. + if repl is None: + explicitly_skipped_keys.add(k) + used_patterns.add(pattern) + logging.info( + "Skipping optimizer param=%s, which had a None " + "replacement using pattern=%s in the assignment map.", k, pattern) + break + + old_k = p_match.expand(repl) + used_patterns.add(pattern) + + # Remove the old key, but read the value from the original dict since + # it's OK if it was referenced twice. + unmapped_old_keys.pop(old_k, None) + try: + result[k] = flat_ckpt[old_k] + logging.info( + "Assigning checkpoint param=%s to optimizer param=%s " + "using pattern=%s", old_k, k, pattern) + except KeyError: + raise ValueError( + f"Parameter '{old_k}' does not exist in restore checkpoint. " + f"Must be one of: {sorted(flat_ckpt.keys())}") + break + + # Now re-add the unmapped keys. This is a 2-step process so that the `pop()` + # call above doesn't mis-fire if the assignment map "rotates" a chain of keys. + for key, v in unmapped_old_keys.items(): + if key not in explicitly_skipped_keys: + result[key] = v + + # If any new keys weren't mapped, but are in the old checkpoint, copy those. + for key in set(flat_opt) - set(result): + if key in explicitly_skipped_keys: + pass + elif key in flat_ckpt: + result[key] = flat_ckpt[key] + else: + logging.warning( + "Skipping key=%s which did not match assignment map or checkpoint.", + key) + + if require_all_rules_match and len(assignment_map) != len(used_patterns): + unused_patterns = set(p for p, _ in assignment_map) - used_patterns + unused_patterns_str = ", ".join(f"'{p}'" for p in unused_patterns) + raise ValueError("Unused patterns in `assignment_map`: {" + + unused_patterns_str + "}") + + return traverse_util.unflatten_dict(result, sep="/") diff --git a/t5x/state_utils_test.py b/t5x/state_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c30cb8c3c2346b19354ad3fc81d8518896f004e7 --- /dev/null +++ b/t5x/state_utils_test.py @@ -0,0 +1,318 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for state_utils.""" + +import re + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from t5x import state_utils + + +class StateUtilsTest(parameterized.TestCase): + + @parameterized.parameters( + dict( + state_dict={"a": { + "b": 2, + "c": 3 + }}, + intersect_state_dict={ + "a": { + "b": 4 + }, + "d": 5 + }, + expect_state={"a": { + "b": 2 + }})) + def test_intersect_state(self, state_dict, intersect_state_dict, + expect_state): + actual_state = state_utils.intersect_state(state_dict, intersect_state_dict) + self.assertEqual(actual_state, expect_state) + + @parameterized.parameters( + dict( + state_dict={"a": { + "b": 2, + "c": 3 + }}, + merge_state_dict={ + "a": { + "b": 4 + }, + "d": 5 + }, + expect_state={ + "a": { + "b": 2, + "c": 3 + }, + "d": 5 + })) + def test_merge_state(self, state_dict, merge_state_dict, expect_state): + actual_state = state_utils.merge_state(state_dict, merge_state_dict) + self.assertEqual(actual_state, expect_state) + + def test_tensorstore_leaf(self): + leaf = { + "driver": "zarr", + "kvstore": { + "driver": "gfile", + "path": "target.bias" + }, + "metadata": { + "chunks": [4, 1], + "compressor": { + "id": "gzip", + "level": 1 + }, + "dtype": " int: + """Convert grid coordinates to linear index given a dimension ordering. + + Args: + coords: coordinates in minor to major ordering. + bounds: coordinate grid bonuds in SAME minor to major ordering as above. + + Returns: + Linear index for grid point. + """ + # Calculate stride multipliers. + strides = tuple(itertools.accumulate((1,) + bounds[:-1], operator.mul)) + # Sum linear index from strides and coords + return sum(jax.tree_multimap(lambda x, y: x * y, coords, strides)) + + +def make_devices(nx: int, + ny: int, + nz: int, + nc: int = 2, + host_layout: Tuple[int, ...] = (2, 2, 1, 2), + kind='TPU v3'): + """Create mock TPU devices.""" + devices = [] + device_bounds = (nx, ny, nz, nc) + hnx, hny, hnz, hnc = jax.tree_multimap(lambda a, b: a // b, device_bounds, + host_layout) + for x, y, z, c in itertools.product(*map(range, device_bounds)): + hx, hy, hz, hc = jax.tree_multimap(lambda a, b: a // b, (x, y, z, c), + host_layout) + # TODO(levskaya, jekbradbury): verify this id/host ordering on TPU v4 + device_id = coords_to_idx((c, x, y, z), (nc, nx, ny, nz)) # pytype: disable=wrong-arg-types + process_index = coords_to_idx((hc, hx, hy, hz), (hnc, hnx, hny, hnz)) # pytype: disable=wrong-arg-types + devices.append( + TpuDevice( + id=device_id, + process_index=process_index, + coords=(x, y, z), + core_on_chip=c, + platform='tpu', + device_kind=kind)) + return devices + + +def get_t5_test_model(**config_overrides) -> models.EncoderDecoderModel: + """Returns a tiny T5 1.1 model to use for testing.""" + tiny_config = network.T5Config( + vocab_size=32128, + dtype='bfloat16', + emb_dim=8, + num_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + head_dim=3, + mlp_dim=16, + mlp_activations=('gelu', 'linear'), + dropout_rate=0.0, + logits_via_embedding=False, + ) + + tiny_config = dataclasses.replace(tiny_config, **config_overrides) + sentencepiece_model_file = 'gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model' + vocabulary = seqio.SentencePieceVocabulary(sentencepiece_model_file) + return models.EncoderDecoderModel( + module=network.Transformer(tiny_config), + input_vocabulary=vocabulary, + output_vocabulary=vocabulary, + optimizer_def=adafactor.Adafactor( + decay_rate=0.8, + step_offset=0, + logical_factor_rules=adafactor.standard_logical_factor_rules())) + +# -------------------- Mesh parametrization helpers -------------------- +# Adapted from jax.test_util +MeshSpec = List[Tuple[str, int]] + + +@contextlib.contextmanager +def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: + """Test utility for setting up meshes given mesh data from `schedules`.""" + axis_names, shape = zip(*named_shape) + size = np.prod(shape) + local_devices = list(jax.local_devices()) + if len(local_devices) < size: + raise unittest.SkipTest(f'Test requires {size} local devices') + mesh_devices = np.array(local_devices[:size]).reshape(shape) + with Mesh(mesh_devices, axis_names): + yield + + +def create_global_mesh(mesh_shape, axis_names): + size = np.prod(mesh_shape) + if len(jax.devices()) < size: + raise unittest.SkipTest(f'Test requires {size} global devices.') + devices = sorted(jax.devices(), key=lambda d: d.id) + mesh_devices = np.array(devices[:size]).reshape(mesh_shape) + global_mesh = Mesh(mesh_devices, axis_names) + return global_mesh + + +def get_fake_vocab(): + """Creates fake vocabulary compatible with `get_fake_tokenized_dataset`.""" + + @dataclasses.dataclass + class DummyVocab: + vocab_size: int = 128 + eos_id: int = 1 + + vocab = DummyVocab() + return (vocab, vocab) + + +# Text preprocessed and tokenized. +_FAKE_TOKENIZED_DATASET = { + 'train': [ + { + 'inputs': (3, 13, 7, 14, 15, 9, 4, 16), + 'inputs_pretokenized': 'complete: this', + 'targets': (3, 8, 6, 3, 5, 10), + 'targets_pretokenized': 'is a test' + }, + { + 'inputs': (3, 13, 7, 14, 15, 9, 4, 16), + 'inputs_pretokenized': 'complete: that', + 'targets': (17, 5, 6, 3, 5, 10), + 'targets_pretokenized': 'was a test' + }, + { + 'inputs': (3, 13, 7, 14, 15, 9, 4, 16), + 'inputs_pretokenized': 'complete: those', + 'targets': (17, 4, 23, 4, 10, 6), + 'targets_pretokenized': 'were tests' + }, + ], + # Notice that we repeat consecutively each examples 4 times, + # this needed for tests like infer_tests to validate determinism. + 'validation': [{ + 'inputs': (3, 13, 7, 14, 15, 9, 4, 16), + 'inputs_pretokenized': 'complete: this', + 'targets': (3, 8, 6, 3, 5, 3, 25, 5), + 'targets_pretokenized': 'is a validation', + }] * 4 + [{ + 'inputs': (3, 13, 7, 14, 15, 9, 4, 17), + 'inputs_pretokenized': 'complete: that', + 'targets': (17, 5, 6, 3, 5, 22, 7, 24), + 'targets_pretokenized': 'was another validation', + }] * 4 +} + + +def get_fake_tokenized_dataset(*_, split='validation', **__): + """Creates fake dataset compatible with T5X models inputs.""" + + if split == 'test': + split = 'validation' + output_types = { + 'inputs': tf.int32, + 'targets': tf.int32, + 'inputs_pretokenized': tf.string, + 'targets_pretokenized': tf.string + } + output_shapes = { + 'inputs': [None], + 'targets': [None], + 'inputs_pretokenized': [], + 'targets_pretokenized': [] + } + ds = tf.data.Dataset.from_generator(lambda: _FAKE_TOKENIZED_DATASET[split], + output_types, output_shapes) + if split == 'train': + ds = ds.repeat(None) + return ds + + +def assert_same(tree_a, tree_b): + """Asserts that both trees are the same.""" + tree_a, tree_b = jax.device_get((tree_a, tree_b)) + jax.tree_multimap(np.testing.assert_array_equal, tree_a, tree_b) + + +def get_train_state_from_variables(variables, + optimizer_def=adafactor.Adafactor(0.0)): + """Returns a default Train State with Adafactor optimizer.""" + optimizer = optimizer_def.create(variables['params']) + return train_state_lib.FlaxOptimTrainState(optimizer) diff --git a/t5x/train.py b/t5x/train.py new file mode 100644 index 0000000000000000000000000000000000000000..162c1f6f7cf6d260f2ce3d848a41a6073cf448b2 --- /dev/null +++ b/t5x/train.py @@ -0,0 +1,680 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Script to pretrain or finetune in JAX using a SeqIO pipeline. + +""" +import functools +import math +import os +import time +from typing import Callable, Sequence, Mapping, Tuple, Type, Optional + +# Set Linen to add profiling information when constructing Modules. +# Must be set before flax imports. +# pylint:disable=g-import-not-at-top +os.environ['FLAX_PROFILE'] = 'true' +# TODO(adarob): Re-enable once users are notified and tests are updated. +os.environ['FLAX_LAZY_RNG'] = 'no' +from absl import logging +from clu import metric_writers +import clu.data +import jax +from jax import random +from jax.experimental import multihost_utils +import jax.numpy as jnp +import numpy as np +import seqio +from t5x import models +from t5x import partitioning +from t5x import train_state as train_state_lib +from t5x import trainer as trainer_lib +from t5x import utils +import tensorflow as tf + + +# Automatically search for gin files relative to the T5X package. +_DEFAULT_GIN_SEARCH_PATHS = [ + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +] +PyTreeDef = type(jax.tree_structure(None)) +P = partitioning.PartitionSpec +# Special key that used to distinguish train metrics. +TRAIN_METRIC_KEY = 'train' +# String keys that is acceptable from config. +_ACTION_KEYS = frozenset(trainer_lib.ActionMode.__members__.keys()) + + +def run_actions( + mode: trainer_lib.ActionMode, actions: trainer_lib.ActionMapType, + train_state: train_state_lib.TrainState, + metrics_by_task: Mapping[str, trainer_lib.MetricValueMapType]) -> bool: + """Invokes all actions on the given mode on host 0, then broadcasts to all. + + Args: + mode: The mode to run the actions. e.g., if mode is `train`, only actions + configured to run with `train` mode will be invoked. + actions: A mapping of actions that runs after train, eval or infer_eval, to + inspect the model and perform useful operations, e.g., early stopping. + train_state: The current train_state of the trainer. + metrics_by_task: A map of metrics keyed by task name. + + Returns: + A bool indicating whether training should be halted. + + Raises: + RuntimeError: When the metrics processed on host 0 is None. + """ + stop_training = False + if jax.process_index() == 0: + if not metrics_by_task: + raise RuntimeError('Metric is unexpectedly empty on process 0') + for action in actions.get(mode, []): + stop_training |= action.run(train_state, metrics_by_task=metrics_by_task) + # Broadcast result from host 0 to others. + return bool(multihost_utils.broadcast_one_to_all(jnp.array(stop_training))) + + +def train( + *, + model: models.BaseTransformerModel, + train_dataset_cfg: utils.DatasetConfig, + train_eval_dataset_cfg: Optional[utils.DatasetConfig], + infer_eval_dataset_cfg: Optional[utils.DatasetConfig], + checkpoint_cfg: utils.CheckpointConfig, + partitioner: partitioning.BasePartitioner, + trainer_cls: Type[trainer_lib.BaseTrainer], + model_dir: str, + total_steps: int, + eval_steps: int, + eval_period: int, + stats_period: Optional[int] = None, + random_seed: Optional[int], + use_hardware_rng: bool = False, + summarize_config_fn: Callable[[str, metric_writers.MetricWriter, int], + None], + inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator, + get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset, + concurrent_metrics: bool = True, + actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None, + train_eval_get_dataset_fn: Optional[utils.GetDatasetCallable] = None, + run_eval_before_training: bool = False, + use_gda: bool = False) -> Tuple[int, train_state_lib.TrainState]: + """Train function. + + Args: + model: The model object to use for training. + train_dataset_cfg: Specification for the dataset to train with. + train_eval_dataset_cfg: Specification for the dataset to evaluate with using + the train metrics and no inference (e.g., uses teacher forcing). If None, + train eval is disabled. + infer_eval_dataset_cfg: Specification for the dataset to evaluate with using + the inference metrics (e.g., uses sampled decoding). If None, inference + eval is disabled. + checkpoint_cfg: Specification for saving and restoring model parameters and + dataset state to/from checkpoints. + partitioner: Partitioner for model parameters and data across devices. + trainer_cls: An implementation of BaseTrainer. + model_dir: Path of directory to store checkpoints and metric summaries. + total_steps: The step number to stop training after. The number of actual + steps trained in this run will be this number minus the starting step from + the checkpoint. + eval_steps: The number of batches to process for each train-eval loop. + eval_period: The number of train steps between each evaluation (both + train-eval and infer-eval). + stats_period: The number of train steps between writing scalar stats. If + None, defaults to eval_period. + random_seed: A random seed to use for dropout and initialization. If None, a + fast, non-deterministic hardware-based RNG is used. + use_hardware_rng: Whether to force using the RngBitGenerator based hardware + rng, which takes seeds and acts similarly to software PRNG in that it + should be seed-deterministic. The new RngBitGenerator custom PRNG system + should be reproducible for a given sharding, but the numbers will change + for different shardings of the same model. + summarize_config_fn: A function that takes in the model directory, a + SummaryWriter, and the step number, and writes a summary of the + inference_evaluator_cls: seqio.Evaluator class to use for inference + evaluation, potentially with bound configuration args. + get_dataset_fn: The callable use to get the train and train-eval datasets + based on the DatasetConfig and shard information. + concurrent_metrics: If True, allow metrics computation and logging to + overlap with training. Will likely result in additional TPU memory usage. + actions: A mapping of actions that runs after train, eval or infer_eval, to + inspect the model and perform useful operations, e.g., early stopping. The + key must have a 1:1 mapping to ActionMode enum. For EVAL actions to + actually work, this requires `concurrent_metrics` to be turned off, since + chaining futures and mutating states concurrently might be error-prone. + train_eval_get_dataset_fn: Optional callable use to get the train-eval + datasets based on the DatasetConfig and shard information. If missing, it + defaults to `get_dataset_fn`. + run_eval_before_training: If True, calculate training eval and inference + eval metrics before training begins. + use_gda: if True, uses GlobalDeviceArray. Experimental feature. + + Returns: + The tuple of (last_step, last_train_state). + """ + logging.info('Process ID: %d', jax.process_index()) + tf.io.gfile.makedirs(model_dir) + + jax.config.update('jax_parallel_functions_output_gda', use_gda) + + # Each "epoch" of the training loop should be the min of the eval period, + # checkpoint period or the full training. + # We compute here to ensure that the eval period and checkpoint period are + # divisible by this number, otherwise we fail. + eval_enabled = (train_eval_dataset_cfg or infer_eval_dataset_cfg) + eval_period = eval_period if eval_enabled else 0 + checkpoint_period = checkpoint_cfg.save.period if checkpoint_cfg.save else 0 + if eval_period or checkpoint_period: + steps_per_epoch = min(eval_period or np.inf, checkpoint_period or np.inf) + else: + steps_per_epoch = total_steps + stats_period = stats_period or steps_per_epoch + if (eval_period and eval_period % steps_per_epoch or + checkpoint_period and checkpoint_period % steps_per_epoch): + raise ValueError( + f'Checkpoint period ({checkpoint_period}) must evenly divide eval ' + f'period ({eval_period}), or vice-versa.') + + if use_hardware_rng or random_seed is None: + logging.info( + 'Using fast RngBitGenerator PRNG for initialization and dropout.') + + if random_seed is None: + random_seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) + logging.info('Random seed not provided, using RNG seed %s', random_seed) + else: + logging.warning( + 'When using hardware RNG with a fixed seed, repeatability is only ' + 'guaranteed for fixed hardware and partitioning schemes and for a ' + 'fixed version of this code and its dependencies.') + utils.set_hardware_rng_ops() + rng = random.PRNGKey(random_seed) + else: + logging.info('Using seed for initialization and dropout RNG: %d', + random_seed) + rng = random.PRNGKey(random_seed) + + init_rng, trainer_rng = random.split(rng, 2) + + # --------------------------------------------------------------------------- + # Initialize datasets + # --------------------------------------------------------------------------- + + if (train_dataset_cfg.seed and + not (checkpoint_cfg.save or checkpoint_cfg.save.save_dataset)): + logging.warning( + 'Providing a random seed for the train dataset with ' + '`checkpoint_train_ds=False` is dangerous since each ' + 'preemption/restart will cause the dataset to deterministically replay ' + 'from the beginning.') + + data_layout = partitioner.get_data_layout(train_dataset_cfg.batch_size) + ds_shard_id = data_layout.shard_id + num_ds_shards = data_layout.num_shards + + def _verify_matching_vocabs(cfg: utils.DatasetConfig): + ds_vocabs = utils.get_vocabulary(cfg) + if (ds_vocabs[0] != model.input_vocabulary or + ds_vocabs[1] != model.output_vocabulary): + raise ValueError(f'Model and Task vocabularies do not match:\n' + f' task={cfg.mixture_or_task_name}\n' + f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' + f' model.input_vocabulary={model.input_vocabulary}\n' + f' model.output_vocabulary={model.output_vocabulary}\n') + + _verify_matching_vocabs(train_dataset_cfg) + + train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards, + model.FEATURE_CONVERTER_CLS) + if isinstance(train_ds, tf.data.Dataset): + train_iter = clu.data.TfDatasetIterator(train_ds) + elif isinstance(train_ds, clu.data.DatasetIterator): + train_iter = train_ds + else: + raise ValueError( + f'get_dataset_fn returned unsupported type {type(train_ds)}.') + + if train_eval_dataset_cfg: + _verify_matching_vocabs(train_eval_dataset_cfg) + train_eval_datasets = utils.get_training_eval_datasets( + train_eval_dataset_cfg, + ds_shard_id, + num_ds_shards, + eval_steps, + model.FEATURE_CONVERTER_CLS, + get_dataset_fn=train_eval_get_dataset_fn if train_eval_get_dataset_fn + is not None else get_dataset_fn) # type: Mapping[str, tf.data.Dataset] + if not train_eval_datasets: + logging.warning( + 'No train_eval datasets loaded from config `train_eval_dataset_cfg`: ' + '%s', train_eval_dataset_cfg) + else: + train_eval_datasets = {} + + # The manner in which parameters are initialized follows this order of + # preference: + # 1. From a T5X checkpoint in `model_dir`, if one exists. + # 2. From a T5X or TF checkpoint specified by `cfg.path`, if set. + # 3. From scratch using `init_fn`. + + # 1. From a T5X checkpoint in `model_dir`, if one exists. + if checkpoint_cfg.restore is not None: + state_transforms_for_restore = [ + functools.partial(fn, is_resuming=True) + for fn in checkpoint_cfg.restore.state_transformation_fns + ] + else: + state_transforms_for_restore = [] + restore_cfgs = [ + utils.RestoreCheckpointConfig( + path=model_dir, + mode='latest', + dtype=checkpoint_cfg.save.dtype, + checkpointer_cls=checkpoint_cfg.save.checkpointer_cls, + # Restore dataset state if it is being saved. + restore_dataset=(checkpoint_cfg.save and + checkpoint_cfg.save.save_dataset), + state_transformation_fns=state_transforms_for_restore) + ] + # 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set. + if checkpoint_cfg.restore: + if checkpoint_cfg.restore.mode == 'all': + raise ValueError( + "Restore checkpoint mode 'all' is not supported in training.") + + # TODO(dhgarrette): Split "restore" behavior into separate configurations + # for the initial restoration for a new run, vs resuming a stopped run. + if isinstance(checkpoint_cfg.restore.path, str): + restore_cfgs.append(checkpoint_cfg.restore) + elif not checkpoint_cfg.restore.path: + # `path` is an empty (non-`str`) sequence, so there is nothing to restore. + pass + else: + raise ValueError( + 'Restore checkpoint config may only have a single path in training.') + + # Need to use full batch size. + input_shapes = { + k: (data_layout.batch_size, *v.shape[1:]) + for k, v in train_ds.element_spec.items() + } + input_types = { + k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() + } + init_or_restore_tick = time.time() + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=model.optimizer_def, + init_fn=model.get_initial_variables, + input_shapes=input_shapes, + input_types=input_types, + partitioner=partitioner) + + # May be None, empty + valid_restore_cfg, restore_paths = utils.get_first_valid_restore_config_and_paths( + restore_cfgs) + if len(restore_paths) > 1: + raise ValueError('Multiple restore paths not permitted in training.') + checkpointable_train_iter = ( + train_iter.iterator + if isinstance(train_iter, clu.data.TfDatasetIterator) else None) + checkpoint_manager = utils.LegacyCheckpointManager( + checkpoint_cfg.save, + valid_restore_cfg, + train_state_initializer.global_train_state_shape, + partitioner, + ds_iter=checkpointable_train_iter, + model_dir=model_dir, + use_gda=use_gda) + + train_state = checkpoint_manager.restore( + restore_paths, valid_restore_cfg, + utils.get_fallback_state( + valid_restore_cfg, + lambda rng: train_state_initializer.from_scratch(rng).state_dict(), + init_rng)) + + # 3. If no checkpoint to restore, init from scratch. + train_state = train_state or train_state_initializer.from_scratch(init_rng) + train_state_axes = train_state_initializer.train_state_axes + init_or_restore_secs = time.time() - init_or_restore_tick + logging.info('Initialize/restore complete (%.2f seconds).', + init_or_restore_secs) + + # Log the variable shapes information and write to a file. + log_file = os.path.join(model_dir, 'model-info.txt') + utils.log_model_info(log_file, + train_state_initializer.global_train_state_shape, + partitioner) + + # Restore step from last checkpoint or set to 0 if training from scratch. + host_step = int(utils.get_local_data(train_state.step)) # pytype: disable=attribute-error + + # --------------------------------------------------------------------------- + # Trainer + # --------------------------------------------------------------------------- + + trainer: trainer_lib.BaseTrainer = trainer_cls( + model=model, + train_state=train_state, + partitioner=partitioner, + train_state_axes=train_state_axes, + eval_names=train_eval_datasets.keys(), + summary_dir=model_dir, + rng=trainer_rng) + del train_state + + train_metrics = trainer.train_metrics_manager + summarize_config_fn(model_dir, train_metrics.summary_writer, host_step) + + train_metrics.write_scalar('timing/init_or_restore_seconds', + init_or_restore_secs, host_step) + + # ---------------------------------------------------------------------------- + # SeqIO (inference-based) evaluation setup + # ---------------------------------------------------------------------------- + # Init evaluator to set up cached datasets + evaluator = None + if infer_eval_dataset_cfg is not None: + _verify_matching_vocabs(infer_eval_dataset_cfg) + evaluator = inference_evaluator_cls( + log_dir=os.path.join(model_dir, 'inference_eval'), + mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name, + feature_converter=model.FEATURE_CONVERTER_CLS(pack=False), + eval_split=infer_eval_dataset_cfg.split, + use_cached=infer_eval_dataset_cfg.use_cached, + seed=infer_eval_dataset_cfg.seed, + sequence_length=infer_eval_dataset_cfg.task_feature_lengths, + use_memory_cache=infer_eval_dataset_cfg.use_memory_cache) + if not evaluator.eval_tasks: + # Skip evaluaton. + evaluator = None + + if evaluator is not None: + predict_fn = utils.get_infer_fn( + infer_step=model.predict_batch, + batch_size=infer_eval_dataset_cfg.batch_size, + train_state_axes=train_state_axes, + partitioner=partitioner) + + predict_with_aux_fn = utils.get_infer_fn( + infer_step=model.predict_batch_with_aux, + batch_size=infer_eval_dataset_cfg.batch_size, + train_state_axes=train_state_axes, + partitioner=partitioner) + + score_fn = utils.get_infer_fn( + infer_step=model.score_batch, + batch_size=infer_eval_dataset_cfg.batch_size, + train_state_axes=train_state_axes, + partitioner=partitioner) + + if actions is None: + actions = {} + + if set(actions.keys()).difference(_ACTION_KEYS): + raise ValueError(f'actions keys must be one of {_ACTION_KEYS}, but got : ' + f'{actions.keys()}') + + # Transform the string key into proper ActionMode enum. + actions = {trainer_lib.ActionMode[k]: v for k, v in actions.items()} + + if concurrent_metrics and actions.get(trainer_lib.ActionMode.INFER_EVAL, + None) is not None: + logging.warning('Actions for INFER_EVAL will not be triggered when async ' + 'metrics computation is enabled') + if concurrent_metrics and actions.get(trainer_lib.ActionMode.TRAIN, + None) is not None: + logging.warning('Actions for TRAIN will not be triggered when async ' + 'metrics computation is enabled') + + # ---------------------------------------------------------------------------- + # Setup Eval Utility Functions + # ---------------------------------------------------------------------------- + def _run_training_eval(first_run: bool = False): + if first_run: + logging.info('Compiling training eval loop.') + trainer.compile_eval({ + task: utils.get_zeros_batch_like_dataset(ds) + for task, ds in train_eval_datasets.items() + }) + logging.info('Computing training evaluation metrics.') + eval_batch_iters = { + task: ds.as_numpy_iterator() + for task, ds in train_eval_datasets.items() + } + eval_summaries = trainer.eval(eval_batch_iters) + trainer.stop_training = run_actions(trainer_lib.ActionMode.TRAIN_EVAL, + actions, trainer.train_state, + eval_summaries) + + def _run_inference_eval(): + """Run prediction based inference eval.""" + if evaluator is None: + return + logging.info('Running inference evaluation.') + evaluate_tick = time.time() + all_metrics, _, _ = evaluator.evaluate( + compute_metrics=jax.process_index() == 0, + step=host_step, + predict_fn=functools.partial( + predict_fn, + train_state=trainer.train_state, + rng=jax.random.PRNGKey(0)), + score_fn=functools.partial(score_fn, train_state=trainer.train_state), + predict_with_aux_fn=functools.partial( + predict_with_aux_fn, + train_state=trainer.train_state, + rng=jax.random.PRNGKey(0)), + ) + if not concurrent_metrics: + # Ensure metrics are finished being computed. + all_metrics_done = all_metrics.result() or {} + trainer.stop_training = run_actions(trainer_lib.ActionMode.INFER_EVAL, + actions, trainer.train_state, + all_metrics_done) + train_metrics.write_scalar('timing/evaluate_seconds', + time.time() - evaluate_tick, host_step) + + # Optionally run teacher-forcing training eval and SeqIO inference-base eval + # before training. Useful for testing how much a model knows before any + # finetuning. + if run_eval_before_training: + if train_eval_datasets: + logging.info('Running training eval before training.') + _run_training_eval(first_run=True) + if evaluator is not None: + logging.info('Running inference eval before training.') + _run_inference_eval() + + # ---------------------------------------------------------------------------- + # Main training loop + # ---------------------------------------------------------------------------- + logging.info('Starting training loop.') + + first_step = host_step + + if total_steps < first_step: + raise ValueError( + f'Unexpected total_steps ({total_steps}) < checkpoint step ' + f' ({first_step}).') + + logging.info('Starting main loop over steps %d-%d', first_step, total_steps) + + steps_per_epoch = min(steps_per_epoch, total_steps) + first_epoch = first_step // steps_per_epoch + num_epochs = first_epoch + math.ceil( + (total_steps - first_step) / steps_per_epoch) + logging.info('Training with artificial "epochs" of %d steps.', + steps_per_epoch) + + logging.info('Compiling train loop.') + logging.flush() + dummy_batch = { + k: np.ones(v.shape, v.dtype) for k, v in train_iter.element_spec.items() + } + trainer.compile_train(dummy_batch) + + # Main Loop over "epochs". + for epoch in range(first_epoch, num_epochs): + final_epoch = epoch == num_epochs - 1 + logging.info('Epoch %d of %d', epoch, num_epochs) + + # `stop_training` is requested, break out the main loop immediately. + if trainer.stop_training: + break + + logging.info('BEGIN Train loop.') + try: + # Until the last epoch, `num_steps = steps_per_epoch` + num_steps = min(total_steps - host_step, steps_per_epoch) + epoch_end_step = host_step + num_steps + logging.info('Training for %d steps.', num_steps) + while host_step < epoch_end_step: + if trainer.stop_training: + logging.info('Saving a checkpoint before early stopping...') + checkpoint_manager.save(trainer.train_state, + checkpoint_cfg.save.state_transformation_fns) + logging.info('Stopping training loop early since `stop_training` is ' + 'requested.') + break + + inner_num_steps = min(epoch_end_step - host_step, stats_period) + train_summary = trainer.train( + train_iter, inner_num_steps, start_step=host_step) + if not concurrent_metrics: + # Note that we always pass the dictionary of `tasks` -> summary so + # that the actions can be performed without special casing. The only + # caveat is that train would need its own special `key` given no + # `task` will be applied. + trainer.stop_training = run_actions( + trainer_lib.ActionMode.TRAIN, actions, trainer.train_state, + {TRAIN_METRIC_KEY: train_summary.result()}) + + host_step += inner_num_steps + logging.info('END Train loop.') + except trainer_lib.PreemptionError as e: + logging.info('Saving emergency checkpoint.') + checkpoint_manager.save(trainer.train_state, + checkpoint_cfg.save.state_transformation_fns) + logging.info('Saving emergency checkpoint done.') + raise e + + step_offset = host_step - first_step + + # Maybe save a checkpoint. + if checkpoint_period and (final_epoch or + step_offset % checkpoint_period == 0): + # Make sure last train step has completed before starting the clock. + train_summary.result() + logging.info('Saving checkpoint.') + checkpoint_tick = time.time() + checkpoint_manager.save(trainer.train_state, + checkpoint_cfg.save.state_transformation_fns) + checkpoint_tock = time.time() + train_metrics.write_scalar('timing/checkpoint_seconds', + checkpoint_tock - checkpoint_tick, host_step) + + is_eval_epoch = eval_period and (final_epoch or + step_offset % eval_period == 0) + + # Training Evaluation (i.e., with teacher forcing). + if is_eval_epoch and train_eval_datasets: + # Maybe less if final step < period. + first_run = step_offset // eval_period <= 1 + _run_training_eval(first_run and not run_eval_before_training) + + # Inference Evaluation (i.e., with decoding or scoring). + if evaluator is not None: + _run_inference_eval() + + # Wait until computations are done before exiting + logging.info('Finished.') + trainer.close() + if evaluator: + evaluator.close() + multihost_utils.sync_global_devices('complete') + + return host_step, trainer.train_state + + +if __name__ == '__main__': + # pylint: disable=g-import-not-at-top + from absl import app + from absl import flags + import gin + from t5x import gin_utils + # pylint: enable=g-import-not-at-top + + FLAGS = flags.FLAGS + + jax.config.parse_flags_with_absl() + + flags.DEFINE_multi_string( + 'gin_file', + default=None, + help='Path to gin configuration file. Multiple paths may be passed and ' + 'will be imported in the given order, with later configurations ' + 'overriding earlier ones.') + + flags.DEFINE_multi_string( + 'gin_bindings', default=[], help='Individual gin bindings.') + + flags.DEFINE_list( + 'gin_search_paths', + default=['.'], + help='Comma-separated list of gin config path prefixes to be prepended ' + 'to suffixes given via `--gin_file`. If a file appears in. Only the ' + 'first prefix that produces a valid path for each suffix will be ' + 'used.') + + flags.DEFINE_string( + 'tfds_data_dir', None, + 'If set, this directory will be used to store datasets prepared by ' + 'TensorFlow Datasets that are not available in the public TFDS GCS ' + 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' + 'all `Task`s.') + + flags.DEFINE_list( + 'seqio_additional_cache_dirs', [], + 'Directories to search for cached Tasks in addition to defaults.') + + + + def main(argv: Sequence[str]): + """Wrapper for pdb post mortems.""" + _main(argv) + + def _main(argv: Sequence[str]): + """True main function.""" + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + if FLAGS.tfds_data_dir: + seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) + + seqio.add_global_cache_dirs(FLAGS.seqio_additional_cache_dirs) + + # Create gin-configurable version of `train`. + train_using_gin = gin.configurable(train) + + gin_utils.parse_gin_flags( + # User-provided gin paths take precedence if relative paths conflict. + FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, + FLAGS.gin_file, + FLAGS.gin_bindings) + train_using_gin() + + gin_utils.run(main) diff --git a/t5x/train_state.py b/t5x/train_state.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1baee4acf75077a9e35d17b52c6d2f85aebefe --- /dev/null +++ b/t5x/train_state.py @@ -0,0 +1,278 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train state for passing around objects during training.""" + +from typing import Any, Mapping, MutableMapping, Optional, Tuple + +from flax import traverse_util +import flax.core +from flax.core import scope as flax_scope +from flax.linen import partitioning as flax_partitioning +import flax.serialization +import flax.struct +import jax.numpy as jnp +from t5x import optimizers + +import typing_extensions + +EMPTY_DICT = flax.core.freeze({}) +FrozenDict = flax_scope.FrozenDict +FrozenVariableDict = flax_scope.FrozenVariableDict +MutableVariableDict = flax_scope.MutableVariableDict +VariableDict = flax_scope.VariableDict + + +class TrainState(typing_extensions.Protocol): + """TrainState interface.""" + + @property + def step(self) -> jnp.ndarray: + """The current training step as an integer scalar.""" + ... + + @property + def params(self) -> FrozenVariableDict: + """The parameters of the model as a PyTree matching the Flax module.""" + ... + + @property + def param_states(self) -> FrozenVariableDict: + """The optimizer states of the parameters as a PyTree.""" + ... + + @property + def flax_mutables(self) -> FrozenVariableDict: + """Flax mutable collection.""" + ... + + def state_dict(self) -> MutableVariableDict: + """Returns a mutable representation of the state for checkpointing.""" + ... + + def restore_state(self, state_dict: Mapping[str, Any]) -> 'TrainState': + """Restores the object state from a state dict.""" + ... + + def replace_params(self, params: VariableDict) -> 'TrainState': + ... + + def replace_step(self, step: jnp.ndarray) -> 'TrainState': + ... + + def apply_gradient(self, + grads, + learning_rate, + flax_mutables=EMPTY_DICT) -> 'TrainState': + """Applies gradient, increments step, and returns an updated TrainState.""" + ... + + def as_logical_axes(self) -> 'TrainState': + """Replaces `param` and `param-states` with their logical axis names.""" + ... + + +def _validate_params_axes(params_axes, params): + axis_names = flax_partitioning.get_axis_names(params_axes) + missing_params_axes = ( + set(traverse_util.flatten_dict(params, sep='/')) - + set(traverse_util.flatten_dict(axis_names, sep='/'))) + if missing_params_axes: + raise ValueError( + f'Missing axis names for parameters: {missing_params_axes}') + + +def _split_variables_and_axes( + variables_and_axes: FrozenVariableDict +) -> Tuple[FrozenVariableDict, FrozenVariableDict]: + """Splits `variables_and_axes` into two separate dicts with the same keys.""" + # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`. + variables = {} + axes = {} + for k, v in variables_and_axes.items(): + if k.endswith('_axes'): + axes[k[:-5]] = v # k without "_axes". + _validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes". + else: + variables[k] = v + return flax.core.freeze(variables), flax.core.freeze(axes) + + +class FlaxOptimTrainState(flax.struct.PyTreeNode): + """Simple train state for holding parameters, step, optimizer state.""" + _optimizer: optimizers.OptimizerType + # Contains axis metadata (e.g., names) matching parameter tree. + params_axes: Optional[FrozenVariableDict] = None + # Flax mutable fields. + flax_mutables: FrozenDict = EMPTY_DICT + # Contains axis metadata (e.g., names) matching flax_mutables tree. + flax_mutables_axes: Optional[FrozenVariableDict] = EMPTY_DICT + + @classmethod + def create(cls, optimizer_def: optimizers.OptimizerDefType, + model_variables: FrozenVariableDict) -> 'FlaxOptimTrainState': + other_variables, params = model_variables.pop('params') + if 'params_axes' in other_variables: + other_variables, params_axes = other_variables.pop('params_axes') + _validate_params_axes(params_axes, params) + else: + params_axes = None + + # Split other_variables into mutables and their corresponding axes. + flax_mutables, flax_mutables_axes = _split_variables_and_axes( + other_variables) + + # If the optimizer supports `set_param_axes`, then assume that the model + # code is emitting these axes and use it. + if hasattr(optimizer_def, 'set_param_axes'): + if params_axes is None: + raise ValueError('The optimizer supports params_axes for model-based ' + 'partitioning, but the model is not emitting them.') + # `get_axis_names` removes "_axes" suffix in the leaf name and replaces + # `AxisMetadata` with `PartitionSpec`. + axis_names = flax_partitioning.get_axis_names(params_axes) + optimizer_def.set_param_axes(axis_names) + + optimizer = optimizer_def.create(params) + return FlaxOptimTrainState( + optimizer, + params_axes=params_axes, + flax_mutables=flax_mutables, + flax_mutables_axes=flax_mutables_axes) + + @property + def step(self) -> jnp.ndarray: + return self._optimizer.state.step + + @property + def params(self) -> FrozenVariableDict: + return self._optimizer.target + + @property + def param_states(self) -> FrozenVariableDict: + return self._optimizer.state.param_states + + def state_dict(self) -> MutableVariableDict: + state_dict = self._optimizer.state_dict() + if self.flax_mutables: + state_dict['flax_mutables'] = flax.core.unfreeze(self.flax_mutables) + return state_dict + + def apply_gradient(self, + grads, + learning_rate, + flax_mutables=EMPTY_DICT) -> 'FlaxOptimTrainState': + new_optimizer = self._optimizer.apply_gradient( + grads, learning_rate=learning_rate) + return self.replace(_optimizer=new_optimizer, flax_mutables=flax_mutables) + + def replace_params(self, params: VariableDict) -> 'FlaxOptimTrainState': + return self.replace(_optimizer=self._optimizer.replace(target=params)) + + def replace_step(self, step: jnp.ndarray) -> 'FlaxOptimTrainState': + state_dict = self.state_dict() + state_dict['state']['step'] = step + return self.restore_state(state_dict) + + def restore_state(self, state_dict: VariableDict) -> 'FlaxOptimTrainState': + new_optimizer = self._optimizer.restore_state(state_dict) + return self.replace( + _optimizer=new_optimizer, + flax_mutables=flax.core.freeze(state_dict['flax_mutables']) + if 'flax_mutables' in state_dict else EMPTY_DICT) + + def as_logical_axes(self) -> 'FlaxOptimTrainState': + if not hasattr(self._optimizer.optimizer_def, 'derive_logical_axes'): + raise ValueError( + f"Optimizer '{self._optimizer.optimizer_def.__class__.__name__}' " + 'requires a `derive_logical_axes` method to be used with named axis ' + 'partitioning.') + return FlaxOptimTrainState( + _optimizer=self._optimizer.optimizer_def.derive_logical_axes( + self._optimizer, + flax_partitioning.get_axis_names(self.params_axes)), + flax_mutables=flax_partitioning.get_axis_names(self.flax_mutables_axes)) + + +class InferenceState(flax.struct.PyTreeNode): + """State compatible with FlaxOptimTrainState without optimizer state.""" + + step: jnp.ndarray + params: flax_scope.FrozenVariableDict + params_axes: Optional[flax_scope.FrozenVariableDict] = None + flax_mutables: flax_scope.FrozenDict = EMPTY_DICT + flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None + + @classmethod + def create(cls, model_variables: FrozenVariableDict) -> 'InferenceState': + other_variables, params = model_variables.pop('params') + if 'params_axes' in other_variables: + other_variables, params_axes = other_variables.pop('params_axes') + _validate_params_axes(params_axes, params) + else: + params_axes = None + + # Split other_variables into mutables and their corresponding axes. + flax_mutables, flax_mutables_axes = _split_variables_and_axes( + other_variables) + + return InferenceState( + step=jnp.array(0), + params=params, + params_axes=params_axes, + flax_mutables=flax_mutables, + flax_mutables_axes=flax_mutables_axes) + + @property + def param_states(self) -> FrozenVariableDict: + """The optimizer states of the parameters as a PyTree.""" + raise NotImplementedError('InferenceState has no optimizer states.') + + def apply_gradient(self, *args, **kwargs) -> 'InferenceState': + raise NotImplementedError( + 'InferenceState does not support `apply_gradient`.') + + def state_dict(self) -> MutableMapping[str, Any]: + state_dict = { + 'target': flax.core.unfreeze(self.params), + 'state': { + 'step': self.step + } + } + if self.flax_mutables: + state_dict['flax_mutables'] = flax.core.unfreeze(self.flax_mutables) + return state_dict + + def replace_step(self, step: jnp.ndarray) -> 'InferenceState': + return self.replace(step=step) + + def replace_params(self, params: FrozenVariableDict) -> 'InferenceState': + return self.replace(params=params) + + def restore_state(self, state_dict: Mapping[str, Any]) -> 'InferenceState': + return self.replace( + params=flax.core.freeze(state_dict['target']), + step=state_dict['state']['step'], + flax_mutables=flax.core.freeze(state_dict['flax_mutables']) + if 'flax_mutables' in state_dict else EMPTY_DICT) + + def as_logical_axes(self) -> 'InferenceState': + # Set step to None so that when the logical axes are processed by the + # flax.partitioning.logical_to_mesh_axes function, it will be skipped + # because jax.tree_map will short circut and never call the function on the + # step. + return InferenceState( + step=None, + params=flax_partitioning.get_axis_names(self.params_axes), + flax_mutables=flax_partitioning.get_axis_names(self.flax_mutables_axes)) diff --git a/t5x/train_state_test.py b/t5x/train_state_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c9fa6e1d92290976bad5046eaa343b835792c875 --- /dev/null +++ b/t5x/train_state_test.py @@ -0,0 +1,628 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for train_state.""" +from absl.testing import absltest +from flax import linen as nn +from flax import optim +import flax.core +from flax.linen import partitioning as flax_partitioning +import jax +import numpy as np +from t5x import adafactor +from t5x import optimizers +from t5x import partitioning +from t5x import train_state as train_state_lib + +mock = absltest.mock +AxisMetadata = flax_partitioning.AxisMetadata +FactorDim = adafactor.FactorDim + + +class FlaxOptimTrainStateTest(absltest.TestCase): + + def test_init(self): + model = nn.Dense(10) + inputs = np.ones([2, 3], dtype=np.float32) + params = model.init(jax.random.PRNGKey(0), inputs)['params'] + optimizer_def = optimizers.adam(0.1) + optimizer = optimizer_def.create(params) + flax_mutables = flax.core.freeze({'flax_mutable1': np.ones(10)}) + state = train_state_lib.FlaxOptimTrainState( + optimizer, flax_mutables=flax_mutables) + self.assertEqual(state.step, 0) + self.assertIsInstance(state._optimizer, optimizers.Optimizer) + self.assertEqual(state.state_dict()['flax_mutables'], + flax.core.unfreeze(flax_mutables)) + jax.tree_multimap(np.testing.assert_array_equal, params, state.params) + jax.tree_multimap(np.testing.assert_array_equal, + optimizer.state.param_states, state.param_states) + + def test_create(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'mutables': np.ones(3) + }) + optimizer_def = optimizers.sgd(0.42) + state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, + model_variables) + self.assertEqual(state.step, 0) + self.assertIsInstance(state._optimizer, optimizers.Optimizer) + self.assertEqual(state._optimizer.optimizer_def, optimizer_def) + jax.tree_multimap(np.testing.assert_array_equal, state.flax_mutables, + flax.core.freeze({'mutables': np.ones(3)})) + jax.tree_multimap(np.testing.assert_array_equal, state.params, + model_variables['params']) + self.assertIsNone(state.params_axes) + + def test_create_with_params_axes(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + } + }, + }) + optimizer_def = adafactor.Adafactor( + 0.42, + logical_factor_rules={ + 'vocab': FactorDim.COLUMN, + 'embed': FactorDim.ROW + }) + state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, + model_variables) + self.assertEqual(state.step, 0) + self.assertIsInstance(state._optimizer, optimizers.Optimizer) + self.assertEqual(state._optimizer.optimizer_def, optimizer_def) + self.assertDictEqual( + state._optimizer.optimizer_def.hyper_params.factor_map, { + 'dense/bias': (FactorDim.NONE,), + 'dense/kernel': (FactorDim.COLUMN, FactorDim.ROW) + }) + self.assertEqual(state.flax_mutables, flax.core.freeze({})) + jax.tree_multimap(np.testing.assert_array_equal, model_variables['params'], + state.params) + jax.tree_multimap(np.testing.assert_array_equal, + model_variables['params_axes'], state.params_axes) + + def test_create_with_flax_mutables_axes(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + } + }, + 'grads': { + 'dense': { + 'output_grad': np.zeros(4), + } + }, + 'grads_axes': { + 'dense': { + 'output_grad': AxisMetadata(names=('embed',)), + } + }, + }) + optmizer_def = adafactor.Adafactor( + 0.42, + logical_factor_rules={ + 'vocab': FactorDim.COLUMN, + 'embed': FactorDim.ROW + }) + state = train_state_lib.FlaxOptimTrainState.create(optmizer_def, + model_variables) + self.assertEqual(state.step, 0) + self.assertIsInstance(state._optimizer, optimizers.Optimizer) + self.assertEqual(state._optimizer.optimizer_def, optmizer_def) + self.assertDictEqual( + state._optimizer.optimizer_def.hyper_params.factor_map, { + 'dense/bias': (FactorDim.NONE,), + 'dense/kernel': (FactorDim.COLUMN, FactorDim.ROW) + }) + self.assertEqual(state.flax_mutables, + flax.core.freeze({'grads': model_variables['grads']})) + jax.tree_multimap(np.testing.assert_array_equal, model_variables['params'], + state.params) + jax.tree_multimap(np.testing.assert_array_equal, + model_variables['params_axes'], state.params_axes) + jax.tree_multimap(np.testing.assert_array_equal, + model_variables['grads_axes'], + state.flax_mutables_axes['grads']) + + def test_create_missing_params_axes(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'mutables': np.ones(3) + }) + with self.assertRaisesWithLiteralMatch( + ValueError, + 'The optimizer supports params_axes for model-based partitioning, but ' + 'the model is not emitting them.'): + train_state_lib.FlaxOptimTrainState.create(adafactor.Adafactor(), + model_variables) + + def test_create_mismatched_params_axes(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + } + }, + 'mutables': np.ones(3) + }) + with self.assertRaisesWithLiteralMatch( + ValueError, "Missing axis names for parameters: {'dense/kernel'}"): + train_state_lib.FlaxOptimTrainState.create(adafactor.Adafactor(), + model_variables) + + def test_replace_params(self): + optimizer_def = optimizers.sgd(0.1) + optimizer = optimizer_def.create({'test': np.ones(10)}) + state = train_state_lib.FlaxOptimTrainState(optimizer) + + new_params = {'test': np.zeros(10)} + new_state = state.replace_params(new_params) + jax.tree_multimap(np.testing.assert_array_equal, new_params, + new_state.params) + expected_state_dict = state.state_dict() + expected_state_dict['target'] = new_params + jax.tree_multimap(np.testing.assert_array_equal, expected_state_dict, + new_state.state_dict()) + + def test_replace_step(self): + optimizer_def = optimizers.adam(0.1) + optimizer = optimizer_def.create({'test': np.ones(10)}) + state = train_state_lib.FlaxOptimTrainState(optimizer) + + self.assertEqual(state.step, 0) + self.assertEqual(state.replace_step(jax.numpy.array(1)).step, 1) + + def test_apply_gradient(self): + updated_optimizer = object() + optimizer = mock.Mock( + apply_gradient=mock.Mock(return_value=updated_optimizer)) + state = train_state_lib.FlaxOptimTrainState(optimizer) + + new_flax_mutables = {'test': 44} + new_state = state.apply_gradient( + grads=42, learning_rate=43, flax_mutables={'test': 44}) + + optimizer.apply_gradient.assert_called_once_with(42, learning_rate=43) + + self.assertEqual(new_state._optimizer, updated_optimizer) + self.assertEqual( + new_state, + train_state_lib.FlaxOptimTrainState( + updated_optimizer, flax_mutables=new_flax_mutables)) + + def test_as_logical_axes(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + } + }, + }) + optimizer_def = adafactor.Adafactor( + 0.42, + logical_factor_rules={ + 'vocab': FactorDim.COLUMN, + 'embed': FactorDim.ROW + }) + state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, + model_variables) + axes_state = state.as_logical_axes() + self.assertIsNone(axes_state.params_axes) + jax.tree_multimap( + np.testing.assert_array_equal, axes_state.params, + flax.core.freeze({ + 'dense': { + 'bias': partitioning.PartitionSpec('embed'), + 'kernel': partitioning.PartitionSpec('vocab', 'embed'), + } + })) + + def test_as_logical_axes_with_flax_mutables(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + } + }, + 'grads': { + 'dense': { + 'output_grad': np.zeros(4), + } + }, + 'grads_axes': { + 'dense': { + 'output_grad': AxisMetadata(names=('embed',)), + } + }, + }) + optmizer_def = adafactor.Adafactor( + 0.42, + logical_factor_rules={ + 'vocab': FactorDim.COLUMN, + 'embed': FactorDim.ROW + }) + state = train_state_lib.FlaxOptimTrainState.create(optmizer_def, + model_variables) + axes_state = state.as_logical_axes() + self.assertIsNone(axes_state.params_axes) + jax.tree_multimap( + np.testing.assert_array_equal, axes_state.flax_mutables, + flax.core.freeze({ + 'grads': { + 'dense': { + 'output_grad': partitioning.PartitionSpec('embed'), + } + } + })) + + def test_as_logical_axes_unsupported_optimizer(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + } + }, + }) + optimizer_def = optim.GradientDescent(0.42) + state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, + model_variables) + with self.assertRaisesWithLiteralMatch( + ValueError, + "Optimizer 'GradientDescent' requires a `derive_logical_axes` method " + 'to be used with named axis partitioning.'): + state.as_logical_axes() + + def test_to_state_dict(self): + model_variables = flax.core.freeze({ + 'params': { + 'kernel': np.zeros((2, 4)) + }, + 'params_axes': { + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + }, + 'mutables': np.ones(3) + }) + optimizer_def = adafactor.Adafactor( + 0.42, + logical_factor_rules={ + 'vocab': FactorDim.COLUMN, + 'embed': FactorDim.ROW + }) + state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, + model_variables) + jax.tree_multimap( + np.testing.assert_array_equal, state.state_dict(), { + 'state': { + 'step': np.array(0), + 'param_states': { + 'kernel': { + 'm': np.zeros(1), + 'v': np.zeros((2, 4)), + 'v_col': np.zeros(1), + 'v_row': np.zeros(1) + }, + } + }, + 'target': { + 'kernel': np.zeros((2, 4)) + }, + 'flax_mutables': { + 'mutables': np.ones(3) + } + }) + + def test_restore_state(self): + model_variables = flax.core.freeze({ + 'params': { + 'kernel': np.zeros((2, 4)) + }, + 'params_axes': { + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + }, + 'mutables': np.ones(3) + }) + optimizer_def = adafactor.Adafactor( + 0.42, + logical_factor_rules={ + 'vocab': FactorDim.COLUMN, + 'embed': FactorDim.ROW + }) + state = train_state_lib.FlaxOptimTrainState.create(optimizer_def, + model_variables) + restored = state.restore_state({ + 'state': { + 'step': np.array(1), + 'param_states': { + 'kernel': { + 'm': np.ones(1), + 'v': np.ones((2, 4)), + 'v_col': np.ones(1), + 'v_row': np.ones(1) + }, + } + }, + 'target': { + 'kernel': np.ones((2, 4)) + }, + 'flax_mutables': { + 'mutables': np.zeros(3) + } + }) + + self.assertEqual(restored.step, 1) + self.assertIsInstance(restored._optimizer, optimizers.Optimizer) + self.assertEqual(restored._optimizer.optimizer_def, optimizer_def) + jax.tree_multimap(np.testing.assert_array_equal, restored.flax_mutables, + flax.core.freeze({'mutables': np.zeros(3)})) + jax.tree_multimap(np.testing.assert_array_equal, restored.params, + flax.core.freeze({'kernel': np.ones((2, 4))})) + jax.tree_multimap( + np.testing.assert_array_equal, restored.param_states, + flax.core.freeze({ + 'kernel': + adafactor._AdafactorParamState( + np.ones(1), np.ones(1), np.ones((2, 4)), np.ones(1)) + })) + jax.tree_multimap(np.testing.assert_array_equal, restored.params_axes, + model_variables['params_axes']) + + +class InferenceStateTest(absltest.TestCase): + + def test_init(self): + model = nn.Dense(10) + inputs = np.ones([2, 3], dtype=np.float32) + params = model.init(jax.random.PRNGKey(0), inputs)['params'] + flax_mutables = flax.core.freeze({'flax_mutable1': np.ones(10)}) + state = train_state_lib.InferenceState( + step=jax.numpy.array(1), params=params, flax_mutables=flax_mutables) + self.assertEqual(state.step, 1) + self.assertEqual(state.flax_mutables, flax.core.unfreeze(flax_mutables)) + jax.tree_multimap(np.testing.assert_array_equal, params, state.params) + self.assertIsNone(state.params_axes) + + def test_create(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + } + }, + 'mutables': np.ones(3) + }) + state = train_state_lib.InferenceState.create(model_variables) + self.assertEqual(state.step, 0) + jax.tree_multimap(np.testing.assert_array_equal, state.flax_mutables, + flax.core.freeze({'mutables': np.ones(3)})) + jax.tree_multimap(np.testing.assert_array_equal, state.params, + model_variables['params']) + jax.tree_multimap(np.testing.assert_array_equal, state.params_axes, + model_variables['params_axes']) + + def test_create_mismatched_params_axes(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + } + }, + 'mutables': np.ones(3) + }) + with self.assertRaisesWithLiteralMatch( + ValueError, "Missing axis names for parameters: {'dense/kernel'}"): + train_state_lib.InferenceState.create(model_variables) + + def test_replace_params(self): + model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) + state = train_state_lib.InferenceState.create(model_variables) + + new_params = {'test': np.zeros(10)} + new_state = state.replace_params(new_params) + jax.tree_multimap(np.testing.assert_array_equal, new_params, + new_state.params) + + def test_replace_step(self): + model_variables = flax.core.freeze({'params': {'test': np.ones(10)}}) + state = train_state_lib.InferenceState.create(model_variables) + + self.assertEqual(state.step, 0) + self.assertEqual(state.replace_step(jax.numpy.array(1)).step, 1) + + def test_as_logical_axes(self): + model_variables = flax.core.freeze({ + 'params': { + 'dense': { + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + } + }, + 'params_axes': { + 'dense': { + 'bias_axes': AxisMetadata(names=('embed',)), + 'kernel_axes': AxisMetadata(names=('vocab', 'embed')), + } + }, + }) + state = train_state_lib.InferenceState.create(model_variables) + axes_state = state.as_logical_axes() + self.assertIsNone(axes_state.params_axes) + jax.tree_multimap( + np.testing.assert_array_equal, axes_state.params, + flax.core.freeze({ + 'dense': { + 'bias': partitioning.PartitionSpec('embed'), + 'kernel': partitioning.PartitionSpec('vocab', 'embed'), + } + })) + + def test_to_state_dict(self): + model_variables = flax.core.freeze({ + 'params': { + 'bias': np.zeros(4), + }, + 'params_axes': { + 'bias_axes': AxisMetadata(names=('embed',)), + }, + 'mutables': np.ones(3) + }) + state = train_state_lib.InferenceState.create(model_variables) + jax.tree_multimap( + np.testing.assert_array_equal, state.state_dict(), { + 'state': { + 'step': np.array(0) + }, + 'target': { + 'bias': np.zeros(4), + }, + 'flax_mutables': { + 'mutables': np.ones(3) + } + }) + + def test_to_state_dict_no_mutables(self): + model_variables = flax.core.freeze({ + 'params': { + 'bias': np.zeros(4), + }, + 'params_axes': { + 'bias_axes': AxisMetadata(names=('embed',)), + }, + }) + state = train_state_lib.InferenceState.create(model_variables) + jax.tree_multimap(np.testing.assert_array_equal, state.state_dict(), { + 'state': { + 'step': np.array(0) + }, + 'target': { + 'bias': np.zeros(4), + }, + }) + + def test_restore_state(self): + state = train_state_lib.InferenceState( + np.array(0), {'bias': np.zeros(4)}, + {'bias_axes': AxisMetadata(names=('embed',))}) + + state_dict = { + 'state': { + 'step': np.array(10) + }, + 'target': { + 'bias': np.ones(4), + }, + 'flax_mutables': { + 'mutables': np.ones(3) + } + } + restored = state.restore_state(state_dict) + + self.assertEqual(restored.step, 10) + jax.tree_multimap(np.testing.assert_array_equal, restored.flax_mutables, + flax.core.freeze(state_dict['flax_mutables'])) + jax.tree_multimap(np.testing.assert_array_equal, restored.params, + flax.core.freeze(state_dict['target'])) + self.assertEqual(restored.params_axes, + {'bias_axes': AxisMetadata(names=('embed',))}) + + def test_restore_state_no_mutables_no_axes(self): + state = train_state_lib.InferenceState(np.array(0), {}) + + state_dict = { + 'state': { + 'step': np.array(10) + }, + 'target': { + 'bias': np.zeros(4), + }, + } + restored = state.restore_state(state_dict) + + self.assertEqual(restored.step, 10) + self.assertEqual(restored.flax_mutables, train_state_lib.EMPTY_DICT) + jax.tree_multimap(np.testing.assert_array_equal, restored.params, + flax.core.freeze(state_dict['target'])) + self.assertIsNone(restored.params_axes) + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/trainer.py b/t5x/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6537a6127625dac826c126004246b64235aa46e7 --- /dev/null +++ b/t5x/trainer.py @@ -0,0 +1,1055 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trainer and MetricsManager classes for use in train loop. + +To create a custom trainer, subclass `BaseTrainer` and implement +`_partitioned_train_step` and `_partitioned_eval_step` methods, +possibly by re-using the utility functions provided in this module. +""" +import abc +import enum +import os +import threading +import time +from typing import Any, Dict, Iterator, Mapping, MutableMapping, Optional, Sequence, TYPE_CHECKING, Tuple, Union + +from absl import logging +import cached_property +from clu import asynclib +from clu import metric_writers +import clu.data +import clu.metrics +import clu.values +from flax.core import FrozenDict +from jax.experimental import multihost_utils +import jax.lax +import jax.numpy as jnp +import jax.random +import numpy as np +from t5x import metrics as metrics_lib +from t5x import models +from t5x import partitioning +from t5x import train_state as train_state_lib +from t5x import utils +import typing_extensions + + +Array = Union[np.ndarray, jnp.ndarray] +BatchSpec = Mapping[str, jax.ShapeDtypeStruct] +BatchType = Mapping[str, np.ndarray] +FlaxMutables = FrozenDict +Rng = jnp.ndarray +MetricMapType = MutableMapping[str, clu.metrics.Metric] +MetricMapSpec = Mapping[str, jax.ShapeDtypeStruct] +MetricValueMapType = Mapping[str, clu.values.Value] +ModelWeights = Any +MutableMetricMapType = Dict[str, clu.metrics.Metric] +PyTreeDef = type(jax.tree_structure(None)) +PartitionSpec = partitioning.PartitionSpec + +if TYPE_CHECKING: # See b/163639353 + cached_property = property # pylint: disable=invalid-name +else: + cached_property = cached_property.cached_property + + +@jax.jit +def _merge_metrics(a, b): + return jax.tree_multimap( + lambda a, b: a.merge(b), a, b, is_leaf=metrics_lib.is_metric_obj) + + +# Merges two metrics pytrees (mapping of metric_name (str) to clu.Metric object) +def merge_metrics(a, b): + a, b = jax.tree_map(utils.get_local_data, (a, b)) + return _merge_metrics(a, b) + + +class ArrayMapFuture(typing_extensions.Protocol): + + def result(self) -> Mapping[str, Array]: + ... + + +class MetricValueMapFuture(typing_extensions.Protocol): + + def result(self) -> Mapping[str, clu.values.Value]: + ... + + +class TimeFuture(typing_extensions.Protocol): + + def result(self) -> float: + ... + + +class LearningRateCallable(typing_extensions.Protocol): + + def __call__( + self, + step: jnp.ndarray, + ) -> jnp.ndarray: + ... + + +class SummarizeMetricsCallable(typing_extensions.Protocol): + """PyType template for a metrics summary function.""" + + def __call__(self, metrics: MetricMapType, duration: float, + num_steps: int) -> Mapping[str, jnp.ndarray]: + """Summarizes metrics accumulated across multiple steps. + + Args: + metrics: Metrics accumulated across multiple steps. + duration: The duration of the run being summarized. + num_steps: The number of steps the metrics are accumulated across. + + Returns: + Summarized metrics. + """ + ... + + +class PartitionedTrainCallable(typing_extensions.Protocol): + """Protocol for a partitioned train step.""" + + def __call__( + self, train_state: train_state_lib.TrainState, + batch: BatchType) -> Tuple[train_state_lib.TrainState, MetricMapType]: + ... + + +class PartitionedEvalCallable(typing_extensions.Protocol): + """Protocol for a partitioned eval step.""" + + def __call__(self, train_state: train_state_lib.TrainState, + batch: jnp.ndarray) -> MetricMapType: + ... + + +class WeightMetricsComputer(object): + """Decides which weight metrics to compute during training.""" + + _WEIGHT_METRICS = [ + "weight_rms", "weight_gradient_rms", "weight_update_rms", "weight_max" + ] + + @staticmethod + def _make_rms_metrics(name, tree): + """Calculates the root-mean-square metric for a pytree.""" + return { + f"{name}/{k}": metrics_lib.AveragePerStep.from_model_output( + jnp.sqrt(jnp.mean(jnp.square(v)))) + for k, v in utils.flatten_dict_string_keys(tree).items() + } + + @staticmethod + def _make_max_metrics(name, tree): + """Calculates the L-inf norm for a pytree.""" + return { + f"{name}/{k}": + metrics_lib.AveragePerStep.from_model_output(jnp.max(jnp.abs(v))) + for k, v in utils.flatten_dict_string_keys(tree).items() + } + + def compute_metrics( + self, gradients: ModelWeights, + old_train_state: train_state_lib.TrainState, + new_train_state: train_state_lib.TrainState) -> MutableMetricMapType: + """Compute some metrics about weights after having updating these weights. + + Args: + gradients: The gradients of all weights. + old_train_state: The training state before applying the gradients. + new_train_state: The training state after applying the gradients. + + Returns: + A dictionary of Metrics, where the keys are either metric names, or of the + form metric_name/parameter_name, depending on whether or not they are + global to the model, or specific to each model parameter. + """ + # TODO(reinerp): Extend weight stats logging with support for non-reduced + # axes of tensors. For example, for stacked layers (QKV stacking or layer + # stacking), we might not want to reduce over the stacking dimension, in + # order to provide more localization in the logged stats. + metrics = {} + metrics.update(self._make_rms_metrics("weight_rms", new_train_state.params)) + metrics.update(self._make_rms_metrics("weight_gradient_rms", gradients)) + grad_norm = jnp.sqrt( + jnp.sum( + jnp.array([jnp.vdot(x, x) for x in jax.tree_leaves(gradients)]))) + metrics.update({ + "weight_gradient_norm": + metrics_lib.AveragePerStep.from_model_output(grad_norm) + }) + metrics.update( + self._make_rms_metrics( + "weight_update_rms", + jax.tree_multimap(jnp.subtract, new_train_state.params, + old_train_state.params))) + metrics.update(self._make_max_metrics("weight_max", new_train_state.params)) + + return metrics + + +class _AsyncTimer(object): + """A timer that computes computes durations between async jax operations. + + You should call close() to wait for threads started by this class to finish. + """ + + def __init__(self): + # We use a thread pool with a single worker to ensure that calls to the + # function are run in order (but in a background thread). + self._pool = asynclib.Pool(thread_name_prefix="AsyncTimer", max_workers=1) + self._start_future = None + + def close(self): + self._pool.close() + + def __del__(self): + self.close() + + def _get_completion_future(self, block_on: PyTreeDef = ()) -> TimeFuture: + """Returns Future containing time when `block_on` is ready.""" + + def _get_completion_time(): + try: + jax.block_until_ready(block_on) + except RuntimeError as e: + # If the buffer no longer exists, we assume it was completed. + if (str(e) != + "INVALID_ARGUMENT: BlockHostUntilReady() called on deleted or " + "donated buffer"): + raise + return time.time() + + return self._pool(_get_completion_time)() + + def start(self, block_on: PyTreeDef = ()): + """Starts timer after `block_on` is ready.""" + self._start_future = self._get_completion_future(block_on) + + def stop(self, block_on: PyTreeDef = ()) -> TimeFuture: + """Stops timer after `block_on` is ready, returning the duration.""" + if not self._start_future: + raise ValueError("The timer hasn't been started.") + + start_future = self._start_future + self._start_future = None + stop_future = self._get_completion_future(block_on) + return self._pool(lambda: stop_future.result() - start_future.result())() + + +class MetricsManager(object): + """Manages a set of distributed metrics and their logging. + + Logging is disabled on all but host 0. + + Logs to: + * TensorBoard + * ABSL + + You should call close() to wait for threads started by this class to finish. + """ + + def __init__(self, name: str, summary_dir: Optional[str] = None): + """MetricsManager constructor. + + Constructs an empty MetricWriter on all but host 0. + + Args: + name: an identifier of the metrics to use when logging (e.g., 'train'). + summary_dir: the summary directory. If provided, TensorBoard summaries + will be written to a `name` subdirectory. + """ + self._name = name + if jax.process_index() == 0: + self._writer = metric_writers.create_default_writer( + summary_dir, + collection=name, + asynchronous=True) + else: + self._writer = metric_writers.MultiWriter([]) + self.summary_dir = os.path.join(summary_dir, name) if summary_dir else None + self._writer_lock = threading.Lock() + # We use a thread pool with a single worker to ensure that calls to the + # function are run in order (but in a background thread). + self._summary_pool = asynclib.Pool( + thread_name_prefix="MetricsManager", max_workers=1) + # Times the duration between steps. + self._duration_timer = _AsyncTimer() + + def __del__(self): + self.close() + + def close(self): + try: + self._summary_pool.close() + finally: + try: + self._duration_timer.close() + finally: + if self._writer: + self._writer.close() + self._writer = None + + @property + def summary_writer(self) -> metric_writers.MetricWriter: + """Returns the MetricWriter used by this class.""" + # TODO(adarob): Make returned writer threadsafe. + return self._writer + + def write_scalar(self, key: str, val: metric_writers.interface.Scalar, + step: int): + """Writes scalar value to metric writers in a threadsafe manner.""" + step = int(utils.get_local_data(step)) + self.write_scalars(step, {key: val}) + + def write_scalars(self, step: int, + scalars: Mapping[str, metric_writers.interface.Scalar]): + """Writes scalar value to metric writers in a threadsafe manner.""" + step = utils.get_local_data(step) + with self._writer_lock: + self._writer.write_scalars(step, scalars) + + def start_duration_timer(self, block_on: PyTreeDef = ()): + """Starts the duration timer.""" + self._duration_timer.start(block_on=block_on) + + def write_metrics_summary(self, metrics: MetricMapType, step: int, + num_steps: int) -> MetricValueMapFuture: + """Writes summary based on accumulated metrics in a background thread. + + Duration is automatically computed as the interval between completion of + metrics fetching. This closely approximates the duration of `num_steps`, + as the steps must be computes sequentually, and it is more accurate than + computing the time since the call to the step function since its actual + execution occurs asynchronously on the TPU/GPU device. + + Args: + metrics: acculumated metric values. + step: the current train step. + num_steps: the number of steps the metrics are accumulated across. + + Returns: + A mapping of name -> scalar value of the written summary. Only return the + real scalar value on host 0. For other hosts, return None. + """ + step = utils.get_local_data(step) + + # Must be called in the main thread to avoid race condition. + duration_future = self._duration_timer.stop(block_on=metrics) + + def _summarize_and_write(): + # For thread safety we first copy the metrics to host. + fetched_metrics = jax.tree_map(jax.device_get, metrics) + + duration = duration_future.result() + # We set the duration on time-related metrics. + final_metrics = metrics_lib.set_time_metrics_duration( + fetched_metrics, duration) + # Set num_steps for Step metrics (AveragePerStep, StepsPerTime, ...) + final_metrics = metrics_lib.set_step_metrics_num_steps( + final_metrics, num_steps) + + # Ensure the metrics are not on device, which could lead to a deadlock. + def _ensure_not_on_device(x): + assert not isinstance(x, jax.numpy.DeviceArray) + + jax.tree_map(_ensure_not_on_device, final_metrics) + final_metrics = jax.tree_map(utils.get_local_data, final_metrics) + + summary = {k: v.compute_value() for k, v in final_metrics.items()} + with self._writer_lock: + metric_writers.write_values(self._writer, int(step), summary) + + return summary + + return self._summary_pool(_summarize_and_write)() + + def flush(self): + try: + self._summary_pool.join() + finally: + self._writer.flush() + + +class PreemptionError(Exception): + """Training has been interrupted and needs an emergency checkpoint.""" + + +class BaseTrainer(abc.ABC): + """Abstract base trainer class. + + Internally this uses MetricsManagers that start threads. You should + use the trainer as a context manager, or call close() directly in + order to wait for these threads to finish after training is done. + """ + + def __init__(self, model: models.BaseModel, + train_state: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, + eval_names: Sequence[str], summary_dir: Optional[str], + train_state_axes: Any, rng: Rng): + """Trainer constructor. + + Args: + model: the instantiation of `BaseModel` to train. + train_state: A train state with model parameters and optimizer state. + partitioner: the partitioner to use. + eval_names: names of evaluation datasets, which must match the keys of the + mapping passed to `eval`. + summary_dir: optional directory to write TensorBoard metrics to. + train_state_axes: partitioning info for the train state to be used. + rng: jax PRNGKey seed for random operations, to be combined with step + number for a deterministic RNG. + """ + self._model = model + self._train_state_axes = train_state_axes + self._base_rng = rng + self._partitioner = partitioner + self._compiled_train_step: Optional[PartitionedTrainCallable] = None + self._compiled_eval_steps: MutableMapping[str, PartitionedEvalCallable] = {} + self._compiled_eval_step_cache: MutableMapping[ + BatchSpec, PartitionedEvalCallable] = {} + + self._train_state_mutex = threading.RLock() + self._train_state = train_state + + self.stop_training = False + + # The training metrics combine metrics added by the Model (e.g., loss and + # accuracy) and Trainer (e.g., learning rate). + self.train_metrics_manager = MetricsManager( + "train", summary_dir=summary_dir) + + # The eval metrics only include metrics added by the Model. + self.eval_metrics_managers = { # pylint:disable=g-complex-comprehension + n: MetricsManager(f"training_eval/{n}", summary_dir=summary_dir) + for n in eval_names + } + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def close(self): + """Stops all train metric managers threads.""" + self.train_metrics_manager.close() + for mm in self.eval_metrics_managers.values(): + mm.close() + + def _get_step_rng(self, step: int) -> Rng: + return jax.random.fold_in(self._base_rng, step) + + @property + def train_state(self): + with self._train_state_mutex: + return self._train_state + + @train_state.setter + def train_state(self, train_state: PyTreeDef): + with self._train_state_mutex: + self._train_state = train_state + + def train(self, + batch_iter: Union[Iterator[BatchType], clu.data.DatasetIterator], + num_steps: int, + start_step: Optional[int] = None) -> ArrayMapFuture: + """Runs the train loop for the given number of steps.""" + metrics = None + # Use pre-compiled step, if available. + train_step_fn = self._compiled_train_step or self._partitioned_train_step + + # We lock `train_state` access during the loop to avoid race conditions. + with self._train_state_mutex: + train_state = self.train_state + # Compute step number on host to avoid communication overhead. + start_step = int( + start_step if start_step is not None else train_state.step) + self.train_metrics_manager.start_duration_timer(block_on=train_state) + for step_num in range(start_step, start_step + num_steps): + logging.log_every_n_seconds(logging.INFO, "Training: step %d", 10, + step_num) + with jax.profiler.StepTraceAnnotation("train", step_num=step_num): + batch = next(batch_iter) + train_state, metrics_update = train_step_fn(train_state, batch) + if metrics: + metrics = merge_metrics(metrics, metrics_update) + else: + metrics = metrics_update + + self.train_state = train_state + + return self.train_metrics_manager.write_metrics_summary( + metrics, start_step + num_steps, num_steps) + + def compile_train(self, batch: BatchType) -> None: + """Pre-compiles train step (if not yet compiled). + + Not required. + + If not called before `train`, compilation will occur automatically on the + first step and JAX's "jit cache" will be used to avoid recompilation for + future steps. + + Args: + batch: A sample batch that may contain dummy values, but with correct + shapes and dtypes. + """ + tick = time.time() + self._compiled_train_step = self._partitioner.compile( + self._partitioned_train_step, self.train_state, batch) + tock = time.time() + self.train_metrics_manager.write_scalar("timing/compilation_seconds", + tock - tick, self.train_state.step) + + def eval( + self, batch_iters: Mapping[str, + Iterator[BatchType]]) -> Mapping[str, Array]: + """Runs evaluation loop over the iterator and writes summary.""" + eval_summaries = {} + train_state = self.train_state + for iter_name, batch_iter in batch_iters.items(): + logging.info("Evaluating: %s.", iter_name) + metrics = None + # Use a pre-compiled step function, if available. + eval_step_fn = self._compiled_eval_steps.get(iter_name, + self._partitioned_eval_step) + mm = self.eval_metrics_managers[iter_name] + + num_steps = 0 + mm.start_duration_timer(block_on=train_state) + for batch in batch_iter: + num_steps += 1 + multihost_utils.assert_equal( + jnp.array(num_steps), + "Eval step mismatch across hosts. Check for empty dataset shard.") + metrics_update = eval_step_fn(train_state, batch) + if metrics: + metrics = merge_metrics(metrics, metrics_update) + else: + metrics = metrics_update + multihost_utils.assert_equal( + jnp.array(-1), + "Eval step mismatch across hosts. Check for empty dataset shard.") + + eval_summaries[iter_name] = mm.write_metrics_summary( + metrics, train_state.step, num_steps) + + # TODO(adarob): Return futures. + return {k: v.result() for k, v in eval_summaries.items()} + + def compile_eval(self, batches: Mapping[str, BatchType]) -> None: + """Pre-compiles eval step (if not yet compiled). + + Not required. + + Pre-compiles the evaluation step for each evaluation dataset, reusing cached + compilations where possible. In other words, if multiple evaluation datasets + have equivalent shapes/dtypes for the batch and initial metrics, + recompilation will be avoided. + + If not called before `eval`, compilation will occur automatically on the + first step and JAX's "jit cache" will be used to avoid recompilation for + future steps. + + Args: + batches: a mapping from evaluation dataset name to a sample batch. The + batch may contain dummy values, but the shapes and dtypes must be + correct. + """ + for eval_name, batch in batches.items(): + tick = time.time() + cache_key: BatchSpec = FrozenDict(jax.eval_shape(lambda: batch)) # pylint:disable=cell-var-from-loop + if cache_key not in self._compiled_eval_step_cache: + self._compiled_eval_step_cache[cache_key] = self._partitioner.compile( + self._partitioned_eval_step, self.train_state, batch) + self._compiled_eval_steps[eval_name] = self._compiled_eval_step_cache[ + cache_key] + tock = time.time() + self.eval_metrics_managers[eval_name].write_scalar( + "timing/compilation_seconds", tock - tick, self.train_state.step) + + @property + @abc.abstractmethod + def _partitioned_train_step(self) -> PartitionedTrainCallable: + """Partitioned train step.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def _partitioned_eval_step(self) -> PartitionedEvalCallable: + """Partitioned eval step.""" + raise NotImplementedError + + +def accumulate_grads_microbatched( + model: models.BaseModel, + train_state: train_state_lib.TrainState, + batch: BatchType, + dropout_rng: Rng, + num_microbatches: Optional[int], + data_partition_spec: PartitionSpec = PartitionSpec("data"), +) -> Tuple[train_state_lib.TrainState, MutableMetricMapType, + Optional[FlaxMutables]]: + """Implements optional microbatched gradient accumulation. + + Args: + model: the instantiation of `BaseModel` to train. + train_state: A train state with model parameters and optimizer state. + batch: input batch consisting of either - simply-padded batched features + 'encoder_input_tokens', 'decoder_input_tokens' 'decoder_target_tokens' + 'decoder_loss_weights'- packed, batched features with additional + "(encoder|decoder)_segment_id", "(encoder|decoder)_position" + dropout_rng: jax PRNGKey for dropout. + num_microbatches: the number of microbatches to use, or None for direct + training. + data_partition_spec: the PartitionSpec to use for partitioning annotations + on the batch. + + Returns: + Accumulated gradients and incremental metrics. + """ + batch_size = next(iter(batch.values())).shape[0] + + grad_fn = jax.value_and_grad(model.loss_fn, has_aux=True) + + # We assume that the model loss_fn supports flax mutables if and only if + # the train state includes non-empty flax mutables. + # Note: Default t5x models don't support flax_mutables. One needs to subclass + # them and return flax_mutables from `get_initial_variables` and `loss_fn`. + + initial_flax_mutables = train_state.flax_mutables if train_state.flax_mutables else None + + if num_microbatches is None or num_microbatches <= 1: + + if initial_flax_mutables is None: + (_, metrics), grad_accum = grad_fn(train_state.params, batch, dropout_rng) + flax_mutables = None + else: + (_, metrics, flax_mutables), grad_accum = grad_fn(train_state.params, + batch, dropout_rng, + initial_flax_mutables) + else: + assert batch_size % num_microbatches == 0, ( + "Batch size isn't divided evenly by num_microbatches.") + microbatch_size = batch_size // num_microbatches + logging.info("using microbatches: %d microbatches, %d size", + num_microbatches, microbatch_size) + + def get_microbatch(batch: BatchType, idx: int) -> Mapping[str, jnp.ndarray]: + """Fetch microbatch slice from possibly-packed input data.""" + offset = idx * microbatch_size + length = microbatch_size + starts = {k: [offset] + [0] * (b.ndim - 1) for k, b in batch.items()} + limits = {k: [length] + list(b.shape[1:]) for k, b in batch.items()} + return { + k: jax.lax.dynamic_slice(b, starts[k], limits[k]) + for k, b in batch.items() + } + + def metrics_and_grad(loop_cnt, dropout_rng, flax_mutables=None): + dropout_rng, sub_dropout_rng = jax.random.split(dropout_rng) + mbatch = get_microbatch(batch, loop_cnt) + # We need to annotate the microbatch sharding as we would a batch. + mbatch = jax.tree_map( + lambda x: partitioning.with_sharding_constraint( # pylint: disable=g-long-lambda + x, data_partition_spec), + mbatch) + if flax_mutables is None: + (_, metrics), grad = grad_fn(train_state.params, mbatch, + sub_dropout_rng) + else: + (_, metrics, flax_mutables), grad = grad_fn(train_state.params, mbatch, + sub_dropout_rng, + flax_mutables) + return metrics, grad, flax_mutables + + def per_microbatch_train_step( + loop_cnt: int, state: Tuple[jnp.ndarray, jnp.ndarray, + Mapping[str, jnp.ndarray], + Optional[FlaxMutables]] + ) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, jnp.ndarray], + Optional[FlaxMutables]]: + (dropout_rng, grad_accum, prev_metrics, flax_mutables) = state + metrics, grad, flax_mutables = metrics_and_grad(loop_cnt, dropout_rng, + flax_mutables) + + grad_accum = jax.tree_multimap(jnp.add, grad_accum, grad) + metrics = jax.lax.cond(loop_cnt == 0, lambda _: metrics, + lambda _: merge_metrics(prev_metrics, metrics), + None) + return dropout_rng, grad_accum, metrics, flax_mutables + + # Initialize gradient accumulation loop state. + accum_dtype = jnp.float32 + grad_accum_init = jax.tree_map(lambda x: jnp.zeros(x.shape, accum_dtype), + train_state.params) + initial_metrics_shape, _, _ = jax.eval_shape( + metrics_and_grad, loop_cnt=0, dropout_rng=dropout_rng) + + initial_metrics = { + k: metrics_lib.shape_obj_to_defined_obj(v) + for k, v in initial_metrics_shape.items() + } + loop_init = (dropout_rng, grad_accum_init, initial_metrics, + initial_flax_mutables) + new_dropout_rng, grad_accum, metrics, flax_mutables = jax.lax.fori_loop( + 0, num_microbatches, per_microbatch_train_step, loop_init) + + del new_dropout_rng + + return grad_accum, metrics, flax_mutables + + +def apply_grads( + train_state: train_state_lib.TrainState, + grad_accum: ModelWeights, + metrics: MutableMetricMapType, + learning_rate: jnp.ndarray, + weight_metrics_computer: Optional[WeightMetricsComputer], + other_state_variables: Optional[Mapping[str, Any]] = None +) -> Tuple[train_state_lib.TrainState, MetricMapType]: + """Applies gradients to the optimizer. + + Args: + train_state: A train state that contains model and optimizer params. + grad_accum: results of `accumulate_grads` call. + metrics: incremental metrics from `accumulate_grads` call. + learning_rate: the learning rate to use for this step. + weight_metrics_computer: A WeightMetricsComputer instance, or None, to + decide what metrics, if any, to log about weights and weight updates + during training. + other_state_variables: other variables to update the state with. + + Returns: + The updated train state, metrics. + """ + if other_state_variables is None: + other_state_variables = {} + # Update optimizer using accumulated gradient. + new_train_state = train_state.apply_gradient( + grad_accum, learning_rate=learning_rate, **other_state_variables) + metrics["learning_rate"] = clu.metrics.Average.from_model_output( + jnp.asarray([learning_rate])) + metrics["learning_rate/current"] = clu.metrics.LastValue.from_model_output( + jnp.asarray([learning_rate])) + if weight_metrics_computer is not None: + metrics.update( + weight_metrics_computer.compute_metrics(grad_accum, train_state, + new_train_state)) + return new_train_state, metrics + + +def eval_step(model: models.BaseModel, train_state: train_state_lib.TrainState, + batch: jnp.ndarray) -> MetricMapType: + """Default evaluation step.""" + _, metrics = model.eval_fn(train_state.params, batch) + return metrics + + +def train_with_lr( + train_state: train_state_lib.TrainState, + batch: BatchType, + learning_rate: jnp.ndarray, + dropout_rng: Rng, + model: models.BaseModel, + num_microbatches: Optional[int], + weight_metrics_computer: Optional[WeightMetricsComputer] = None, + data_partition_spec: PartitionSpec = PartitionSpec("data")): + """Main training function with LR schedule.""" + grad_accum, metrics, flax_mutables = ( + accumulate_grads_microbatched(model, train_state, batch, dropout_rng, + num_microbatches, data_partition_spec)) + new_train_state, metrics = apply_grads( + train_state, + grad_accum, + metrics, + learning_rate, + weight_metrics_computer, + other_state_variables={"flax_mutables": flax_mutables} + if flax_mutables else None) + + return new_train_state, metrics + + +class Trainer(BaseTrainer): + """Training loop with optional microbatches.""" + + def __init__(self, + model: models.BaseModel, + train_state: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, + eval_names: Sequence[str], + summary_dir: Optional[str], + train_state_axes: Any, + rng: Rng, + learning_rate_fn: LearningRateCallable, + num_microbatches: Optional[int], + weight_metrics_computer: Optional[WeightMetricsComputer] = None): + """Trainer constructor. + + Args: + model: the instantiation of `BaseModel` to train. + train_state: a train state with parameters and optimizer state. + partitioner: the partitioner to use. + eval_names: names of evaluation datasets, which must match the keys of the + mapping passed to `eval`. + summary_dir: optional directory to write TensorBoard metrics to. + train_state_axes: partitioning info for the optimizer to be used. + rng: jax PRNGKey seed for random operations, to be combined with step + number for a deterministic RNG. + learning_rate_fn: returns the learning rate given the current step. + num_microbatches: the number of microbatches to use, or None for direct + training. + weight_metrics_computer: A WeightMetricsComputer instance, or None, to + decide what metrics, if any, to log about weights and weight updates + during training. + """ + self._learning_rate_fn = learning_rate_fn + self._num_microbatches = num_microbatches + self._weight_metrics_computer = weight_metrics_computer + + super().__init__( + model=model, + train_state=train_state, + partitioner=partitioner, + eval_names=eval_names, + summary_dir=summary_dir, + train_state_axes=train_state_axes, + rng=rng) + + @cached_property + def _partitioned_train_step(self) -> PartitionedTrainCallable: + + def train_step(train_state: train_state_lib.TrainState, batch: BatchType): + return train_with_lr( + train_state, + batch, + learning_rate=self._learning_rate_fn(train_state.step), + dropout_rng=self._get_step_rng(train_state.step), + model=self._model, + num_microbatches=self._num_microbatches, + weight_metrics_computer=self._weight_metrics_computer, + data_partition_spec=self._partitioner.data_partition_spec) + + return self._partitioner.partition( + train_step, + in_axis_resources=(self._train_state_axes, + self._partitioner.data_partition_spec), + out_axis_resources=(self._train_state_axes, None), + donate_argnums=(0,)) + + @cached_property + def _partitioned_eval_step(self) -> PartitionedEvalCallable: + return self._partitioner.partition( + lambda *args, **kwargs: eval_step(self._model, *args, **kwargs), + in_axis_resources=(self._train_state_axes, + self._partitioner.data_partition_spec), + out_axis_resources=None) + + +def _warn_action_not_run(action, task, metric): + logging.warning( + "The action: %s that tracks metric: %s for task: %s is not run", action, + metric, task) + + +# TODO(b/200701930): Support dynamic registration for enum. +@enum.unique +class ActionMode(enum.Enum): + """Defines when to run a action. + + For example, TRAIN means to run an action after a TRAIN loop is done. + """ + TRAIN = 1 + TRAIN_EVAL = 2 + INFER_EVAL = 3 + + +class BaseAction(abc.ABC): + """Base Action class for override. The action itself does nothing.""" + + @abc.abstractmethod + def run(self, train_state: train_state_lib.TrainState, + metrics_by_task: Mapping[str, MetricValueMapType]) -> bool: + """Runs an action for the given train_state and metrics. + + Args: + train_state: The current train_state in the training loop. + metrics_by_task: A map of metrics that is grouped by each task. + + Returns: + A bool indicating whether training should be halted. + """ + raise NotImplementedError("Action must define its run method.") + + +ActionMapType = Mapping[ActionMode, Sequence[BaseAction]] + + +class EarlyStoppingAction(BaseAction): + """Terminates training when the specified metric is not improving. + + Checks whether the monitored metrics are decreasing after every `train` or + `eval`, or `both`. If the loss is no longer decreasing for `patience` times, + terminate the training process. + """ + + def __init__(self, + metric: Tuple[str, str], + mode: str, + patience: int = 3, + atol: float = 0., + rtol: float = 0.): + """Constructs the EarlyStoppingAction. + + Args: + metric: A metric to monitor when invoking the action. When the metric does + not improve for a number of times (specified in patience), stop the + training. The tuple takes 2 strings, whereas the first string defines + the task to track, and the second defines the metric of the task to + track. e.g.,: ('mt5_xnli_dev_test.all_langs', 'accuracy') would monitor + the 'accuracy' for `mt5_xnli_dev_test.all_langs`. + mode: One of `{"min", "max"}`. In `min` mode, training will stop when the + quantity monitored has stopped decreasing; in `"max"` mode it will stop + when the quantity monitored has stopped increasing; + patience: The threshold of stopping criteria. Usually this is measured by + number of steps. + atol: Absolute tolerance in the monitored quantity to qualify as an + improvement, i.e. a change of less than `atol`, will count as no + improvement. + rtol: Relative tolerance in the monitoried quantity to qualify as an + improvement. This combined with `atol` defines whether a change is + considered improvement. The total change is calculated as following: + `delta = (atol + rtol * previous)` See `numpy.allclose` for detailed + information. + """ + self._task, self._metric = metric + if mode not in ["min", "max"]: + raise ValueError('mode must be in ["min", "max"]') + self._mode = mode + + if atol < 0: + raise ValueError("atol must be greater equal than 0") + self._atol = atol + + if rtol < 0: + raise ValueError("rtol must be greater equal than 0") + self._rtol = rtol + + self._patience = patience + self._metric_history = [] + + def _compare_fn(self, current, previous): + compare_fn = jnp.greater_equal if self._mode == "min" else jnp.less_equal + delta = self._atol + self._rtol * abs(previous) + if self._mode == "max": + delta *= -1 + return compare_fn(current, previous - delta) + + def run(self, train_state: train_state_lib.TrainState, + metrics_by_task: Mapping[str, MetricValueMapType]) -> bool: + if self._task not in metrics_by_task.keys(): + logging.warning( + "Monitoring task: %s does not exist in all task metrics. " + "Available tasks are : %s", self._task, metrics_by_task.keys()) + _warn_action_not_run(type(self), self._task, self._metric) + return False + if self._metric not in metrics_by_task[self._task].keys(): + logging.warning("Metric : %s does not exist in metrics for task : %s", + self._metric, self._task) + _warn_action_not_run(type(self), self._task, self._metric) + return False + + m = metrics_by_task[self._task][self._metric] + + if not isinstance(m, clu.values.Scalar): + logging.warning("Metric %s does not have Scalar type. Found %s.", + self._metric, type(m)) + _warn_action_not_run(type(self), self._task, self._metric) + return False + + self._metric_history.append(m.value) + + # Not enough history. + if len(self._metric_history) < self._patience: + return False + + if all( + self._compare_fn(self._metric_history[i + 1], self._metric_history[i]) + for i in range(self._patience - 1)): + logging.warning( + "Requested `stop_training` in training loop (Details below).\n " + "Metric: %s for Task: %s has not improved for %s iterations, detail " + "history of the metric: %s", self._metric, self._task, self._patience, + self._metric_history) + return True + # Remove extra histories that we don't need to keep. + self._metric_history.pop(0) + return False + + +class TerminateOnNanAction(BaseAction): + """Terminates training when NaN loss is detected. + + Checks whether the loss metric for the given task is NaN or Inf and terminates + training if so. + """ + + def __init__(self, task: str, metric: str = "loss"): + """Constructs the TerminateOnNanAction. + + Args: + task: Defines the task from which to track the given metric. + metric: Defines a metric to track for NaN values (defaults to "loss"). + """ + self._task = task + self._metric = metric + + def run(self, train_state: train_state_lib.TrainState, + metrics_by_task: Mapping[str, MetricValueMapType]) -> bool: + if self._task not in metrics_by_task.keys(): + logging.warning( + "Monitoring task: %s does not exist in all task metrics. " + "Available tasks are : %s", self._task, metrics_by_task.keys()) + _warn_action_not_run(type(self), self._task, self._metric) + return False + if self._metric not in metrics_by_task[self._task].keys(): + logging.warning("Metric : %s does not exist in metrics for task : %s", + self._metric, self._task) + _warn_action_not_run(type(self), self._task, self._metric) + return False + + metric = metrics_by_task[self._task][self._metric] + + if not isinstance(metric, clu.values.Scalar): + logging.warning("Metric %s does not have Scalar type. Found %s.", + self._metric, type(metric)) + _warn_action_not_run(type(self), self._task, self._metric) + return False + + value = metric.value + if np.isnan(value) or np.isinf(value): + logging.warning( + "Requested `stop_training` in training loop (Details below).\n " + "NaN encountered in metric for task : %s", self._task) + return True + + return False diff --git a/t5x/trainer_test.py b/t5x/trainer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..26912e4425b5fac23b2d6341b05890cf564e2c11 --- /dev/null +++ b/t5x/trainer_test.py @@ -0,0 +1,983 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.trainer_lib.""" +import collections +import contextlib +import os + +from absl.testing import absltest +from absl.testing import parameterized +import chex +from clu import metric_writers +import clu.metrics +import clu.values +import flax +import jax +import jax.numpy as jnp +import numpy as np +from t5x import metrics as metrics_lib +from t5x import models as models_lib +from t5x import optimizers +from t5x import partitioning +from t5x import test_utils +from t5x import train_state as train_state_lib +from t5x import trainer as trainer_lib +import tensorflow as tf +from tensorflow.io import gfile + +mock = absltest.mock +jax.config.parse_flags_with_absl() + + +# Make `log_elapsed_time` a no-op to simplify mocking of `time.time()`. +@contextlib.contextmanager +def fake_log_elapsed_time(_): + yield + + +jax._src.dispatch.log_elapsed_time = fake_log_elapsed_time + + +def _validate_events(test_case, summary_dir, expected_metrics, steps): + summaries = gfile.listdir(summary_dir) + test_case.assertLen(summaries, 1) + summary_path = os.path.join(summary_dir, summaries[0]) + event_file = os.path.join(summary_path) + events = list(tf.compat.v1.train.summary_iterator(event_file)) + actual_events = {} + # First event is boilerplate + test_case.assertLen(events, len(steps) + 1) + for step, event in zip(steps, events[1:]): + test_case.assertEqual(event.step, step) + test_case.assertLen(event.summary.value, 1) + tensor = event.summary.value[0].tensor + if tensor.string_val: + actual_events[event.summary.value[0].tag] = tensor.string_val[0].decode() + else: + actual_events[event.summary.value[0].tag] = float(tf.make_ndarray(tensor)) + + jax.tree_multimap(test_case.assertAlmostEqual, actual_events, + expected_metrics) + + +class MetricsManagerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.model_dir = self.create_tempdir().full_path + + def test_summary_dir(self): + # All hosts have the summary dir. + with mock.patch('jax.process_index', return_value=0): + mm = trainer_lib.MetricsManager('eval', self.model_dir) + self.assertEqual(mm.summary_dir, os.path.join(self.model_dir, 'eval')) + mm.close() + + with mock.patch('jax.process_index', return_value=1): + mm = trainer_lib.MetricsManager('eval', self.model_dir) + self.assertEqual(mm.summary_dir, os.path.join(self.model_dir, 'eval')) + mm.close() + + def test_summary_writer(self): + # Only host 0 creates a non-empty summary writer. + with mock.patch('jax.process_index', return_value=1): + mm = trainer_lib.MetricsManager('eval', self.model_dir) + self.assertFalse(gfile.exists(mm.summary_dir)) + mm.close() + + with mock.patch('jax.process_index', return_value=0): + mm = trainer_lib.MetricsManager('eval', self.model_dir) + self.assertIsInstance(mm.summary_writer, metric_writers.MetricWriter) + self.assertTrue(gfile.exists(mm.summary_dir)) + mm.close() + + def test_write_scalar(self): + gfile.makedirs(os.path.join(self.model_dir, 'eval')) + + # tag, value, step + scalars = [('loss', 1.0, 1), ('accuracy', 100.0, 2)] + + # Only host 0 has actually writes summaries. + with mock.patch('jax.process_index', return_value=1): + mm = trainer_lib.MetricsManager('eval', self.model_dir) + for s in scalars: + mm.write_scalar(*s) + self.assertEmpty(gfile.listdir(mm.summary_dir)) + mm.close() + + with mock.patch('jax.process_index', return_value=0): + mm = trainer_lib.MetricsManager('eval', self.model_dir) + for s in scalars: + mm.write_scalar(*s) + mm.flush() + + summaries = gfile.listdir(mm.summary_dir) + self.assertLen(summaries, 1) + + event_file = os.path.join(mm.summary_dir, summaries[0]) + events = list(tf.compat.v1.train.summary_iterator(event_file)) + # First event is boilerplate + self.assertLen(events, 3) + for event, (tag, value, step) in zip(events[1:], scalars): + self.assertEqual(event.step, step) + self.assertLen(event.summary.value, 1) + self.assertEqual(event.summary.value[0].tag, tag) + self.assertEqual(tf.make_ndarray(event.summary.value[0].tensor), value) + mm.close() + + def test_write_metrics_summary(self): + gfile.makedirs(os.path.join(self.model_dir, 'eval')) + + @flax.struct.dataclass + class MockTextMetric(clu.metrics.Metric): + + def compute_value(self): + return clu.values.Text('test metric') + + accumulated_metrics = { + 'loss': metrics_lib.Sum(40.0), + 'accuracy': metrics_lib.AveragePerStep.from_model_output(20.0), + 'steps_per_second': metrics_lib.StepsPerTime(), + 'text': MockTextMetric() + } + expected_values = { + 'loss': clu.values.Scalar(40.0), + 'accuracy': clu.values.Scalar(10.0), + 'steps_per_second': clu.values.Scalar(0.05), + 'text': clu.values.Text('test metric') + } + with mock.patch( + 'jax.process_index', return_value=0), mock.patch( + 'time.time', + side_effect=[0, 40] # start_time, end_time + ), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() + mm = trainer_lib.MetricsManager('eval', summary_dir=self.model_dir) + mm.start_duration_timer() + summary = mm.write_metrics_summary( + accumulated_metrics, step=4, num_steps=2) + mm.flush() + + self.assertDictEqual(summary.result(), expected_values) + _validate_events( + self, + mm.summary_dir, {k: v.value for k, v in expected_values.items()}, + steps=[4, 4, 4, 4]) + + mm.close() + + def test_timer_blocking_on_donated_buffer(self): + mm = trainer_lib.MetricsManager('train', summary_dir=None) + x = jnp.zeros(1) + + # Not deleted. + mm.start_duration_timer(block_on=x) + mm._duration_timer._start_future.result() + + # Deleted/donated. + x.device_buffer.delete() + mm.start_duration_timer(block_on=x) + mm._duration_timer._start_future.result() + + def test_timer_concurrency(self): + mm = trainer_lib.MetricsManager('train') + + n = 10 + with mock.patch( + 'time.time', + side_effect=range(2 * n) # start_time, end_time + ), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() + for _ in range(n): + mm.start_duration_timer() + summary = mm.write_metrics_summary({'time': metrics_lib.Time()}, 0, 1) + self.assertEqual(1, summary.result()['time'].value) + mm.flush() + + +def fake_accum_grads(model, optimizer, batch, rng, num_microbatches, + data_partition_spec): + del model, num_microbatches, rng, data_partition_spec + # Add `i` to each optimzer value. + i = batch['i'].sum() + grad_accum = jax.tree_map(lambda x: i, optimizer) + # Add j to each metric. + j = batch['j'].sum() + metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)} + return grad_accum, metrics, None + + +def fake_apply_grads(optimizer, + grad_accum, + metrics, + learning_rate, + weight_metrics_computer, + other_state_variables=None): + del weight_metrics_computer + del other_state_variables + metrics['learning_rate'] = clu.metrics.Average(learning_rate, count=1) + optimizer = jax.tree_multimap(lambda x, g: x + g, optimizer, grad_accum) + return optimizer, metrics + + +def fake_eval_step(model, optimizer, batch): + del model, optimizer + # Add `i` to each metric. + i = batch['i'].sum() + + return {'loss': metrics_lib.Sum(i), 'accuracy': metrics_lib.Sum(i)} + + +def fake_eval_fn_without_weight_sum(params, batch): + del params + # Add `i` to each metric. + i = batch['i'].sum() + + loss = metrics_lib.Sum(i) + return loss, {'loss': loss, 'accuracy': metrics_lib.Sum(i)} + + +def fake_value_and_grad_fn_without_weight_sum(callable_fn, has_aux=False): + del callable_fn, has_aux + + def fake_grad_fn_without_weight_sum(train_state_params, + batch, + dropout_rng, + flax_mutables=None): + del dropout_rng, train_state_params, flax_mutables + # Add `i` to each optimzer value. + i = batch['i'].sum() + optimizer = optimizers.Optimizer( + optimizers.sgd(0.1), + state=optimizers.OptimizerState( + step=0, param_states={ + 'bias': 0, + 'kernel': 0 + }), + target={ + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + }) + train_state = train_state_lib.FlaxOptimTrainState(optimizer) + grad_accum = jax.tree_map(lambda x: i, train_state) + # Add j to each metric. + j = batch['j'].sum() + metrics = {'loss': metrics_lib.Sum(j), 'accuracy': metrics_lib.Sum(j)} + return (None, metrics), grad_accum.params + + return fake_grad_fn_without_weight_sum + + +class TrainerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.init_optimizer = optimizers.Optimizer( + optimizers.sgd(0.1), + state=optimizers.OptimizerState( + step=0, param_states={ + 'bias': 0, + 'kernel': 0 + }), + target={ + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + }) + self.init_train_state = train_state_lib.FlaxOptimTrainState( + self.init_optimizer) + train_state_axes = jax.tree_map(lambda x: None, self.init_train_state) + model_dir = self.create_tempdir().full_path + + mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} + self.dataset = tf.data.Dataset.range(6).map(mapfn).batch( + 2, drop_remainder=True) + + self.test_trainer = trainer_lib.Trainer( + mock.create_autospec(models_lib.BaseModel, instance=True), + self.init_train_state, + partitioning.PjitPartitioner(num_partitions=1), + eval_names=['task1', 'task2'], + summary_dir=model_dir, + train_state_axes=train_state_axes, + rng=np.ones(2, np.uint32), + learning_rate_fn=lambda step: 2 * step, + num_microbatches=None) + + def tearDown(self) -> None: + self.test_trainer.close() + return super().tearDown() + + @mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_accum_grads) + @mock.patch('t5x.trainer.apply_grads', fake_apply_grads) + def _test_train(self, precompile): + trainer = self.test_trainer + initial_rng = trainer._base_rng + + if precompile: + with mock.patch( + 'time.time', + side_effect=[0, 1] # compile start, end + ), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() + trainer.compile_train(next(self.dataset.as_numpy_iterator())) + trainer._compiled_train_step = mock.Mock( + side_effect=trainer._compiled_train_step) + + trainer._partitioned_train_step = mock.Mock( + side_effect=trainer._partitioned_train_step) + + num_steps = 2 + with mock.patch( + 'time.time', + side_effect=[1, 5] # start_time, end_time + ), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() + trainer.train(self.dataset.as_numpy_iterator(), num_steps).result() + + initial_metrics = { + 'loss': 0., + 'accuracy': 0., + } + expected_metrics = { + k: (v + 2 * num_steps) for k, v in initial_metrics.items() + } + # (0 + 2) / 2 = 1 + expected_metrics['learning_rate'] = 1 + # 0+1+2+3 = 6 + expected_train_state = jax.tree_map(lambda x: np.array(x + 6), + self.init_train_state) + + # Base rng must remain the same + np.testing.assert_array_equal(trainer._base_rng, initial_rng) + jax.tree_multimap(np.testing.assert_equal, trainer.train_state, + expected_train_state) + # Expected step is 6 since we increment it along with the other optimizer + # values. + steps = [2, 2, 2] + if precompile: + steps = [0] + steps + expected_metrics['timing/compilation_seconds'] = 1 + self.assertEqual(trainer._compiled_train_step.call_count, num_steps) + trainer._partitioned_train_step.assert_not_called() + else: + self.assertIsNone(trainer._compiled_train_step) + self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) + trainer.train_metrics_manager.flush() + _validate_events( + self, + trainer.train_metrics_manager.summary_dir, + expected_metrics, + steps=steps) + + def test_train_noprecompile(self): + self._test_train(False) + + def test_train_precompile(self): + self._test_train(True) + + @mock.patch('t5x.trainer.eval_step', fake_eval_step) + def _test_eval(self, precompile): + trainer = self.test_trainer + initial_rng = trainer._base_rng + + task_datasets = { + 'task1': self.dataset.take(2), + 'task2': self.dataset.repeat().take(5) + } + + if precompile: + # [task1 start, task1 end, task2 start, task2 end] + with mock.patch( + 'time.time', + side_effect=[0, 1, 2, 3] # [t1 start, t1 end, t2 start, t2 end] + ), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() + trainer.compile_eval({ + task: next(ds.as_numpy_iterator()) + for task, ds in task_datasets.items() + }) + trainer._compiled_eval_steps = { + task: mock.Mock(side_effect=trainer._compiled_eval_steps[task]) + for task in task_datasets + } + + trainer._partitioned_eval_step = mock.Mock( + side_effect=trainer._partitioned_eval_step) + + with mock.patch( + 'time.time', + side_effect=[1, 5, 5, 8] # t1 start, t1 end, t2 start, t2 end] + ), mock.patch('absl.logging.log'): # avoids hidden calls to time.time() + trainer.eval( + {task: ds.as_numpy_iterator() for task, ds in task_datasets.items()}) + + all_expected_metrics = { + # 0+1+2+3 = 6 + 'task1': { + 'loss': 6, + 'accuracy': 6, + }, + # 0+1+2+3+4+5+0+1+2+3 = 21 + 'task2': { + 'loss': 21, + 'accuracy': 21, + }, + } + + np.testing.assert_array_equal(trainer._base_rng, initial_rng) + for task_name, expected_metrics in all_expected_metrics.items(): + steps = [0, 0] + if precompile: + steps = [0] + steps + expected_metrics['timing/compilation_seconds'] = 1 + self.assertEqual( # pylint:disable=g-generic-assert + trainer._compiled_eval_steps[task_name].call_count, + len(task_datasets[task_name])) + trainer._partitioned_eval_step.assert_not_called() + else: + self.assertEmpty(trainer._compiled_eval_steps) + self.assertEqual(trainer._partitioned_eval_step.call_count, + sum(len(ds) for ds in task_datasets.values())) + mm = trainer.eval_metrics_managers[task_name] + mm.flush() + _validate_events(self, mm.summary_dir, expected_metrics, steps=steps) + + def test_eval_noprecompile(self): + self._test_eval(False) + + def test_eval_precompile(self): + self._test_eval(True) + + @parameterized.named_parameters([ + { + 'testcase_name': 'max_no_increase', + 'mode': 'max', + 'metrics': [1, 1, 1], + 'atol': 0., + 'rtol': 0., + 'stop_training': True, + }, + { + 'testcase_name': 'max_no_atol', + 'mode': 'max', + 'metrics': [1, 0.9, 0.8], + 'atol': 0., + 'rtol': 0., + 'stop_training': True, + }, + { + 'testcase_name': 'max_not_enough_atol', + 'mode': 'max', + 'metrics': [1, 1.09, 1.18], + 'atol': 0.1, + 'rtol': 0., + 'stop_training': True, + }, + { + 'testcase_name': 'max_enough_atol', + 'mode': 'max', + 'metrics': [1, 1.2, 1.4], + 'atol': 0.1, + 'rtol': 0., + 'stop_training': False, + }, + { + 'testcase_name': 'max_enough_atol_rtol', + 'mode': 'max', + # first delta = 0.1 + 1* 0.08 = 0.18 + # second delta = 0.1 + 1.2 * 0.08 = 0.196 + 'metrics': [1, 1.2, 1.4], + 'atol': 0.1, + 'rtol': 0.08, + 'stop_training': False, + }, + { + 'testcase_name': 'max_not_enough_rtol', + 'mode': 'max', + 'metrics': [1, 1.2, 1.4], + 'atol': 0., + 'rtol': 0.2, + 'stop_training': True, + }, + { + 'testcase_name': 'min_no_decrease', + 'mode': 'min', + 'metrics': [1, 1, 1], + 'atol': 0., + 'rtol': 0., + 'stop_training': True, + }, + { + 'testcase_name': 'min_no_atol', + 'mode': 'min', + 'metrics': [1, 1, 1], + 'atol': 0., + 'rtol': 0., + 'stop_training': True, + }, + { + 'testcase_name': 'min_not_enough_atol', + 'mode': 'min', + 'metrics': [1, 0.9, 0.71], + 'atol': 0.2, + 'rtol': 0., + 'stop_training': True, + }, + { + 'testcase_name': 'min_enough_atol', + 'mode': 'min', + 'metrics': [1, 0.8, 0.6], + 'atol': 0.15, + 'rtol': 0., + 'stop_training': False, + }, + { + 'testcase_name': 'min_enough_atol_rtol', + 'mode': 'min', + # first delta = 0.1 + 1* 0.09 = 0.19 + # second delta = 0.1 + 0.8 * 0.09 = 0.172 + 'metrics': [1, 0.8, 0.6], + 'atol': 0.1, + 'rtol': 0.09, + 'stop_training': False, + }, + { + 'testcase_name': 'min_not_enough_rtol', + 'mode': 'min', + 'metrics': [1, 0.8, 0.6], + 'atol': 0.0, + 'rtol': 0.3, + 'stop_training': True, + }, + { + 'testcase_name': 'longer_history', + 'mode': 'min', + 'metrics': [1, 0.8, 0.7, 0.6], + 'atol': 0.15, + 'rtol': 0., + 'stop_training': True, + } + ]) + def test_early_stopping_action(self, mode, metrics, atol, rtol, + stop_training): + trainer = self.test_trainer + metrics = [clu.values.Scalar(metric) for metric in metrics] + hook = trainer_lib.EarlyStoppingAction(('test_task', 'metric'), + mode=mode, + patience=3, + atol=atol, + rtol=rtol) + + for metric in metrics: + trainer_stop_training = hook.run(trainer.train_state, + {'test_task': { + 'metric': metric + }}) + + self.assertEqual(trainer_stop_training, stop_training) + + @parameterized.named_parameters([ + { + 'testcase_name': 'invalid_task', + 'task': 'wrong_task', + 'metric': 'metric', + 'value': clu.values.Scalar(np.nan), + }, + { + 'testcase_name': 'invalid_metric_name', + 'task': 'task', + 'metric': 'wrong_metric_name', + 'value': clu.values.Scalar(np.nan), + }, + { + 'testcase_name': 'invalid_value', + 'task': 'task', + 'metric': 'metric', + 'value': 1.0, + }, + ]) + def test_early_stopping_action_error(self, task, metric, value): + trainer = self.test_trainer + hook = trainer_lib.EarlyStoppingAction((task, metric), + mode='min', + patience=5, + atol=1, + rtol=1) + + trainer_stop_training = hook.run(trainer.train_state, + {task: { + metric: value + }}) + + self.assertFalse(trainer_stop_training) + + @parameterized.named_parameters([{ + 'testcase_name': 'valid_loss', + 'metric': 'loss', + 'value': 1.0, + 'stop_training': False, + }, { + 'testcase_name': 'nan', + 'metric': 'loss', + 'value': np.nan, + 'stop_training': True, + }, { + 'testcase_name': 'inf', + 'metric': 'loss', + 'value': np.inf, + 'stop_training': True, + }, { + 'testcase_name': 'other_metric', + 'metric': 'some_metric', + 'value': np.inf, + 'stop_training': True, + }]) + def test_terminate_on_nan_action(self, metric, value, stop_training): + trainer = self.test_trainer + value = clu.values.Scalar(value) + hook = trainer_lib.TerminateOnNanAction(task='test_task', metric=metric) + + trainer_stop_training = hook.run(trainer.train_state, + {'test_task': { + metric: value + }}) + + self.assertEqual(trainer_stop_training, stop_training) + + @parameterized.named_parameters([ + { + 'testcase_name': 'invalid_task', + 'task': 'wrong_task', + 'metric': 'metric', + 'value': clu.values.Scalar(np.nan), + }, + { + 'testcase_name': 'invalid_metric_name', + 'task': 'task', + 'metric': 'wrong_metric_name', + 'value': clu.values.Scalar(np.nan), + }, + { + 'testcase_name': 'invalid_value', + 'task': 'task', + 'metric': 'metric', + 'value': 1.0, + }, + ]) + def test_terminate_on_nan_action_error(self, task, metric, value): + trainer = self.test_trainer + hook = trainer_lib.TerminateOnNanAction(task=task, metric=metric) + + trainer_stop_training = hook.run(trainer.train_state, + {'task': { + 'metric': value + }}) + + self.assertFalse(trainer_stop_training) + + def test_compile_train(self): + trainer = self.test_trainer + trainer._partitioned_train_step = mock.Mock() + trainer.train_metrics_manager = mock.Mock() + + batch = { + 'i': np.arange(10, dtype=np.int32).reshape((2, 5)), + 'j': np.ones((), dtype=np.float32) + } + # compile start, compile end + with mock.patch('time.time', side_effect=[1, 5]): + trainer.compile_train(batch) + + trainer.train_metrics_manager.write_scalar.assert_called_with( + 'timing/compilation_seconds', 4, trainer.train_state.step) + trainer._partitioned_train_step.lower.assert_called_once() + train_step_args = trainer._partitioned_train_step.lower.call_args[0] + self.assertLen(train_step_args, 2) + self.assertEqual(train_step_args[0], trainer.train_state) + test_utils.assert_same(train_step_args[1], batch) + + def test_compile_eval(self): + trainer = self.test_trainer + trainer._partitioned_eval_step = mock.Mock() + trainer.eval_metrics_managers = { + 'eval1': mock.Mock(), + 'eval2': mock.Mock(), + 'eval3': mock.Mock(), + 'eval4': mock.Mock() + } + trainer._partitioned_eval_step.lower().compile.side_effect = [ + 'compiled1', 'compiled2', 'compiled3' + ] + + batches = { + 'eval1': { + 'i': np.zeros((2, 5), dtype=np.int32) + }, + 'eval2': { + 'j': np.zeros((), dtype=np.float32) + }, + 'eval3': { + 'j': np.zeros((), dtype=np.float32) + }, + 'eval4': { + 'k': np.zeros((4), dtype=np.float32) + }, + } + + # eval1 start/end, eval2 start/end, eval3 start/end, eval 4 start/end + with mock.patch('time.time', side_effect=[1, 5, 6, 9, 10, 11, 12, 13]): + trainer.compile_eval(collections.OrderedDict(sorted(batches.items()))) + + trainer.eval_metrics_managers['eval1'].write_scalar.assert_called_with( + 'timing/compilation_seconds', 4, trainer.train_state.step) + trainer.eval_metrics_managers['eval2'].write_scalar.assert_called_with( + 'timing/compilation_seconds', 3, trainer.train_state.step) + trainer.eval_metrics_managers['eval3'].write_scalar.assert_called_with( + 'timing/compilation_seconds', 1, trainer.train_state.step) + trainer.eval_metrics_managers['eval4'].write_scalar.assert_called_with( + 'timing/compilation_seconds', 1, trainer.train_state.step) + eval_step_args = trainer._partitioned_eval_step.lower.call_args_list[1:] + self.assertLen(eval_step_args, 3) + + eval1_call_args = eval_step_args[0][0] + self.assertLen(eval1_call_args, 2) + self.assertEqual(eval1_call_args[0], trainer.train_state) + test_utils.assert_same(eval1_call_args[1], { + 'i': np.zeros((2, 5), dtype=np.int32), + }) + + eval2_call_args = eval_step_args[1][0] + self.assertLen(eval2_call_args, 2) + self.assertEqual(eval2_call_args[0], trainer.train_state) + test_utils.assert_same(eval2_call_args[1], { + 'j': np.zeros((), dtype=np.float32), + }) + + eval3_call_args = eval_step_args[2][0] + self.assertLen(eval3_call_args, 2) + self.assertEqual(eval3_call_args[0], trainer.train_state) + test_utils.assert_same(eval3_call_args[1], { + 'k': np.zeros((4), dtype=np.float32), + }) + + self.assertDictEqual( + trainer._compiled_eval_steps, { + 'eval1': 'compiled1', + 'eval2': 'compiled2', + 'eval3': 'compiled2', + 'eval4': 'compiled3' + }) + + @mock.patch('jax.value_and_grad', fake_value_and_grad_fn_without_weight_sum) + def test_accumulate_grads_microbatched_without_weight_sum_single_batch(self): + batch_iter = self.dataset.as_numpy_iterator() + batch = next(batch_iter) + num_microbatches = 1 + grad_accum, metrics, flax_mutables = trainer_lib.accumulate_grads_microbatched( + self.test_trainer._model, self.init_train_state, batch, + self.test_trainer._base_rng, num_microbatches) + + i = batch['i'].sum() + expected_grad_accum = jax.tree_map(lambda x: i, + self.init_train_state).params + self.assertEqual(expected_grad_accum, grad_accum) + self.assertEqual(metrics['loss'].compute(), 2) + self.assertEqual(metrics['accuracy'].compute(), 2) + self.assertIsNone(flax_mutables) + + @mock.patch('jax.value_and_grad', fake_value_and_grad_fn_without_weight_sum) + def test_accumulate_grads_microbatched_without_weight_sum_multiple_batches( + self): + batch_iter = self.dataset.as_numpy_iterator() + batch = next(batch_iter) + num_micro_batches = 2 + grad_accum, metrics, flax_mutables = trainer_lib.accumulate_grads_microbatched( + self.test_trainer._model, self.init_train_state, batch, + self.test_trainer._base_rng, num_micro_batches) + + expected_grad_accum = {'bias': jnp.ones(4), 'kernel': jnp.ones((2, 4))} + chex.assert_trees_all_equal(expected_grad_accum, grad_accum) + self.assertEqual(metrics['loss'].compute(), 2) + self.assertEqual(metrics['accuracy'].compute(), 2) + self.assertIsNone(flax_mutables) + + def test_eval_step_without_weight_sum(self): + batch_iter = self.dataset.as_numpy_iterator() + batch = next(batch_iter) + self.test_trainer._model.eval_fn = fake_eval_fn_without_weight_sum + metrics = trainer_lib.eval_step(self.test_trainer._model, + self.init_train_state, batch) + + self.assertEqual(metrics['loss'].compute(), 1) + self.assertEqual(metrics['accuracy'].compute(), 1) + + +class TrainerRngDeterminismTest(parameterized.TestCase): + + def create_trainer(self, step, random_seed): + init_optimizer = optimizers.Optimizer( + optimizers.sgd(0.1), + state=optimizers.OptimizerState( + step=step, param_states={ + 'bias': 0, + 'kernel': 0 + }), + target={ + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + }) + init_train_state = train_state_lib.FlaxOptimTrainState(init_optimizer) + train_state_axes = jax.tree_map(lambda x: None, init_train_state) + + test_trainer = trainer_lib.Trainer( + mock.create_autospec(models_lib.BaseModel, instance=True), + init_train_state, + partitioning.PjitPartitioner(num_partitions=1), + eval_names=['task1', 'task2'], + summary_dir=None, + train_state_axes=train_state_axes, + rng=jax.random.PRNGKey(random_seed), + learning_rate_fn=lambda step: 2 * step, + num_microbatches=None) + return test_trainer + + @mock.patch('t5x.trainer.accumulate_grads_microbatched') + @mock.patch('t5x.trainer.apply_grads', fake_apply_grads) + def test_rng_determinism(self, mock_accum_grads): + + def fake_accum_grads_rng(model, optimizer, batch, rng, num_microbatches, + data_partition_spec): + del model, batch, num_microbatches, data_partition_spec + # Add 1, which will increment the step as a side effect. + grad_accum = jax.tree_map(lambda x: 1, optimizer) + m = {'rng': metrics_lib.Sum(jnp.sum(rng))} + return grad_accum, m, None + + mock_accum_grads.side_effect = fake_accum_grads_rng + # Create a trainer at a given step (53) with a given random seed (23), + # train up to a given train step (100), check the sum of the rngs from the + # metrics. + start_step = 47 + end_step = 100 + random_seed = 23 + trainer = self.create_trainer(step=start_step, random_seed=random_seed) + # 500 batches of size 2 + ds = [np.zeros(2)] * 500 + + metrics = trainer.train(iter(ds), num_steps=end_step - start_step) + base_rng = jax.random.PRNGKey(random_seed) + expected_rng_sum = np.sum( + [jax.random.fold_in(base_rng, i) for i in range(start_step, end_step)], + dtype=np.uint32) + np.testing.assert_array_equal(metrics.result()['rng'].value, + expected_rng_sum) + + +def fake_mut_accum_grads(model, optimizer, batch, rng, num_microbatches, + data_partition_spec): + del model, num_microbatches, rng, data_partition_spec + # Add `i` to each optimzer value. + i = batch['i'].sum() + grad_accum = jax.tree_map(lambda x: i, optimizer) + # Add j to each metric. + j = batch['j'].sum() + metrics = { + 'loss': metrics_lib.Sum.from_model_output(j), + 'accuracy': metrics_lib.Sum.from_model_output(j) + } + return grad_accum, metrics, {'mutables': 0} + + +def fake_mut_apply_grads(optimizer, grad_accum, metrics, learning_rate, + weight_metrics_computer, other_state_variables): + del weight_metrics_computer, other_state_variables + metrics['learning_rate'] = clu.metrics.Average.from_model_output( + learning_rate) + optimizer = jax.tree_multimap(lambda x, g: x + g, optimizer, grad_accum) + return optimizer, metrics + + +class MutableTrainerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.init_optimizer = optimizers.Optimizer( + optimizers.sgd(0.1), + state=optimizers.OptimizerState( + step=0, param_states={ + 'bias': 0, + 'kernel': 0 + }), + target={ + 'bias': np.zeros(4), + 'kernel': np.zeros((2, 4)) + }) + self.init_train_state = train_state_lib.FlaxOptimTrainState( + self.init_optimizer) + train_state_axes = jax.tree_map(lambda x: None, self.init_train_state) + model_dir = self.create_tempdir().full_path + + mapfn = lambda i: {'i': [tf.cast(i, tf.int32)], 'j': [tf.cast(1, tf.int32)]} + self.dataset = tf.data.Dataset.range(6).map(mapfn).batch( + 2, drop_remainder=True) + self.dataset1 = tf.data.Dataset.range(6).map(mapfn).batch( + 2, drop_remainder=True) + + self.test_trainer = trainer_lib.Trainer( + mock.create_autospec(models_lib.BaseModel, instance=True), + self.init_train_state, + partitioning.PjitPartitioner(num_partitions=1), + eval_names=['task1', 'task2'], + summary_dir=model_dir, + train_state_axes=train_state_axes, + rng=np.ones(2, np.uint32), + learning_rate_fn=lambda step: 2 * (step + 1), + num_microbatches=None) + + @mock.patch('time.time') + @mock.patch('t5x.trainer.accumulate_grads_microbatched', fake_mut_accum_grads) + @mock.patch('t5x.trainer.apply_grads', fake_mut_apply_grads) + # avoids calls time.time() during logging + @mock.patch('absl.logging.info', lambda *_: None) + @mock.patch('absl.logging.log_every_n_seconds', lambda *_: None) + def test_train(self, mock_time=None): + trainer = self.test_trainer + initial_rng = trainer._base_rng + + trainer._partitioned_train_step = mock.Mock( + side_effect=trainer._partitioned_train_step) + + # train start, logging, train end, logging + mock_time.side_effect = [1, 5, 5, 5] + num_steps = 1 + ds_iter = self.dataset.as_numpy_iterator() + batch = next(ds_iter) + train_state, _ = trainer._partitioned_train_step(trainer.train_state, batch) + + expected_train_state = jax.tree_map(lambda x: np.array(x + 1), + self.init_train_state) + # Base rng must remain the same + np.testing.assert_array_equal(trainer._base_rng, initial_rng) + jax.tree_multimap(np.testing.assert_equal, train_state, + expected_train_state) + + self.assertIsNone(trainer._compiled_train_step) + self.assertEqual(trainer._partitioned_train_step.call_count, num_steps) + + def tearDown(self) -> None: + # Manually close managers to avoid phantom threads crossing test cases. + self.test_trainer.train_metrics_manager.close() + for mm in self.test_trainer.eval_metrics_managers.values(): + mm.close() + return super().tearDown() + + +if __name__ == '__main__': + absltest.main() diff --git a/t5x/utils.py b/t5x/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b69c7e6d1e69d1d898247a827472e2ea588839b --- /dev/null +++ b/t5x/utils.py @@ -0,0 +1,1380 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General utility functions for t5x.""" +import collections.abc +from concurrent.futures import thread +import contextlib +import dataclasses +import functools +import importlib +import inspect +import os +import re +import time +import typing +from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union +import warnings + +from absl import logging +import clu.data +from flax import traverse_util +import flax.core +from flax.core import scope as flax_scope +from flax.linen import partitioning as flax_partitioning +import jax +from jax import prng +from jax import pxla +from jax.experimental import multihost_utils +from jax.experimental.global_device_array import GlobalDeviceArray +import jax.numpy as jnp +import numpy as np +import orbax.checkpoint +import seqio +from t5x import checkpoints +from t5x import optimizers +from t5x import partitioning +from t5x import state_utils +from t5x import train_state as train_state_lib +import tensorflow as tf +from tensorflow.io import gfile +import typing_extensions + + +Array = Union[np.ndarray, jnp.ndarray, jax.pxla.ShardedDeviceArray, tf.Tensor] +PyTreeDef = type(jax.tree_structure(None)) +PartitionSpec = partitioning.PartitionSpec +DType = Union[np.dtype, type(jnp.bfloat16)] +Shape = Tuple[int, ...] + + +# TODO(adarob): Remove namespace mapping after client gin files are updated. +TensorBoardLogger = seqio.TensorBoardLogger + +# ----------------------------------------------------------------------------- +# Configurations +# ----------------------------------------------------------------------------- + + +@dataclasses.dataclass +class SaveCheckpointConfig: + """Configuration for saving model checkpoints.""" + # The dtype to save ('float32' or 'bfloat16'). + dtype: str = 'float32' + # Number of steps between writing checkpoints. + period: Optional[int] = None + # Number of most recent checkpoints to keep, or None to keep them all. + keep: Optional[int] = None + # Number of dataset checkpoints to keep, or None to keep them all. + # Note: Dataset checkpoints are also affected by `keep`. + keep_dataset_checkpoints: Optional[int] = None + # Whether to save dataset checkpoints. + save_dataset: bool = False + # The checkpointer class to use. + checkpointer_cls: checkpoints.CheckpointerConstructor = checkpoints.Checkpointer + # Transformations to apply, in order, to the state before writing. + state_transformation_fns: Sequence[checkpoints.SaveStateTransformationFn] = ( + dataclasses.field(default_factory=list)) + + def __post_init__(self): + if self.dtype not in ('float32', 'bfloat16'): + raise ValueError( + "`SaveCheckpointConfig.dtype` must be one of 'float32' or " + f"'bfloat16'. Got {self.dtype}.") + + +@dataclasses.dataclass +class RestoreCheckpointConfig: + """Configuration for restoring model from checkpoint.""" + # Path(s) to checkpoint to restore from or directory (depending on `mode`). + path: Union[str, Sequence[str]] + # One of 'specific', 'latest', or 'all'. + # specific: load the checkpoint specified by `path`. + # latest: load most recent checkpoint in the directory specified by `path`. + # all: sequentially load all of checkpoints in the directory `path`. + mode: str = 'latest' + # An optional sequence of (pattern, replacement) regex pairs. The pattern + # matches parameters in the model and the replacement matches the checkpoint + # (after substitutions). The replacement may be None, in which case the + # parameter can be dropped. Use `fallback_to_scratch` to fill them in with + # newly initialized values. + assignment_map: Optional[Sequence[Tuple[str, Optional[str]]]] = None + # Whether to restore all optimizer parameters from the checkpoint. + strict: bool = True + # Whether to initialize parameters that are in the model being restored but + # are missing from the checkpoint (after `assignment_map` is applied). + fallback_to_scratch: bool = False + # The dtype to restore ('float32' or 'bfloat16'), or None to load as saved. + dtype: Optional[str] = None + # Whether to restore the dataset checkpoint. Fails if checkpoint not present. + restore_dataset: bool = False + # The checkpointer class to use. + checkpointer_cls: checkpoints.CheckpointerConstructor = checkpoints.Checkpointer + # Transformations to apply, in order, to the state after reading. These will + # be applied after the `assignment_map` transformations. + state_transformation_fns: Sequence[ + checkpoints.RestoreStateTransformationFn] = () + + def __post_init__(self): + if self.mode not in ('specific', 'latest', 'all'): + raise ValueError( + "`RestoreCheckpointConfig.mode` must be one of 'specific', 'latest', " + f"or 'all'. Got {self.mode}.") + if self.dtype not in (None, 'float32', 'bfloat16'): + raise ValueError( + "`RestoreCheckpointConfig.dtype` must be one of `None`, 'float32', " + f"or 'bfloat16'. Got {self.dtype}.") + if self.assignment_map is not None: + # Turns `assignment_map` into a transformation function. + assignment_map_fn = functools.partial( + state_utils.apply_assignment_map, assignment_map=self.assignment_map) + # Prepends the `assignment_map` transformation to the front of the list. + self.state_transformation_fns = (assignment_map_fn, + *self.state_transformation_fns) + + +@dataclasses.dataclass +class CheckpointConfig: + """Configuration for checkpointing of model and dataset.""" + save: Optional[SaveCheckpointConfig] = None + restore: Optional[RestoreCheckpointConfig] = None + + +class LegacyCheckpointer(orbax.checkpoint.Checkpointer): + """Implementation of Checkpointer interface for T5X. + + Relies on underlying save_checkpointer and restore_checkpointer, which are + t5x.checkpoints.Checkpointer objects. + """ + + def __init__(self, + save_checkpointer: checkpoints.Checkpointer, + restore_checkpointer: checkpoints.Checkpointer, + *, + strict: Optional[bool] = False): + self._save_checkpointer = save_checkpointer + self._restore_checkpointer = restore_checkpointer + self._strict = strict + + async def async_save(self, path: str, item: Any): + raise NotImplementedError + + async def async_restore(self, path: str, item: Optional[Any] = None) -> Any: + raise NotImplementedError + + def save(self, + path: str, + item: train_state_lib.TrainState, + state_transformation_fns: Sequence[ + checkpoints.SaveStateTransformationFn] = (), + *, + concurrent_gb: int = 128): + """Performs save operation using save_checkpointer. + + Args: + path: path to save item to. + item: a TrainState PyTree to save. + state_transformation_fns: Transformations to apply, in order, to the state + before writing. + concurrent_gb: the approximate number of gigabytes of partitionable + parameters to process in parallel. Useful to preserve RAM. + """ + train_state = item + del path # stored in save_checkpointer + # dataset_iterator is also saved, but is provided in checkpointer init + self._save_checkpointer.save( + train_state, state_transformation_fns, concurrent_gb=concurrent_gb) + + def restore(self, + path: str, + item: Optional[train_state_lib.TrainState], + state_transformation_fns: Sequence[ + checkpoints.RestoreStateTransformationFn] = (), + fallback_state: Optional[Mapping[str, Any]] = None, + lazy_parameters: bool = False) -> train_state_lib.TrainState: + """Performs restore operation using restore_checkpointer. + + Determines whether the indicated path is a Tensorflow checkpoint. + + Args: + path: the string path to restore from. + item: a TrainState PyTree to restore. Unused. + state_transformation_fns: Transformations to apply, in order, to the state + before writing. + fallback_state: a state dict of an optimizer to fall back to for loading + params that do not exist in the checkpoint (after applying all + `state_transformation_fns`), but do exist in `Checkpointer.optimizer`. + The union of `fallback_state` and state loaded from the checkpoint must + match `Checkpointer.optimizer`. + lazy_parameters: whether to load the parameters as LazyArrays to preserve + memory. + + Returns: + The restored train state. + """ + del item # not needed for restore in T5X + from_tensorflow = gfile.exists(path + '.index') + if from_tensorflow and state_transformation_fns: + raise ValueError('Cannot initialize from a TensorFlow checkpoint using ' + '`state_transformation_fns`.') + if from_tensorflow: + logging.info('Initializing parameters from TensorFlow checkpoint %s', + path) + return self._restore_checkpointer.restore_from_tf_checkpoint( + path, strict=self._strict) + return self._restore_checkpointer.restore( + path=path, + state_transformation_fns=state_transformation_fns, + fallback_state=fallback_state, + lazy_parameters=lazy_parameters) + + +class LegacyCheckpointManager(orbax.checkpoint.CheckpointManager): + """Implementation of CheckpointManager interface for T5X. + + Uses underlying LegacyCheckpointer to handle save/restore for Dataset and + TrainState. + """ + + def __init__(self, + save_cfg: SaveCheckpointConfig, + restore_cfg: RestoreCheckpointConfig, + train_state_shape: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, + ds_iter: Optional[tf.data.Iterator] = None, + model_dir: Optional[str] = None, + use_gda: Optional[bool] = False): + if save_cfg.save_dataset: + assert ds_iter is not None + save_checkpointer = save_cfg.checkpointer_cls( + train_state=train_state_shape, + partitioner=partitioner, + checkpoints_dir=model_dir, + dataset_iterator=ds_iter if save_cfg.save_dataset else None, + save_dtype=save_cfg.dtype, + keep=save_cfg.keep, + use_gda=use_gda, + keep_dataset_checkpoints=save_cfg.keep_dataset_checkpoints) + + if restore_cfg: + restore_checkpointer = restore_cfg.checkpointer_cls( + train_state=train_state_shape, + partitioner=partitioner, + checkpoints_dir='', # unused for restore + dataset_iterator=ds_iter if restore_cfg.restore_dataset else None, + restore_dtype=jnp.dtype(restore_cfg.dtype) + if restore_cfg.dtype else None) + strict = restore_cfg.strict + else: + restore_checkpointer = None + strict = False + + self._checkpointer = LegacyCheckpointer( + save_checkpointer, restore_checkpointer, strict=strict) + + def save(self, + train_state: train_state_lib.TrainState, + state_transformation_fns: Sequence[ + checkpoints.SaveStateTransformationFn] = ()): + """Performs save operation. + + Args: + train_state: a TrainState PyTree to save. + state_transformation_fns: Transformations to apply, in order, to the state + before writing. + """ + self._checkpointer.save( + path='', # not used + item=train_state, + state_transformation_fns=state_transformation_fns) + + def restore( + self, + paths: Sequence[str], + restore_cfg: RestoreCheckpointConfig, + fallback_state: Optional[Mapping[str, Any]] = None + ) -> Union[train_state_lib.TrainState, Sequence[train_state_lib.TrainState]]: + """Performs restore operation using restore_checkpointer. + + Determines whether the indicated path is a Tensorflow checkpoint. + + Args: + paths: A sequence of paths to restore from. + restore_cfg: RestoreCheckpointConfig specifying restoration information. + fallback_state: a state dict of an optimizer to fall back to for loading + params that do not exist in the checkpoint (after applying all + `state_transformation_fns`), but do exist in `Checkpointer.optimizer`. + The union of `fallback_state` and state loaded from the checkpoint must + match `Checkpointer.optimizer`. + + Returns: + The restored TrainState if only one TrainState can be restored from the + given paths, otherwise a sequence of TrainStates. + """ + if restore_cfg is None or paths is None: + return None + + restored = [] + for path in paths: + logging.info('Initializing parameters from specific T5X checkpoint %s', + path) + restored.append( + self._checkpointer.restore( + path=path, + item=None, # not used + state_transformation_fns=restore_cfg.state_transformation_fns, + fallback_state=fallback_state)) + + if len(restored) == 1: + restored = restored[0] + return restored + + +@dataclasses.dataclass +class DatasetConfig: + """Configuration for loading a dataset from a SeqIO Task or Mixture.""" + mixture_or_task_name: str + task_feature_lengths: Mapping[str, int] + split: str + batch_size: int + shuffle: bool + seed: Optional[int] + # Whether to use a precomputed version of the dataset from a cache dir. + use_cached: bool = False + pack: bool = False + # Whether to use tensor2tensor custom ops for more efficient packing. + use_custom_packing_ops: bool = False + # An optional module to import for registering the referenced Mixture or Task. + # DEPRECATED. + module: Optional[str] = None + # Whether to cache the dataset in memory (only applies to evaluation data). + use_memory_cache: bool = True + + +#------------------------------------------------------------------------------ +# Fast *nondeterministic* hardware RNG for faster Dropout +#------------------------------------------------------------------------------ +def _hardware_uniform( + rng_key: Array, + shape: Shape, + dtype: jnp.dtype = np.float32, + minval: Array = np.float32(0), + maxval: Array = np.float32(1) +) -> Array: + """Random uniform method that uses non-deterministic accelerator hardware.""" + del rng_key # non-deterministic prng. + minval = jax.lax.convert_element_type(minval, dtype) + maxval = jax.lax.convert_element_type(maxval, dtype) + return jax.lax.rng_uniform(minval, maxval, shape) + + +# For dropout-only hardware rng. +def _hardware_bernoulli( + rng_key: Array, p: np.ndarray = np.float32(0.5), + shape: Shape = ()) -> Array: + del rng_key # non-deterministic prng. + return jax.lax.rng_uniform(0.0, 1.0, shape) < p + + +def set_hardware_rng_ops(): + """Enable JAX Custom PRNG extension.""" + jax.config.update('jax_enable_custom_prng', True) + # Use only fast TPU hardware PRNG with iterated-hash "split" substitute. + # Expected to be deterministic for a fixed partitioning. + # Monkey-patch JAX PRNGKey to use unsafe_rbg_prng_impl + # TODO(levskaya): replace with jax global config option once we debug it. + rbg_prng_key = functools.partial(prng.seed_with_impl, + prng.unsafe_rbg_prng_impl) + jax.random.PRNGKey = rbg_prng_key + jax._src.random.PRNGKey = rbg_prng_key # pylint: disable=protected-access + + +# ----------------------------------------------------------------------------- +# Training utility functions. +# ----------------------------------------------------------------------------- + + +def get_zeros_batch_like_spec( + batch_spec: Mapping[str, + jax.ShapeDtypeStruct]) -> Mapping[str, jnp.ndarray]: + return {k: jnp.zeros(t.shape, t.dtype) for k, t in batch_spec.items()} + + +def get_zeros_batch_like_dataset(dataset: tf.data.Dataset, + batch_size=None) -> Mapping[str, jnp.ndarray]: + reshape = lambda s: (batch_size,) + s[1:] if batch_size else tuple(s) + batch_spec = { + k: jax.ShapeDtypeStruct(reshape(t.shape), t.dtype.as_numpy_dtype) + for k, t in dataset.element_spec.items() + } + return get_zeros_batch_like_spec(batch_spec) + + +class InitFnCallable(typing_extensions.Protocol): + """A callable that initializes model variables.""" + + def __call__( + self, rng: Array, input_shapes: Mapping[str, Array], + input_types: Optional[Mapping[str, + DType]]) -> flax_scope.FrozenVariableDict: + ... + + +class LearningRateCallable(typing_extensions.Protocol): + + def __call__(self, step: jnp.ndarray) -> jnp.ndarray: + ... + + +def create_learning_rate_scheduler( + factors: str = 'constant * linear_warmup * rsqrt_decay', + base_learning_rate: float = 0.5, + warmup_steps: int = 1000, + decay_factor: float = 0.5, + steps_per_decay: int = 20000, + steps_per_cycle: int = 100000, + step_offset: int = 0, + min_learning_rate: float = 1e-8) -> LearningRateCallable: + """Creates learning rate schedule. + + Interprets factors in the factors string which can consist of: + * constant: interpreted as the constant value, + * linear_warmup: interpreted as linear warmup until warmup_steps, + * linear_decay: linear decay from warmup_steps with decay_factor slope. Note + this option implies 'constant * linear_warmup', and should not be used in + in conjunction with `constant` or `linear_warmup` factors. + * rsqrt_decay: divide by square root of max(step, warmup_steps) + * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) + * decay_every: Every k steps decay the learning rate by decay_factor. + * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. + + Args: + factors: string, factors separated by '*' that defines the schedule. + base_learning_rate: float, the starting constant for the lr schedule. + warmup_steps: int, how many steps to warm up for in the warmup schedule. + decay_factor: float, the amount to decay the learning rate by. + steps_per_decay: int, how often to decay the learning rate. + steps_per_cycle: int, steps per cycle when using cosine decay. + step_offset: int, an offset that the step parameters to this function are + relative to. + min_learning_rate: float, minimum learning rate to output. Useful for cases + when a decay function is (mis)configured to decay to non-positive values. + + Returns: + a function learning_rate(step): float -> {'learning_rate': float}, the + step-dependent lr. + """ + factors = [n.strip() for n in factors.split('*')] + + def step_fn(step: jnp.ndarray) -> jnp.ndarray: + """Step to learning rate function.""" + step = jnp.maximum(0, step - step_offset) + ret = 1.0 + for name in factors: + if name == 'constant': + ret *= base_learning_rate + elif name == 'linear_warmup': + ret *= jnp.minimum(1.0, step / warmup_steps) + elif name == 'linear_decay': + ret *= base_learning_rate * jnp.minimum( + step / warmup_steps, 1.0 + decay_factor * (warmup_steps - step)) + elif name == 'rsqrt_decay': + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == 'rsqrt_normalized_decay': + ret *= jnp.sqrt(warmup_steps) + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == 'decay_every': + ret *= (decay_factor**(step // steps_per_decay)) + elif name == 'cosine_decay': + progress = jnp.maximum(0.0, + (step - warmup_steps) / float(steps_per_cycle)) + ret *= jnp.maximum(0.0, + 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) + else: + raise ValueError('Unknown factor %s.' % name) + ret = jnp.maximum(ret, min_learning_rate) + return jnp.asarray(ret, dtype=jnp.float32) + + return step_fn + + +def get_first_valid_restore_config_and_paths( + restore_cfgs: Sequence[RestoreCheckpointConfig] +) -> Tuple[Optional[RestoreCheckpointConfig], Sequence[str]]: + """Returns first valid restore_cfg and the paths to restore. + + Args: + restore_cfgs: a sequence of RestoreCheckpointConfig objects, which should be + filtered to determine the first valid object. + + Returns: + Tuple of valid RestoreCheckpointConfig and a sequence of paths. + If the first config encountered has mode 'specfic', it is immediately + returned, along with its specified paths. + If the mode is 'all' or 'latest', checks to ensure that there are valid + checkpoints at each of the provided paths and filters the returned paths + accordingly. + """ + for restore_cfg in restore_cfgs: + paths = ([restore_cfg.path] + if isinstance(restore_cfg.path, str) else restore_cfg.path) + if restore_cfg.mode == 'specific': + return restore_cfg, paths + elif restore_cfg.mode in ('all', 'latest'): + for ckpt_dir in paths: + if not gfile.isdir(ckpt_dir): + raise ValueError( + 'Checkpoint path(s) must be valid directories when using ' + "restore mode 'all' or 'latest'.") + # Check if this is a TensorFlow checkpoint dir. + tf_ckpt_state = tf.train.get_checkpoint_state(ckpt_dir) + + if tf_ckpt_state: + ckpt_paths = tf_ckpt_state.all_model_checkpoint_paths + else: + ckpt_paths = [ + os.path.join(ckpt_dir, f'checkpoint_{step}') + for step in checkpoints.all_steps(ckpt_dir) + ] + if not ckpt_paths: + logging.info('No checkpoints found in specified directory: %s', + ckpt_dir) + continue + if restore_cfg.mode == 'latest': + logging.info('Using latest T5X checkpoint.') + ckpt_paths = ckpt_paths[-1:] + return restore_cfg, ckpt_paths + else: + logging.error('Unsupported checkpoint restore mode: %s', restore_cfg.mode) + return None, [] + + +def get_fallback_state(restore_cfg: RestoreCheckpointConfig, + init_fn: Callable[[jnp.ndarray], Mapping[str, Any]], + init_rng: jnp.ndarray) -> Optional[Mapping[str, Any]]: + """Returns the fallback_state that can be used in restore().""" + if restore_cfg is None: + return + if restore_cfg.fallback_to_scratch: + if not restore_cfg.state_transformation_fns: + raise ValueError('`state_transformation_fns` must be provided with ' + '`fallback_to_scratch`') + if init_rng is None: + raise ValueError('An `init_rng` must be provided with ' + '`fallback_to_scratch`') + fallback_state = init_fn(init_rng) + else: + fallback_state = None + return fallback_state + + +class TrainStateInitializer: + """Helper for initializing partitioned TrainState from checkpoints or scratch. + + Common use cases: + + * To restore from a single checkpoint, use `from_checkpoint`. + * To iterate over multiple checkpoints without recompiling the model, + use `from_checkpoints`. + * To initialize from scratch, use `from_scratch`. + * To restore from a checkpoint with a fallback to initializing from scratch, + use `from_checkpoint_or_scratch`. + + Attributes: + global_train_state_shape: a TrainState containing the global (unpartitioned) + shape (in `jax.ShapeDtypeStruct`) of each parameter instead of its value. + train_state_axes: a TrainState object containing a PartitionSpec (or None) + for each parameter, in place of the parameter itself. + """ + + # TODO(adarob): Replace input_shapes and input_types with sample batch. + def __init__(self, + optimizer_def: Optional[optimizers.OptimizerDefType], + init_fn: InitFnCallable, + input_shapes: Mapping[str, Array], + partitioner: partitioning.BasePartitioner, + input_types: Optional[Mapping[str, DType]] = None): + """TrainStateInitializer constructor. + + Args: + optimizer_def: Optimizer def to be initialized, or None to create a + `InferenceState` without an optimizer. + init_fn: callable that initializes model variables from a PRNGKey and the + input shapes. + input_shapes: a mapping from key to array shape for each feature in the + global (unsharded) input batch. + partitioner: the partitioner to use. + input_types: a mapping from key to array type for each feature in the + global (unshared) input batch. If not provided, the type is assumed to + be `jnp.float32`. + """ + + def initialize_train_state(rng: Array): + initial_variables = init_fn( + rng=rng, input_shapes=input_shapes, input_types=input_types) + if optimizer_def: + return train_state_lib.FlaxOptimTrainState.create( + optimizer_def, initial_variables) + return train_state_lib.InferenceState.create(initial_variables) + + self._partitioner = partitioner + self.global_train_state_shape = jax.eval_shape( + initialize_train_state, rng=jax.random.PRNGKey(0)) + self.train_state_axes = partitioner.get_mesh_axes( + self.global_train_state_shape) + self._initialize_train_state = initialize_train_state + + # Currently scanned layers require passing annotations through to the + # point of the scan transformation to resolve an XLA SPMD issue. + + # init_fn is always(?) equal to model.get_initial_variables, fetch the model + # instance from the bound method. + model = init_fn.__self__ # pytype: disable=attribute-error + if (hasattr(model, 'module') and hasattr(model.module, 'scan_layers') and + model.module.scan_layers): + if hasattr(model.module, 'spmd_annotations'): + # update top-level module with spmd annotations. + model.module = model.module.clone( + parent=None, spmd_annotations=self.train_state_axes.params) + + def from_scratch(self, init_rng: Array) -> train_state_lib.TrainState: + """Initializes the partitioned Optimizer from scratch.""" + logging.info('Initializing parameters from scratch.') + + # If pretraining and no checkpoint imported, we jit the (sharded-) init + # function to minimize fragmentation. We use the same partition + # setup as the training step/loop to initialize everything "in-place" and + # avoid communication or OOM. + p_initialize_train_state_fn = self._partitioner.partition( + self._initialize_train_state, + in_axis_resources=None, + out_axis_resources=self.train_state_axes) + return p_initialize_train_state_fn(init_rng) + + # TODO(b/216650048) deprecate this function and use orbax. + def from_checkpoints( + self, + restore_cfgs: Sequence[RestoreCheckpointConfig], + ds_iter: Optional[tf.data.Iterator] = None, + init_rng: Optional[jnp.ndarray] = None, + ) -> Iterable[train_state_lib.TrainState]: + """Yields 0 or more restored partitioned Optimizers, and maybe datasets. + + The manner in which parameters are initialized depends on `restore_cfgs` and + `restore_cfgs` is iterated over and the first config that matches one or + more existing checkpoints is used to generate restored optimizers from the + checkpoint(s). Any remaining configs are ignored. + + Args: + restore_cfgs: ordered sequence of configurations specifying checkpoint(s) + to restore from. The first config to match a checkpoint will be used. + ds_iter: a tf.data.Iterator for the input data, or None. If provided, the + referenced iterator's state may be silently restored (depending on the + config's `restore_dataset` value) along with the optimizer. + init_rng: for initializing parameters from scratch when they are not + available in the checkpoint and `fallback_to_scratch` is True + + Yields: + TrainState with initialized optimizer, with parameters copied to devices. + """ + + def _restore_path(path, cfg): + restore_checkpointer = cfg.checkpointer_cls( + train_state=self.global_train_state_shape, + partitioner=self._partitioner, + checkpoints_dir='', # unused for restore + dataset_iterator=ds_iter if cfg.restore_dataset else None, + restore_dtype=jnp.dtype(cfg.dtype) if cfg.dtype else None) + + from_tensorflow = gfile.exists(path + '.index') + if from_tensorflow and cfg.state_transformation_fns: + raise ValueError('Cannot initialize from a TensorFlow checkpoint using ' + '`state_transformation_fns`.') + if from_tensorflow: + logging.info('Initializing parameters from TensorFlow checkpoint %s', + path) + return restore_checkpointer.restore_from_tf_checkpoint( + path, strict=cfg.strict) + + else: + fallback_state = get_fallback_state( + cfg, lambda rng: self.from_scratch(rng).state_dict(), init_rng) + + logging.info('Initializing parameters from specific T5X checkpoint %s', + path) + return restore_checkpointer.restore( + path=path, + state_transformation_fns=cfg.state_transformation_fns, + fallback_state=fallback_state) + + restore_cfg, paths = get_first_valid_restore_config_and_paths(restore_cfgs) + for path in paths: + yield _restore_path(path, restore_cfg) + + def from_checkpoint( + self, + ckpt_cfgs: Sequence[RestoreCheckpointConfig], + *, + ds_iter: Optional[tf.data.Iterator] = None, + init_rng: Optional[jnp.ndarray] = None + ) -> Optional[train_state_lib.TrainState]: + """Restores (at most) 1 checkpoint using `from_checkpoints`, or dies.""" + train_states = list( + self.from_checkpoints(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng)) + if len(train_states) > 1: + raise ValueError( + f'Expected at most 1 checkpoint but got {len(train_states)} for ' + f'config(s): {ckpt_cfgs}') + return (train_states[0]) if train_states else None + + def from_checkpoint_or_scratch( + self, + ckpt_cfgs: Sequence[RestoreCheckpointConfig], + *, + init_rng: Array, + ds_iter: Optional[tf.data.Iterator] = None) -> train_state_lib.TrainState: + """Initializes from checkpoint, if found, or from scratch.""" + return (self.from_checkpoint(ckpt_cfgs, ds_iter=ds_iter, init_rng=init_rng) + or self.from_scratch(init_rng)) + + +# ----------------------------------------------------------------------------- +# Logging utility functions +# ----------------------------------------------------------------------------- + + +def log_model_info(log_file: Optional[str], + full_train_state: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner): + """Log the variable shapes information and optionally write it to a file.""" + # Only write logs on host 0. + if jax.process_index() != 0: + return + + state_dict = full_train_state.state_dict() + total_num_params = jax.tree_util.tree_reduce( + np.add, jax.tree_map(np.size, state_dict['target'])) + + logical_axes = partitioner.get_logical_axes(full_train_state).state_dict() + + mesh_axes = jax.tree_map( + lambda x: tuple(x) if x is not None else None, + partitioner.get_mesh_axes(full_train_state).state_dict()) + + def _log_info_and_write_to_file(writer, format_str, *args): + logging.info(format_str, *args) + if writer is not None: + writer.write(format_str % args + '\n') + + with contextlib.ExitStack() as stack: + writer = stack.enter_context(gfile.GFile( + log_file, 'w')) if log_file is not None else None + + # Log params + def _log_variable(name: str, arr: Optional[np.ndarray], + logical_axes: Optional[partitioning.AxisNames], + mesh_axes: Optional[partitioning.PartitionSpec]): + # Log nothing on empty dict leaves, which occur with optax EmptyState(). + if isinstance(arr, dict) and not arr: + return + if arr is None: + _log_info_and_write_to_file(writer, 'Variable %-80s None', name) + return + if logical_axes is None or len(logical_axes) != len(arr.shape): + shape_str = str(arr.shape) + else: + shape_str = '({})'.format(', '.join( + f'{name}={dimension}' + for name, dimension in zip(logical_axes, arr.shape))) + _log_info_and_write_to_file( + writer, 'Variable %-80s size %-12s shape %-40s partition spec %s', + name, arr.size, shape_str, mesh_axes) + + jax.tree_map( + _log_variable, + state_utils.get_name_tree(state_dict['target'], keep_empty_nodes=True), + state_dict['target'], logical_axes['target'], mesh_axes['target']) + + _log_info_and_write_to_file(writer, 'Total number of parameters: %d', + total_num_params) + + # Add a blank line between params and states. + _log_info_and_write_to_file(writer, '') + + jax.tree_map( + _log_variable, + state_utils.get_name_tree(state_dict['state'], keep_empty_nodes=True), + state_dict['state'], logical_axes['state'], mesh_axes['state']) + + +# ----------------------------------------------------------------------------- +# Utility functions for prediction and evaluation. +# ----------------------------------------------------------------------------- + + +class InferStepWithRngCallable(typing_extensions.Protocol): + + def __call__(self, + params: Mapping[str, Any], + batch: Mapping[str, jnp.ndarray], + rng: jnp.ndarray = None) -> PyTreeDef: + """Runs an inference step returning a prediction or score.""" + ... + + +class InferStepWithoutRngCallable(typing_extensions.Protocol): + + def __call__(self, params: Mapping[str, Any], + batch: Mapping[str, jnp.ndarray]) -> PyTreeDef: + """Runs an inference step returning a prediction or score.""" + ... + + +InferStepCallable = Union[InferStepWithRngCallable, InferStepWithoutRngCallable] + +# NOTE: We're not more prescriptive than PyTreeDef because that's what +# InferStepCallable expects. +_InferFnResult = Sequence[Tuple[int, PyTreeDef]] +_InferFnWithAuxResult = Tuple[_InferFnResult, Mapping[str, Sequence[Any]]] + + +class InferFnCallable(typing_extensions.Protocol): + + def __call__( + self, + ds: tf.data.Dataset, + train_state: train_state_lib.TrainState, + rng: Optional[jnp.ndarray] = None + ) -> Union[_InferFnResult, _InferFnWithAuxResult]: + """Runs inference on the dataset.""" + ... + + +def _remove_padding(all_inferences, all_indices): + """Remove padded examples. + + Args: + all_inferences: PyTree[total_examples + padding_count, ...]. + all_indices: [total_examples + padding_count]. + + Returns: + all_inferences in shape PyTree[total_examples, ...]. + all_indices in shape [total_exmamples]. + """ + non_pad_idxs = np.where(all_indices >= 0) + all_indices = all_indices[non_pad_idxs] + all_inferences = jax.tree_map(lambda x: x[non_pad_idxs], all_inferences) + return all_inferences, all_indices + + +def get_infer_fn(infer_step: InferStepCallable, batch_size: int, + train_state_axes: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner) -> InferFnCallable: + """Get prediction function for the SeqIO evaluator. + + The returned prediction function should take in an enumerated dataset, make + predictions and return in an enumerated form with the original indices and + examples zipped together. This ensures that the predictions are compared to + the targets in a correct order even if the dataset is sharded across + multiple hosts and gathered in a nondeterministic way. + + jax.process_index == 0 is used as a "main host", i.e., it gathers all + inference results and returns. + + Shape notation: + Per replica set num replicas: R + Per replica set batch size: B + Number of replica sets: H + Length: L + + Some transformations have shape transformation annotation, e.g., + [B, L] -> [R, B/R, L]. + + Args: + infer_step: a callable that executes one prediction step. Should not yet be + partitioned or pmapped. + batch_size: the global infer batch size. + train_state_axes: Partitioning info for the train state object. + partitioner: partitioner to use. + + Returns: + predict_fn: a callable which takes in the enumerated infer dataset and an + optimizer and runs the prediction. + """ + + def infer_step_with_indices(params, batch, rng, indices): + if 'rng' in inspect.signature(infer_step).parameters: + res = typing.cast(InferStepWithRngCallable, infer_step)(params, batch, + rng) + else: + res = typing.cast(InferStepWithoutRngCallable, infer_step)(params, batch) + return indices, res + + partitioned_infer_step = partitioner.partition( + infer_step_with_indices, + in_axis_resources=(train_state_axes.params, + partitioner.data_partition_spec, None, + partitioner.data_partition_spec), + out_axis_resources=(None, None)) + + data_layout = partitioner.get_data_layout(batch_size) + shard_id = data_layout.shard_id + num_shards = data_layout.num_shards + + per_shard_batch_size = batch_size // num_shards + + def infer_fn(ds: tf.data.Dataset, + train_state: train_state_lib.TrainState, + rng: Optional[jnp.ndarray] = None): + ds_shapes = jax.tree_map(lambda x: jnp.array(x.shape), ds.element_spec) + multihost_utils.assert_equal( + ds_shapes, 'Dataset element shapes do not agree across hosts. ' + 'This could be an indication that the dataset is nondeterministic.') + try: + original_ds_length = len(ds) + dataset_remainder = original_ds_length % batch_size # pytype:disable=wrong-arg-types + logging.info('length of dataset = %s', len(ds)) + except TypeError as e: + if str(e) == 'dataset length is unknown.': + logging.warning( + 'The following error is likely due to the use of TensorFlow v1 in ' + 'your dataset pipeline. Verify you are not importing from ' + '`tf.compat.v1` as part of your pipeline.') + raise e + + if dataset_remainder: + dataset_pad_amt = batch_size - dataset_remainder + logging.info( + 'Padding infer dataset with %d examples for even per-replica shards.', + dataset_pad_amt) + # Pad with the first example using an index of -1 so seqio will ignore. + pad_ds = ds.take(1).map(lambda i, x: (np.int64(-1), x)).repeat( + dataset_pad_amt) + ds = ds.concatenate(pad_ds) + + # Shard the infer dataset across replica sets. + sharded_ds = ds.shard(num_shards, shard_id).batch( + per_shard_batch_size, drop_remainder=True) + multihost_utils.assert_equal( + jnp.array(len(sharded_ds)), + 'Dataset lengths do not agree across hosts.') + + logging.info( + 'The infer dataset is sharded into %d shards with per-shard ' + 'batch size of %d', num_shards, per_shard_batch_size) + + # Run inference for each replica set. + batched_results, all_indices = [], [] + for index, infer_batch in sharded_ds.as_numpy_iterator(): + if rng is None: + step_rng = None + else: + step_rng, rng = jax.random.split(rng) + # Run fast inference on batch. + # [B, ...] -> [B * shard_count, ...] + # partitioned_infer_step executes infer_step on sharded batched data, and + # returns de-sharded batched indices and result replicated on all hosts. + batch_indices, batch_result = partitioned_infer_step( + train_state.params, infer_batch, step_rng, index) + logging.info('Inference of batch %s done.', index) + + # Issue asynchronous copy request which serves as prefetching to the host. + def _copy_to_host_async(x): + if isinstance(x, GlobalDeviceArray): + x.local_data(0).copy_to_host_async() # GDA is fully replicated + return x.local_data(0) + else: + x.copy_to_host_async() + return x + + try: + batch_result = jax.tree_map(_copy_to_host_async, batch_result) + batch_indices = jax.tree_map(_copy_to_host_async, batch_indices) + except AttributeError: + # Similar to jax.device_get, we skip transfers for non DeviceArrays. + pass + + batched_results.append(batch_result) + all_indices.append(batch_indices) + + logging.info('Inference of all batches done.') + all_inferences = batched_results + + # List[B * shard_count, ...] -> [B * shard_count * batch_count, ...] + all_inferences = jax.tree_multimap(lambda *args: np.concatenate(args), + *all_inferences) + all_indices = np.concatenate(all_indices) + + all_inferences, all_indices = _remove_padding(all_inferences, all_indices) + + # Results are returned from infer_step out of order due to shard operation. + # Note: remove padding first, as -1 indices would mess up this operation. + # Note: all_inferences may be a PyTree, not just an array, e.g. if + # `infer_step` is `model.predict_batch_with_aux`. + all_inferences = jax.tree_map(lambda x: x[all_indices], all_inferences) + all_indices = all_indices[all_indices] + + # aux_values is supposed to be a dictionary that maps strings to a set of + # auxiliary values. + # + # We don't want to flatten/unflatten the aux values. We want to preserve the + # unflattened values with the type List[Mapping[str, Sequence[Any]]]. We do + # this as a memory optimization to avoid lots of redundant keys if we'd + # instead had List[Mapping[str, Any]]. + # + # It has shape Mapping[str, [B * shard_count * batch_count, ...]]. That is, + # the first dimension of each of the values in aux_values is equal to + # len(all_inferences). + aux_values = None + if (isinstance(all_inferences, tuple) and len(all_inferences) == 2 and + isinstance(all_inferences[1], Mapping)): + all_inferences, aux_values = all_inferences + + # Translate to List[...] by flattening inferences making sure to + # preserve structure of individual elements (inferences are not assumed to + # be simple np.array). Finally, zip inferences with corresponding indices + # and convert leaf np.arrays into lists. + all_inferences, struct = jax.tree_flatten(all_inferences) + all_inferences = map( + functools.partial(jax.tree_unflatten, struct), zip(*all_inferences)) + indices_and_outputs = list(zip(all_indices, all_inferences)) + indices_and_outputs = jax.tree_map(lambda x: np.array(x).tolist(), + indices_and_outputs) + if len(indices_and_outputs) != original_ds_length: + raise ValueError( + 'Size of indices_and_outputs does not match length of original ' + 'dataset: %d versus %d' % + (len(indices_and_outputs), original_ds_length)) + + if aux_values is None: + return indices_and_outputs + else: + aux_values = jax.tree_map(lambda x: np.array(x).tolist(), aux_values) + return indices_and_outputs, aux_values + + return infer_fn + + +# ----------------------------------------------------------------------------- +# SeqIO utility functions. +# ----------------------------------------------------------------------------- + + +def import_module(module: str): + """Imports the given module at runtime.""" + logging.info('Importing %s.', module) + try: + importlib.import_module(module) + except RuntimeError as e: + if (str(e) == + 'Attempted to add a new configurable after the config was locked.'): + raise RuntimeError( + 'Your Task/Mixture module contains gin configurables that must be ' + 'loaded before gin flag parsing. One fix is to add ' + f"'import {module}' in your gin file.") + raise e + + +def get_vocabulary( + cfg: DatasetConfig) -> Tuple[seqio.Vocabulary, seqio.Vocabulary]: + """Returns `seqio.Vocabulary` objects associated with the `Mixture`/`Task`. + + Args: + cfg: the DatasetConfig specifying which mixture or task to get the + vocabularies for. + + Returns: + A tuple of seqio.Vocabulary for inputs and targets. + + Raises: + ValueError: if inputs and targets are not both present and vocabularies + are different. + """ + if cfg.module: + warnings.warn( + 'The use of `DatasetConfig.module` and `MIXTURE_OR_TASK_MODULE` is ' + 'deprecated in favor of importing the module directly or via gin.', + DeprecationWarning) + import_module(cfg.module) + + provider = seqio.get_mixture_or_task(cfg.mixture_or_task_name) + features = provider.output_features + + if 'inputs' in features and 'targets' in features: + return (features['inputs'].vocabulary, features['targets'].vocabulary) + + # If a mix of PassThroughVocabularies and other Vocabularies are specified, + # use the non-PassThroughVocabularies. + # TODO(b/185912004): Remove this once a more general solution is implemented. + vocabularies = list( + f.vocabulary + for f in features.values() + if not isinstance(f.vocabulary, seqio.PassThroughVocabulary)) + + # Otherwise, if all of the vocabs are PassThroughVocabularies, use those. + if not vocabularies: + vocabularies = list(f.vocabulary for f in features.values()) + + # If there still aren't any vocabularies, raise an error. + if not vocabularies: + raise ValueError('"inputs" and "targets" are not both present, and ' + 'no vocabularies were set for any features.') + + first_vocab = vocabularies[0] + for vocab in vocabularies[1:]: + if vocab != first_vocab: + raise ValueError('"inputs" and "targets" are not both present, and ' + 'vocabularies are different.') + return (first_vocab, first_vocab) + + + + +def get_dataset(cfg: DatasetConfig, + shard_id: int, + num_shards: int, + feature_converter_cls: Type[seqio.FeatureConverter], + num_epochs: Optional[int] = None, + continue_from_last_checkpoint: bool = False) -> tf.data.Dataset: + """Returns a dataset from SeqIO based on a `DatasetConfig`.""" + if continue_from_last_checkpoint: + raise ValueError( + '`continue_from_last_checkpoint` must be set to False as this is not ' + 'supported by this dataset fn.') + del continue_from_last_checkpoint + + if cfg.module: + import_module(cfg.module) + + if cfg.batch_size % num_shards: + raise ValueError( + f'Batch size ({cfg.batch_size}) must be divisible by number of ' + f'shards ({num_shards}).') + + + shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards) + + if cfg.seed is None: + # Use a shared timestamp across devices as the seed. + seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) + else: + seed = cfg.seed + + return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed, + num_epochs) + + +def get_dataset_inner(cfg: DatasetConfig, + shard_info: seqio.ShardInfo, + feature_converter_cls: Type[seqio.FeatureConverter], + seed: Optional[int] = None, + num_epochs: Optional[int] = None): + """Internal fn to load a dataset from SeqIO based on a `DatasetConfig`.""" + batch_size = cfg.batch_size // shard_info.num_shards + if seed is not None: + multihost_utils.assert_equal( + np.array(seed), + f'`seed` is not same across hosts; {jax.process_index} has a seed of ' + f'{seed}') + logging.info( + "Initializing dataset for task '%s' with a replica batch size of %d and " + 'a seed of %d', cfg.mixture_or_task_name, batch_size, seed) + + ds = seqio.get_dataset( + mixture_or_task_name=cfg.mixture_or_task_name, + task_feature_lengths=cfg.task_feature_lengths, + dataset_split=cfg.split, + shuffle=cfg.shuffle, + num_epochs=num_epochs, + feature_converter=feature_converter_cls( + pack=cfg.pack, use_custom_packing_ops=cfg.use_custom_packing_ops), # pytype: disable=not-instantiable + shard_info=shard_info, + use_cached=cfg.use_cached, + seed=seed) + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +class GetDatasetCallable(typing_extensions.Protocol): + """Interface for a function returning a dataset (iterator).""" + + def __call__( + self, + cfg: DatasetConfig, + shard_id: int, + num_shards: int, + feature_converter_cls: Callable[..., seqio.FeatureConverter], + num_epochs: Optional[int] = None, + continue_from_last_checkpoint: bool = True + ) -> Union[clu.data.DatasetIterator, tf.data.Dataset]: + ... + + +def get_training_eval_datasets( + cfg: DatasetConfig, + shard_id: int, + num_shards: int, + eval_steps: int, + feature_converter_cls: Callable[..., seqio.FeatureConverter], + get_dataset_fn: GetDatasetCallable = get_dataset, +) -> Mapping[str, tf.data.Dataset]: + """Returns a mapping from eval task name to its dataset.""" + mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) + datasets = {} + + if cfg.batch_size % num_shards: + raise ValueError( + f'Batch size ({cfg.batch_size}) must be divisible by number of ' + f'shards ({num_shards}).') + + def _repeat_shard_batch_take_cache(ds: tf.data.Dataset): + # We shard and batch the full, repeated dataset to avoid issues with uneven + # file shards. + if not isinstance(ds, tf.data.Dataset): + raise ValueError('Only tf.data.Dataset objects supported.') + return ds.unbatch().repeat().shard(num_shards, shard_id).batch( + cfg.batch_size // num_shards, + drop_remainder=True).take(eval_steps).cache() + + for task in seqio.get_subtasks(mixture_or_task): + if cfg.split not in task.splits: + logging.info("Task %s has no '%s' split; skipping training evaluation.", + task.name, cfg.split) + continue + logging.info('Loading task %s for training evaluation.', task.name) + task_cfg = dataclasses.replace( + cfg, mixture_or_task_name=task.name, batch_size=1) + # We set `num_epochs` to be finite to avoid infinite loops on shards that + # have input examples that are all filtered. + datasets[task.name] = _repeat_shard_batch_take_cache( + get_dataset_fn( + task_cfg, + shard_id=0, + num_shards=1, + feature_converter_cls=feature_converter_cls, + num_epochs=eval_steps * cfg.batch_size, + continue_from_last_checkpoint=False)) + + if isinstance(mixture_or_task, seqio.Mixture): + datasets[mixture_or_task.name] = _repeat_shard_batch_take_cache( + get_dataset_fn( + dataclasses.replace(cfg, batch_size=1), + shard_id=0, + num_shards=1, + feature_converter_cls=feature_converter_cls, + num_epochs=eval_steps * cfg.batch_size, + continue_from_last_checkpoint=False)) + + return datasets + + +def round_vocab_size_to_multiple(vocabulary: seqio.Vocabulary, + divisor: int = 128): + """Round up vocabulary size for improved TPU performance.""" + size = vocabulary.vocab_size + return size + -size % divisor + + +def flatten_dict_string_keys(x): + """Flattens a nested dictionary to have string keys and '/' separators.""" + return traverse_util.flatten_dict(flax.core.unfreeze(x), sep='/') + + +class _RegexMap(collections.abc.Mapping): + """Ordered mapping from regexes to values requiring a full match.""" + + def __init__(self, kvs: Sequence[Tuple[str, Any]]): + self._kvs = [(re.compile(k), v) for k, v in kvs] + + def __getitem__(self, key: str) -> Any: + for pattern, v in self._kvs: + if pattern.fullmatch(key): + return v + raise KeyError(f'No pattern matching key: {key}') + + def __len__(self) -> int: + return len(self._kvs) + + def __iter__(self) -> Iterable[Tuple[re.Pattern, Any]]: + return iter(self._kvs) + + +def override_params_axes_names( + model_variables: flax_scope.FrozenVariableDict, + params_axes_names_override: Sequence[Tuple[str, Tuple[str, ...]]] = () +) -> flax_scope.FrozenVariableDict: + """Applies parameter axis names overrides to axes variables. + + Args: + model_variables: the original model variables containing the 'params_axes' + collection. + params_axes_names_override: a priority-ordered mapping from regex patterns + (fully matching parameter names) to tuples containing string logical axis + names to replace model-derived names. + + Returns: + an updated set of model variables with the overrides applied to the + 'params_axes' collection. + """ + params_axes_names_override_map = _RegexMap(params_axes_names_override) + + if 'params_axes' not in model_variables: + raise ValueError( + "Model variables do not contain a 'params_axes' collection to apply an " + 'override to.') + model_variables = model_variables.unfreeze() + flat_params = traverse_util.flatten_dict(model_variables['params']) + flat_params_axes = traverse_util.flatten_dict(model_variables['params_axes']) + + for key, param in flat_params.items(): + param_name = '/'.join(key) + override = params_axes_names_override_map.get(param_name) + if override is None: + continue + + param_axes_key = key[:-1] + (f'{key[-1]}_axes',) + + curr_metadata = flat_params_axes.get(param_axes_key) + + if curr_metadata is None: + logging.info('Adding axis names for %s: %s', param_name, override) + else: + assert isinstance(curr_metadata, flax_partitioning.AxisMetadata) + logging.info('Replacing axis names for %s (%s) with %s.', param_name, + curr_metadata.names, override) + + if param.ndim != len(override): + raise ValueError( + f'Provided axis name override for {param_name} does not match ' + f'param rank ({param.ndim}): {override}') + flat_params_axes[param_axes_key] = flax_partitioning.AxisMetadata( + names=override) + + model_variables['params_axes'] = traverse_util.unflatten_dict( + flat_params_axes) + return flax.core.freeze(model_variables) + + + + +def get_local_data(x): + if isinstance(x, GlobalDeviceArray): + return x.local_data(0) + elif isinstance(x, pxla.ShardedDeviceArray): + val = x.device_buffers[0] + if val.aval is None: + val.aval = jax.ShapedArray(val.shape, val.dtype) + return val + else: + return x diff --git a/t5x/utils_test.py b/t5x/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6a33b819474665e723e4807f82275f19632fc603 --- /dev/null +++ b/t5x/utils_test.py @@ -0,0 +1,604 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for t5x.utils.""" + +import dataclasses +import os +import re +from typing import Optional + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +import flax.core +from flax.linen import partitioning as flax_partitioning +import jax +import numpy as np +import seqio +from t5x import checkpoints +from t5x import partitioning +from t5x import test_utils +from t5x import train_state as train_state_lib +from t5x import utils +import tensorflow as tf + +mock = absltest.mock +Evaluator = seqio.Evaluator +PartitionSpec = partitioning.PartitionSpec +AxisMetadata = flax_partitioning.AxisMetadata + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + +FLAGS = flags.FLAGS + + +def get_mock_train_state(params, param_states=None, step=0): + """Returns a mock TrainState.""" + step = np.array(step) if step is not None else None + state = mock.Mock(param_states=param_states, step=step) + state_dict = dict( + target=params, state=dict(param_states=param_states, step=step)) + return mock.Mock( + params=params, + param_states=param_states, + step=step, + state_dict=lambda: state_dict, + optimizer=mock.Mock( + target=params, state=state, state_dict=lambda: state_dict), + ) + + +class UtilsTest(parameterized.TestCase): + + def round_vocab_size_to_multiple(self): + self.assertEqual(utils.round_vocab_size_to_multiple(1), 128) + self.assertEqual(utils.round_vocab_size_to_multiple(128), 128) + self.assertEqual(utils.round_vocab_size_to_multiple(129), 256) + self.assertEqual(utils.round_vocab_size_to_multiple(129), 256) + self.assertEqual( + utils.round_vocab_size_to_multiple(25600, divisor=384), 256128) + + def test_get_zeros_batch_like_spec(self): + test_utils.assert_same( + utils.get_zeros_batch_like_spec({ + "i": jax.ShapeDtypeStruct((2, 5), dtype=np.int32), + "j": jax.ShapeDtypeStruct((1,), dtype=np.float32), + }), { + "i": np.zeros((2, 5), dtype=np.int32), + "j": np.zeros((1,), dtype=np.float32) + }) + + def test_get_zeros_batch_like_dataset(self): + ds = tf.data.Dataset.from_tensors({ + "i": np.arange(10, dtype=np.int32).reshape((2, 5)), + "j": np.ones((1,), dtype=np.float32) + }) + + test_utils.assert_same( + utils.get_zeros_batch_like_dataset(ds), { + "i": np.zeros((2, 5), dtype=np.int32), + "j": np.zeros((1,), dtype=np.float32) + }) + + test_utils.assert_same( + utils.get_zeros_batch_like_dataset(ds, batch_size=4), { + "i": np.zeros((4, 5), dtype=np.int32), + "j": np.zeros((4,), dtype=np.float32) + }) + + @parameterized.named_parameters( + dict(testcase_name="write_to_file", write_to_log_file=True), + dict(testcase_name="do_not_write_to_file", write_to_log_file=False), + ) + def test_log_model_info(self, write_to_log_file): + log_file = self.create_tempfile() if write_to_log_file else None + + mock_train_state = get_mock_train_state( + params={ + "a": { + "aa": jax.ShapeDtypeStruct(shape=(2, 3), dtype=np.int32) + }, + "c": jax.ShapeDtypeStruct(shape=(7, 8), dtype=np.int32) + }, + param_states={ + "a": { + "aa": { + "v_row": jax.ShapeDtypeStruct(shape=(2,), dtype=np.int32), + "v_col": jax.ShapeDtypeStruct(shape=(3,), dtype=np.int32) + } + }, + "c": { + "v_row": jax.ShapeDtypeStruct(shape=(2, 4), dtype=np.int32), + "v_col": None + } + }) + + mock_logical_axes = get_mock_train_state( + params={ + "a": { + "aa": partitioning.AxisNames("a1", None) + }, + "c": partitioning.AxisNames(None, "a1") + }, + param_states={ + "a": { + "aa": { + "v_row": partitioning.AxisNames(None,), + "v_col": partitioning.AxisNames(None,) + } + }, + "c": { + "v_row": partitioning.AxisNames("a1",), + "v_col": partitioning.AxisNames("a2",) + } + }, + step=None) + + mock_mesh_axes = get_mock_train_state( + params={ + "a": { + "aa": PartitionSpec("b1", None) + }, + "c": PartitionSpec(None, "b1") + }, + param_states={ + "a": { + "aa": { + "v_row": partitioning.AxisNames(None,), + "v_col": partitioning.AxisNames(None,) + } + }, + "c": { + "v_row": partitioning.AxisNames("b1",), + "v_col": partitioning.AxisNames("b2",) + } + }, + step=None) + + partitioner = mock.Mock( + get_logical_axes=lambda _: mock_logical_axes, + get_mesh_axes=lambda _: mock_mesh_axes) + + with self.assertLogs(level="INFO") as logs: + utils.log_model_info(log_file and log_file.full_path, mock_train_state, + partitioner) + + relevant_logs = [ + re.sub(r"\s+", " ", output) + for record, output in zip(logs.records, logs.output) + if "t5x/utils.py" in record.pathname + ] + self.assertLen(relevant_logs, 9) + self.assertIn( + "Variable a/aa size 6 shape (a1=2, None=3) partition spec ('b1', None)", + relevant_logs[0]) + self.assertIn( + "Variable c size 56 shape (None=7, a1=8) partition spec (None, 'b1')", + relevant_logs[1]) + + if write_to_log_file: + self.assertEqual( + re.sub(r"\s+", " ", log_file.read_text()), + "Variable a/aa size 6 shape (a1=2, None=3) partition spec ('b1', None) " + "Variable c size 56 shape (None=7, a1=8) partition spec (None, 'b1') " + "Total number of parameters: 62 " + "Variable param_states/a/aa/v_col size 3 shape (None=3) partition spec (None,) " + "Variable param_states/a/aa/v_row size 2 shape (None=2) partition spec (None,) " + "Variable param_states/c/v_col None " + "Variable param_states/c/v_row size 8 shape (2, 4) partition spec ('b1',) " + "Variable step size 1 shape () partition spec None ") + + + def test_get_training_eval_datasets_task(self): + task = mock.create_autospec(seqio.Task, instance=True) + task.name = "mock_task" + task.splits = set(["train", "test"]) + seqio.TaskRegistry.add_provider("mock_task", task) + + mock_get_dataset_fn = mock.Mock( + return_value=tf.data.Dataset.range(10).batch(1)) + mock_fc_cls = mock.Mock() + + cfg = utils.DatasetConfig( + mixture_or_task_name="mock_task", + task_feature_lengths={}, + split="test", + batch_size=4, + shuffle=False, + seed=None) + + # Single shard. + ds = utils.get_training_eval_datasets( + cfg, + shard_id=0, + num_shards=1, + eval_steps=3, + feature_converter_cls=mock_fc_cls, + get_dataset_fn=mock_get_dataset_fn) + + mock_get_dataset_fn.assert_called_once_with( + dataclasses.replace(cfg, batch_size=1), + shard_id=0, + num_shards=1, + feature_converter_cls=mock_fc_cls, + num_epochs=12, + continue_from_last_checkpoint=False) + + self.assertSameElements(ds.keys(), ["mock_task"]) + jax.tree_map(np.testing.assert_equal, list(ds["mock_task"]), [ + np.array([0, 1, 2, 3]), + np.array([4, 5, 6, 7]), + np.array([8, 9, 0, 1]), + ]) + + # 2 shards, shard 0 + mock_get_dataset_fn.reset_mock() + ds = utils.get_training_eval_datasets( + cfg, + shard_id=0, + num_shards=2, + eval_steps=3, + feature_converter_cls=mock_fc_cls, + get_dataset_fn=mock_get_dataset_fn) + + # Call the underlying function loading all shards since the fn shards at the + # example level. + mock_get_dataset_fn.assert_called_once_with( + dataclasses.replace(cfg, batch_size=1), + shard_id=0, + num_shards=1, + feature_converter_cls=mock_fc_cls, + num_epochs=12, + continue_from_last_checkpoint=False) + + self.assertSameElements(ds.keys(), ["mock_task"]) + jax.tree_map(np.testing.assert_equal, list(ds["mock_task"]), [ + np.array([0, 2]), + np.array([4, 6]), + np.array([8, 0]), + ]) + + # 2 shards, shard 1 + mock_get_dataset_fn.reset_mock() + ds = utils.get_training_eval_datasets( + cfg, + shard_id=1, + num_shards=2, + eval_steps=3, + feature_converter_cls=mock_fc_cls, + get_dataset_fn=mock_get_dataset_fn) + + # Call the underlying function loading all shards since the fn shards at the + # example level. + mock_get_dataset_fn.assert_called_once_with( + dataclasses.replace(cfg, batch_size=1), + shard_id=0, + num_shards=1, + feature_converter_cls=mock_fc_cls, + num_epochs=12, + continue_from_last_checkpoint=False) + + self.assertSameElements(ds.keys(), ["mock_task"]) + jax.tree_map(np.testing.assert_equal, list(ds["mock_task"]), [ + np.array([1, 3]), + np.array([5, 7]), + np.array([9, 1]), + ]) + + # 3 shards + with self.assertRaisesWithLiteralMatch( + ValueError, + "Batch size (4) must be divisible by number of shards (3)."): + _ = utils.get_training_eval_datasets( + cfg, + shard_id=0, + num_shards=3, + eval_steps=3, + feature_converter_cls=mock_fc_cls, + get_dataset_fn=mock_get_dataset_fn) + + def test_get_training_eval_datasets_mixture(self): + # Register a mock SeqIO mixture. + task1 = mock.create_autospec(seqio.Task, instance=True) + task1.name = "mock_task1" + task1.splits = set(["train", "test"]) + task2 = mock.create_autospec(seqio.Task, instance=True) + task2.name = "mock_task2" + task2.splits = set(["train", "test"]) + seqio.TaskRegistry.add_provider("mock_task1", task1) + seqio.TaskRegistry.add_provider("mock_task2", task2) + mixture = seqio.Mixture( + "mock_mix", ["mock_task1", "mock_task2"], default_rate=1.0) + seqio.MixtureRegistry.add_provider("mock_mix", mixture) + + mock_get_dataset = mock.Mock( + return_value=tf.data.Dataset.range(10).batch(1)) + + # Verify calls to utils.get_dataset + cfg = utils.DatasetConfig( + mixture_or_task_name="mock_mix", + task_feature_lengths={}, + split="test", + batch_size=4, + shuffle=False, + seed=23) + + res = utils.get_training_eval_datasets( + cfg, + shard_id=0, + num_shards=2, + eval_steps=3, + feature_converter_cls=seqio.FeatureConverter, + get_dataset_fn=mock_get_dataset) + + expected_calls = [ + mock.call( + dataclasses.replace( + cfg, mixture_or_task_name="mock_task1", batch_size=1), + shard_id=0, + num_shards=1, + feature_converter_cls=seqio.FeatureConverter, + continue_from_last_checkpoint=False, + num_epochs=12), + mock.call( + dataclasses.replace( + cfg, mixture_or_task_name="mock_task2", batch_size=1), + shard_id=0, + num_shards=1, + feature_converter_cls=seqio.FeatureConverter, + continue_from_last_checkpoint=False, + num_epochs=12), + mock.call( + dataclasses.replace( + cfg, mixture_or_task_name="mock_mix", batch_size=1), + shard_id=0, + num_shards=1, + feature_converter_cls=seqio.FeatureConverter, + continue_from_last_checkpoint=False, + num_epochs=12) + ] + mock_get_dataset.assert_has_calls(expected_calls) + + self.assertSameElements(res.keys(), + ["mock_task1", "mock_task2", "mock_mix"]) + for ds in res.values(): + jax.tree_map(np.testing.assert_equal, list(ds), [ + np.array([0, 2]), + np.array([4, 6]), + np.array([8, 0]), + ]) + + def test_override_params_axes_names(self): + model_variables = flax.core.freeze({ + "params": { + "logits_dense": np.zeros((2, 4)), + "mlp": { + "wo": { + "kernel": np.zeros((4, 6)), + "bias": np.zeros(6), + } + } + }, + "params_axes": { + "logits_dense_axes": AxisMetadata(names=("vocab", "embed")), + "mlp": { + "wo": { + "kernel_axes": AxisMetadata(names=("embed", "mlp")) + } + } + } + }) + + with self.assertRaisesWithLiteralMatch( + ValueError, + "Model variables do not contain a 'params_axes' collection to apply an " + "override to."): + utils.override_params_axes_names({"params": model_variables["params"]}, + [("mlp/wo/kernel", ("embed",))]) + + with self.assertRaisesWithLiteralMatch( + ValueError, + "Provided axis name override for mlp/wo/kernel does not match param " + "rank (2): ('embed',)"): + utils.override_params_axes_names(model_variables, + [("mlp/wo/kernel", ("embed",))]) + + overridden_variables = utils.override_params_axes_names( + model_variables, + [ + ("wo/kernel", ("batch",)), # unused since not a full match + (".*/wo/kernel", ("batch", "embed")), # this one is used + ("mlp/wo/kernel", ("embed",)), # unused since already matched + ("mlp/wo/bias", ("embed",)), # used + ]) + + jax.tree_multimap( + np.testing.assert_equal, overridden_variables, + flax.core.freeze({ + "params": { + "logits_dense": np.zeros((2, 4)), + "mlp": { + "wo": { + "kernel": np.zeros((4, 6)), + "bias": np.zeros(6), + } + } + }, + "params_axes": { + "logits_dense_axes": AxisMetadata(names=("vocab", "embed")), + "mlp": { + "wo": { + "kernel_axes": AxisMetadata(names=("batch", "embed")), + "bias_axes": AxisMetadata(names=("embed",)), + } + } + } + })) + + +@dataclasses.dataclass +class MockTrainState: + path: Optional[str] = None + from_scratch: Optional[bool] = None + + +class MockCheckpointer(checkpoints.Checkpointer): + + def __init__(self, *args, **kwargs): + pass + + # restore should return TrainState, but we force it to return Mock with path + # for simplicity. + def restore(self, path, *args, **kwargs): + return MockTrainState(path=path, from_scratch=False) + + +class TrainStateInitializerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + + def _partition(train_state, in_axis_resources, out_axis_resources): + del train_state, in_axis_resources, out_axis_resources + partitioned_fn = lambda _: MockTrainState(from_scratch=True) + return partitioned_fn + + partitioner = mock.Mock(get_mesh_axes=lambda _: None, partition=_partition) + mock_inference_state_create = self.enter_context( + mock.patch.object(train_state_lib.InferenceState, "create")) + mock_inference_state_create.return_value = None + + shapes = { + "ones": (1, 1), + "twos": (2, 2), + "threes": (3, 3), + } + types = { + "ones": int, + "twos": float, + "threes": int, + } + + def _init_fn(rng, input_shapes, input_types): + del rng + return { + "ones": + np.ones(input_shapes["ones"], dtype=input_types["ones"]), + "twos": + np.ones(input_shapes["twos"], dtype=input_types["twos"]) * 2, + "threes": + np.ones(input_shapes["threes"], dtype=input_types["threes"]) * 3 + } + + init_fn = mock.Mock() + init_fn.__call__ = _init_fn + init_fn.__self__ = None + + self.train_state_init = utils.TrainStateInitializer(None, init_fn, shapes, + partitioner, types) + + self.ckptdir = self.create_tempdir(name="primary_checkpoints") + steps = (2, 3) + self.paths = [] + for s in steps: + step_dir = self.ckptdir.mkdir(f"checkpoint_{s}") + step_dir.create_file("checkpoint") + self.paths += [step_dir.full_path] + + def test_from_checkpoints_specific(self): + # multiple paths + ckpt_cfg = utils.RestoreCheckpointConfig( + path=self.paths, mode="specific", checkpointer_cls=MockCheckpointer) + restored = self.train_state_init.from_checkpoints([ckpt_cfg]) + self.assertSequenceEqual(self.paths, [state.path for state in restored]) + with self.assertRaisesRegex(ValueError, r"^Expected at most 1 checkpoint"): + self.train_state_init.from_checkpoint([ckpt_cfg]) + + def test_from_checkpoints_latest(self): + # only restore single latest + ckpt_cfg = utils.RestoreCheckpointConfig( + path=self.ckptdir.full_path, + mode="latest", + checkpointer_cls=MockCheckpointer) + restored = list(self.train_state_init.from_checkpoints([ckpt_cfg])) + assert len(restored) == 1 + self.assertEqual(self.paths[-1], restored[0].path) + restored = self.train_state_init.from_checkpoint([ckpt_cfg]) + self.assertEqual(self.paths[-1], restored.path) + + def test_from_checkpoints_multiple_configs(self): + # uses first checkpoint with files present. + ckpt_cfg = utils.RestoreCheckpointConfig( + path=self.ckptdir.full_path, + mode="latest", + checkpointer_cls=MockCheckpointer) + secondary_ckptdir = self.create_tempdir(name="secondary_checkpoints") + for s in (4, 5): + step_dir = secondary_ckptdir.mkdir(f"checkpoint_{s}") + step_dir.create_file("checkpoint") + secondary_ckpt_cfg = utils.RestoreCheckpointConfig( + path=secondary_ckptdir.full_path, + mode="latest", + checkpointer_cls=MockCheckpointer) + restored = self.train_state_init.from_checkpoint( + [ckpt_cfg, secondary_ckpt_cfg]) + self.assertEqual(self.paths[-1], restored.path) + + def test_from_checkpoints_multiple_configs_one_empty(self): + # skips empty_checkpoints directory with no checkpoints present. + ckpt_cfg = utils.RestoreCheckpointConfig( + path=self.ckptdir.full_path, + mode="latest", + checkpointer_cls=MockCheckpointer) + empty_ckptdir = self.create_tempdir(name="empty_checkpoints") + empty_ckpt_cfg = utils.RestoreCheckpointConfig( + path=empty_ckptdir.full_path, + mode="latest", + checkpointer_cls=MockCheckpointer) + restored = self.train_state_init.from_checkpoint([empty_ckpt_cfg, ckpt_cfg]) + self.assertEqual(self.paths[-1], restored.path) + + def test_from_scratch(self): + self.assertTrue( + self.train_state_init.from_scratch(jax.random.PRNGKey(13)).from_scratch) + + def test_from_checkpoint_or_scratch(self): + ckpt_cfg = utils.RestoreCheckpointConfig( + path=self.ckptdir.full_path, + mode="latest", + checkpointer_cls=MockCheckpointer) + empty_ckptdir = self.create_tempdir(name="empty_checkpoints") + empty_ckpt_cfg = utils.RestoreCheckpointConfig( + path=empty_ckptdir.full_path, + mode="latest", + checkpointer_cls=MockCheckpointer) + + init_rng = jax.random.PRNGKey(13) + + # ckpt_cfg has checkpoints, restore from there + restored = self.train_state_init.from_checkpoint_or_scratch( + [empty_ckpt_cfg, ckpt_cfg], init_rng=init_rng) + self.assertEqual(self.paths[-1], restored.path) + self.assertFalse(restored.from_scratch) + + # no checkpoints available, init from scratch + initialized = self.train_state_init.from_checkpoint_or_scratch( + [empty_ckpt_cfg], init_rng=init_rng) + self.assertTrue(initialized.from_scratch) + + +if __name__ == "__main__": + absltest.main() diff --git a/t5x/version.py b/t5x/version.py new file mode 100644 index 0000000000000000000000000000000000000000..d0646ecf047d43b6602bba1ef88d4c7aa78ce282 --- /dev/null +++ b/t5x/version.py @@ -0,0 +1,20 @@ +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Separate file for storing the current version of T5X. + +Stored in a separate file so that setup.py can reference the version without +pulling in all the dependencies in __init__.py. +""" +__version__ = '0.0.0'