diff --git a/.gitattributes b/.gitattributes index 4be18050261ccaa46c79adbc7a0add5da290706a..34cd0b7d9ed3676541c1cfaa61e9c4e7257e7b42 100644 --- a/.gitattributes +++ b/.gitattributes @@ -42,3 +42,6 @@ fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text fairseq/fairseq/libnat.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/fairseq/fairseq/data/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f79e390b40761694b429cc7b7e5aab72dd6f3b Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0193961835ccdcee50bf89844e9c489a8c19315 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9efbb8d303d8f0f42212648c2a19bc6556e2f81 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e99c6ecaa6912191f55ac11f05b636372f8c9f90 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34a930ef86408243443fe0eb8e3af58f8e78691a Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e1033d19b243dd9349421944fa71947e2f7eaf2 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/codedataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/codedataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32f71e6aa568e9282950bb0b65d0b8d8ba6f269b Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/codedataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff88a067615e613d5e574c86a365dc6283d4e0f4 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..825cb4cffde567aac7fe80a1b192202eafd56ea1 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18edb3821e9735c4962c3e6fd8154e2fe30fba49 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/data_utils.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/data_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0b7c27ce9f76f47968f796351c5ca2f6ec21115 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/data_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e27aa94a0db24f2d61889ac3e14ddad04d8bd6ec Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/dictionary.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/dictionary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5deaca93207816e04764f6e5e7a7e530557b299a Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/dictionary.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c1fcf622c07f6c0a1f0e7f0d551cc59d9ecbd01 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b73da5c06e5402ee011f1ce2e2fc707d1fb6440e Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/id_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/id_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39df62a735de8d04539061cd65532f94a18dcfb4 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/id_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28af08664a96b6134c7ca775b40c49495355d5eb Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6708004f45a467c0fc717308b6df3b0d1a4a3879 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/list_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/list_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83954067f5ce4f90a178e262648dcc7550e813b7 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/list_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4c3e9f772f636c9bc578e181475c5bec1c4656c Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39ba35c8bf575abdd944111dfd425316d820cfb5 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d09ba0baa8874a98ea0abbaac114ce51b83ce961 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4245cf8251e5cc174217c58f558854865ba2e8f Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2609c85b0c2e6f547f24522f9a0498b378cbd5a8 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d45117e5ac65addb2c4fea9f183813fa59e5693 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9618fcfa6f2f7a8df52c59a5d2dac31703c40444 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/noising.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/noising.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3368672907d3f4bcf19f676da106053b741bfbc5 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/noising.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ebb7fcf3c8ded52c417b69f7180c2b6c76be306 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e66fb3128ae04766e42ad66b3d47309ea4b28e1 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a58dc46778269054f992a75a25a6495eaeceadc0 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..900411750874d4b7c0d82a89b638934cd7a371c0 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06a5c9cb8cdf0e2b8019f744c393974ba84098e3 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3462fa66c8897772673546b2d9de67b79b29ce3 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81484571de4803e069e0335efa426fb22a98f7a7 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ef0d82d31a6c956174193da3b5fdd76750906a1 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..735a5a0be8f16633fce9fd6b2874059976f1174f Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa65c98a965bd5afa1fdd9caa4a4c7214a660a7 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfe6357fb4751494229079fcafeaa70509acd5c1 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99d9e15d2f7ff849f9659db9e97a48b98b4b8683 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a796a1360badcd4532abee5e50cc830c0c7f2ec Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/text_compressor.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/text_compressor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe691aa804d2a26aeff8e90d0450933edaaa4695 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/text_compressor.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..188380cdbbce1617f6f509f0e3b1f2d429c73dc4 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..094b9eb385cf6c5f08f87ad7e572c2ca3cb05770 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..762839f2f20143c2d201dc8ec82def1464147bb4 --- /dev/null +++ b/fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:382bd1943d43d4d607db5910ce6671e50cbbb6a6fdd8a38b922c9d8fa379efc9 +size 267952 diff --git a/fairseq/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b27817fbeeec0c1a00628c36f2f13173f9b08316 Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c27e0dec9cd96b8f3492b117e49e3634b6daecf Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70a2ea4c11b75d23c607ac6e70c3f8999f2fa572 Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19c0e4e7b7a1f0550da8e76d8846bcd3db8628fd Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..affafe0cc58ef2249fe37c41939fc8908fa9d2c6 Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b889a92387d1a98d70a06be059e48d984dc14c2 Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/huffman/__init__.py b/fairseq/fairseq/data/huffman/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b61fafadba28f65fe78a28b2099368b83cfcf41 --- /dev/null +++ b/fairseq/fairseq/data/huffman/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder +from .huffman_mmap_indexed_dataset import ( + HuffmanMMapIndex, + HuffmanMMapIndexedDataset, + HuffmanMMapIndexedDatasetBuilder, + vocab_file_path, +) + +__all__ = [ + "HuffmanCoder", + "HuffmanCodeBuilder", + "HuffmanMMapIndexedDatasetBuilder", + "HuffmanMMapIndexedDataset", + "HuffmanMMapIndex", + "vocab_file_path", +] diff --git a/fairseq/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a17437713c05ff2184b755e12c5dd39a4b2d4591 Binary files /dev/null and b/fairseq/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc b/fairseq/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67cfb0c92fc1591c370cfd9be5ef1fa3a0ead460 Binary files /dev/null and b/fairseq/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc b/fairseq/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8b550e7481bbf5bd637262e872ee6c1d06736c8 Binary files /dev/null and b/fairseq/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/huffman/huffman_coder.py b/fairseq/fairseq/data/huffman/huffman_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..c04f84564e6a22209439c67fed3cac31f010c6e9 --- /dev/null +++ b/fairseq/fairseq/data/huffman/huffman_coder.py @@ -0,0 +1,267 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re +import typing as tp +from collections import Counter, deque +from dataclasses import dataclass + +from bitarray import bitarray, util +from fairseq.data import Dictionary + +# basically we have to write to addressable bytes for the memory mapped +# dataset loader. Sentences that get encoded to a length that is not a +# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder) +BLOCKSIZE = 8 + + +class HuffmanCoder: + def __init__( + self, root: "HuffmanNode", bos="", pad="", eos="", unk="" + ): + self.root = root + self.table = root.code_table() + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + + def _pad(self, a: bitarray) -> bitarray: + """ + bitpadding, 1 then 0. + + If the array is already a multiple of blocksize, we add a full block. + """ + pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1 + padding = bitarray("1" + "0" * pad_len) + return a + padding + + def _unpad(self, a: bitarray) -> bitarray: + """ + remove the bitpadding. + + There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that + """ + # count the 0 padding at the end until we find the first 1 + # we want to remove the one too + remove_cnt = util.rindex(a, 1) + return a[:remove_cnt] + + def encode(self, iter: tp.List[str]) -> bytes: + """ + encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes. + """ + a = bitarray() + for token in iter: + code = self.get_code(token) + if code is None: + if self.unk_word is None: + raise Exception(f"unknown token {token} cannot be encoded.") + else: + token = self.unk_word + a = a + self.get_code(token) + return self._pad(a).tobytes() + + def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]: + """ + take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id + """ + a = bitarray() + a.frombytes(bits) + return self.root.decode(self._unpad(a)) + + def get_code(self, symbol: str) -> tp.Optional[bitarray]: + node = self.get_node(symbol) + return None if node is None else node.code + + def get_node(self, symbol: str) -> "HuffmanNode": + return self.table.get(symbol) + + @classmethod + def from_file( + cls, + filename: str, + bos="", + pad="", + eos="", + unk="", + ) -> "HuffmanCoder": + builder = HuffmanCodeBuilder.from_file(filename) + return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk) + + def to_file(self, filename, sep="\t"): + nodes = list(self.table.values()) + nodes.sort(key=lambda n: n.id) + with open(filename, "w", encoding="utf-8") as output: + for n in nodes: + output.write(f"{n.symbol}{sep}{n.count}\n") + + def __iter__(self): + for n in self.table.values(): + yield n + + def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder": + builder = HuffmanCodeBuilder() + for n in self: + builder.increment(n.symbol, n.count) + for n in other_coder: + builder.increment(n.symbol, n.count) + return builder.build_code() + + def __eq__(self, other: "HuffmanCoder") -> bool: + return self.table == other.table + + def __len__(self) -> int: + return len(self.table) + + def __contains__(self, sym: str) -> bool: + return sym in self.table + + def to_dictionary(self) -> Dictionary: + dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos) + for n in self: + dictionary.add_symbol(n.symbol, n=n.count) + dictionary.finalize() + return dictionary + + +@dataclass +class HuffmanNode: + """ + a node in a Huffman tree + """ + + id: int + count: int + symbol: tp.Optional[str] = None + left: tp.Optional["HuffmanNode"] = None + right: tp.Optional["HuffmanNode"] = None + code: tp.Optional[bitarray] = None + + def is_leaf(self) -> bool: + return self.left is None and self.right is None + + def code_table( + self, prefix: tp.Optional[bitarray] = None + ) -> tp.Dict[str, "HuffmanNode"]: + defaulted_prefix = prefix if prefix is not None else bitarray() + if self.is_leaf(): + self.code = ( + defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0") + ) # leaf could be the root if there is only one symbol + return {self.symbol: self} + + codes_right = self.right.code_table(defaulted_prefix + bitarray([0])) + codes_left = self.left.code_table(defaulted_prefix + bitarray([1])) + return {**codes_left, **codes_right} + + def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]: + current_node = self + for bit in bits: + if bit == 0: # go right + current_node = current_node.right + else: # go left + current_node = current_node.left + if current_node is None: + # we shouldn't be on a leaf here + raise Exception("fell off a leaf") + if current_node.is_leaf(): + yield current_node + current_node = self + if current_node != self: + raise Exception("couldn't decode all the bits") + + +class HuffmanCodeBuilder: + """ + build a dictionary with occurence count and then build the Huffman code for it. + """ + + def __init__(self): + self.symbols = Counter() + + def add_symbols(self, *syms) -> None: + self.symbols.update(syms) + + def increment(self, symbol: str, cnt: int) -> None: + self.symbols[symbol] += cnt + + @classmethod + def from_file(cls, filename): + c = cls() + with open(filename, "r", encoding="utf-8") as input: + for line in input: + split = re.split(r"[\s]+", line) + c.increment(split[0], int(split[1])) + return c + + def to_file(self, filename, sep="\t"): + with open(filename, "w", encoding="utf-8") as output: + for (tok, cnt) in self.symbols.most_common(): + output.write(f"{tok}{sep}{cnt}\n") + + def _smallest(self, q1: deque, q2: deque) -> HuffmanNode: + if len(q1) == 0: + return q2.pop() + + if len(q2) == 0: + return q1.pop() + + if q1[-1].count < q2[-1].count: + return q1.pop() + + return q2.pop() + + def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder": + new_c = self.symbols + c.symbols + new_b = HuffmanCodeBuilder() + new_b.symbols = new_c + return new_b + + def build_code( + self, + bos="", + pad="", + eos="", + unk="", + ) -> HuffmanCoder: + assert len(self.symbols) > 0, "cannot build code from empty list of symbols" + + if self.symbols[bos] == 0: + self.add_symbols(bos) + if self.symbols[pad] == 0: + self.add_symbols(pad) + if self.symbols[eos] == 0: + self.add_symbols(eos) + if self.symbols[unk] == 0: + self.add_symbols(unk) + + node_id = 0 + leaves_queue = deque( + [ + HuffmanNode(symbol=symbol, count=count, id=idx) + for idx, (symbol, count) in enumerate(self.symbols.most_common()) + ] + ) # left are the most common, right are the least common + + if len(leaves_queue) == 1: + root = leaves_queue.pop() + root.id = 0 + return HuffmanCoder(root) + + nodes_queue = deque() + + while len(leaves_queue) > 0 or len(nodes_queue) != 1: + # get the lowest two nodes at the head of each queue + node1 = self._smallest(leaves_queue, nodes_queue) + node2 = self._smallest(leaves_queue, nodes_queue) + + # add new node + nodes_queue.appendleft( + HuffmanNode( + count=node1.count + node2.count, left=node1, right=node2, id=node_id + ) + ) + node_id += 1 + + # we are left with the root + return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk) diff --git a/fairseq/fairseq/data/huffman/huffman_mmap_indexed_dataset.py b/fairseq/fairseq/data/huffman/huffman_mmap_indexed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9b098f2c2be32ef65525dd773a6664d7823ada38 --- /dev/null +++ b/fairseq/fairseq/data/huffman/huffman_mmap_indexed_dataset.py @@ -0,0 +1,287 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import mmap +import os +import shutil +import struct +import typing as tp +from functools import lru_cache + +import numpy as np +import torch +from fairseq.data import indexed_dataset +from fairseq.data.huffman import HuffmanCoder +from fairseq.file_io import PathManager + + +class HuffmanMMapIndex: + """ + keep an index of the offsets in the huffman binary file. + First a header, then the list of sizes (num tokens) for each instance and finally + the addresses of each instance. + """ + + _HDR_MAGIC = b"HUFFIDX\x00\x00" + _VERSION = 1 + + @classmethod + def writer(cls, path: str, data_len: int): + class _Writer: + def __enter__(self): + self._file = open(path, "wb") + + # write header (magic + version) + self._file.write(cls._HDR_MAGIC) + self._file.write(struct.pack(" None: + self._path_prefix = path_prefix + self._coder = coder + self._sizes = [] + self._ptrs = [] + self._data_len = 0 + + def open(self): + self._coder.to_file(vocab_file_path(self._path_prefix)) + self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb") + + def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder": + self.open() + return self + + def add_item(self, tokens: tp.List[str]) -> None: + """ + add a list of tokens to the dataset, they will compressed with the + provided coder before being written to file. + """ + encoded = self._coder.encode(tokens) + code_len = len(encoded) + last_ptr = 0 + if len(self._ptrs) > 0: + last_ptr = self._ptrs[-1] + self._sizes.append(len(tokens)) + self._ptrs.append(last_ptr + code_len) + self._data_len += code_len + self._data_file.write(encoded) + + def append(self, other_dataset_path_prefix: str) -> None: + """ + append an existing dataset. + Beware, if it wasn't built with the same coder, you are in trouble. + """ + other_index = HuffmanMMapIndex( + indexed_dataset.index_file_path(other_dataset_path_prefix) + ) + for (ptr, size) in other_index: + self._ptrs.append(ptr + self._data_len) + self._sizes.append(size) + + # Concatenate data + with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f: + shutil.copyfileobj(f, self._data_file) + + self._data_len += other_index.data_len + + def close(self): + self._data_file.close() + with HuffmanMMapIndex.writer( + indexed_dataset.index_file_path(self._path_prefix), self._data_len + ) as index: + index.write(self._sizes, self._ptrs) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() diff --git a/fairseq/fairseq/data/legacy/__init__.py b/fairseq/fairseq/data/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd5c72b5e9d7f67fb7e4ef10808d7ec08967ff4 --- /dev/null +++ b/fairseq/fairseq/data/legacy/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .block_pair_dataset import BlockPairDataset +from .masked_lm_dataset import MaskedLMDataset +from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary + + +__all__ = [ + "BertDictionary", + "BlockPairDataset", + "MaskedLMDataset", + "MaskedLMDictionary", +] diff --git a/fairseq/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc3f5e0cb0064c1ce99929776f805f5d0967ee31 Binary files /dev/null and b/fairseq/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc b/fairseq/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d14c88f1547179a8001c76524198afa320792b9b Binary files /dev/null and b/fairseq/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc b/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2fe87d8a0e16727be3115f18c2104464f50a4fa Binary files /dev/null and b/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc b/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bf6fc398c40e1e2f9628682dff89cca8f738f04 Binary files /dev/null and b/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/legacy/block_pair_dataset.py b/fairseq/fairseq/data/legacy/block_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ba069b46052286c531b4f9706d96788732cd2ad2 --- /dev/null +++ b/fairseq/fairseq/data/legacy/block_pair_dataset.py @@ -0,0 +1,311 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import numpy as np +import torch +from fairseq.data import FairseqDataset + + +class BlockPairDataset(FairseqDataset): + """Break a Dataset of tokens into sentence pair blocks for next sentence + prediction as well as masked language model. + + High-level logics are: + 1. break input tensor to tensor blocks + 2. pair the blocks with 50% next sentence and 50% random sentence + 3. return paired blocks as well as related segment labels + + Args: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes: array of sentence lengths + dictionary: dictionary for the task + block_size: maximum block size + break_mode: mode for breaking copurs into block pairs. currently we support + 2 modes + doc: respect document boundaries and each part of the pair should belong to on document + none: don't respect any boundary and cut tokens evenly + short_seq_prob: probability for generating shorter block pairs + doc_break_size: Size for empty line separating documents. Typically 1 if + the sentences have eos, 0 otherwise. + """ + + def __init__( + self, + dataset, + dictionary, + sizes, + block_size, + break_mode="doc", + short_seq_prob=0.1, + doc_break_size=1, + ): + super().__init__() + self.dataset = dataset + self.pad = dictionary.pad() + self.eos = dictionary.eos() + self.cls = dictionary.cls() + self.mask = dictionary.mask() + self.sep = dictionary.sep() + self.break_mode = break_mode + self.dictionary = dictionary + self.short_seq_prob = short_seq_prob + self.block_indices = [] + + assert len(dataset) == len(sizes) + + if break_mode == "doc": + cur_doc = [] + for sent_id, sz in enumerate(sizes): + assert doc_break_size == 0 or sz != 0, ( + "when doc_break_size is non-zero, we expect documents to be" + "separated by a blank line with a single eos." + ) + # empty line as document separator + if sz == doc_break_size: + if len(cur_doc) == 0: + continue + self.block_indices.append(cur_doc) + cur_doc = [] + else: + cur_doc.append(sent_id) + max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP] + self.sent_pairs = [] + self.sizes = [] + for doc_id, doc in enumerate(self.block_indices): + self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes) + elif break_mode is None or break_mode == "none": + # each block should have half of the block size since we are constructing block pair + sent_length = (block_size - 3) // 2 + total_len = sum(dataset.sizes) + length = math.ceil(total_len / sent_length) + + def block_at(i): + start = i * sent_length + end = min(start + sent_length, total_len) + return (start, end) + + sent_indices = np.array([block_at(i) for i in range(length)]) + sent_sizes = np.array([e - s for s, e in sent_indices]) + dataset_index = self._sent_to_dataset_index(sent_sizes) + + # pair sentences + self._pair_sentences(dataset_index) + else: + raise ValueError("Invalid break_mode: " + break_mode) + + def _pair_sentences(self, dataset_index): + """ + Give a list of evenly cut blocks/sentences, pair these sentences with 50% + consecutive sentences and 50% random sentences. + This is used for none break mode + """ + # pair sentences + for sent_id, sent in enumerate(dataset_index): + next_sent_label = ( + 1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0 + ) + if next_sent_label: + next_sent = dataset_index[sent_id + 1] + else: + next_sent = dataset_index[ + self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1]) + ] + self.sent_pairs.append((sent, next_sent, next_sent_label)) + + # The current blocks don't include the special tokens but the + # sizes already account for this + self.sizes.append(3 + sent[3] + next_sent[3]) + + def _sent_to_dataset_index(self, sent_sizes): + """ + Build index mapping block indices to the underlying dataset indices + """ + dataset_index = [] + ds_idx, ds_remaining = -1, 0 + for to_consume in sent_sizes: + sent_size = to_consume + if ds_remaining == 0: + ds_idx += 1 + ds_remaining = sent_sizes[ds_idx] + start_ds_idx = ds_idx + start_offset = sent_sizes[ds_idx] - ds_remaining + while to_consume > ds_remaining: + to_consume -= ds_remaining + ds_idx += 1 + ds_remaining = sent_sizes[ds_idx] + ds_remaining -= to_consume + dataset_index.append( + ( + start_ds_idx, # starting index in dataset + start_offset, # starting offset within starting index + ds_idx, # ending index in dataset + sent_size, # sentence length + ) + ) + assert ds_remaining == 0 + assert ds_idx == len(self.dataset) - 1 + return dataset_index + + def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes): + """ + Go through a single document and genrate sentence paris from it + """ + current_chunk = [] + current_length = 0 + curr = 0 + # To provide more randomness, we decrease target seq length for parts of + # samples (10% by default). Note that max_num_tokens is the hard threshold + # for batching and will never be changed. + target_seq_length = max_num_tokens + if np.random.random() < self.short_seq_prob: + target_seq_length = np.random.randint(2, max_num_tokens) + # loop through all sentences in document + while curr < len(doc): + sent_id = doc[curr] + current_chunk.append(sent_id) + current_length = sum(sizes[current_chunk]) + # split chunk and generate pair when exceed target_seq_length or + # finish the loop + if curr == len(doc) - 1 or current_length >= target_seq_length: + # split the chunk into 2 parts + a_end = 1 + if len(current_chunk) > 2: + a_end = np.random.randint(1, len(current_chunk) - 1) + sent_a = current_chunk[:a_end] + len_a = sum(sizes[sent_a]) + # generate next sentence label, note that if there is only 1 sentence + # in current chunk, label is always 0 + next_sent_label = ( + 1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0 + ) + if not next_sent_label: + # if next sentence label is 0, sample sent_b from a random doc + target_b_length = target_seq_length - len_a + rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id]) + random_doc = self.block_indices[rand_doc_id] + random_start = np.random.randint(0, len(random_doc)) + sent_b = [] + len_b = 0 + for j in range(random_start, len(random_doc)): + sent_b.append(random_doc[j]) + len_b = sum(sizes[sent_b]) + if len_b >= target_b_length: + break + # return the second part of the chunk since it's not used + num_unused_segments = len(current_chunk) - a_end + curr -= num_unused_segments + else: + # if next sentence label is 1, use the second part of chunk as sent_B + sent_b = current_chunk[a_end:] + len_b = sum(sizes[sent_b]) + # currently sent_a and sent_B may be longer than max_num_tokens, + # truncate them and return block idx and offsets for them + sent_a, sent_b = self._truncate_sentences( + sent_a, sent_b, max_num_tokens + ) + self.sent_pairs.append((sent_a, sent_b, next_sent_label)) + self.sizes.append(3 + sent_a[3] + sent_b[3]) + current_chunk = [] + curr += 1 + + def _skip_sampling(self, total, skip_ids): + """ + Generate a random integer which is not in skip_ids. Sample range is [0, total) + TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later + """ + rand_id = np.random.randint(total - len(skip_ids)) + return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids) + + def _truncate_sentences(self, sent_a, sent_b, max_num_tokens): + """ + Trancate a pair of sentence to limit total length under max_num_tokens + Logics: + 1. Truncate longer sentence + 2. Tokens to be truncated could be at the beginning or the end of the sentnce + Returns: + Truncated sentences represented by dataset idx + """ + len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b]) + front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0 + + while True: + total_length = ( + len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b + ) + if total_length <= max_num_tokens: + break + + if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b: + if np.random.rand() < 0.5: + front_cut_a += 1 + else: + end_cut_a += 1 + else: + if np.random.rand() < 0.5: + front_cut_b += 1 + else: + end_cut_b += 1 + + # calculate ds indices as well as offsets and return + truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a) + truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b) + return truncated_sent_a, truncated_sent_b + + def _cut_sentence(self, sent, front_cut, end_cut): + """ + Cut a sentence based on the numbers of tokens to be cut from beginning and end + Represent the sentence as dataset idx and return + """ + start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0 + target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut + while front_cut > 0: + if self.dataset.sizes[start_ds_idx] > front_cut: + offset += front_cut + break + else: + front_cut -= self.dataset.sizes[start_ds_idx] + start_ds_idx += 1 + while end_cut > 0: + if self.dataset.sizes[end_ds_idx] > end_cut: + break + else: + end_cut -= self.dataset.sizes[end_ds_idx] + end_ds_idx -= 1 + return start_ds_idx, offset, end_ds_idx, target_len + + def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length): + """ + Fetch a block of tokens based on its dataset idx + """ + buffer = torch.cat( + [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] + ) + s, e = offset, offset + length + return buffer[s:e] + + def __getitem__(self, index): + block1, block2, next_sent_label = self.sent_pairs[index] + block1 = self._fetch_block(*block1) + block2 = self._fetch_block(*block2) + return block1, block2, next_sent_label + + def __len__(self): + return len(self.sizes) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + prefetch_idx = set() + for index in indices: + for block1, block2, _ in [self.sent_pairs[index]]: + for ds_idx in range(block1[0], block1[2] + 1): + prefetch_idx.add(ds_idx) + for ds_idx in range(block2[0], block2[2] + 1): + prefetch_idx.add(ds_idx) + self.dataset.prefetch(prefetch_idx) diff --git a/fairseq/fairseq/data/legacy/masked_lm_dataset.py b/fairseq/fairseq/data/legacy/masked_lm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dd8ea2c60aff306ab3a756223a298a28d41a4991 --- /dev/null +++ b/fairseq/fairseq/data/legacy/masked_lm_dataset.py @@ -0,0 +1,303 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Tuple + +import numpy as np +import torch +from fairseq.data import Dictionary, FairseqDataset, data_utils +from fairseq.data.concat_dataset import ConcatDataset +from fairseq.data.legacy.block_pair_dataset import BlockPairDataset +from fairseq.data.token_block_dataset import TokenBlockDataset + + +class MaskedLMDataset(FairseqDataset): + """ + A wrapper Dataset for masked language modelling. The dataset + wraps around TokenBlockDataset or BlockedPairDataset and creates a batch + where the input blocks are masked according to the specified masking + probability. Additionally the batch can also contain sentence level targets + if this is specified. + + Args: + dataset: Dataset which generates blocks of data. Only BlockPairDataset + and TokenBlockDataset are supported. + sizes: Sentence lengths + vocab: Dictionary with the vocabulary and special tokens. + pad_idx: Id of padding token in dictionary + mask_idx: Id of mask token in dictionary + classif_token_idx: Id of classification token in dictionary. This is the + token associated with the sentence embedding (Eg: CLS for BERT) + sep_token_idx: Id of separator token in dictionary + (Eg: SEP in BERT) + seed: Seed for random number generator for reproducibility. + shuffle: Shuffle the elements before batching. + has_pairs: Specifies whether the underlying dataset + generates a pair of blocks along with a sentence_target or not. + Setting it to True assumes that the underlying dataset generates a + label for the pair of sentences which is surfaced as + sentence_target. The default value assumes a single block with no + sentence target. + segment_id: An optional segment id for filling in the segment labels + when we are in the single block setting (Eg: XLM). Default is 0. + masking_ratio: specifies what percentage of the blocks should be masked. + masking_prob: specifies the probability of a given token being + replaced with the "MASK" token. + random_token_prob: specifies the probability of a given token being + replaced by a random token from the vocabulary. + """ + + def __init__( + self, + dataset: FairseqDataset, + sizes: np.ndarray, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + classif_token_idx: int, + sep_token_idx: int, + seed: int = 1, + shuffle: bool = True, + has_pairs: bool = True, + segment_id: int = 0, + masking_ratio: float = 0.15, + masking_prob: float = 0.8, + random_token_prob: float = 0.1, + ): + # Make sure the input datasets are the ones supported + assert ( + isinstance(dataset, TokenBlockDataset) + or isinstance(dataset, BlockPairDataset) + or isinstance(dataset, ConcatDataset) + ), ( + "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " + "ConcatDataset" + ) + + self.dataset = dataset + self.sizes = np.array(sizes) + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.classif_token_idx = classif_token_idx + self.sep_token_idx = sep_token_idx + self.shuffle = shuffle + self.seed = seed + self.has_pairs = has_pairs + self.segment_id = segment_id + self.masking_ratio = masking_ratio + self.masking_prob = masking_prob + self.random_token_prob = random_token_prob + + # If we have only one block then sizes needs to be updated to include + # the classification token + if not has_pairs: + self.sizes = self.sizes + 1 + + def __getitem__(self, index: int): + # if has_pairs, then expect 2 blocks and a sentence target + if self.has_pairs: + (block_one, block_two, sentence_target) = self.dataset[index] + else: + block_one = self.dataset[index] + + return { + "id": index, + "block_one": block_one, + "block_two": block_two if self.has_pairs else None, + "sentence_target": sentence_target if self.has_pairs else None, + } + + def __len__(self): + return len(self.dataset) + + def _mask_block( + self, + sentence: np.ndarray, + mask_idx: int, + pad_idx: int, + dictionary_token_range: Tuple, + ): + """ + Mask tokens for Masked Language Model training + Samples mask_ratio tokens that will be predicted by LM. + + Note:This function may not be efficient enough since we had multiple + conversions between np and torch, we can replace them with torch + operators later. + + Args: + sentence: 1d tensor to be masked + mask_idx: index to use for masking the sentence + pad_idx: index to use for masking the target for tokens we aren't + predicting + dictionary_token_range: range of indices in dictionary which can + be used for random word replacement + (e.g. without special characters) + Return: + masked_sent: masked sentence + target: target with words which we are not predicting replaced + by pad_idx + """ + masked_sent = np.copy(sentence) + sent_length = len(sentence) + mask_num = math.ceil(sent_length * self.masking_ratio) + mask = np.random.choice(sent_length, mask_num, replace=False) + target = np.copy(sentence) + + for i in range(sent_length): + if i in mask: + rand = np.random.random() + + # replace with mask if probability is less than masking_prob + # (Eg: 0.8) + if rand < self.masking_prob: + masked_sent[i] = mask_idx + + # replace with random token if probability is less than + # masking_prob + random_token_prob (Eg: 0.9) + elif rand < (self.masking_prob + self.random_token_prob): + # sample random token from dictionary + masked_sent[i] = np.random.randint( + dictionary_token_range[0], dictionary_token_range[1] + ) + else: + target[i] = pad_idx + + return masked_sent, target + + def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int): + """ + Does the heavy lifting for creating a batch from the input list of + examples. The logic is as follows: + 1. Mask the input blocks. In case has_pair is True then we have 2 + blocks to mask. + 2. Prepend the first masked block tensor with the special token + used as sentence embedding. Eg: CLS in BERT. This happens + irrespective of the value of has_pair. + 3. If has_pair is True, then append the first masked block with the + special separator token (eg: SEP for BERT) and compute segment + label accordingly. In this case, also append the second masked + block with this special separator token and compute its segment + label. + 4. For the targets tensor, prepend and append with padding index + accordingly. + 5. Concatenate all tensors. + """ + if len(samples) == 0: + return {} + # To ensure determinism, we reset the state of the PRNG after every + # batch based on the seed and the first id of the batch. This ensures + # that across epochs we get the same mask for the same example. This + # is needed for reproducibility and is how BERT does masking + # TODO: Can we add deteminism without this constraint? + with data_utils.numpy_seed(self.seed + samples[0]["id"]): + for s in samples: + + # token range is needed for replacing with random token during + # masking + token_range = (self.vocab.nspecial, len(self.vocab)) + + # mask according to specified probabilities. + masked_blk_one, masked_tgt_one = self._mask_block( + s["block_one"], + self.mask_idx, + self.pad_idx, + token_range, + ) + + tokens = np.concatenate([[self.classif_token_idx], masked_blk_one]) + targets = np.concatenate([[self.pad_idx], masked_tgt_one]) + segments = np.ones(len(tokens)) * self.segment_id + + # if has_pairs is True then we need to add the SEP token to both + # the blocks after masking and re-compute segments based on the new + # lengths. + if self.has_pairs: + tokens_one = np.concatenate([tokens, [self.sep_token_idx]]) + targets_one = np.concatenate([targets, [self.pad_idx]]) + + masked_blk_two, masked_tgt_two = self._mask_block( + s["block_two"], self.mask_idx, self.pad_idx, token_range + ) + tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]]) + targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]]) + + # block + 1 sep + 1 special (CLS) + segments_one = np.zeros(len(tokens_one)) + # block + 1 sep + segments_two = np.ones(len(tokens_two)) + + tokens = np.concatenate([tokens_one, tokens_two]) + targets = np.concatenate([targets_one, targets_two]) + segments = np.concatenate([segments_one, segments_two]) + + s["source"] = torch.LongTensor(tokens) + s["segment_labels"] = torch.LongTensor(segments) + s["lm_target"] = torch.LongTensor(targets) + + def merge(key): + return data_utils.collate_tokens( + [s[key] for s in samples], pad_idx, eos_idx, left_pad=False + ) + + return { + "id": torch.LongTensor([s["id"] for s in samples]), + "ntokens": sum(len(s["source"]) for s in samples), + "net_input": { + "src_tokens": merge("source"), + "segment_labels": merge("segment_labels"), + }, + "lm_target": merge("lm_target"), + "sentence_target": torch.LongTensor([s["sentence_target"] for s in samples]) + if self.has_pairs + else None, + "nsentences": len(samples), + } + + def collater(self, samples: List[Dict]): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch of data + """ + return self._collate(samples, self.vocab.pad(), self.vocab.eos()) + + def num_tokens(self, index: int): + """ + Return the number of tokens in a sample. This value is used to + enforce max-tokens during batching. + """ + return self.sizes[index] + + def size(self, index: int): + """ + Return an example's size as a float or tuple. This value is used when + filtering a dataset with max-positions. + """ + return self.sizes[index] + + def ordered_indices(self): + """ + Return an ordered list of indices. Batches will be constructed based + on this order. + """ + if self.shuffle: + return np.random.permutation(len(self)) + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch(indices) diff --git a/fairseq/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq/fairseq/data/legacy/masked_lm_dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..dee88f7a3ed72ea465ea4e8ffe7b1c01ff6f57f1 --- /dev/null +++ b/fairseq/fairseq/data/legacy/masked_lm_dictionary.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data import Dictionary + + +class MaskedLMDictionary(Dictionary): + """ + Dictionary for Masked Language Modelling tasks. This extends Dictionary by + adding the mask symbol. + """ + + def __init__( + self, + pad="", + eos="", + unk="", + mask="", + ): + super().__init__(pad=pad, eos=eos, unk=unk) + self.mask_word = mask + self.mask_index = self.add_symbol(mask) + self.nspecial = len(self.symbols) + + def mask(self): + """Helper to get index of mask symbol""" + return self.mask_index + + +class BertDictionary(MaskedLMDictionary): + """ + Dictionary for BERT task. This extends MaskedLMDictionary by adding support + for cls and sep symbols. + """ + + def __init__( + self, + pad="", + eos="", + unk="", + mask="", + cls="", + sep="", + ): + super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) + self.cls_word = cls + self.sep_word = sep + self.cls_index = self.add_symbol(cls) + self.sep_index = self.add_symbol(sep) + self.nspecial = len(self.symbols) + + def cls(self): + """Helper to get index of cls symbol""" + return self.cls_index + + def sep(self): + """Helper to get index of sep symbol""" + return self.sep_index diff --git a/fairseq/fairseq/data/multilingual/__init__.py b/fairseq/fairseq/data/multilingual/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a --- /dev/null +++ b/fairseq/fairseq/data/multilingual/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1be7366c6c45f31707b7a0e0adf8b450b524b1df Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd2c94daf519ebb0ecf9586de72a97d631891526 Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..352f3adffa8dda9690c0b378197cbbf634c06ede Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe84145d7d743929c38a5880c81848ae189f410c Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d79df9622bd2fb0b9eb10eb908fbc591567d9d0 Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad6f2b8046d9ce468ba62361aba5caa6b335e44b Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/fairseq/data/multilingual/multilingual_data_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..876dfcec36e4cf9236c21e440e9657a68036a278 --- /dev/null +++ b/fairseq/fairseq/data/multilingual/multilingual_data_manager.py @@ -0,0 +1,1156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import json +import logging +import math +import os +from collections import OrderedDict, defaultdict +from argparse import ArgumentError + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + ConcatDataset, + Dictionary, + LanguagePairDataset, + PrependTokenDataset, + SampledMultiDataset, + SampledMultiEpochDataset, + StripTokenDataset, + TransformEosLangPairDataset, + TruncateDataset, + data_utils, + indexed_dataset, +) +from fairseq.data.multilingual.multilingual_utils import ( + EncoderLangtok, + LangTokSpec, + LangTokStyle, + augment_dictionary, + get_lang_tok, +) +from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat +from fairseq.file_io import PathManager +from fairseq.utils import FileContentsAction, csv_str_list, eval_str_dict + + +logger = logging.getLogger(__name__) + +SRC_DICT_NAME = "src" +TGT_DICT_NAME = "tgt" + + +def _lang_id(dic: Dictionary, lang: str): + """Return language ID index.""" + idx = dic.index(lang) + assert idx != dic.unk_index, "cannot find language ID for lang {}".format(lang) + return idx + + +def load_sampling_weights(from_file): + with open(from_file) as f: + weights = json.load(f) + return weights + + +class MultilingualDatasetManager(object): + def __init__(self, args, lang_pairs, langs, dicts, sampling_method): + super().__init__() + self.args = args + self.seed = args.seed + self.lang_pairs = lang_pairs + self.extra_lang_pairs = ( + list({p for _, v in args.extra_lang_pairs.items() for p in v.split(",")}) + if args.extra_lang_pairs + else [] + ) + self.src_langs = { + p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs + } + self.tgt_langs = { + p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs + } + self.langs = langs + self.dicts = dicts + self.lang_dict = self.create_lang_dictionary(self.langs) + self.sampling_method = sampling_method + self.sampling_scheduler = None + self._has_sharded_data = False + self._num_shards_dict = {} + self._training_data_sizes = defaultdict(lambda: {}) + + @classmethod + def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method): + return MultilingualDatasetManager( + args, lang_pairs, langs, dicts, sampling_method + ) + + @staticmethod + def add_args(parser): + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + action=FileContentsAction, + ) + parser.add_argument( + "--langs", + default=None, + type=csv_str_list, + help="a list of languages comma sperated languages which can appear in lang-pairs; " + "note that the ordering determines language token IDs", + ) + parser.add_argument( + "--lang-dict", + default=None, + type=str, + help="an external file which contains a list of " + "languages which can appear in lang-pairs; " + "note that the ordering determines language token IDs; " + "--langs and --lang-dict are two exclusive options", + ) + parser.add_argument( + "--source-dict", + default=None, + type=str, + help="path to source dictionary; if specified it will override per language dictionary loading", + ) + parser.add_argument( + "--target-dict", + default=None, + type=str, + help="path to target dictionary; if specified it will override per language dictionary loading", + ) + parser.add_argument( + "--lang-tok-style", + default=LangTokStyle.multilingual.value, + type=str, + choices=[LangTokStyle.multilingual.value, LangTokStyle.mbart.value], + help="language token styles", + ) + + parser.add_argument( + "--load-alignments", + action="store_true", + help="load the binarized alignments", + ) + parser.add_argument( + "--left-pad-source", + default="True", + type=str, + metavar="BOOL", + help="pad the source on the left", + ) + parser.add_argument( + "--left-pad-target", + default="False", + type=str, + metavar="BOOL", + help="pad the target on the left", + ) + try: + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass + parser.add_argument( + "--upsample-primary", + default=1, + type=int, + help="amount to upsample primary dataset", + ) + parser.add_argument( + "--truncate-source", + action="store_true", + default=False, + help="truncate source to max-source-positions", + ) + parser.add_argument( + "--encoder-langtok", + default=None, + type=str, + choices=[EncoderLangtok.src.value, EncoderLangtok.tgt.value], + metavar="SRCTGT", + help="prepend to the beginning of source sentence the source or target " + "language token. (src/tgt)", + ) + parser.add_argument( + "--decoder-langtok", + action="store_true", + help="prepend to the beginning of target sentence the target language token", + ) + parser.add_argument( + "--lang-tok-replacing-bos-eos", action="store_true", default=False + ) + parser.add_argument( + "--enable-lang-ids", + default=False, + action="store_true", + help="whether to include language IDs in samples", + ) + parser.add_argument( + "--enable-reservsed-directions-shared-datasets", + default=False, + action="store_true", + help="whether to allow datasets be used in reversed directions", + ) + + parser.add_argument( + "--extra-data", + help='a dictionary of data name to this path, \ + e.g. {"mined", path_to_mined_data, "denoised": path_to_denoised_data}', + type=lambda uf: eval_str_dict(uf, type=str), + default=None, + ) + parser.add_argument( + "--extra-lang-pairs", + help='a dictionary of data name to the language pairs they serve, \ + e.g. {"mined": comma-separated-lang-pairs, "denoised": comma-separated-lang-pairs}', + type=lambda uf: eval_str_dict(uf, type=str), + default=None, + ) + parser.add_argument( + "--fixed-dictionary", + help="Fixed dictionary to use with model path", + default=None, + type=str, + ) + parser.add_argument( + "--langtoks-specs", + help='a list of comma separated data types that a set of language tokens to be specialized for, \ + e.g. "main,dae,mined". There will be a set of language tokens added to the vocab to \ + distinguish languages in different training data types. If not specified, default language \ + tokens per languages will be added', + default=LangTokSpec.main.value, + type=csv_str_list, + ) + parser.add_argument( + "--langtoks", + help='a dictionary of how to add language tokens, \ + e.g. {"mined": (None, "tgt"), "mono_dae": ("src.dae", "tgt"), "main": \ + ("src", "tgt")}, or {"mined": ("src.mined", "tgt")}', + default=None, + type=lambda uf: eval_str_dict(uf, type=str), + ) + parser.add_argument( + "--sampling-weights-from-file", + help='a file contain a python dictionary of how to sample data sets, \ + e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \ + "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }', + default=None, + type=str, + ) + parser.add_argument( + "--sampling-weights", + help='a dictionary of how to sample data sets, \ + e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \ + "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }', + default=None, + type=lambda uf: eval_str_dict(uf, type=str), + ) + parser.add_argument( + "--virtual-epoch-size", + default=None, + type=int, + help="virtual epoch size to speed up data loading", + ) + parser.add_argument( + "--virtual-data-size", + default=None, + type=int, + help="virtual data size of the whole joint dataset to speed" + "up data loading and have specific dynamic sampling strategy interval", + ) + + @classmethod + def load_langs(cls, args, **kwargs): + if args.lang_dict and args.langs: + raise ValueError("--langs and --lang-dict can not both be specified") + if args.lang_dict is None and args.langs is None: + logger.warning( + "External language dictionary is not provided; " + "use lang-pairs to infer the set of supported languages. " + "The language ordering is not stable which might cause " + "misalignment in pretraining and finetuning." + ) + # infer from lang_pairs as it is + langs = list( + {x for lang_pair in args.lang_pairs for x in lang_pair.split("-")} + ) + langs = sorted(langs) + logger.info(f"inferred language list: {langs}") + elif args.lang_dict: + with open( + PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8" + ) as f: + langs = [lang.strip() for lang in f.readlines() if lang.strip()] + logger.info( + f"loaded language list from {args.lang_dict} as they are ordered in file" + ) + elif args.langs: + langs = args.langs + logger.info( + f"parsed the language list as they are ordered in the option: {langs}" + ) + return langs + + def has_sharded_data(self, split): + return self._has_sharded_data and split == getattr( + self.args, "train_subset", None + ) + + def _shared_collater(self): + return not (self.args.extra_data and "mono_dae" in self.args.extra_data) and ( + not self.args.lang_tok_replacing_bos_eos + ) + + def estimate_global_pass_epoch(self, epoch): + if self.args.virtual_epoch_size is None or self.args.virtual_data_size is None: + return None + # one epoch more for remaining data in each shard + virtual_epochs_per_shard = math.ceil( + self.args.virtual_data_size / self.args.virtual_epoch_size + ) + # note that fairseq epoch / shard_epoch starts from 1 + shard_epoch = (epoch - 1) // virtual_epochs_per_shard + 1 + return shard_epoch + + @classmethod + def prepare(cls, load_dictionary, args, **kargs): + args.left_pad_source = utils.eval_bool(args.left_pad_source) + args.left_pad_target = utils.eval_bool(args.left_pad_target) + + if not hasattr(args, "shuffle_instance"): + args.shuffle_instance = False + if args.langtoks is None: + args.langtoks = {} + if "main" not in args.langtoks: + src_langtok_spec = args.encoder_langtok if args.encoder_langtok else None + tgt_langtok_spec = "tgt" if args.decoder_langtok else None + args.langtoks["main"] = (src_langtok_spec, tgt_langtok_spec) + + def check_langs(langs, pairs): + messages = [] + for src, tgt in pairs: + if src not in langs or tgt not in langs: + messages.append( + f"language pair {src}-{tgt} contains languages " + "that are not in the language dictionary" + ) + if len(messages) > 0: + raise ValueError(" ".join(messages) + f"; langs: {langs}") + + if args.lang_pairs is None: + raise ValueError( + "--lang-pairs is required. List all the language pairs in the training objective." + ) + if isinstance(args.lang_pairs, str): + args.lang_pairs = args.lang_pairs.split(",") + if args.source_lang is not None or args.target_lang is not None: + training = False + else: + training = True + language_list = cls.load_langs(args, **kargs) + check_langs( + language_list, + ( + [p.split("-") for p in args.lang_pairs] + if training + else [(args.source_lang, args.target_lang)] + ), + ) + + def load_dictionary_and_postproc(path): + d = load_dictionary(path) + augment_dictionary( + dictionary=d, + language_list=language_list, + lang_tok_style=args.lang_tok_style, + langtoks_specs=args.langtoks_specs, + extra_data=args.extra_data, + ) + return d + + dicts = cls.load_all_dictionaries( + args, language_list, load_dictionary_and_postproc, training + ) + return language_list, dicts, training + + @classmethod + def load_all_dictionaries(cls, args, language_list, load_dictionary, training): + dicts = OrderedDict() + if args.source_dict is not None: + dicts[SRC_DICT_NAME] = load_dictionary(args.source_dict) + if args.target_dict is not None: + dicts[TGT_DICT_NAME] = load_dictionary(args.target_dict) + + if training: + extra_lang_pairs = ( + list( + {p for _, v in args.extra_lang_pairs.items() for p in v.split(",")} + ) + if args.extra_lang_pairs + else [] + ) + src_langs_to_load_dicts = sorted( + {p.split("-")[0] for p in (args.lang_pairs + extra_lang_pairs)} + ) + tgt_langs_to_load_dicts = sorted( + {p.split("-")[1] for p in (args.lang_pairs + extra_lang_pairs)} + ) + else: + src_langs_to_load_dicts = [args.source_lang] + tgt_langs_to_load_dicts = [args.target_lang] + + paths = utils.split_paths(args.data) + assert len(paths) > 0 + + def load_dicts(langs_to_load_dicts): + for lang in langs_to_load_dicts: + dicts[lang] = load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(lang)) + ) + if len(dicts) > 0: + dict0 = next(iter(dicts.values())) + assert dicts[lang].pad() == dict0.pad() + assert dicts[lang].eos() == dict0.eos() + assert dicts[lang].unk() == dict0.unk() + logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) + + if args.fixed_dictionary is not None: + fixed_dict = load_dictionary(args.fixed_dictionary) + dicts = { + lang: fixed_dict + for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts + } + else: + if args.source_dict is None: + load_dicts(src_langs_to_load_dicts) + if args.target_dict is None: + load_dicts(tgt_langs_to_load_dicts) + return dicts + + def get_source_dictionary(self, lang): + if self.args.source_dict is not None: + return self.dicts[SRC_DICT_NAME] + else: + return self.dicts[lang] + + def get_target_dictionary(self, lang): + if self.args.target_dict is not None: + return self.dicts[TGT_DICT_NAME] + else: + return self.dicts[lang] + + @classmethod + def create_lang_dictionary(cls, langs): + unk = "" + # hack to remove symbols other than unk as they are not needed by lang dict + lang_dict = Dictionary(pad=unk, eos=unk, unk=unk, bos=unk) + for lang in langs: + lang_dict.add_symbol(lang) + return lang_dict + + @classmethod + def get_langtok_index(cls, lang_tok, dic): + idx = dic.index(lang_tok) + assert ( + idx != dic.unk_index + ), "cannot find language token {} in the dictionary".format(lang_tok) + return idx + + def get_encoder_langtok(self, src_lang, tgt_lang, spec=None): + if spec is None: + return None + if spec and spec.startswith("src"): + if src_lang is None: + return None + langtok = get_lang_tok( + lang=src_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) + else: + if tgt_lang is None: + return None + langtok = get_lang_tok( + lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) + return self.get_langtok_index( + langtok, + self.get_source_dictionary(src_lang) + if src_lang + else self.get_target_dictionary(tgt_lang), + ) + + def get_decoder_langtok(self, tgt_lang, spec=None): + if spec is None: + return None + langtok = get_lang_tok( + lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) + return self.get_langtok_index(langtok, self.get_target_dictionary(tgt_lang)) + + @classmethod + def load_data(cls, path, vdict, impl): + dataset = data_utils.load_indexed_dataset(path, vdict, impl) + return dataset + + @classmethod + def split_exists(cls, split, src, tgt, lang, data_path, dataset_impl): + filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) + return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + + def load_lang_dataset( + self, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + max_source_positions, + prepend_bos=False, + load_alignments=False, + truncate_source=False, + ): + + src_datasets = [] + tgt_datasets = [] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + + # infer langcode + if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) + elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) + else: + if k > 0: + break + else: + logger.error( + f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}" + ) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + src_dataset = self.load_data(prefix + src, src_dict, dataset_impl) + if truncate_source: + src_dataset = AppendTokenDataset( + TruncateDataset( + StripTokenDataset(src_dataset, src_dict.eos()), + max_source_positions - 1, + ), + src_dict.eos(), + ) + src_datasets.append(src_dataset) + tgt_datasets.append(self.load_data(prefix + tgt, tgt_dict, dataset_impl)) + + logger.info( + "{} {} {}-{} {} examples".format( + data_path, split_k, src, tgt, len(src_datasets[-1]) + ) + ) + + if not combine: + break + + assert len(src_datasets) == len(tgt_datasets) + + if len(src_datasets) == 1: + src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] + else: + sample_ratios = [1] * len(src_datasets) + sample_ratios[0] = upsample_primary + src_dataset = ConcatDataset(src_datasets, sample_ratios) + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + + if prepend_bos: + assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") + src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) + tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + + align_dataset = None + if load_alignments: + align_path = os.path.join( + data_path, "{}.align.{}-{}".format(split, src, tgt) + ) + if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): + align_dataset = data_utils.load_indexed_dataset( + align_path, None, dataset_impl + ) + + return src_dataset, tgt_dataset, align_dataset + + def load_langpair_dataset( + self, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos=False, + load_alignments=False, + truncate_source=False, + src_dataset_transform_func=lambda dataset: dataset, + tgt_dataset_transform_func=lambda dataset: dataset, + src_lang_id=None, + tgt_lang_id=None, + langpairs_sharing_datasets=None, + ): + norm_direction = "-".join(sorted([src, tgt])) + if langpairs_sharing_datasets is not None: + src_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, src), "NotInCache" + ) + tgt_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, tgt), "NotInCache" + ) + align_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, src, tgt), "NotInCache" + ) + + # a hack: any one is not in cache, we need to reload them + if ( + langpairs_sharing_datasets is None + or src_dataset == "NotInCache" + or tgt_dataset == "NotInCache" + or align_dataset == "NotInCache" + or split != getattr(self.args, "train_subset", None) + ): + # source and target datasets can be reused in reversed directions to save memory + # reversed directions of valid and test data will not share source and target datasets + src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset( + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + max_source_positions=max_source_positions, + prepend_bos=prepend_bos, + load_alignments=load_alignments, + truncate_source=truncate_source, + ) + src_dataset = src_dataset_transform_func(src_dataset) + tgt_dataset = tgt_dataset_transform_func(tgt_dataset) + if langpairs_sharing_datasets is not None: + langpairs_sharing_datasets[ + (data_path, split, norm_direction, src) + ] = src_dataset + langpairs_sharing_datasets[ + (data_path, split, norm_direction, tgt) + ] = tgt_dataset + langpairs_sharing_datasets[ + (data_path, split, norm_direction, src, tgt) + ] = align_dataset + if align_dataset is None: + # no align data so flag the reverse direction as well in sharing + langpairs_sharing_datasets[ + (data_path, split, norm_direction, tgt, src) + ] = align_dataset + else: + logger.info( + f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: " + f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}" + ) + + return LanguagePairDataset( + src_dataset, + src_dataset.sizes, + src_dict, + tgt_dataset, + tgt_dataset.sizes if tgt_dataset is not None else None, + tgt_dict, + left_pad_source=left_pad_source, + left_pad_target=left_pad_target, + align_dataset=align_dataset, + src_lang_id=src_lang_id, + tgt_lang_id=tgt_lang_id, + ) + + def src_dataset_tranform_func(self, src_lang, tgt_lang, dataset, spec=None): + if self.args.lang_tok_replacing_bos_eos: + # it is handled by self.alter_dataset_langtok + # TODO: Unifiy with alter_dataset_langtok + return dataset + if spec is None: + return dataset + tok = self.get_encoder_langtok(src_lang, tgt_lang, spec) + if tok: + return PrependTokenDataset(dataset, tok) + return dataset + + def tgt_dataset_tranform_func(self, source_lang, target_lang, dataset, spec=None): + if dataset is None: + # note that target dataset can be None during inference time + return None + if self.args.lang_tok_replacing_bos_eos: + # TODO: Unifiy with alter_dataset_langtok + # It is handled by self.alter_dataset_langtok. + # The complication in self.alter_dataset_langtok + # makes a unified framework difficult. + return dataset + # if not self.args.decoder_langtok: + if not spec: + return dataset + tok = self.get_decoder_langtok(target_lang, spec) + if tok: + return PrependTokenDataset(dataset, tok) + return dataset + + def alter_dataset_langtok( + self, + lang_pair_dataset, + src_eos=None, + src_lang=None, + tgt_eos=None, + tgt_lang=None, + src_langtok_spec=None, + tgt_langtok_spec=None, + ): + if src_langtok_spec is None and tgt_langtok_spec is None: + return lang_pair_dataset + + new_src_eos = None + if ( + src_langtok_spec is not None + and src_eos is not None + and (src_lang is not None or tgt_lang is not None) + ): + new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang, src_langtok_spec) + else: + src_eos = None + + new_tgt_bos = None + if tgt_langtok_spec and tgt_eos is not None and tgt_lang is not None: + new_tgt_bos = self.get_decoder_langtok(tgt_lang, tgt_langtok_spec) + else: + tgt_eos = None + + return TransformEosLangPairDataset( + lang_pair_dataset, + src_eos=src_eos, + new_src_eos=new_src_eos, + tgt_bos=tgt_eos, + new_tgt_bos=new_tgt_bos, + ) + + def load_a_dataset( + self, + split, + data_path, + src, + src_dict, + tgt, + tgt_dict, + combine, + prepend_bos=False, + langpairs_sharing_datasets=None, + data_category=None, + **extra_kwargs, + ): + dataset_impl = self.args.dataset_impl + upsample_primary = self.args.upsample_primary + left_pad_source = self.args.left_pad_source + left_pad_target = self.args.left_pad_target + max_source_positions = self.args.max_source_positions + max_target_positions = self.args.max_target_positions + load_alignments = self.args.load_alignments + truncate_source = self.args.truncate_source + src_dataset_transform_func = self.src_dataset_tranform_func + tgt_dataset_transform_func = self.tgt_dataset_tranform_func + enable_lang_ids = self.args.enable_lang_ids + lang_dictionary = self.lang_dict + src_langtok_spec, tgt_langtok_spec = extra_kwargs["langtok_spec"] + + src_langtok = self.get_encoder_langtok(src, tgt, src_langtok_spec) + tgt_langtok = self.get_decoder_langtok(tgt, tgt_langtok_spec) + logger.info( + f"{data_category}:{src}-{tgt} src_langtok: {src_langtok}; tgt_langtok: {tgt_langtok}" + ) + + langpair_ds = self.load_langpair_dataset( + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos, + load_alignments, + truncate_source, + src_dataset_transform_func=lambda dataset: src_dataset_transform_func( + src, tgt, dataset, src_langtok_spec + ), + tgt_dataset_transform_func=lambda dataset: tgt_dataset_transform_func( + src, tgt, dataset, tgt_langtok_spec + ), + src_lang_id=_lang_id(lang_dictionary, src) + if enable_lang_ids and lang_dictionary is not None + else None, + tgt_lang_id=_lang_id(lang_dictionary, tgt) + if enable_lang_ids and lang_dictionary is not None + else None, + langpairs_sharing_datasets=langpairs_sharing_datasets, + ) + # TODO: handle modified lang toks for mined data and dae data + if self.args.lang_tok_replacing_bos_eos: + ds = self.alter_dataset_langtok( + langpair_ds, + src_eos=self.get_source_dictionary(src).eos() + if src + else self.get_target_dictionary(tgt).eos(), + src_lang=src, + tgt_eos=self.get_target_dictionary(tgt).eos(), + tgt_lang=tgt, + src_langtok_spec=src_langtok_spec, + tgt_langtok_spec=tgt_langtok_spec, + ) + else: + ds = langpair_ds + return ds + + def load_split_langpair_datasets(self, split, data_param_list): + datasets = [] + langpairs_sharing_datasets = ( + {} if self.args.enable_reservsed_directions_shared_datasets else None + ) + for param in data_param_list: + ds = self.load_a_dataset( + split=split, + langpairs_sharing_datasets=langpairs_sharing_datasets, + **param, + ) + datasets.append(ds) + return datasets + + def get_data_paths_and_lang_pairs(self, split): + datapaths = {"main": self.args.data} + lang_pairs = {"main": self.lang_pairs} + if split == getattr(self.args, "train_subset", None): + # only training data can have extra data and extra language pairs + if self.args.extra_data: + extra_datapaths = self.args.extra_data + datapaths.update(extra_datapaths) + if self.args.extra_lang_pairs: + extra_lang_pairs = { + k: v.split(",") for k, v in self.args.extra_lang_pairs.items() + } + lang_pairs.update(extra_lang_pairs) + return datapaths, lang_pairs + + @classmethod + def get_dataset_key(cls, data_category, src, tgt): + return f"{data_category}:{src}-{tgt}" + + @classmethod + def _get_shard_num_dict(cls, split, paths): + shards = defaultdict(int) + for path in paths: + files = PathManager.ls(path) + directions = set() + for f in files: + if f.startswith(split) and f.endswith(".idx"): + # idx files of the form "{split}.{src}-{tgt}.{lang}.idx" + direction = f.split(".")[-3] + directions.add(direction) + for direction in directions: + shards[direction] += 1 + return shards + + def get_split_num_data_shards(self, split): + if split in self._num_shards_dict: + return self._num_shards_dict[split] + num_shards_dict = {} + data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) + + for data_category, paths in data_paths.items(): + if data_category not in lang_pairs: + continue + paths = utils.split_paths(paths) + shards_dict = self._get_shard_num_dict(split, paths) + lang_dirs = [ + lang_pair.split("-") for lang_pair in lang_pairs[data_category] + ] + lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] + for src, tgt in lang_dirs: + key = self.get_dataset_key(data_category, src, tgt) + if "mono_" in data_category: + # monolingual data requires tgt only + assert src is None or src == tgt, ( + f"error: src={src}, " + f"tgt={tgt} for data_category={data_category}" + ) + num_shards_dict[key] = shards_dict[tgt] + else: + if f"{src}-{tgt}" in shards_dict: + num_shards_dict[key] = shards_dict[f"{src}-{tgt}"] + elif f"{tgt}-{src}" in shards_dict: + # follow the fairseq tradition to use reversed direction data if it is not available + num_shards_dict[key] = shards_dict[f"{tgt}-{src}"] + self._num_shards_dict[split] = num_shards_dict + logger.info(f"[{split}] num of shards: {num_shards_dict}") + return num_shards_dict + + @classmethod + def get_shard_id(cls, num_shards, epoch, shard_epoch=None): + shard = epoch if shard_epoch is None else shard_epoch + shard = (shard - 1) % num_shards + return shard + + def get_split_data_path(self, paths, epoch, shard_epoch, num_shards): + path = paths[self.get_shard_id(num_shards, epoch, shard_epoch)] + return path + + def get_split_data_param_list(self, split, epoch, shard_epoch=None): + # TODO: to extend with extra datasets and keys and loop over different shard data paths + param_list = [] + data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) + logger.info(f"langtoks settings: {self.args.langtoks}") + split_num_shards_dict = self.get_split_num_data_shards(split) + for data_category, paths in data_paths.items(): + if data_category not in lang_pairs: + continue + paths = utils.split_paths(paths) + assert len(paths) > 0 + if len(paths) > 1: + self._has_sharded_data = True + if split != getattr(self.args, "train_subset", None): + # if not training data set, use the first shard for valid and test + paths = paths[:1] + + if data_category in self.args.langtoks: + lang_tok_spec = self.args.langtoks[data_category] + else: + # default to None + lang_tok_spec = (None, None) + + # infer langcode + lang_dirs = [ + lang_pair.split("-") for lang_pair in lang_pairs[data_category] + ] + lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] + for src, tgt in lang_dirs: + assert src is not None or data_category == "mono_dae", ( + f"error: src={src}, " f"tgt={tgt} for data_category={data_category}" + ) + # logger.info(f"preparing param for {data_category}: {src} - {tgt}") + key = self.get_dataset_key(data_category, src, tgt) + data_path = self.get_split_data_path( + paths, epoch, shard_epoch, split_num_shards_dict[key] + ) + param_list.append( + { + "key": key, + "data_path": data_path, + "split": split, + "src": src, + "src_dict": self.get_source_dictionary(src) + if src and data_category != "mono_dae" + else None, + "tgt": tgt, + "tgt_dict": self.get_target_dictionary(tgt), + "data_category": data_category, + "langtok_spec": lang_tok_spec, + } + ) + return param_list + + def get_train_dataset_sizes( + self, data_param_list, datasets, epoch, shard_epoch=None + ): + num_shards = [ + self.get_split_num_data_shards(param["split"])[param["key"]] + for param in data_param_list + ] + data_sizes = [] + for (key, d), num_shard in zip(datasets, num_shards): + my_data_sizes = self._training_data_sizes[key] + shard_ind = self.get_shard_id(num_shard, epoch, shard_epoch) + if shard_ind not in my_data_sizes: + my_data_sizes[shard_ind] = len(d) + known_size = max(my_data_sizes.values()) + data_sizes.append( + # If we don't know the data size of the shard yet, + # use the the max known data size to approximate. + # Note that we preprocess shards by a designated shard size + # and put any remaining data at the end into the last shard so + # the max shard size approximation is almost correct before loading + # the last shard; after loading the last shard, it will have the + # exact data sizes of the whole data size. + (key, sum(my_data_sizes.get(i, known_size) for i in range(num_shard))) + ) + logger.info( + f"estimated total data sizes of all shards used in sampling ratios: {data_sizes}. " + "Note that if the data a shard has not been loaded yet, use the max known data size to approximate" + ) + return [s for _, s in data_sizes] + + def get_train_sampling_ratios( + self, data_param_list, datasets, epoch=1, shard_epoch=None + ): + data_sizes = self.get_train_dataset_sizes( + data_param_list, datasets, epoch, shard_epoch + ) + sampling_func = self.sampling_method.sampling_method_selector() + sample_ratios = sampling_func(data_sizes) if sampling_func is not None else None + return sample_ratios + + def get_sampling_ratios(self, data_param_list, datasets, epoch, shard_epoch=None): + if self.args.sampling_weights_from_file: + weights = load_sampling_weights(self.args.sampling_weights_from_file) + sample_ratios = [weights[k] for k, _ in datasets] + logger.info( + "| ignoring --sampling-weights when loadding sampling weights " + f"from file {self.args.sampling_weights_from_file}" + ) + elif self.args.sampling_weights: + sample_ratios = [self.args.sampling_weights[k] for k, _ in datasets] + else: + sample_ratios = self.get_train_sampling_ratios( + data_param_list, datasets, epoch, shard_epoch + ) + + if sample_ratios is not None: + logger.info( + "| Upsample ratios: {}".format( + list(zip(map(lambda x: x["key"], data_param_list), sample_ratios)) + ) + ) + assert len(sample_ratios) == len(datasets) + return sample_ratios + + def load_split_datasets( + self, split, training, epoch=1, combine=False, shard_epoch=None, **kwargs + ): + data_param_list = self.get_split_data_param_list( + split, epoch, shard_epoch=shard_epoch + ) + langpairs_sharing_datasets = ( + {} if self.args.enable_reservsed_directions_shared_datasets else None + ) + datasets = [ + ( + param["key"], + self.load_a_dataset( + combine=combine, + langpairs_sharing_datasets=langpairs_sharing_datasets, + **param, + ), + ) + for param in data_param_list + ] + return datasets, data_param_list + + def load_into_concat_dataset(self, split, datasets, data_param_list): + if self.args.lang_tok_replacing_bos_eos: + # TODO: to investigate why TransformEosLangPairDataset doesn't work with ConcatDataset + return SampledMultiDataset( + OrderedDict(datasets), + sampling_ratios=None, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=None, + split=split, + ) + return ConcatDataset([d for _, d in datasets]) + + def load_sampled_multi_epoch_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + datasets, data_param_list = self.load_split_datasets( + split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs + ) + if training and split == getattr(self.args, "train_subset", None): + sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) + return SampledMultiEpochDataset( + OrderedDict(datasets), + epoch=epoch, + shard_epoch=shard_epoch, + # valid and test datasets will be degenerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + virtual_epoch_size=self.args.virtual_epoch_size, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), + ) + else: + return self.load_into_concat_dataset(split, datasets, data_param_list) + + def load_sampled_multi_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + datasets, data_param_list = self.load_split_datasets( + split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs + ) + if training and split == getattr(self.args, "train_subset", None): + sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) + return SampledMultiDataset( + OrderedDict(datasets), + epoch=epoch, + # valid and test datasets will be degerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), + ) + else: + return self.load_into_concat_dataset(split, datasets, data_param_list) + + def load_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + if self.args.virtual_epoch_size is None: + return self.load_sampled_multi_dataset( + split, training, epoch, combine, shard_epoch, **kwargs + ) + else: + return self.load_sampled_multi_epoch_dataset( + split, training, epoch, combine, shard_epoch, **kwargs + ) diff --git a/fairseq/fairseq/data/multilingual/multilingual_utils.py b/fairseq/fairseq/data/multilingual/multilingual_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e0f9828cabfdbe375d05d9152b58bdbd6de7dc --- /dev/null +++ b/fairseq/fairseq/data/multilingual/multilingual_utils.py @@ -0,0 +1,63 @@ +from enum import Enum +from typing import Dict, List, Optional, Sequence + +import torch +from fairseq.data import Dictionary + + +class EncoderLangtok(Enum): + """ + Prepend to the beginning of source sentence either the + source or target language token. (src/tgt). + """ + + src = "src" + tgt = "tgt" + + +class LangTokSpec(Enum): + main = "main" + mono_dae = "mono_dae" + + +class LangTokStyle(Enum): + multilingual = "multilingual" + mbart = "mbart" + + +@torch.jit.export +def get_lang_tok( + lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value +) -> str: + # TOKEN_STYLES can't be defined outside this fn since it needs to be + # TorchScriptable. + TOKEN_STYLES: Dict[str, str] = { + LangTokStyle.mbart.value: "[{}]", + LangTokStyle.multilingual.value: "__{}__", + } + + if spec.endswith("dae"): + lang = f"{lang}_dae" + elif spec.endswith("mined"): + lang = f"{lang}_mined" + style = TOKEN_STYLES[lang_tok_style] + return style.format(lang) + + +def augment_dictionary( + dictionary: Dictionary, + language_list: List[str], + lang_tok_style: str, + langtoks_specs: Sequence[str] = (LangTokSpec.main.value,), + extra_data: Optional[Dict[str, str]] = None, +) -> None: + for spec in langtoks_specs: + for language in language_list: + dictionary.add_symbol( + get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec) + ) + + if lang_tok_style == LangTokStyle.mbart.value or ( + extra_data is not None and LangTokSpec.mono_dae.value in extra_data + ): + dictionary.add_symbol("") diff --git a/fairseq/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/fairseq/data/multilingual/sampled_multi_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ece9a9721e453112553d7f41755133b1c937e14e --- /dev/null +++ b/fairseq/fairseq/data/multilingual/sampled_multi_dataset.py @@ -0,0 +1,468 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import hashlib +import logging +import time +from bisect import bisect_right +from collections import OrderedDict, defaultdict +from enum import Enum +from typing import List + +import numpy as np +import torch + +from fairseq.data import FairseqDataset, data_utils +from fairseq.distributed import utils as distributed_utils + + +def get_time_gap(s, e): + return ( + datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) + ).__str__() + + +logger = logging.getLogger(__name__) + + +def default_virtual_size_func(datasets, ratios, max_scale_up=1.5): + sizes = [len(d) for d in datasets] + if ratios is None: + return sum(sizes) + largest_idx = np.argmax(sizes) + largest_r = ratios[largest_idx] + largest_s = sizes[largest_idx] + # set virtual sizes relative to the largest dataset + virtual_sizes = [(r / largest_r) * largest_s for r in ratios] + vsize = sum(virtual_sizes) + max_size = sum(sizes) * max_scale_up + return int(vsize if vsize < max_size else max_size) + + +class CollateFormat(Enum): + single = 1 + ordered_dict = 2 + + +class SampledMultiDataset(FairseqDataset): + """Samples from multiple sub-datasets according to given sampling ratios. + Args: + datasets ( + List[~torch.utils.data.Dataset] + or OrderedDict[str, ~torch.utils.data.Dataset] + ): datasets + sampling_ratios (List[float]): list of probability of each dataset to be sampled + (default: None, which corresponds to concatenating all dataset together). + seed (int): RNG seed to use (default: 2). + epoch (int): starting epoch number (default: 1). + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. + collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or + CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures + the collater to output batches of data mixed from all sub-datasets, + and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys + of sub-datasets. + Note that not all sub-datasets will present in a single batch in both formats. + virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). + split (str): the split of the data, e.g. 'train', 'valid' or 'test'. + shared_collater (bool): whether or not to all sub-datasets have the same collater. + shuffle (bool): whether or not to shuffle data (default: True). + """ + + def __init__( + self, + datasets, + sampling_ratios=None, + seed=2, + epoch=1, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=default_virtual_size_func, + split="", + shared_collater=False, + shuffle=True, + ): + super().__init__() + self.shared_collater = shared_collater + self.shuffle = shuffle + + if isinstance(datasets, OrderedDict): + self.keys = list(datasets.keys()) + datasets = list(datasets.values()) + elif isinstance(datasets, List): + self.keys = list(range(len(datasets))) + else: + raise AssertionError() + self.datasets = datasets + self.split = split + + self.eval_key = eval_key + if self.eval_key is not None: + self.collate_format = CollateFormat.single + else: + self.collate_format = collate_format + + self.seed = seed + self._cur_epoch = None + + self.cumulated_sizes = None + # self.datasets[k][self._cur_indices[i]] is the data item i in this sampled dataset + # namely, data item i is sampled from the kth sub-dataset self.datasets[k] + # where self.cumulated_sizes[k-1] <= i < self.cumulated_sizes[k] + self._cur_indices = None + + self._sizes = None + self.virtual_size_per_dataset = None + # caching properties + self._reset_cached_properties() + self.setup_sampling(sampling_ratios, virtual_size) + self.set_epoch(epoch) + + def _clean_if_not_none(self, var_list): + for v in var_list: + if v is not None: + del v + + def _reset_cached_properties(self): + self._clean_if_not_none([self._sizes, self._cur_indices]) + self._sizes = None + self._cur_indices = None + + def setup_sampling(self, sample_ratios, virtual_size): + sizes = [len(d) for d in self.datasets] + if sample_ratios is None: + # default back to concating datasets + self.sample_ratios = None + self.virtual_size = sum(sizes) + else: + if not isinstance(sample_ratios, np.ndarray): + sample_ratios = np.array(sample_ratios) + self.sample_ratios = sample_ratios + virtual_size = ( + default_virtual_size_func if virtual_size is None else virtual_size + ) + self.virtual_size = ( + virtual_size(self.datasets, self.sample_ratios) + if callable(virtual_size) + else virtual_size + ) + + def adjust_sampling(self, epoch, sampling_ratios, virtual_size): + if sampling_ratios is not None: + sampling_ratios = self._sync_sample_ratios(sampling_ratios) + self.setup_sampling(sampling_ratios, virtual_size) + + def _sync_sample_ratios(self, ratios): + # in case the ratios are not precisely the same across processes + # also to ensure every procresses update the ratios in the same pace + ratios = torch.DoubleTensor(ratios) + if torch.distributed.is_initialized(): + if torch.cuda.is_available(): + distributed_utils.all_reduce( + ratios.cuda(), group=distributed_utils.get_data_parallel_group() + ) + else: + distributed_utils.all_reduce( + ratios, group=distributed_utils.get_data_parallel_group() + ) + ret = ratios.cpu() + ret = ret.numpy() + return ret + + def random_choice_in_dataset(self, rng, dataset, choice_size): + if hasattr(dataset, "random_choice_in_dataset"): + return dataset.random_choice_in_dataset(rng, choice_size) + dataset_size = len(dataset) + return rng.choice( + dataset_size, choice_size, replace=(choice_size > dataset_size) + ) + + def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size): + def get_counts(sample_ratios): + counts = np.array([virtual_size * r for r in sample_ratios], dtype=np.int64) + diff = virtual_size - counts.sum() + assert diff >= 0 + # due to round-offs, the size might not match the desired sizes + if diff > 0: + dataset_indices = rng.choice( + len(sample_ratios), size=diff, p=sample_ratios + ) + for i in dataset_indices: + counts[i] += 1 + return counts + + def get_in_dataset_indices(datasets, sizes, sample_ratios): + counts = get_counts(sample_ratios) + # uniformally sample desired counts for each dataset + # if the desired counts are large, sample with replacement: + indices = [ + self.random_choice_in_dataset(rng, d, c) + for c, d in zip(counts, datasets) + ] + return indices + + sizes = [len(d) for d in datasets] + if sample_ratios is None: + # default back to concating datasets + in_dataset_indices = [list(range(s)) for s in sizes] + virtual_sizes_per_dataset = sizes + else: + ratios = sample_ratios / sample_ratios.sum() + in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios) + virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices] + virtual_sizes_per_dataset = np.array(virtual_sizes_per_dataset, np.int64) + cumulative_sizes = np.cumsum(virtual_sizes_per_dataset) + assert sum(virtual_sizes_per_dataset) == virtual_size + assert cumulative_sizes[-1] == virtual_size + if virtual_size < sum(sizes): + logger.warning( + f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})." + " If virtual size << real data size, there could be data coverage issue." + ) + in_dataset_indices = np.hstack(in_dataset_indices) + return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset + + def _get_dataset_and_index(self, index): + i = bisect_right(self.cumulated_sizes, index) + return i, self._cur_indices[index] + + def __getitem__(self, index): + # self.__getitem__(index) returns self.datasets[k][self._cur_indices[index]] + # where k satisfies self.cumulated_sizes[k - 1] <= k < self.cumulated_sizes[k] + ds_idx, ds_sample_idx = self._get_dataset_and_index(index) + ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx]) + return ret + + def num_tokens(self, index): + return self.sizes[index].max() + + def num_tokens_vec(self, indices): + sizes_vec = self.sizes[np.array(indices)] + # max across all dimensions but first one + return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape)))) + + def size(self, index): + return self.sizes[index] + + def __len__(self): + return self.virtual_size + + def collater(self, samples, **extra_args): + """Merge a list of samples to form a mini-batch.""" + if len(samples) == 0: + return None + if self.collate_format == "ordered_dict": + collect_samples = [[] for _ in range(len(self.datasets))] + for (i, sample) in samples: + collect_samples[i].append(sample) + batch = OrderedDict( + [ + (self.keys[i], dataset.collater(collect_samples[i])) + for i, (key, dataset) in enumerate(zip(self.keys, self.datasets)) + if len(collect_samples[i]) > 0 + ] + ) + elif self.shared_collater: + batch = self.datasets[0].collater([s for _, s in samples]) + else: + samples_dict = defaultdict(list) + pad_to_length = ( + defaultdict(int) + if "pad_to_length" not in extra_args + else extra_args["pad_to_length"] + ) + for ds_idx, s in samples: + pad_to_length["source"] = max( + pad_to_length["source"], s["source"].size(0) + ) + if s["target"] is not None: + pad_to_length["target"] = max( + pad_to_length["target"], s["target"].size(0) + ) + samples_dict[ds_idx].append(s) + batches = [ + self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length) + for i in range(len(self.datasets)) + if len(samples_dict[i]) > 0 + ] + + def straight_data(tensors): + batch = torch.cat(tensors, dim=0) + return batch + + src_lengths = straight_data( + [b["net_input"]["src_lengths"] for b in batches] + ) + src_lengths, sort_order = src_lengths.sort(descending=True) + + def straight_order(tensors): + batch = straight_data(tensors) + return batch.index_select(0, sort_order) + + batch = { + "id": straight_order([b["id"] for b in batches]), + "nsentences": sum(b["nsentences"] for b in batches), + "ntokens": sum(b["ntokens"] for b in batches), + "net_input": { + "src_tokens": straight_order( + [b["net_input"]["src_tokens"] for b in batches] + ), + "src_lengths": src_lengths, + }, + "target": straight_order([b["target"] for b in batches]) + if batches[0]["target"] is not None + else None, + } + if "prev_output_tokens" in batches[0]["net_input"]: + batch["net_input"]["prev_output_tokens"] = straight_order( + [b["net_input"]["prev_output_tokens"] for b in batches] + ) + if "src_lang_id" in batches[0]["net_input"]: + batch["net_input"]["src_lang_id"] = straight_order( + [b["net_input"]["src_lang_id"] for b in batches] + ) + if "tgt_lang_id" in batches[0]: + batch["tgt_lang_id"] = straight_order( + [b["tgt_lang_id"] for b in batches] + ) + return batch + + @property + def sizes(self): + if self._sizes is not None: + return self._sizes + start_time = time.time() + in_sub_dataset_indices = [ + self._cur_indices[ + 0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i] + ] + for i in range(len(self.datasets)) + ] + sub_dataset_sizes = [ + d.sizes[indices] + for d, indices in zip(self.datasets, in_sub_dataset_indices) + ] + self._sizes = np.vstack(sub_dataset_sizes) + logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}") + return self._sizes + + def ordered_indices(self): + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) + + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] + sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")] + return sort_indices + + def prefetch(self, indices): + prefetch_indices = [[] for _ in range(len(self.datasets))] + for i in indices: + ds_idx, ds_sample_idx = self._get_dataset_and_index(i) + prefetch_indices[ds_idx].append(ds_sample_idx) + for i in range(len(prefetch_indices)): + self.datasets[i].prefetch(prefetch_indices[i]) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if epoch == self._cur_epoch: + # re-enter so return + return + for d in self.datasets: + if hasattr(d, "set_epoch"): + d.set_epoch(epoch) + self._cur_epoch = epoch + self._establish_virtual_datasets() + + def _establish_virtual_datasets(self): + if self.sample_ratios is None and self._cur_indices is not None: + # not a samping dataset, no need to resample if indices are already established + return + self._reset_cached_properties() + + start_time = time.time() + # Generate a weighted sample of indices as a function of the + # random seed and the current epoch. + rng = np.random.RandomState( + [ + int( + hashlib.sha1( + str(self.__class__.__name__).encode("utf-8") + ).hexdigest(), + 16, + ) + % (2**32), + self.seed % (2**32), # global seed + self._cur_epoch, # epoch index, + ] + ) + self._clean_if_not_none( + [self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes] + ) + self._sizes = None + + indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices( + rng, self.datasets, self.sample_ratios, self.virtual_size + ) + self._cur_indices = indices + self.cumulated_sizes = cumulated_sizes + self.virtual_size_per_dataset = virtual_size_per_dataset + + raw_sizes = [len(d) for d in self.datasets] + sampled_sizes = self.virtual_size_per_dataset + logger.info( + f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; " + f"raw total size: {sum(raw_sizes)}" + ) + logger.info( + f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; " + f"resampled total size: {sum(sampled_sizes)}" + ) + if self.sample_ratios is not None: + logger.info( + f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}" + ) + else: + logger.info(f"[{self.split}] A concat dataset") + logger.info( + f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}" + ) + + def filter_indices_by_size(self, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) + + return data_utils.filter_paired_dataset_indices_by_size( + src_sizes, tgt_sizes, indices, max_sizes + ) diff --git a/fairseq/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/fairseq/data/multilingual/sampled_multi_epoch_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bb187a8dc28c7647fe93cd4ba3d26f5a892ca7fd --- /dev/null +++ b/fairseq/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -0,0 +1,199 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hashlib +import logging +import math + +import numpy as np + +from fairseq.data import SampledMultiDataset + +from .sampled_multi_dataset import CollateFormat, default_virtual_size_func + +logger = logging.getLogger(__name__) + + +class SampledMultiEpochDataset(SampledMultiDataset): + """Samples from multiple sub-datasets according to sampling ratios + using virtual epoch sizes to speed up dataloading. + Args: + datasets ( + List[~torch.utils.data.Dataset] + or OrderedDict[str, ~torch.utils.data.Dataset] + ): datasets + sampling_ratios (List[float]): list of probability of each dataset to be sampled + (default: None, which corresponds to concating all dataset together). + seed (int): RNG seed to use (default: 2). + epoch (int): starting epoch number (default: 1). + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. + collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or + CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures + the collater to output batches of data mixed from all sub-datasets, + and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys + of sub-datasets. + Note that not all sub-datasets will present in a single batch in both formats. + virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). + split (str): the split of the data, e.g. 'train', 'valid' or 'test'. + virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by + this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering + can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded. + shared_collater (bool): whether or not to all sub-datasets have the same collater. + shard_epoch (int): the real epoch number for shard selection. + shuffle (bool): whether or not to shuffle data (default: True). + """ + + def __init__( + self, + datasets, + sampling_ratios=None, + seed=2, + epoch=1, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=default_virtual_size_func, + split="", + virtual_epoch_size=None, + shared_collater=False, + shard_epoch=1, + shuffle=True, + ): + self.virtual_epoch_size = virtual_epoch_size + self._current_epoch_start_index = None + self._random_global_indices = None + self.shard_epoch = shard_epoch if shard_epoch is not None else 1 + self.load_next_shard = None + self._epoch_sizes = None + super().__init__( + datasets=datasets, + sampling_ratios=sampling_ratios, + seed=seed, + epoch=epoch, + eval_key=eval_key, + collate_format=collate_format, + virtual_size=virtual_size, + split=split, + shared_collater=shared_collater, + shuffle=shuffle, + ) + + def _setup(self, epoch): + self.virtual_epoch_size = ( + self.virtual_epoch_size + if self.virtual_epoch_size is not None + else self.virtual_size + ) + if self.virtual_epoch_size > self.virtual_size: + logger.warning( + f"virtual epoch size {self.virtual_epoch_size} " + f"is greater than virtual dataset size {self.virtual_size}" + ) + self.virtual_epoch_size = self.virtual_size + self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size) + self._current_epoch_start_index = self._get_epoch_start_index(epoch) + logger.info( + f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}" + ) + + def _map_epoch_index_to_global(self, index): + index = self._current_epoch_start_index + index + # add randomness + return self._random_global_indices[index] + + @property + def sizes(self): + if self._epoch_sizes is not None: + return self._epoch_sizes + _sizes = super().sizes + indices = self._random_global_indices[ + self._current_epoch_start_index : self._current_epoch_start_index + + len(self) + ] + self._epoch_sizes = _sizes[indices] + # del super()._sizes to save memory + del self._sizes + self._sizes = None + return self._epoch_sizes + + def _get_dataset_and_index(self, index): + i = self._map_epoch_index_to_global(index) + return super()._get_dataset_and_index(i) + + def __len__(self): + return ( + self.virtual_epoch_size + if self._current_epoch_start_index + self.virtual_epoch_size + < self.virtual_size + else self.virtual_size - self._current_epoch_start_index + ) + + def set_epoch(self, epoch): + if self._current_epoch_start_index is None: + # initializing epoch idnices of a virtual dataset + self._setup(epoch) + self._next_virtual_epoch(epoch) + else: + # working on already intialized epoch indices + if epoch == self._cur_epoch: + # re-enter so return + return + self._next_virtual_epoch(epoch) + + def _get_epoch_start_index(self, epoch): + assert epoch >= 1 # fairseq is using 1-based epoch everywhere + return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size + + def _next_global_indices(self, epoch): + rng = np.random.RandomState( + [ + int( + hashlib.sha1( + str(self.__class__.__name__).encode("utf-8") + ).hexdigest(), + 16, + ) + % (2**32), + self.seed % (2**32), # global seed + epoch, # epoch index, + ] + ) + del self._random_global_indices + self._random_global_indices = rng.choice( + self.virtual_size, self.virtual_size, replace=False + ) + if self.load_next_shard is None: + self.load_next_shard = False + else: + # increase shard epoch for next loading + self.shard_epoch += 1 + self.load_next_shard = True + logger.info( + "to load next epoch/shard in next load_dataset: " + f"epoch={epoch}/shard_epoch={self.shard_epoch}" + ) + + def _next_virtual_epoch(self, epoch): + index = self._get_epoch_start_index(epoch) + if index == 0 or self._random_global_indices is None: + # need to start from the beginning, + # so call super().set_epoch(epoch) to establish the global virtual indices + logger.info( + "establishing a new set of global virtual indices for " + f"epoch={epoch}/shard_epoch={self.shard_epoch}" + ) + super().set_epoch(epoch) + self._next_global_indices(epoch) + else: + self._cur_epoch = epoch + + # reset cache sizes and ordered_indices for the epoch after moving to a new epoch + self._clean_if_not_none( + [ + self._epoch_sizes, + ] + ) + self._epoch_sizes = None + self._current_epoch_start_index = index diff --git a/fairseq/fairseq/data/multilingual/sampling_method.py b/fairseq/fairseq/data/multilingual/sampling_method.py new file mode 100644 index 0000000000000000000000000000000000000000..140c68f01d60e902ef88f11f30f8813dc15fc681 --- /dev/null +++ b/fairseq/fairseq/data/multilingual/sampling_method.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List + + +logger = logging.getLogger(__name__) + + +def uniform(dataset_sizes: List[int]): + return [1.0] * len(dataset_sizes) + + +def temperature_sampling(dataset_sizes, temp): + total_size = sum(dataset_sizes) + return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] + + +def make_temperature_sampling(temp=1.0): + def sampling_func(dataset_sizes): + return temperature_sampling(dataset_sizes, temp) + + return sampling_func + + +def make_ratio_sampling(ratios): + def sampling_func(dataset_sizes): + return ratios + + return sampling_func + + +class SamplingMethod: + @staticmethod + def add_arguments(parser): + parser.add_argument( + "--sampling-method", + choices=[ + "uniform", + "temperature", + "concat", + "RoundRobin", + ], + type=str, + default="concat", + help="The method to sample data per language pairs", + ) + parser.add_argument( + "--sampling-temperature", + default=1.5, + type=float, + help="only work with --sampling-method temperature", + ) + + @staticmethod + def build_sampler(args, task): + return SamplingMethod(args, task) + + def __init__(self, args, task): + self.args = args + self.task = task + + def is_adaptive(self): + return False + + def sampling_method_selector(self): + args = self.args + logger.info(f"selected sampler: {args.sampling_method}") + if args.sampling_method == "uniform": + return uniform + elif args.sampling_method == "temperature" or self.is_adaptive(): + return make_temperature_sampling(float(args.sampling_temperature)) + else: + # default to concating all data set together + return None diff --git a/fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..29681fa71a77b40638f051ae8d77bd4a9845853b --- /dev/null +++ b/fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf5c5a79b2d76f0e492935b1574f8d570b2b7a8c607415172b0d326478cc1c18 +size 1559392 diff --git a/fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..2e3c85fe3f5701fa4625f60ce719639b12a1e352 --- /dev/null +++ b/fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd79aa55743c6a6c664461ccc6704f7474fb9f79a43abb5283e57f0083cbb277 +size 1023360