diff --git a/.gitattributes b/.gitattributes
index 4be18050261ccaa46c79adbc7a0add5da290706a..34cd0b7d9ed3676541c1cfaa61e9c4e7257e7b42 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -42,3 +42,6 @@ fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs
fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
fairseq/fairseq/libnat.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
+fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
+fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
+fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
diff --git a/fairseq/fairseq/data/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3f79e390b40761694b429cc7b7e5aab72dd6f3b
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0193961835ccdcee50bf89844e9c489a8c19315
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9efbb8d303d8f0f42212648c2a19bc6556e2f81
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e99c6ecaa6912191f55ac11f05b636372f8c9f90
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34a930ef86408243443fe0eb8e3af58f8e78691a
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e1033d19b243dd9349421944fa71947e2f7eaf2
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/codedataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/codedataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32f71e6aa568e9282950bb0b65d0b8d8ba6f269b
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/codedataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff88a067615e613d5e574c86a365dc6283d4e0f4
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..825cb4cffde567aac7fe80a1b192202eafd56ea1
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18edb3821e9735c4962c3e6fd8154e2fe30fba49
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/data_utils.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/data_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0b7c27ce9f76f47968f796351c5ca2f6ec21115
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/data_utils.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e27aa94a0db24f2d61889ac3e14ddad04d8bd6ec
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/dictionary.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/dictionary.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5deaca93207816e04764f6e5e7a7e530557b299a
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/dictionary.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c1fcf622c07f6c0a1f0e7f0d551cc59d9ecbd01
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b73da5c06e5402ee011f1ce2e2fc707d1fb6440e
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/id_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/id_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39df62a735de8d04539061cd65532f94a18dcfb4
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/id_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28af08664a96b6134c7ca775b40c49495355d5eb
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6708004f45a467c0fc717308b6df3b0d1a4a3879
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/list_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/list_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83954067f5ce4f90a178e262648dcc7550e813b7
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/list_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4c3e9f772f636c9bc578e181475c5bec1c4656c
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39ba35c8bf575abdd944111dfd425316d820cfb5
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d09ba0baa8874a98ea0abbaac114ce51b83ce961
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4245cf8251e5cc174217c58f558854865ba2e8f
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2609c85b0c2e6f547f24522f9a0498b378cbd5a8
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d45117e5ac65addb2c4fea9f183813fa59e5693
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9618fcfa6f2f7a8df52c59a5d2dac31703c40444
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/noising.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/noising.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3368672907d3f4bcf19f676da106053b741bfbc5
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/noising.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ebb7fcf3c8ded52c417b69f7180c2b6c76be306
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e66fb3128ae04766e42ad66b3d47309ea4b28e1
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a58dc46778269054f992a75a25a6495eaeceadc0
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..900411750874d4b7c0d82a89b638934cd7a371c0
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06a5c9cb8cdf0e2b8019f744c393974ba84098e3
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3462fa66c8897772673546b2d9de67b79b29ce3
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81484571de4803e069e0335efa426fb22a98f7a7
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ef0d82d31a6c956174193da3b5fdd76750906a1
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..735a5a0be8f16633fce9fd6b2874059976f1174f
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afa65c98a965bd5afa1fdd9caa4a4c7214a660a7
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfe6357fb4751494229079fcafeaa70509acd5c1
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99d9e15d2f7ff849f9659db9e97a48b98b4b8683
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a796a1360badcd4532abee5e50cc830c0c7f2ec
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/text_compressor.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/text_compressor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe691aa804d2a26aeff8e90d0450933edaaa4695
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/text_compressor.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..188380cdbbce1617f6f509f0e3b1f2d429c73dc4
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..094b9eb385cf6c5f08f87ad7e572c2ca3cb05770
Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 0000000000000000000000000000000000000000..762839f2f20143c2d201dc8ec82def1464147bb4
--- /dev/null
+++ b/fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:382bd1943d43d4d607db5910ce6671e50cbbb6a6fdd8a38b922c9d8fa379efc9
+size 267952
diff --git a/fairseq/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b27817fbeeec0c1a00628c36f2f13173f9b08316
Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c27e0dec9cd96b8f3492b117e49e3634b6daecf
Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70a2ea4c11b75d23c607ac6e70c3f8999f2fa572
Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19c0e4e7b7a1f0550da8e76d8846bcd3db8628fd
Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..affafe0cc58ef2249fe37c41939fc8908fa9d2c6
Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc b/fairseq/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b889a92387d1a98d70a06be059e48d984dc14c2
Binary files /dev/null and b/fairseq/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/huffman/__init__.py b/fairseq/fairseq/data/huffman/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b61fafadba28f65fe78a28b2099368b83cfcf41
--- /dev/null
+++ b/fairseq/fairseq/data/huffman/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder
+from .huffman_mmap_indexed_dataset import (
+ HuffmanMMapIndex,
+ HuffmanMMapIndexedDataset,
+ HuffmanMMapIndexedDatasetBuilder,
+ vocab_file_path,
+)
+
+__all__ = [
+ "HuffmanCoder",
+ "HuffmanCodeBuilder",
+ "HuffmanMMapIndexedDatasetBuilder",
+ "HuffmanMMapIndexedDataset",
+ "HuffmanMMapIndex",
+ "vocab_file_path",
+]
diff --git a/fairseq/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a17437713c05ff2184b755e12c5dd39a4b2d4591
Binary files /dev/null and b/fairseq/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc b/fairseq/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67cfb0c92fc1591c370cfd9be5ef1fa3a0ead460
Binary files /dev/null and b/fairseq/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc b/fairseq/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8b550e7481bbf5bd637262e872ee6c1d06736c8
Binary files /dev/null and b/fairseq/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/huffman/huffman_coder.py b/fairseq/fairseq/data/huffman/huffman_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c04f84564e6a22209439c67fed3cac31f010c6e9
--- /dev/null
+++ b/fairseq/fairseq/data/huffman/huffman_coder.py
@@ -0,0 +1,267 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+import typing as tp
+from collections import Counter, deque
+from dataclasses import dataclass
+
+from bitarray import bitarray, util
+from fairseq.data import Dictionary
+
+# basically we have to write to addressable bytes for the memory mapped
+# dataset loader. Sentences that get encoded to a length that is not a
+# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder)
+BLOCKSIZE = 8
+
+
+class HuffmanCoder:
+ def __init__(
+ self, root: "HuffmanNode", bos="", pad="", eos="", unk=""
+ ):
+ self.root = root
+ self.table = root.code_table()
+ self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
+
+ def _pad(self, a: bitarray) -> bitarray:
+ """
+ bitpadding, 1 then 0.
+
+ If the array is already a multiple of blocksize, we add a full block.
+ """
+ pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1
+ padding = bitarray("1" + "0" * pad_len)
+ return a + padding
+
+ def _unpad(self, a: bitarray) -> bitarray:
+ """
+ remove the bitpadding.
+
+ There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that
+ """
+ # count the 0 padding at the end until we find the first 1
+ # we want to remove the one too
+ remove_cnt = util.rindex(a, 1)
+ return a[:remove_cnt]
+
+ def encode(self, iter: tp.List[str]) -> bytes:
+ """
+ encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes.
+ """
+ a = bitarray()
+ for token in iter:
+ code = self.get_code(token)
+ if code is None:
+ if self.unk_word is None:
+ raise Exception(f"unknown token {token} cannot be encoded.")
+ else:
+ token = self.unk_word
+ a = a + self.get_code(token)
+ return self._pad(a).tobytes()
+
+ def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]:
+ """
+ take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id
+ """
+ a = bitarray()
+ a.frombytes(bits)
+ return self.root.decode(self._unpad(a))
+
+ def get_code(self, symbol: str) -> tp.Optional[bitarray]:
+ node = self.get_node(symbol)
+ return None if node is None else node.code
+
+ def get_node(self, symbol: str) -> "HuffmanNode":
+ return self.table.get(symbol)
+
+ @classmethod
+ def from_file(
+ cls,
+ filename: str,
+ bos="",
+ pad="",
+ eos="",
+ unk="",
+ ) -> "HuffmanCoder":
+ builder = HuffmanCodeBuilder.from_file(filename)
+ return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk)
+
+ def to_file(self, filename, sep="\t"):
+ nodes = list(self.table.values())
+ nodes.sort(key=lambda n: n.id)
+ with open(filename, "w", encoding="utf-8") as output:
+ for n in nodes:
+ output.write(f"{n.symbol}{sep}{n.count}\n")
+
+ def __iter__(self):
+ for n in self.table.values():
+ yield n
+
+ def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder":
+ builder = HuffmanCodeBuilder()
+ for n in self:
+ builder.increment(n.symbol, n.count)
+ for n in other_coder:
+ builder.increment(n.symbol, n.count)
+ return builder.build_code()
+
+ def __eq__(self, other: "HuffmanCoder") -> bool:
+ return self.table == other.table
+
+ def __len__(self) -> int:
+ return len(self.table)
+
+ def __contains__(self, sym: str) -> bool:
+ return sym in self.table
+
+ def to_dictionary(self) -> Dictionary:
+ dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos)
+ for n in self:
+ dictionary.add_symbol(n.symbol, n=n.count)
+ dictionary.finalize()
+ return dictionary
+
+
+@dataclass
+class HuffmanNode:
+ """
+ a node in a Huffman tree
+ """
+
+ id: int
+ count: int
+ symbol: tp.Optional[str] = None
+ left: tp.Optional["HuffmanNode"] = None
+ right: tp.Optional["HuffmanNode"] = None
+ code: tp.Optional[bitarray] = None
+
+ def is_leaf(self) -> bool:
+ return self.left is None and self.right is None
+
+ def code_table(
+ self, prefix: tp.Optional[bitarray] = None
+ ) -> tp.Dict[str, "HuffmanNode"]:
+ defaulted_prefix = prefix if prefix is not None else bitarray()
+ if self.is_leaf():
+ self.code = (
+ defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0")
+ ) # leaf could be the root if there is only one symbol
+ return {self.symbol: self}
+
+ codes_right = self.right.code_table(defaulted_prefix + bitarray([0]))
+ codes_left = self.left.code_table(defaulted_prefix + bitarray([1]))
+ return {**codes_left, **codes_right}
+
+ def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]:
+ current_node = self
+ for bit in bits:
+ if bit == 0: # go right
+ current_node = current_node.right
+ else: # go left
+ current_node = current_node.left
+ if current_node is None:
+ # we shouldn't be on a leaf here
+ raise Exception("fell off a leaf")
+ if current_node.is_leaf():
+ yield current_node
+ current_node = self
+ if current_node != self:
+ raise Exception("couldn't decode all the bits")
+
+
+class HuffmanCodeBuilder:
+ """
+ build a dictionary with occurence count and then build the Huffman code for it.
+ """
+
+ def __init__(self):
+ self.symbols = Counter()
+
+ def add_symbols(self, *syms) -> None:
+ self.symbols.update(syms)
+
+ def increment(self, symbol: str, cnt: int) -> None:
+ self.symbols[symbol] += cnt
+
+ @classmethod
+ def from_file(cls, filename):
+ c = cls()
+ with open(filename, "r", encoding="utf-8") as input:
+ for line in input:
+ split = re.split(r"[\s]+", line)
+ c.increment(split[0], int(split[1]))
+ return c
+
+ def to_file(self, filename, sep="\t"):
+ with open(filename, "w", encoding="utf-8") as output:
+ for (tok, cnt) in self.symbols.most_common():
+ output.write(f"{tok}{sep}{cnt}\n")
+
+ def _smallest(self, q1: deque, q2: deque) -> HuffmanNode:
+ if len(q1) == 0:
+ return q2.pop()
+
+ if len(q2) == 0:
+ return q1.pop()
+
+ if q1[-1].count < q2[-1].count:
+ return q1.pop()
+
+ return q2.pop()
+
+ def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder":
+ new_c = self.symbols + c.symbols
+ new_b = HuffmanCodeBuilder()
+ new_b.symbols = new_c
+ return new_b
+
+ def build_code(
+ self,
+ bos="",
+ pad="",
+ eos="",
+ unk="",
+ ) -> HuffmanCoder:
+ assert len(self.symbols) > 0, "cannot build code from empty list of symbols"
+
+ if self.symbols[bos] == 0:
+ self.add_symbols(bos)
+ if self.symbols[pad] == 0:
+ self.add_symbols(pad)
+ if self.symbols[eos] == 0:
+ self.add_symbols(eos)
+ if self.symbols[unk] == 0:
+ self.add_symbols(unk)
+
+ node_id = 0
+ leaves_queue = deque(
+ [
+ HuffmanNode(symbol=symbol, count=count, id=idx)
+ for idx, (symbol, count) in enumerate(self.symbols.most_common())
+ ]
+ ) # left are the most common, right are the least common
+
+ if len(leaves_queue) == 1:
+ root = leaves_queue.pop()
+ root.id = 0
+ return HuffmanCoder(root)
+
+ nodes_queue = deque()
+
+ while len(leaves_queue) > 0 or len(nodes_queue) != 1:
+ # get the lowest two nodes at the head of each queue
+ node1 = self._smallest(leaves_queue, nodes_queue)
+ node2 = self._smallest(leaves_queue, nodes_queue)
+
+ # add new node
+ nodes_queue.appendleft(
+ HuffmanNode(
+ count=node1.count + node2.count, left=node1, right=node2, id=node_id
+ )
+ )
+ node_id += 1
+
+ # we are left with the root
+ return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk)
diff --git a/fairseq/fairseq/data/huffman/huffman_mmap_indexed_dataset.py b/fairseq/fairseq/data/huffman/huffman_mmap_indexed_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b098f2c2be32ef65525dd773a6664d7823ada38
--- /dev/null
+++ b/fairseq/fairseq/data/huffman/huffman_mmap_indexed_dataset.py
@@ -0,0 +1,287 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import mmap
+import os
+import shutil
+import struct
+import typing as tp
+from functools import lru_cache
+
+import numpy as np
+import torch
+from fairseq.data import indexed_dataset
+from fairseq.data.huffman import HuffmanCoder
+from fairseq.file_io import PathManager
+
+
+class HuffmanMMapIndex:
+ """
+ keep an index of the offsets in the huffman binary file.
+ First a header, then the list of sizes (num tokens) for each instance and finally
+ the addresses of each instance.
+ """
+
+ _HDR_MAGIC = b"HUFFIDX\x00\x00"
+ _VERSION = 1
+
+ @classmethod
+ def writer(cls, path: str, data_len: int):
+ class _Writer:
+ def __enter__(self):
+ self._file = open(path, "wb")
+
+ # write header (magic + version)
+ self._file.write(cls._HDR_MAGIC)
+ self._file.write(struct.pack(" None:
+ self._path_prefix = path_prefix
+ self._coder = coder
+ self._sizes = []
+ self._ptrs = []
+ self._data_len = 0
+
+ def open(self):
+ self._coder.to_file(vocab_file_path(self._path_prefix))
+ self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb")
+
+ def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder":
+ self.open()
+ return self
+
+ def add_item(self, tokens: tp.List[str]) -> None:
+ """
+ add a list of tokens to the dataset, they will compressed with the
+ provided coder before being written to file.
+ """
+ encoded = self._coder.encode(tokens)
+ code_len = len(encoded)
+ last_ptr = 0
+ if len(self._ptrs) > 0:
+ last_ptr = self._ptrs[-1]
+ self._sizes.append(len(tokens))
+ self._ptrs.append(last_ptr + code_len)
+ self._data_len += code_len
+ self._data_file.write(encoded)
+
+ def append(self, other_dataset_path_prefix: str) -> None:
+ """
+ append an existing dataset.
+ Beware, if it wasn't built with the same coder, you are in trouble.
+ """
+ other_index = HuffmanMMapIndex(
+ indexed_dataset.index_file_path(other_dataset_path_prefix)
+ )
+ for (ptr, size) in other_index:
+ self._ptrs.append(ptr + self._data_len)
+ self._sizes.append(size)
+
+ # Concatenate data
+ with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f:
+ shutil.copyfileobj(f, self._data_file)
+
+ self._data_len += other_index.data_len
+
+ def close(self):
+ self._data_file.close()
+ with HuffmanMMapIndex.writer(
+ indexed_dataset.index_file_path(self._path_prefix), self._data_len
+ ) as index:
+ index.write(self._sizes, self._ptrs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+ self.close()
diff --git a/fairseq/fairseq/data/legacy/__init__.py b/fairseq/fairseq/data/legacy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bd5c72b5e9d7f67fb7e4ef10808d7ec08967ff4
--- /dev/null
+++ b/fairseq/fairseq/data/legacy/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .block_pair_dataset import BlockPairDataset
+from .masked_lm_dataset import MaskedLMDataset
+from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
+
+
+__all__ = [
+ "BertDictionary",
+ "BlockPairDataset",
+ "MaskedLMDataset",
+ "MaskedLMDictionary",
+]
diff --git a/fairseq/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc3f5e0cb0064c1ce99929776f805f5d0967ee31
Binary files /dev/null and b/fairseq/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc b/fairseq/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d14c88f1547179a8001c76524198afa320792b9b
Binary files /dev/null and b/fairseq/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc b/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2fe87d8a0e16727be3115f18c2104464f50a4fa
Binary files /dev/null and b/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc b/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bf6fc398c40e1e2f9628682dff89cca8f738f04
Binary files /dev/null and b/fairseq/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/legacy/block_pair_dataset.py b/fairseq/fairseq/data/legacy/block_pair_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba069b46052286c531b4f9706d96788732cd2ad2
--- /dev/null
+++ b/fairseq/fairseq/data/legacy/block_pair_dataset.py
@@ -0,0 +1,311 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import numpy as np
+import torch
+from fairseq.data import FairseqDataset
+
+
+class BlockPairDataset(FairseqDataset):
+ """Break a Dataset of tokens into sentence pair blocks for next sentence
+ prediction as well as masked language model.
+
+ High-level logics are:
+ 1. break input tensor to tensor blocks
+ 2. pair the blocks with 50% next sentence and 50% random sentence
+ 3. return paired blocks as well as related segment labels
+
+ Args:
+ dataset (~torch.utils.data.Dataset): dataset to break into blocks
+ sizes: array of sentence lengths
+ dictionary: dictionary for the task
+ block_size: maximum block size
+ break_mode: mode for breaking copurs into block pairs. currently we support
+ 2 modes
+ doc: respect document boundaries and each part of the pair should belong to on document
+ none: don't respect any boundary and cut tokens evenly
+ short_seq_prob: probability for generating shorter block pairs
+ doc_break_size: Size for empty line separating documents. Typically 1 if
+ the sentences have eos, 0 otherwise.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ dictionary,
+ sizes,
+ block_size,
+ break_mode="doc",
+ short_seq_prob=0.1,
+ doc_break_size=1,
+ ):
+ super().__init__()
+ self.dataset = dataset
+ self.pad = dictionary.pad()
+ self.eos = dictionary.eos()
+ self.cls = dictionary.cls()
+ self.mask = dictionary.mask()
+ self.sep = dictionary.sep()
+ self.break_mode = break_mode
+ self.dictionary = dictionary
+ self.short_seq_prob = short_seq_prob
+ self.block_indices = []
+
+ assert len(dataset) == len(sizes)
+
+ if break_mode == "doc":
+ cur_doc = []
+ for sent_id, sz in enumerate(sizes):
+ assert doc_break_size == 0 or sz != 0, (
+ "when doc_break_size is non-zero, we expect documents to be"
+ "separated by a blank line with a single eos."
+ )
+ # empty line as document separator
+ if sz == doc_break_size:
+ if len(cur_doc) == 0:
+ continue
+ self.block_indices.append(cur_doc)
+ cur_doc = []
+ else:
+ cur_doc.append(sent_id)
+ max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP]
+ self.sent_pairs = []
+ self.sizes = []
+ for doc_id, doc in enumerate(self.block_indices):
+ self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes)
+ elif break_mode is None or break_mode == "none":
+ # each block should have half of the block size since we are constructing block pair
+ sent_length = (block_size - 3) // 2
+ total_len = sum(dataset.sizes)
+ length = math.ceil(total_len / sent_length)
+
+ def block_at(i):
+ start = i * sent_length
+ end = min(start + sent_length, total_len)
+ return (start, end)
+
+ sent_indices = np.array([block_at(i) for i in range(length)])
+ sent_sizes = np.array([e - s for s, e in sent_indices])
+ dataset_index = self._sent_to_dataset_index(sent_sizes)
+
+ # pair sentences
+ self._pair_sentences(dataset_index)
+ else:
+ raise ValueError("Invalid break_mode: " + break_mode)
+
+ def _pair_sentences(self, dataset_index):
+ """
+ Give a list of evenly cut blocks/sentences, pair these sentences with 50%
+ consecutive sentences and 50% random sentences.
+ This is used for none break mode
+ """
+ # pair sentences
+ for sent_id, sent in enumerate(dataset_index):
+ next_sent_label = (
+ 1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0
+ )
+ if next_sent_label:
+ next_sent = dataset_index[sent_id + 1]
+ else:
+ next_sent = dataset_index[
+ self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1])
+ ]
+ self.sent_pairs.append((sent, next_sent, next_sent_label))
+
+ # The current blocks don't include the special tokens but the
+ # sizes already account for this
+ self.sizes.append(3 + sent[3] + next_sent[3])
+
+ def _sent_to_dataset_index(self, sent_sizes):
+ """
+ Build index mapping block indices to the underlying dataset indices
+ """
+ dataset_index = []
+ ds_idx, ds_remaining = -1, 0
+ for to_consume in sent_sizes:
+ sent_size = to_consume
+ if ds_remaining == 0:
+ ds_idx += 1
+ ds_remaining = sent_sizes[ds_idx]
+ start_ds_idx = ds_idx
+ start_offset = sent_sizes[ds_idx] - ds_remaining
+ while to_consume > ds_remaining:
+ to_consume -= ds_remaining
+ ds_idx += 1
+ ds_remaining = sent_sizes[ds_idx]
+ ds_remaining -= to_consume
+ dataset_index.append(
+ (
+ start_ds_idx, # starting index in dataset
+ start_offset, # starting offset within starting index
+ ds_idx, # ending index in dataset
+ sent_size, # sentence length
+ )
+ )
+ assert ds_remaining == 0
+ assert ds_idx == len(self.dataset) - 1
+ return dataset_index
+
+ def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
+ """
+ Go through a single document and genrate sentence paris from it
+ """
+ current_chunk = []
+ current_length = 0
+ curr = 0
+ # To provide more randomness, we decrease target seq length for parts of
+ # samples (10% by default). Note that max_num_tokens is the hard threshold
+ # for batching and will never be changed.
+ target_seq_length = max_num_tokens
+ if np.random.random() < self.short_seq_prob:
+ target_seq_length = np.random.randint(2, max_num_tokens)
+ # loop through all sentences in document
+ while curr < len(doc):
+ sent_id = doc[curr]
+ current_chunk.append(sent_id)
+ current_length = sum(sizes[current_chunk])
+ # split chunk and generate pair when exceed target_seq_length or
+ # finish the loop
+ if curr == len(doc) - 1 or current_length >= target_seq_length:
+ # split the chunk into 2 parts
+ a_end = 1
+ if len(current_chunk) > 2:
+ a_end = np.random.randint(1, len(current_chunk) - 1)
+ sent_a = current_chunk[:a_end]
+ len_a = sum(sizes[sent_a])
+ # generate next sentence label, note that if there is only 1 sentence
+ # in current chunk, label is always 0
+ next_sent_label = (
+ 1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
+ )
+ if not next_sent_label:
+ # if next sentence label is 0, sample sent_b from a random doc
+ target_b_length = target_seq_length - len_a
+ rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
+ random_doc = self.block_indices[rand_doc_id]
+ random_start = np.random.randint(0, len(random_doc))
+ sent_b = []
+ len_b = 0
+ for j in range(random_start, len(random_doc)):
+ sent_b.append(random_doc[j])
+ len_b = sum(sizes[sent_b])
+ if len_b >= target_b_length:
+ break
+ # return the second part of the chunk since it's not used
+ num_unused_segments = len(current_chunk) - a_end
+ curr -= num_unused_segments
+ else:
+ # if next sentence label is 1, use the second part of chunk as sent_B
+ sent_b = current_chunk[a_end:]
+ len_b = sum(sizes[sent_b])
+ # currently sent_a and sent_B may be longer than max_num_tokens,
+ # truncate them and return block idx and offsets for them
+ sent_a, sent_b = self._truncate_sentences(
+ sent_a, sent_b, max_num_tokens
+ )
+ self.sent_pairs.append((sent_a, sent_b, next_sent_label))
+ self.sizes.append(3 + sent_a[3] + sent_b[3])
+ current_chunk = []
+ curr += 1
+
+ def _skip_sampling(self, total, skip_ids):
+ """
+ Generate a random integer which is not in skip_ids. Sample range is [0, total)
+ TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later
+ """
+ rand_id = np.random.randint(total - len(skip_ids))
+ return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids)
+
+ def _truncate_sentences(self, sent_a, sent_b, max_num_tokens):
+ """
+ Trancate a pair of sentence to limit total length under max_num_tokens
+ Logics:
+ 1. Truncate longer sentence
+ 2. Tokens to be truncated could be at the beginning or the end of the sentnce
+ Returns:
+ Truncated sentences represented by dataset idx
+ """
+ len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b])
+ front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0
+
+ while True:
+ total_length = (
+ len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b
+ )
+ if total_length <= max_num_tokens:
+ break
+
+ if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b:
+ if np.random.rand() < 0.5:
+ front_cut_a += 1
+ else:
+ end_cut_a += 1
+ else:
+ if np.random.rand() < 0.5:
+ front_cut_b += 1
+ else:
+ end_cut_b += 1
+
+ # calculate ds indices as well as offsets and return
+ truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a)
+ truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b)
+ return truncated_sent_a, truncated_sent_b
+
+ def _cut_sentence(self, sent, front_cut, end_cut):
+ """
+ Cut a sentence based on the numbers of tokens to be cut from beginning and end
+ Represent the sentence as dataset idx and return
+ """
+ start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0
+ target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut
+ while front_cut > 0:
+ if self.dataset.sizes[start_ds_idx] > front_cut:
+ offset += front_cut
+ break
+ else:
+ front_cut -= self.dataset.sizes[start_ds_idx]
+ start_ds_idx += 1
+ while end_cut > 0:
+ if self.dataset.sizes[end_ds_idx] > end_cut:
+ break
+ else:
+ end_cut -= self.dataset.sizes[end_ds_idx]
+ end_ds_idx -= 1
+ return start_ds_idx, offset, end_ds_idx, target_len
+
+ def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length):
+ """
+ Fetch a block of tokens based on its dataset idx
+ """
+ buffer = torch.cat(
+ [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
+ )
+ s, e = offset, offset + length
+ return buffer[s:e]
+
+ def __getitem__(self, index):
+ block1, block2, next_sent_label = self.sent_pairs[index]
+ block1 = self._fetch_block(*block1)
+ block2 = self._fetch_block(*block2)
+ return block1, block2, next_sent_label
+
+ def __len__(self):
+ return len(self.sizes)
+
+ @property
+ def supports_prefetch(self):
+ return getattr(self.dataset, "supports_prefetch", False)
+
+ def prefetch(self, indices):
+ prefetch_idx = set()
+ for index in indices:
+ for block1, block2, _ in [self.sent_pairs[index]]:
+ for ds_idx in range(block1[0], block1[2] + 1):
+ prefetch_idx.add(ds_idx)
+ for ds_idx in range(block2[0], block2[2] + 1):
+ prefetch_idx.add(ds_idx)
+ self.dataset.prefetch(prefetch_idx)
diff --git a/fairseq/fairseq/data/legacy/masked_lm_dataset.py b/fairseq/fairseq/data/legacy/masked_lm_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd8ea2c60aff306ab3a756223a298a28d41a4991
--- /dev/null
+++ b/fairseq/fairseq/data/legacy/masked_lm_dataset.py
@@ -0,0 +1,303 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+from fairseq.data import Dictionary, FairseqDataset, data_utils
+from fairseq.data.concat_dataset import ConcatDataset
+from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
+from fairseq.data.token_block_dataset import TokenBlockDataset
+
+
+class MaskedLMDataset(FairseqDataset):
+ """
+ A wrapper Dataset for masked language modelling. The dataset
+ wraps around TokenBlockDataset or BlockedPairDataset and creates a batch
+ where the input blocks are masked according to the specified masking
+ probability. Additionally the batch can also contain sentence level targets
+ if this is specified.
+
+ Args:
+ dataset: Dataset which generates blocks of data. Only BlockPairDataset
+ and TokenBlockDataset are supported.
+ sizes: Sentence lengths
+ vocab: Dictionary with the vocabulary and special tokens.
+ pad_idx: Id of padding token in dictionary
+ mask_idx: Id of mask token in dictionary
+ classif_token_idx: Id of classification token in dictionary. This is the
+ token associated with the sentence embedding (Eg: CLS for BERT)
+ sep_token_idx: Id of separator token in dictionary
+ (Eg: SEP in BERT)
+ seed: Seed for random number generator for reproducibility.
+ shuffle: Shuffle the elements before batching.
+ has_pairs: Specifies whether the underlying dataset
+ generates a pair of blocks along with a sentence_target or not.
+ Setting it to True assumes that the underlying dataset generates a
+ label for the pair of sentences which is surfaced as
+ sentence_target. The default value assumes a single block with no
+ sentence target.
+ segment_id: An optional segment id for filling in the segment labels
+ when we are in the single block setting (Eg: XLM). Default is 0.
+ masking_ratio: specifies what percentage of the blocks should be masked.
+ masking_prob: specifies the probability of a given token being
+ replaced with the "MASK" token.
+ random_token_prob: specifies the probability of a given token being
+ replaced by a random token from the vocabulary.
+ """
+
+ def __init__(
+ self,
+ dataset: FairseqDataset,
+ sizes: np.ndarray,
+ vocab: Dictionary,
+ pad_idx: int,
+ mask_idx: int,
+ classif_token_idx: int,
+ sep_token_idx: int,
+ seed: int = 1,
+ shuffle: bool = True,
+ has_pairs: bool = True,
+ segment_id: int = 0,
+ masking_ratio: float = 0.15,
+ masking_prob: float = 0.8,
+ random_token_prob: float = 0.1,
+ ):
+ # Make sure the input datasets are the ones supported
+ assert (
+ isinstance(dataset, TokenBlockDataset)
+ or isinstance(dataset, BlockPairDataset)
+ or isinstance(dataset, ConcatDataset)
+ ), (
+ "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or "
+ "ConcatDataset"
+ )
+
+ self.dataset = dataset
+ self.sizes = np.array(sizes)
+ self.vocab = vocab
+ self.pad_idx = pad_idx
+ self.mask_idx = mask_idx
+ self.classif_token_idx = classif_token_idx
+ self.sep_token_idx = sep_token_idx
+ self.shuffle = shuffle
+ self.seed = seed
+ self.has_pairs = has_pairs
+ self.segment_id = segment_id
+ self.masking_ratio = masking_ratio
+ self.masking_prob = masking_prob
+ self.random_token_prob = random_token_prob
+
+ # If we have only one block then sizes needs to be updated to include
+ # the classification token
+ if not has_pairs:
+ self.sizes = self.sizes + 1
+
+ def __getitem__(self, index: int):
+ # if has_pairs, then expect 2 blocks and a sentence target
+ if self.has_pairs:
+ (block_one, block_two, sentence_target) = self.dataset[index]
+ else:
+ block_one = self.dataset[index]
+
+ return {
+ "id": index,
+ "block_one": block_one,
+ "block_two": block_two if self.has_pairs else None,
+ "sentence_target": sentence_target if self.has_pairs else None,
+ }
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def _mask_block(
+ self,
+ sentence: np.ndarray,
+ mask_idx: int,
+ pad_idx: int,
+ dictionary_token_range: Tuple,
+ ):
+ """
+ Mask tokens for Masked Language Model training
+ Samples mask_ratio tokens that will be predicted by LM.
+
+ Note:This function may not be efficient enough since we had multiple
+ conversions between np and torch, we can replace them with torch
+ operators later.
+
+ Args:
+ sentence: 1d tensor to be masked
+ mask_idx: index to use for masking the sentence
+ pad_idx: index to use for masking the target for tokens we aren't
+ predicting
+ dictionary_token_range: range of indices in dictionary which can
+ be used for random word replacement
+ (e.g. without special characters)
+ Return:
+ masked_sent: masked sentence
+ target: target with words which we are not predicting replaced
+ by pad_idx
+ """
+ masked_sent = np.copy(sentence)
+ sent_length = len(sentence)
+ mask_num = math.ceil(sent_length * self.masking_ratio)
+ mask = np.random.choice(sent_length, mask_num, replace=False)
+ target = np.copy(sentence)
+
+ for i in range(sent_length):
+ if i in mask:
+ rand = np.random.random()
+
+ # replace with mask if probability is less than masking_prob
+ # (Eg: 0.8)
+ if rand < self.masking_prob:
+ masked_sent[i] = mask_idx
+
+ # replace with random token if probability is less than
+ # masking_prob + random_token_prob (Eg: 0.9)
+ elif rand < (self.masking_prob + self.random_token_prob):
+ # sample random token from dictionary
+ masked_sent[i] = np.random.randint(
+ dictionary_token_range[0], dictionary_token_range[1]
+ )
+ else:
+ target[i] = pad_idx
+
+ return masked_sent, target
+
+ def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int):
+ """
+ Does the heavy lifting for creating a batch from the input list of
+ examples. The logic is as follows:
+ 1. Mask the input blocks. In case has_pair is True then we have 2
+ blocks to mask.
+ 2. Prepend the first masked block tensor with the special token
+ used as sentence embedding. Eg: CLS in BERT. This happens
+ irrespective of the value of has_pair.
+ 3. If has_pair is True, then append the first masked block with the
+ special separator token (eg: SEP for BERT) and compute segment
+ label accordingly. In this case, also append the second masked
+ block with this special separator token and compute its segment
+ label.
+ 4. For the targets tensor, prepend and append with padding index
+ accordingly.
+ 5. Concatenate all tensors.
+ """
+ if len(samples) == 0:
+ return {}
+ # To ensure determinism, we reset the state of the PRNG after every
+ # batch based on the seed and the first id of the batch. This ensures
+ # that across epochs we get the same mask for the same example. This
+ # is needed for reproducibility and is how BERT does masking
+ # TODO: Can we add deteminism without this constraint?
+ with data_utils.numpy_seed(self.seed + samples[0]["id"]):
+ for s in samples:
+
+ # token range is needed for replacing with random token during
+ # masking
+ token_range = (self.vocab.nspecial, len(self.vocab))
+
+ # mask according to specified probabilities.
+ masked_blk_one, masked_tgt_one = self._mask_block(
+ s["block_one"],
+ self.mask_idx,
+ self.pad_idx,
+ token_range,
+ )
+
+ tokens = np.concatenate([[self.classif_token_idx], masked_blk_one])
+ targets = np.concatenate([[self.pad_idx], masked_tgt_one])
+ segments = np.ones(len(tokens)) * self.segment_id
+
+ # if has_pairs is True then we need to add the SEP token to both
+ # the blocks after masking and re-compute segments based on the new
+ # lengths.
+ if self.has_pairs:
+ tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
+ targets_one = np.concatenate([targets, [self.pad_idx]])
+
+ masked_blk_two, masked_tgt_two = self._mask_block(
+ s["block_two"], self.mask_idx, self.pad_idx, token_range
+ )
+ tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]])
+ targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])
+
+ # block + 1 sep + 1 special (CLS)
+ segments_one = np.zeros(len(tokens_one))
+ # block + 1 sep
+ segments_two = np.ones(len(tokens_two))
+
+ tokens = np.concatenate([tokens_one, tokens_two])
+ targets = np.concatenate([targets_one, targets_two])
+ segments = np.concatenate([segments_one, segments_two])
+
+ s["source"] = torch.LongTensor(tokens)
+ s["segment_labels"] = torch.LongTensor(segments)
+ s["lm_target"] = torch.LongTensor(targets)
+
+ def merge(key):
+ return data_utils.collate_tokens(
+ [s[key] for s in samples], pad_idx, eos_idx, left_pad=False
+ )
+
+ return {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "ntokens": sum(len(s["source"]) for s in samples),
+ "net_input": {
+ "src_tokens": merge("source"),
+ "segment_labels": merge("segment_labels"),
+ },
+ "lm_target": merge("lm_target"),
+ "sentence_target": torch.LongTensor([s["sentence_target"] for s in samples])
+ if self.has_pairs
+ else None,
+ "nsentences": len(samples),
+ }
+
+ def collater(self, samples: List[Dict]):
+ """Merge a list of samples to form a mini-batch.
+
+ Args:
+ samples (List[dict]): samples to collate
+
+ Returns:
+ dict: a mini-batch of data
+ """
+ return self._collate(samples, self.vocab.pad(), self.vocab.eos())
+
+ def num_tokens(self, index: int):
+ """
+ Return the number of tokens in a sample. This value is used to
+ enforce max-tokens during batching.
+ """
+ return self.sizes[index]
+
+ def size(self, index: int):
+ """
+ Return an example's size as a float or tuple. This value is used when
+ filtering a dataset with max-positions.
+ """
+ return self.sizes[index]
+
+ def ordered_indices(self):
+ """
+ Return an ordered list of indices. Batches will be constructed based
+ on this order.
+ """
+ if self.shuffle:
+ return np.random.permutation(len(self))
+ else:
+ order = [np.arange(len(self))]
+ order.append(self.sizes)
+ return np.lexsort(order)
+
+ @property
+ def supports_prefetch(self):
+ return getattr(self.dataset, "supports_prefetch", False)
+
+ def prefetch(self, indices):
+ self.dataset.prefetch(indices)
diff --git a/fairseq/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq/fairseq/data/legacy/masked_lm_dictionary.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee88f7a3ed72ea465ea4e8ffe7b1c01ff6f57f1
--- /dev/null
+++ b/fairseq/fairseq/data/legacy/masked_lm_dictionary.py
@@ -0,0 +1,60 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from fairseq.data import Dictionary
+
+
+class MaskedLMDictionary(Dictionary):
+ """
+ Dictionary for Masked Language Modelling tasks. This extends Dictionary by
+ adding the mask symbol.
+ """
+
+ def __init__(
+ self,
+ pad="",
+ eos="",
+ unk="",
+ mask="",
+ ):
+ super().__init__(pad=pad, eos=eos, unk=unk)
+ self.mask_word = mask
+ self.mask_index = self.add_symbol(mask)
+ self.nspecial = len(self.symbols)
+
+ def mask(self):
+ """Helper to get index of mask symbol"""
+ return self.mask_index
+
+
+class BertDictionary(MaskedLMDictionary):
+ """
+ Dictionary for BERT task. This extends MaskedLMDictionary by adding support
+ for cls and sep symbols.
+ """
+
+ def __init__(
+ self,
+ pad="",
+ eos="",
+ unk="",
+ mask="",
+ cls="",
+ sep="",
+ ):
+ super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
+ self.cls_word = cls
+ self.sep_word = sep
+ self.cls_index = self.add_symbol(cls)
+ self.sep_index = self.add_symbol(sep)
+ self.nspecial = len(self.symbols)
+
+ def cls(self):
+ """Helper to get index of cls symbol"""
+ return self.cls_index
+
+ def sep(self):
+ """Helper to get index of sep symbol"""
+ return self.sep_index
diff --git a/fairseq/fairseq/data/multilingual/__init__.py b/fairseq/fairseq/data/multilingual/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a
--- /dev/null
+++ b/fairseq/fairseq/data/multilingual/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/fairseq/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1be7366c6c45f31707b7a0e0adf8b450b524b1df
Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd2c94daf519ebb0ecf9586de72a97d631891526
Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..352f3adffa8dda9690c0b378197cbbf634c06ede
Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe84145d7d743929c38a5880c81848ae189f410c
Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d79df9622bd2fb0b9eb10eb908fbc591567d9d0
Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc b/fairseq/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad6f2b8046d9ce468ba62361aba5caa6b335e44b
Binary files /dev/null and b/fairseq/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc differ
diff --git a/fairseq/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/fairseq/data/multilingual/multilingual_data_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..876dfcec36e4cf9236c21e440e9657a68036a278
--- /dev/null
+++ b/fairseq/fairseq/data/multilingual/multilingual_data_manager.py
@@ -0,0 +1,1156 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import itertools
+import json
+import logging
+import math
+import os
+from collections import OrderedDict, defaultdict
+from argparse import ArgumentError
+
+from fairseq import utils
+from fairseq.data import (
+ AppendTokenDataset,
+ ConcatDataset,
+ Dictionary,
+ LanguagePairDataset,
+ PrependTokenDataset,
+ SampledMultiDataset,
+ SampledMultiEpochDataset,
+ StripTokenDataset,
+ TransformEosLangPairDataset,
+ TruncateDataset,
+ data_utils,
+ indexed_dataset,
+)
+from fairseq.data.multilingual.multilingual_utils import (
+ EncoderLangtok,
+ LangTokSpec,
+ LangTokStyle,
+ augment_dictionary,
+ get_lang_tok,
+)
+from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat
+from fairseq.file_io import PathManager
+from fairseq.utils import FileContentsAction, csv_str_list, eval_str_dict
+
+
+logger = logging.getLogger(__name__)
+
+SRC_DICT_NAME = "src"
+TGT_DICT_NAME = "tgt"
+
+
+def _lang_id(dic: Dictionary, lang: str):
+ """Return language ID index."""
+ idx = dic.index(lang)
+ assert idx != dic.unk_index, "cannot find language ID for lang {}".format(lang)
+ return idx
+
+
+def load_sampling_weights(from_file):
+ with open(from_file) as f:
+ weights = json.load(f)
+ return weights
+
+
+class MultilingualDatasetManager(object):
+ def __init__(self, args, lang_pairs, langs, dicts, sampling_method):
+ super().__init__()
+ self.args = args
+ self.seed = args.seed
+ self.lang_pairs = lang_pairs
+ self.extra_lang_pairs = (
+ list({p for _, v in args.extra_lang_pairs.items() for p in v.split(",")})
+ if args.extra_lang_pairs
+ else []
+ )
+ self.src_langs = {
+ p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs
+ }
+ self.tgt_langs = {
+ p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs
+ }
+ self.langs = langs
+ self.dicts = dicts
+ self.lang_dict = self.create_lang_dictionary(self.langs)
+ self.sampling_method = sampling_method
+ self.sampling_scheduler = None
+ self._has_sharded_data = False
+ self._num_shards_dict = {}
+ self._training_data_sizes = defaultdict(lambda: {})
+
+ @classmethod
+ def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method):
+ return MultilingualDatasetManager(
+ args, lang_pairs, langs, dicts, sampling_method
+ )
+
+ @staticmethod
+ def add_args(parser):
+ parser.add_argument(
+ "data",
+ help="colon separated path to data directories list, \
+ will be iterated upon during epochs in round-robin manner",
+ action=FileContentsAction,
+ )
+ parser.add_argument(
+ "--langs",
+ default=None,
+ type=csv_str_list,
+ help="a list of languages comma sperated languages which can appear in lang-pairs; "
+ "note that the ordering determines language token IDs",
+ )
+ parser.add_argument(
+ "--lang-dict",
+ default=None,
+ type=str,
+ help="an external file which contains a list of "
+ "languages which can appear in lang-pairs; "
+ "note that the ordering determines language token IDs; "
+ "--langs and --lang-dict are two exclusive options",
+ )
+ parser.add_argument(
+ "--source-dict",
+ default=None,
+ type=str,
+ help="path to source dictionary; if specified it will override per language dictionary loading",
+ )
+ parser.add_argument(
+ "--target-dict",
+ default=None,
+ type=str,
+ help="path to target dictionary; if specified it will override per language dictionary loading",
+ )
+ parser.add_argument(
+ "--lang-tok-style",
+ default=LangTokStyle.multilingual.value,
+ type=str,
+ choices=[LangTokStyle.multilingual.value, LangTokStyle.mbart.value],
+ help="language token styles",
+ )
+
+ parser.add_argument(
+ "--load-alignments",
+ action="store_true",
+ help="load the binarized alignments",
+ )
+ parser.add_argument(
+ "--left-pad-source",
+ default="True",
+ type=str,
+ metavar="BOOL",
+ help="pad the source on the left",
+ )
+ parser.add_argument(
+ "--left-pad-target",
+ default="False",
+ type=str,
+ metavar="BOOL",
+ help="pad the target on the left",
+ )
+ try:
+ parser.add_argument(
+ "--max-source-positions",
+ default=1024,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the source sequence",
+ )
+ parser.add_argument(
+ "--max-target-positions",
+ default=1024,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the target sequence",
+ )
+ except ArgumentError:
+ # this might have already been defined. Once we transition this to hydra it should be fine to add it here.
+ pass
+ parser.add_argument(
+ "--upsample-primary",
+ default=1,
+ type=int,
+ help="amount to upsample primary dataset",
+ )
+ parser.add_argument(
+ "--truncate-source",
+ action="store_true",
+ default=False,
+ help="truncate source to max-source-positions",
+ )
+ parser.add_argument(
+ "--encoder-langtok",
+ default=None,
+ type=str,
+ choices=[EncoderLangtok.src.value, EncoderLangtok.tgt.value],
+ metavar="SRCTGT",
+ help="prepend to the beginning of source sentence the source or target "
+ "language token. (src/tgt)",
+ )
+ parser.add_argument(
+ "--decoder-langtok",
+ action="store_true",
+ help="prepend to the beginning of target sentence the target language token",
+ )
+ parser.add_argument(
+ "--lang-tok-replacing-bos-eos", action="store_true", default=False
+ )
+ parser.add_argument(
+ "--enable-lang-ids",
+ default=False,
+ action="store_true",
+ help="whether to include language IDs in samples",
+ )
+ parser.add_argument(
+ "--enable-reservsed-directions-shared-datasets",
+ default=False,
+ action="store_true",
+ help="whether to allow datasets be used in reversed directions",
+ )
+
+ parser.add_argument(
+ "--extra-data",
+ help='a dictionary of data name to this path, \
+ e.g. {"mined", path_to_mined_data, "denoised": path_to_denoised_data}',
+ type=lambda uf: eval_str_dict(uf, type=str),
+ default=None,
+ )
+ parser.add_argument(
+ "--extra-lang-pairs",
+ help='a dictionary of data name to the language pairs they serve, \
+ e.g. {"mined": comma-separated-lang-pairs, "denoised": comma-separated-lang-pairs}',
+ type=lambda uf: eval_str_dict(uf, type=str),
+ default=None,
+ )
+ parser.add_argument(
+ "--fixed-dictionary",
+ help="Fixed dictionary to use with model path",
+ default=None,
+ type=str,
+ )
+ parser.add_argument(
+ "--langtoks-specs",
+ help='a list of comma separated data types that a set of language tokens to be specialized for, \
+ e.g. "main,dae,mined". There will be a set of language tokens added to the vocab to \
+ distinguish languages in different training data types. If not specified, default language \
+ tokens per languages will be added',
+ default=LangTokSpec.main.value,
+ type=csv_str_list,
+ )
+ parser.add_argument(
+ "--langtoks",
+ help='a dictionary of how to add language tokens, \
+ e.g. {"mined": (None, "tgt"), "mono_dae": ("src.dae", "tgt"), "main": \
+ ("src", "tgt")}, or {"mined": ("src.mined", "tgt")}',
+ default=None,
+ type=lambda uf: eval_str_dict(uf, type=str),
+ )
+ parser.add_argument(
+ "--sampling-weights-from-file",
+ help='a file contain a python dictionary of how to sample data sets, \
+ e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \
+ "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }',
+ default=None,
+ type=str,
+ )
+ parser.add_argument(
+ "--sampling-weights",
+ help='a dictionary of how to sample data sets, \
+ e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \
+ "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }',
+ default=None,
+ type=lambda uf: eval_str_dict(uf, type=str),
+ )
+ parser.add_argument(
+ "--virtual-epoch-size",
+ default=None,
+ type=int,
+ help="virtual epoch size to speed up data loading",
+ )
+ parser.add_argument(
+ "--virtual-data-size",
+ default=None,
+ type=int,
+ help="virtual data size of the whole joint dataset to speed"
+ "up data loading and have specific dynamic sampling strategy interval",
+ )
+
+ @classmethod
+ def load_langs(cls, args, **kwargs):
+ if args.lang_dict and args.langs:
+ raise ValueError("--langs and --lang-dict can not both be specified")
+ if args.lang_dict is None and args.langs is None:
+ logger.warning(
+ "External language dictionary is not provided; "
+ "use lang-pairs to infer the set of supported languages. "
+ "The language ordering is not stable which might cause "
+ "misalignment in pretraining and finetuning."
+ )
+ # infer from lang_pairs as it is
+ langs = list(
+ {x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}
+ )
+ langs = sorted(langs)
+ logger.info(f"inferred language list: {langs}")
+ elif args.lang_dict:
+ with open(
+ PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8"
+ ) as f:
+ langs = [lang.strip() for lang in f.readlines() if lang.strip()]
+ logger.info(
+ f"loaded language list from {args.lang_dict} as they are ordered in file"
+ )
+ elif args.langs:
+ langs = args.langs
+ logger.info(
+ f"parsed the language list as they are ordered in the option: {langs}"
+ )
+ return langs
+
+ def has_sharded_data(self, split):
+ return self._has_sharded_data and split == getattr(
+ self.args, "train_subset", None
+ )
+
+ def _shared_collater(self):
+ return not (self.args.extra_data and "mono_dae" in self.args.extra_data) and (
+ not self.args.lang_tok_replacing_bos_eos
+ )
+
+ def estimate_global_pass_epoch(self, epoch):
+ if self.args.virtual_epoch_size is None or self.args.virtual_data_size is None:
+ return None
+ # one epoch more for remaining data in each shard
+ virtual_epochs_per_shard = math.ceil(
+ self.args.virtual_data_size / self.args.virtual_epoch_size
+ )
+ # note that fairseq epoch / shard_epoch starts from 1
+ shard_epoch = (epoch - 1) // virtual_epochs_per_shard + 1
+ return shard_epoch
+
+ @classmethod
+ def prepare(cls, load_dictionary, args, **kargs):
+ args.left_pad_source = utils.eval_bool(args.left_pad_source)
+ args.left_pad_target = utils.eval_bool(args.left_pad_target)
+
+ if not hasattr(args, "shuffle_instance"):
+ args.shuffle_instance = False
+ if args.langtoks is None:
+ args.langtoks = {}
+ if "main" not in args.langtoks:
+ src_langtok_spec = args.encoder_langtok if args.encoder_langtok else None
+ tgt_langtok_spec = "tgt" if args.decoder_langtok else None
+ args.langtoks["main"] = (src_langtok_spec, tgt_langtok_spec)
+
+ def check_langs(langs, pairs):
+ messages = []
+ for src, tgt in pairs:
+ if src not in langs or tgt not in langs:
+ messages.append(
+ f"language pair {src}-{tgt} contains languages "
+ "that are not in the language dictionary"
+ )
+ if len(messages) > 0:
+ raise ValueError(" ".join(messages) + f"; langs: {langs}")
+
+ if args.lang_pairs is None:
+ raise ValueError(
+ "--lang-pairs is required. List all the language pairs in the training objective."
+ )
+ if isinstance(args.lang_pairs, str):
+ args.lang_pairs = args.lang_pairs.split(",")
+ if args.source_lang is not None or args.target_lang is not None:
+ training = False
+ else:
+ training = True
+ language_list = cls.load_langs(args, **kargs)
+ check_langs(
+ language_list,
+ (
+ [p.split("-") for p in args.lang_pairs]
+ if training
+ else [(args.source_lang, args.target_lang)]
+ ),
+ )
+
+ def load_dictionary_and_postproc(path):
+ d = load_dictionary(path)
+ augment_dictionary(
+ dictionary=d,
+ language_list=language_list,
+ lang_tok_style=args.lang_tok_style,
+ langtoks_specs=args.langtoks_specs,
+ extra_data=args.extra_data,
+ )
+ return d
+
+ dicts = cls.load_all_dictionaries(
+ args, language_list, load_dictionary_and_postproc, training
+ )
+ return language_list, dicts, training
+
+ @classmethod
+ def load_all_dictionaries(cls, args, language_list, load_dictionary, training):
+ dicts = OrderedDict()
+ if args.source_dict is not None:
+ dicts[SRC_DICT_NAME] = load_dictionary(args.source_dict)
+ if args.target_dict is not None:
+ dicts[TGT_DICT_NAME] = load_dictionary(args.target_dict)
+
+ if training:
+ extra_lang_pairs = (
+ list(
+ {p for _, v in args.extra_lang_pairs.items() for p in v.split(",")}
+ )
+ if args.extra_lang_pairs
+ else []
+ )
+ src_langs_to_load_dicts = sorted(
+ {p.split("-")[0] for p in (args.lang_pairs + extra_lang_pairs)}
+ )
+ tgt_langs_to_load_dicts = sorted(
+ {p.split("-")[1] for p in (args.lang_pairs + extra_lang_pairs)}
+ )
+ else:
+ src_langs_to_load_dicts = [args.source_lang]
+ tgt_langs_to_load_dicts = [args.target_lang]
+
+ paths = utils.split_paths(args.data)
+ assert len(paths) > 0
+
+ def load_dicts(langs_to_load_dicts):
+ for lang in langs_to_load_dicts:
+ dicts[lang] = load_dictionary(
+ os.path.join(paths[0], "dict.{}.txt".format(lang))
+ )
+ if len(dicts) > 0:
+ dict0 = next(iter(dicts.values()))
+ assert dicts[lang].pad() == dict0.pad()
+ assert dicts[lang].eos() == dict0.eos()
+ assert dicts[lang].unk() == dict0.unk()
+ logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang])))
+
+ if args.fixed_dictionary is not None:
+ fixed_dict = load_dictionary(args.fixed_dictionary)
+ dicts = {
+ lang: fixed_dict
+ for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts
+ }
+ else:
+ if args.source_dict is None:
+ load_dicts(src_langs_to_load_dicts)
+ if args.target_dict is None:
+ load_dicts(tgt_langs_to_load_dicts)
+ return dicts
+
+ def get_source_dictionary(self, lang):
+ if self.args.source_dict is not None:
+ return self.dicts[SRC_DICT_NAME]
+ else:
+ return self.dicts[lang]
+
+ def get_target_dictionary(self, lang):
+ if self.args.target_dict is not None:
+ return self.dicts[TGT_DICT_NAME]
+ else:
+ return self.dicts[lang]
+
+ @classmethod
+ def create_lang_dictionary(cls, langs):
+ unk = ""
+ # hack to remove symbols other than unk as they are not needed by lang dict
+ lang_dict = Dictionary(pad=unk, eos=unk, unk=unk, bos=unk)
+ for lang in langs:
+ lang_dict.add_symbol(lang)
+ return lang_dict
+
+ @classmethod
+ def get_langtok_index(cls, lang_tok, dic):
+ idx = dic.index(lang_tok)
+ assert (
+ idx != dic.unk_index
+ ), "cannot find language token {} in the dictionary".format(lang_tok)
+ return idx
+
+ def get_encoder_langtok(self, src_lang, tgt_lang, spec=None):
+ if spec is None:
+ return None
+ if spec and spec.startswith("src"):
+ if src_lang is None:
+ return None
+ langtok = get_lang_tok(
+ lang=src_lang, lang_tok_style=self.args.lang_tok_style, spec=spec
+ )
+ else:
+ if tgt_lang is None:
+ return None
+ langtok = get_lang_tok(
+ lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec
+ )
+ return self.get_langtok_index(
+ langtok,
+ self.get_source_dictionary(src_lang)
+ if src_lang
+ else self.get_target_dictionary(tgt_lang),
+ )
+
+ def get_decoder_langtok(self, tgt_lang, spec=None):
+ if spec is None:
+ return None
+ langtok = get_lang_tok(
+ lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec
+ )
+ return self.get_langtok_index(langtok, self.get_target_dictionary(tgt_lang))
+
+ @classmethod
+ def load_data(cls, path, vdict, impl):
+ dataset = data_utils.load_indexed_dataset(path, vdict, impl)
+ return dataset
+
+ @classmethod
+ def split_exists(cls, split, src, tgt, lang, data_path, dataset_impl):
+ filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
+ return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
+
+ def load_lang_dataset(
+ self,
+ data_path,
+ split,
+ src,
+ src_dict,
+ tgt,
+ tgt_dict,
+ combine,
+ dataset_impl,
+ upsample_primary,
+ max_source_positions,
+ prepend_bos=False,
+ load_alignments=False,
+ truncate_source=False,
+ ):
+
+ src_datasets = []
+ tgt_datasets = []
+
+ for k in itertools.count():
+ split_k = split + (str(k) if k > 0 else "")
+
+ # infer langcode
+ if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl):
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
+ elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl):
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
+ else:
+ if k > 0:
+ break
+ else:
+ logger.error(
+ f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}"
+ )
+ raise FileNotFoundError(
+ "Dataset not found: {} ({})".format(split, data_path)
+ )
+
+ src_dataset = self.load_data(prefix + src, src_dict, dataset_impl)
+ if truncate_source:
+ src_dataset = AppendTokenDataset(
+ TruncateDataset(
+ StripTokenDataset(src_dataset, src_dict.eos()),
+ max_source_positions - 1,
+ ),
+ src_dict.eos(),
+ )
+ src_datasets.append(src_dataset)
+ tgt_datasets.append(self.load_data(prefix + tgt, tgt_dict, dataset_impl))
+
+ logger.info(
+ "{} {} {}-{} {} examples".format(
+ data_path, split_k, src, tgt, len(src_datasets[-1])
+ )
+ )
+
+ if not combine:
+ break
+
+ assert len(src_datasets) == len(tgt_datasets)
+
+ if len(src_datasets) == 1:
+ src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
+ else:
+ sample_ratios = [1] * len(src_datasets)
+ sample_ratios[0] = upsample_primary
+ src_dataset = ConcatDataset(src_datasets, sample_ratios)
+ tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
+
+ if prepend_bos:
+ assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
+ src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
+ tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
+
+ align_dataset = None
+ if load_alignments:
+ align_path = os.path.join(
+ data_path, "{}.align.{}-{}".format(split, src, tgt)
+ )
+ if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
+ align_dataset = data_utils.load_indexed_dataset(
+ align_path, None, dataset_impl
+ )
+
+ return src_dataset, tgt_dataset, align_dataset
+
+ def load_langpair_dataset(
+ self,
+ data_path,
+ split,
+ src,
+ src_dict,
+ tgt,
+ tgt_dict,
+ combine,
+ dataset_impl,
+ upsample_primary,
+ left_pad_source,
+ left_pad_target,
+ max_source_positions,
+ max_target_positions,
+ prepend_bos=False,
+ load_alignments=False,
+ truncate_source=False,
+ src_dataset_transform_func=lambda dataset: dataset,
+ tgt_dataset_transform_func=lambda dataset: dataset,
+ src_lang_id=None,
+ tgt_lang_id=None,
+ langpairs_sharing_datasets=None,
+ ):
+ norm_direction = "-".join(sorted([src, tgt]))
+ if langpairs_sharing_datasets is not None:
+ src_dataset = langpairs_sharing_datasets.get(
+ (data_path, split, norm_direction, src), "NotInCache"
+ )
+ tgt_dataset = langpairs_sharing_datasets.get(
+ (data_path, split, norm_direction, tgt), "NotInCache"
+ )
+ align_dataset = langpairs_sharing_datasets.get(
+ (data_path, split, norm_direction, src, tgt), "NotInCache"
+ )
+
+ # a hack: any one is not in cache, we need to reload them
+ if (
+ langpairs_sharing_datasets is None
+ or src_dataset == "NotInCache"
+ or tgt_dataset == "NotInCache"
+ or align_dataset == "NotInCache"
+ or split != getattr(self.args, "train_subset", None)
+ ):
+ # source and target datasets can be reused in reversed directions to save memory
+ # reversed directions of valid and test data will not share source and target datasets
+ src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset(
+ data_path,
+ split,
+ src,
+ src_dict,
+ tgt,
+ tgt_dict,
+ combine,
+ dataset_impl,
+ upsample_primary,
+ max_source_positions=max_source_positions,
+ prepend_bos=prepend_bos,
+ load_alignments=load_alignments,
+ truncate_source=truncate_source,
+ )
+ src_dataset = src_dataset_transform_func(src_dataset)
+ tgt_dataset = tgt_dataset_transform_func(tgt_dataset)
+ if langpairs_sharing_datasets is not None:
+ langpairs_sharing_datasets[
+ (data_path, split, norm_direction, src)
+ ] = src_dataset
+ langpairs_sharing_datasets[
+ (data_path, split, norm_direction, tgt)
+ ] = tgt_dataset
+ langpairs_sharing_datasets[
+ (data_path, split, norm_direction, src, tgt)
+ ] = align_dataset
+ if align_dataset is None:
+ # no align data so flag the reverse direction as well in sharing
+ langpairs_sharing_datasets[
+ (data_path, split, norm_direction, tgt, src)
+ ] = align_dataset
+ else:
+ logger.info(
+ f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: "
+ f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}"
+ )
+
+ return LanguagePairDataset(
+ src_dataset,
+ src_dataset.sizes,
+ src_dict,
+ tgt_dataset,
+ tgt_dataset.sizes if tgt_dataset is not None else None,
+ tgt_dict,
+ left_pad_source=left_pad_source,
+ left_pad_target=left_pad_target,
+ align_dataset=align_dataset,
+ src_lang_id=src_lang_id,
+ tgt_lang_id=tgt_lang_id,
+ )
+
+ def src_dataset_tranform_func(self, src_lang, tgt_lang, dataset, spec=None):
+ if self.args.lang_tok_replacing_bos_eos:
+ # it is handled by self.alter_dataset_langtok
+ # TODO: Unifiy with alter_dataset_langtok
+ return dataset
+ if spec is None:
+ return dataset
+ tok = self.get_encoder_langtok(src_lang, tgt_lang, spec)
+ if tok:
+ return PrependTokenDataset(dataset, tok)
+ return dataset
+
+ def tgt_dataset_tranform_func(self, source_lang, target_lang, dataset, spec=None):
+ if dataset is None:
+ # note that target dataset can be None during inference time
+ return None
+ if self.args.lang_tok_replacing_bos_eos:
+ # TODO: Unifiy with alter_dataset_langtok
+ # It is handled by self.alter_dataset_langtok.
+ # The complication in self.alter_dataset_langtok
+ # makes a unified framework difficult.
+ return dataset
+ # if not self.args.decoder_langtok:
+ if not spec:
+ return dataset
+ tok = self.get_decoder_langtok(target_lang, spec)
+ if tok:
+ return PrependTokenDataset(dataset, tok)
+ return dataset
+
+ def alter_dataset_langtok(
+ self,
+ lang_pair_dataset,
+ src_eos=None,
+ src_lang=None,
+ tgt_eos=None,
+ tgt_lang=None,
+ src_langtok_spec=None,
+ tgt_langtok_spec=None,
+ ):
+ if src_langtok_spec is None and tgt_langtok_spec is None:
+ return lang_pair_dataset
+
+ new_src_eos = None
+ if (
+ src_langtok_spec is not None
+ and src_eos is not None
+ and (src_lang is not None or tgt_lang is not None)
+ ):
+ new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang, src_langtok_spec)
+ else:
+ src_eos = None
+
+ new_tgt_bos = None
+ if tgt_langtok_spec and tgt_eos is not None and tgt_lang is not None:
+ new_tgt_bos = self.get_decoder_langtok(tgt_lang, tgt_langtok_spec)
+ else:
+ tgt_eos = None
+
+ return TransformEosLangPairDataset(
+ lang_pair_dataset,
+ src_eos=src_eos,
+ new_src_eos=new_src_eos,
+ tgt_bos=tgt_eos,
+ new_tgt_bos=new_tgt_bos,
+ )
+
+ def load_a_dataset(
+ self,
+ split,
+ data_path,
+ src,
+ src_dict,
+ tgt,
+ tgt_dict,
+ combine,
+ prepend_bos=False,
+ langpairs_sharing_datasets=None,
+ data_category=None,
+ **extra_kwargs,
+ ):
+ dataset_impl = self.args.dataset_impl
+ upsample_primary = self.args.upsample_primary
+ left_pad_source = self.args.left_pad_source
+ left_pad_target = self.args.left_pad_target
+ max_source_positions = self.args.max_source_positions
+ max_target_positions = self.args.max_target_positions
+ load_alignments = self.args.load_alignments
+ truncate_source = self.args.truncate_source
+ src_dataset_transform_func = self.src_dataset_tranform_func
+ tgt_dataset_transform_func = self.tgt_dataset_tranform_func
+ enable_lang_ids = self.args.enable_lang_ids
+ lang_dictionary = self.lang_dict
+ src_langtok_spec, tgt_langtok_spec = extra_kwargs["langtok_spec"]
+
+ src_langtok = self.get_encoder_langtok(src, tgt, src_langtok_spec)
+ tgt_langtok = self.get_decoder_langtok(tgt, tgt_langtok_spec)
+ logger.info(
+ f"{data_category}:{src}-{tgt} src_langtok: {src_langtok}; tgt_langtok: {tgt_langtok}"
+ )
+
+ langpair_ds = self.load_langpair_dataset(
+ data_path,
+ split,
+ src,
+ src_dict,
+ tgt,
+ tgt_dict,
+ combine,
+ dataset_impl,
+ upsample_primary,
+ left_pad_source,
+ left_pad_target,
+ max_source_positions,
+ max_target_positions,
+ prepend_bos,
+ load_alignments,
+ truncate_source,
+ src_dataset_transform_func=lambda dataset: src_dataset_transform_func(
+ src, tgt, dataset, src_langtok_spec
+ ),
+ tgt_dataset_transform_func=lambda dataset: tgt_dataset_transform_func(
+ src, tgt, dataset, tgt_langtok_spec
+ ),
+ src_lang_id=_lang_id(lang_dictionary, src)
+ if enable_lang_ids and lang_dictionary is not None
+ else None,
+ tgt_lang_id=_lang_id(lang_dictionary, tgt)
+ if enable_lang_ids and lang_dictionary is not None
+ else None,
+ langpairs_sharing_datasets=langpairs_sharing_datasets,
+ )
+ # TODO: handle modified lang toks for mined data and dae data
+ if self.args.lang_tok_replacing_bos_eos:
+ ds = self.alter_dataset_langtok(
+ langpair_ds,
+ src_eos=self.get_source_dictionary(src).eos()
+ if src
+ else self.get_target_dictionary(tgt).eos(),
+ src_lang=src,
+ tgt_eos=self.get_target_dictionary(tgt).eos(),
+ tgt_lang=tgt,
+ src_langtok_spec=src_langtok_spec,
+ tgt_langtok_spec=tgt_langtok_spec,
+ )
+ else:
+ ds = langpair_ds
+ return ds
+
+ def load_split_langpair_datasets(self, split, data_param_list):
+ datasets = []
+ langpairs_sharing_datasets = (
+ {} if self.args.enable_reservsed_directions_shared_datasets else None
+ )
+ for param in data_param_list:
+ ds = self.load_a_dataset(
+ split=split,
+ langpairs_sharing_datasets=langpairs_sharing_datasets,
+ **param,
+ )
+ datasets.append(ds)
+ return datasets
+
+ def get_data_paths_and_lang_pairs(self, split):
+ datapaths = {"main": self.args.data}
+ lang_pairs = {"main": self.lang_pairs}
+ if split == getattr(self.args, "train_subset", None):
+ # only training data can have extra data and extra language pairs
+ if self.args.extra_data:
+ extra_datapaths = self.args.extra_data
+ datapaths.update(extra_datapaths)
+ if self.args.extra_lang_pairs:
+ extra_lang_pairs = {
+ k: v.split(",") for k, v in self.args.extra_lang_pairs.items()
+ }
+ lang_pairs.update(extra_lang_pairs)
+ return datapaths, lang_pairs
+
+ @classmethod
+ def get_dataset_key(cls, data_category, src, tgt):
+ return f"{data_category}:{src}-{tgt}"
+
+ @classmethod
+ def _get_shard_num_dict(cls, split, paths):
+ shards = defaultdict(int)
+ for path in paths:
+ files = PathManager.ls(path)
+ directions = set()
+ for f in files:
+ if f.startswith(split) and f.endswith(".idx"):
+ # idx files of the form "{split}.{src}-{tgt}.{lang}.idx"
+ direction = f.split(".")[-3]
+ directions.add(direction)
+ for direction in directions:
+ shards[direction] += 1
+ return shards
+
+ def get_split_num_data_shards(self, split):
+ if split in self._num_shards_dict:
+ return self._num_shards_dict[split]
+ num_shards_dict = {}
+ data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split)
+
+ for data_category, paths in data_paths.items():
+ if data_category not in lang_pairs:
+ continue
+ paths = utils.split_paths(paths)
+ shards_dict = self._get_shard_num_dict(split, paths)
+ lang_dirs = [
+ lang_pair.split("-") for lang_pair in lang_pairs[data_category]
+ ]
+ lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs]
+ for src, tgt in lang_dirs:
+ key = self.get_dataset_key(data_category, src, tgt)
+ if "mono_" in data_category:
+ # monolingual data requires tgt only
+ assert src is None or src == tgt, (
+ f"error: src={src}, "
+ f"tgt={tgt} for data_category={data_category}"
+ )
+ num_shards_dict[key] = shards_dict[tgt]
+ else:
+ if f"{src}-{tgt}" in shards_dict:
+ num_shards_dict[key] = shards_dict[f"{src}-{tgt}"]
+ elif f"{tgt}-{src}" in shards_dict:
+ # follow the fairseq tradition to use reversed direction data if it is not available
+ num_shards_dict[key] = shards_dict[f"{tgt}-{src}"]
+ self._num_shards_dict[split] = num_shards_dict
+ logger.info(f"[{split}] num of shards: {num_shards_dict}")
+ return num_shards_dict
+
+ @classmethod
+ def get_shard_id(cls, num_shards, epoch, shard_epoch=None):
+ shard = epoch if shard_epoch is None else shard_epoch
+ shard = (shard - 1) % num_shards
+ return shard
+
+ def get_split_data_path(self, paths, epoch, shard_epoch, num_shards):
+ path = paths[self.get_shard_id(num_shards, epoch, shard_epoch)]
+ return path
+
+ def get_split_data_param_list(self, split, epoch, shard_epoch=None):
+ # TODO: to extend with extra datasets and keys and loop over different shard data paths
+ param_list = []
+ data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split)
+ logger.info(f"langtoks settings: {self.args.langtoks}")
+ split_num_shards_dict = self.get_split_num_data_shards(split)
+ for data_category, paths in data_paths.items():
+ if data_category not in lang_pairs:
+ continue
+ paths = utils.split_paths(paths)
+ assert len(paths) > 0
+ if len(paths) > 1:
+ self._has_sharded_data = True
+ if split != getattr(self.args, "train_subset", None):
+ # if not training data set, use the first shard for valid and test
+ paths = paths[:1]
+
+ if data_category in self.args.langtoks:
+ lang_tok_spec = self.args.langtoks[data_category]
+ else:
+ # default to None
+ lang_tok_spec = (None, None)
+
+ # infer langcode
+ lang_dirs = [
+ lang_pair.split("-") for lang_pair in lang_pairs[data_category]
+ ]
+ lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs]
+ for src, tgt in lang_dirs:
+ assert src is not None or data_category == "mono_dae", (
+ f"error: src={src}, " f"tgt={tgt} for data_category={data_category}"
+ )
+ # logger.info(f"preparing param for {data_category}: {src} - {tgt}")
+ key = self.get_dataset_key(data_category, src, tgt)
+ data_path = self.get_split_data_path(
+ paths, epoch, shard_epoch, split_num_shards_dict[key]
+ )
+ param_list.append(
+ {
+ "key": key,
+ "data_path": data_path,
+ "split": split,
+ "src": src,
+ "src_dict": self.get_source_dictionary(src)
+ if src and data_category != "mono_dae"
+ else None,
+ "tgt": tgt,
+ "tgt_dict": self.get_target_dictionary(tgt),
+ "data_category": data_category,
+ "langtok_spec": lang_tok_spec,
+ }
+ )
+ return param_list
+
+ def get_train_dataset_sizes(
+ self, data_param_list, datasets, epoch, shard_epoch=None
+ ):
+ num_shards = [
+ self.get_split_num_data_shards(param["split"])[param["key"]]
+ for param in data_param_list
+ ]
+ data_sizes = []
+ for (key, d), num_shard in zip(datasets, num_shards):
+ my_data_sizes = self._training_data_sizes[key]
+ shard_ind = self.get_shard_id(num_shard, epoch, shard_epoch)
+ if shard_ind not in my_data_sizes:
+ my_data_sizes[shard_ind] = len(d)
+ known_size = max(my_data_sizes.values())
+ data_sizes.append(
+ # If we don't know the data size of the shard yet,
+ # use the the max known data size to approximate.
+ # Note that we preprocess shards by a designated shard size
+ # and put any remaining data at the end into the last shard so
+ # the max shard size approximation is almost correct before loading
+ # the last shard; after loading the last shard, it will have the
+ # exact data sizes of the whole data size.
+ (key, sum(my_data_sizes.get(i, known_size) for i in range(num_shard)))
+ )
+ logger.info(
+ f"estimated total data sizes of all shards used in sampling ratios: {data_sizes}. "
+ "Note that if the data a shard has not been loaded yet, use the max known data size to approximate"
+ )
+ return [s for _, s in data_sizes]
+
+ def get_train_sampling_ratios(
+ self, data_param_list, datasets, epoch=1, shard_epoch=None
+ ):
+ data_sizes = self.get_train_dataset_sizes(
+ data_param_list, datasets, epoch, shard_epoch
+ )
+ sampling_func = self.sampling_method.sampling_method_selector()
+ sample_ratios = sampling_func(data_sizes) if sampling_func is not None else None
+ return sample_ratios
+
+ def get_sampling_ratios(self, data_param_list, datasets, epoch, shard_epoch=None):
+ if self.args.sampling_weights_from_file:
+ weights = load_sampling_weights(self.args.sampling_weights_from_file)
+ sample_ratios = [weights[k] for k, _ in datasets]
+ logger.info(
+ "| ignoring --sampling-weights when loadding sampling weights "
+ f"from file {self.args.sampling_weights_from_file}"
+ )
+ elif self.args.sampling_weights:
+ sample_ratios = [self.args.sampling_weights[k] for k, _ in datasets]
+ else:
+ sample_ratios = self.get_train_sampling_ratios(
+ data_param_list, datasets, epoch, shard_epoch
+ )
+
+ if sample_ratios is not None:
+ logger.info(
+ "| Upsample ratios: {}".format(
+ list(zip(map(lambda x: x["key"], data_param_list), sample_ratios))
+ )
+ )
+ assert len(sample_ratios) == len(datasets)
+ return sample_ratios
+
+ def load_split_datasets(
+ self, split, training, epoch=1, combine=False, shard_epoch=None, **kwargs
+ ):
+ data_param_list = self.get_split_data_param_list(
+ split, epoch, shard_epoch=shard_epoch
+ )
+ langpairs_sharing_datasets = (
+ {} if self.args.enable_reservsed_directions_shared_datasets else None
+ )
+ datasets = [
+ (
+ param["key"],
+ self.load_a_dataset(
+ combine=combine,
+ langpairs_sharing_datasets=langpairs_sharing_datasets,
+ **param,
+ ),
+ )
+ for param in data_param_list
+ ]
+ return datasets, data_param_list
+
+ def load_into_concat_dataset(self, split, datasets, data_param_list):
+ if self.args.lang_tok_replacing_bos_eos:
+ # TODO: to investigate why TransformEosLangPairDataset doesn't work with ConcatDataset
+ return SampledMultiDataset(
+ OrderedDict(datasets),
+ sampling_ratios=None,
+ eval_key=None,
+ collate_format=CollateFormat.single,
+ virtual_size=None,
+ split=split,
+ )
+ return ConcatDataset([d for _, d in datasets])
+
+ def load_sampled_multi_epoch_dataset(
+ self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs
+ ):
+ datasets, data_param_list = self.load_split_datasets(
+ split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs
+ )
+ if training and split == getattr(self.args, "train_subset", None):
+ sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch)
+ return SampledMultiEpochDataset(
+ OrderedDict(datasets),
+ epoch=epoch,
+ shard_epoch=shard_epoch,
+ # valid and test datasets will be degenerate to concating datasets:
+ sampling_ratios=sample_ratios,
+ eval_key=None,
+ collate_format=CollateFormat.single,
+ virtual_size=self.args.virtual_data_size,
+ split=split,
+ virtual_epoch_size=self.args.virtual_epoch_size,
+ # if not using lang_tok altering, simplified to use the same collater
+ shared_collater=self._shared_collater(),
+ )
+ else:
+ return self.load_into_concat_dataset(split, datasets, data_param_list)
+
+ def load_sampled_multi_dataset(
+ self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs
+ ):
+ datasets, data_param_list = self.load_split_datasets(
+ split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs
+ )
+ if training and split == getattr(self.args, "train_subset", None):
+ sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch)
+ return SampledMultiDataset(
+ OrderedDict(datasets),
+ epoch=epoch,
+ # valid and test datasets will be degerate to concating datasets:
+ sampling_ratios=sample_ratios,
+ eval_key=None,
+ collate_format=CollateFormat.single,
+ virtual_size=self.args.virtual_data_size,
+ split=split,
+ # if not using lang_tok altering, simplified to use the same collater
+ shared_collater=self._shared_collater(),
+ )
+ else:
+ return self.load_into_concat_dataset(split, datasets, data_param_list)
+
+ def load_dataset(
+ self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs
+ ):
+ if self.args.virtual_epoch_size is None:
+ return self.load_sampled_multi_dataset(
+ split, training, epoch, combine, shard_epoch, **kwargs
+ )
+ else:
+ return self.load_sampled_multi_epoch_dataset(
+ split, training, epoch, combine, shard_epoch, **kwargs
+ )
diff --git a/fairseq/fairseq/data/multilingual/multilingual_utils.py b/fairseq/fairseq/data/multilingual/multilingual_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e0f9828cabfdbe375d05d9152b58bdbd6de7dc
--- /dev/null
+++ b/fairseq/fairseq/data/multilingual/multilingual_utils.py
@@ -0,0 +1,63 @@
+from enum import Enum
+from typing import Dict, List, Optional, Sequence
+
+import torch
+from fairseq.data import Dictionary
+
+
+class EncoderLangtok(Enum):
+ """
+ Prepend to the beginning of source sentence either the
+ source or target language token. (src/tgt).
+ """
+
+ src = "src"
+ tgt = "tgt"
+
+
+class LangTokSpec(Enum):
+ main = "main"
+ mono_dae = "mono_dae"
+
+
+class LangTokStyle(Enum):
+ multilingual = "multilingual"
+ mbart = "mbart"
+
+
+@torch.jit.export
+def get_lang_tok(
+ lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value
+) -> str:
+ # TOKEN_STYLES can't be defined outside this fn since it needs to be
+ # TorchScriptable.
+ TOKEN_STYLES: Dict[str, str] = {
+ LangTokStyle.mbart.value: "[{}]",
+ LangTokStyle.multilingual.value: "__{}__",
+ }
+
+ if spec.endswith("dae"):
+ lang = f"{lang}_dae"
+ elif spec.endswith("mined"):
+ lang = f"{lang}_mined"
+ style = TOKEN_STYLES[lang_tok_style]
+ return style.format(lang)
+
+
+def augment_dictionary(
+ dictionary: Dictionary,
+ language_list: List[str],
+ lang_tok_style: str,
+ langtoks_specs: Sequence[str] = (LangTokSpec.main.value,),
+ extra_data: Optional[Dict[str, str]] = None,
+) -> None:
+ for spec in langtoks_specs:
+ for language in language_list:
+ dictionary.add_symbol(
+ get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec)
+ )
+
+ if lang_tok_style == LangTokStyle.mbart.value or (
+ extra_data is not None and LangTokSpec.mono_dae.value in extra_data
+ ):
+ dictionary.add_symbol("")
diff --git a/fairseq/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/fairseq/data/multilingual/sampled_multi_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece9a9721e453112553d7f41755133b1c937e14e
--- /dev/null
+++ b/fairseq/fairseq/data/multilingual/sampled_multi_dataset.py
@@ -0,0 +1,468 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import datetime
+import hashlib
+import logging
+import time
+from bisect import bisect_right
+from collections import OrderedDict, defaultdict
+from enum import Enum
+from typing import List
+
+import numpy as np
+import torch
+
+from fairseq.data import FairseqDataset, data_utils
+from fairseq.distributed import utils as distributed_utils
+
+
+def get_time_gap(s, e):
+ return (
+ datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
+ ).__str__()
+
+
+logger = logging.getLogger(__name__)
+
+
+def default_virtual_size_func(datasets, ratios, max_scale_up=1.5):
+ sizes = [len(d) for d in datasets]
+ if ratios is None:
+ return sum(sizes)
+ largest_idx = np.argmax(sizes)
+ largest_r = ratios[largest_idx]
+ largest_s = sizes[largest_idx]
+ # set virtual sizes relative to the largest dataset
+ virtual_sizes = [(r / largest_r) * largest_s for r in ratios]
+ vsize = sum(virtual_sizes)
+ max_size = sum(sizes) * max_scale_up
+ return int(vsize if vsize < max_size else max_size)
+
+
+class CollateFormat(Enum):
+ single = 1
+ ordered_dict = 2
+
+
+class SampledMultiDataset(FairseqDataset):
+ """Samples from multiple sub-datasets according to given sampling ratios.
+ Args:
+ datasets (
+ List[~torch.utils.data.Dataset]
+ or OrderedDict[str, ~torch.utils.data.Dataset]
+ ): datasets
+ sampling_ratios (List[float]): list of probability of each dataset to be sampled
+ (default: None, which corresponds to concatenating all dataset together).
+ seed (int): RNG seed to use (default: 2).
+ epoch (int): starting epoch number (default: 1).
+ eval_key (str, optional): a key used at evaluation time that causes
+ this instance to pass-through batches from *datasets[eval_key]*.
+ collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or
+ CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
+ the collater to output batches of data mixed from all sub-datasets,
+ and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
+ of sub-datasets.
+ Note that not all sub-datasets will present in a single batch in both formats.
+ virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
+ split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
+ shared_collater (bool): whether or not to all sub-datasets have the same collater.
+ shuffle (bool): whether or not to shuffle data (default: True).
+ """
+
+ def __init__(
+ self,
+ datasets,
+ sampling_ratios=None,
+ seed=2,
+ epoch=1,
+ eval_key=None,
+ collate_format=CollateFormat.single,
+ virtual_size=default_virtual_size_func,
+ split="",
+ shared_collater=False,
+ shuffle=True,
+ ):
+ super().__init__()
+ self.shared_collater = shared_collater
+ self.shuffle = shuffle
+
+ if isinstance(datasets, OrderedDict):
+ self.keys = list(datasets.keys())
+ datasets = list(datasets.values())
+ elif isinstance(datasets, List):
+ self.keys = list(range(len(datasets)))
+ else:
+ raise AssertionError()
+ self.datasets = datasets
+ self.split = split
+
+ self.eval_key = eval_key
+ if self.eval_key is not None:
+ self.collate_format = CollateFormat.single
+ else:
+ self.collate_format = collate_format
+
+ self.seed = seed
+ self._cur_epoch = None
+
+ self.cumulated_sizes = None
+ # self.datasets[k][self._cur_indices[i]] is the data item i in this sampled dataset
+ # namely, data item i is sampled from the kth sub-dataset self.datasets[k]
+ # where self.cumulated_sizes[k-1] <= i < self.cumulated_sizes[k]
+ self._cur_indices = None
+
+ self._sizes = None
+ self.virtual_size_per_dataset = None
+ # caching properties
+ self._reset_cached_properties()
+ self.setup_sampling(sampling_ratios, virtual_size)
+ self.set_epoch(epoch)
+
+ def _clean_if_not_none(self, var_list):
+ for v in var_list:
+ if v is not None:
+ del v
+
+ def _reset_cached_properties(self):
+ self._clean_if_not_none([self._sizes, self._cur_indices])
+ self._sizes = None
+ self._cur_indices = None
+
+ def setup_sampling(self, sample_ratios, virtual_size):
+ sizes = [len(d) for d in self.datasets]
+ if sample_ratios is None:
+ # default back to concating datasets
+ self.sample_ratios = None
+ self.virtual_size = sum(sizes)
+ else:
+ if not isinstance(sample_ratios, np.ndarray):
+ sample_ratios = np.array(sample_ratios)
+ self.sample_ratios = sample_ratios
+ virtual_size = (
+ default_virtual_size_func if virtual_size is None else virtual_size
+ )
+ self.virtual_size = (
+ virtual_size(self.datasets, self.sample_ratios)
+ if callable(virtual_size)
+ else virtual_size
+ )
+
+ def adjust_sampling(self, epoch, sampling_ratios, virtual_size):
+ if sampling_ratios is not None:
+ sampling_ratios = self._sync_sample_ratios(sampling_ratios)
+ self.setup_sampling(sampling_ratios, virtual_size)
+
+ def _sync_sample_ratios(self, ratios):
+ # in case the ratios are not precisely the same across processes
+ # also to ensure every procresses update the ratios in the same pace
+ ratios = torch.DoubleTensor(ratios)
+ if torch.distributed.is_initialized():
+ if torch.cuda.is_available():
+ distributed_utils.all_reduce(
+ ratios.cuda(), group=distributed_utils.get_data_parallel_group()
+ )
+ else:
+ distributed_utils.all_reduce(
+ ratios, group=distributed_utils.get_data_parallel_group()
+ )
+ ret = ratios.cpu()
+ ret = ret.numpy()
+ return ret
+
+ def random_choice_in_dataset(self, rng, dataset, choice_size):
+ if hasattr(dataset, "random_choice_in_dataset"):
+ return dataset.random_choice_in_dataset(rng, choice_size)
+ dataset_size = len(dataset)
+ return rng.choice(
+ dataset_size, choice_size, replace=(choice_size > dataset_size)
+ )
+
+ def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size):
+ def get_counts(sample_ratios):
+ counts = np.array([virtual_size * r for r in sample_ratios], dtype=np.int64)
+ diff = virtual_size - counts.sum()
+ assert diff >= 0
+ # due to round-offs, the size might not match the desired sizes
+ if diff > 0:
+ dataset_indices = rng.choice(
+ len(sample_ratios), size=diff, p=sample_ratios
+ )
+ for i in dataset_indices:
+ counts[i] += 1
+ return counts
+
+ def get_in_dataset_indices(datasets, sizes, sample_ratios):
+ counts = get_counts(sample_ratios)
+ # uniformally sample desired counts for each dataset
+ # if the desired counts are large, sample with replacement:
+ indices = [
+ self.random_choice_in_dataset(rng, d, c)
+ for c, d in zip(counts, datasets)
+ ]
+ return indices
+
+ sizes = [len(d) for d in datasets]
+ if sample_ratios is None:
+ # default back to concating datasets
+ in_dataset_indices = [list(range(s)) for s in sizes]
+ virtual_sizes_per_dataset = sizes
+ else:
+ ratios = sample_ratios / sample_ratios.sum()
+ in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios)
+ virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices]
+ virtual_sizes_per_dataset = np.array(virtual_sizes_per_dataset, np.int64)
+ cumulative_sizes = np.cumsum(virtual_sizes_per_dataset)
+ assert sum(virtual_sizes_per_dataset) == virtual_size
+ assert cumulative_sizes[-1] == virtual_size
+ if virtual_size < sum(sizes):
+ logger.warning(
+ f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})."
+ " If virtual size << real data size, there could be data coverage issue."
+ )
+ in_dataset_indices = np.hstack(in_dataset_indices)
+ return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset
+
+ def _get_dataset_and_index(self, index):
+ i = bisect_right(self.cumulated_sizes, index)
+ return i, self._cur_indices[index]
+
+ def __getitem__(self, index):
+ # self.__getitem__(index) returns self.datasets[k][self._cur_indices[index]]
+ # where k satisfies self.cumulated_sizes[k - 1] <= k < self.cumulated_sizes[k]
+ ds_idx, ds_sample_idx = self._get_dataset_and_index(index)
+ ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx])
+ return ret
+
+ def num_tokens(self, index):
+ return self.sizes[index].max()
+
+ def num_tokens_vec(self, indices):
+ sizes_vec = self.sizes[np.array(indices)]
+ # max across all dimensions but first one
+ return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape))))
+
+ def size(self, index):
+ return self.sizes[index]
+
+ def __len__(self):
+ return self.virtual_size
+
+ def collater(self, samples, **extra_args):
+ """Merge a list of samples to form a mini-batch."""
+ if len(samples) == 0:
+ return None
+ if self.collate_format == "ordered_dict":
+ collect_samples = [[] for _ in range(len(self.datasets))]
+ for (i, sample) in samples:
+ collect_samples[i].append(sample)
+ batch = OrderedDict(
+ [
+ (self.keys[i], dataset.collater(collect_samples[i]))
+ for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
+ if len(collect_samples[i]) > 0
+ ]
+ )
+ elif self.shared_collater:
+ batch = self.datasets[0].collater([s for _, s in samples])
+ else:
+ samples_dict = defaultdict(list)
+ pad_to_length = (
+ defaultdict(int)
+ if "pad_to_length" not in extra_args
+ else extra_args["pad_to_length"]
+ )
+ for ds_idx, s in samples:
+ pad_to_length["source"] = max(
+ pad_to_length["source"], s["source"].size(0)
+ )
+ if s["target"] is not None:
+ pad_to_length["target"] = max(
+ pad_to_length["target"], s["target"].size(0)
+ )
+ samples_dict[ds_idx].append(s)
+ batches = [
+ self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length)
+ for i in range(len(self.datasets))
+ if len(samples_dict[i]) > 0
+ ]
+
+ def straight_data(tensors):
+ batch = torch.cat(tensors, dim=0)
+ return batch
+
+ src_lengths = straight_data(
+ [b["net_input"]["src_lengths"] for b in batches]
+ )
+ src_lengths, sort_order = src_lengths.sort(descending=True)
+
+ def straight_order(tensors):
+ batch = straight_data(tensors)
+ return batch.index_select(0, sort_order)
+
+ batch = {
+ "id": straight_order([b["id"] for b in batches]),
+ "nsentences": sum(b["nsentences"] for b in batches),
+ "ntokens": sum(b["ntokens"] for b in batches),
+ "net_input": {
+ "src_tokens": straight_order(
+ [b["net_input"]["src_tokens"] for b in batches]
+ ),
+ "src_lengths": src_lengths,
+ },
+ "target": straight_order([b["target"] for b in batches])
+ if batches[0]["target"] is not None
+ else None,
+ }
+ if "prev_output_tokens" in batches[0]["net_input"]:
+ batch["net_input"]["prev_output_tokens"] = straight_order(
+ [b["net_input"]["prev_output_tokens"] for b in batches]
+ )
+ if "src_lang_id" in batches[0]["net_input"]:
+ batch["net_input"]["src_lang_id"] = straight_order(
+ [b["net_input"]["src_lang_id"] for b in batches]
+ )
+ if "tgt_lang_id" in batches[0]:
+ batch["tgt_lang_id"] = straight_order(
+ [b["tgt_lang_id"] for b in batches]
+ )
+ return batch
+
+ @property
+ def sizes(self):
+ if self._sizes is not None:
+ return self._sizes
+ start_time = time.time()
+ in_sub_dataset_indices = [
+ self._cur_indices[
+ 0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i]
+ ]
+ for i in range(len(self.datasets))
+ ]
+ sub_dataset_sizes = [
+ d.sizes[indices]
+ for d, indices in zip(self.datasets, in_sub_dataset_indices)
+ ]
+ self._sizes = np.vstack(sub_dataset_sizes)
+ logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}")
+ return self._sizes
+
+ def ordered_indices(self):
+ if self.shuffle:
+ indices = np.random.permutation(len(self))
+ else:
+ indices = np.arange(len(self))
+
+ sizes = self.sizes
+ tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
+ src_sizes = (
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ )
+
+ # sort by target length, then source length
+ if tgt_sizes is not None:
+ indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
+ sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")]
+ return sort_indices
+
+ def prefetch(self, indices):
+ prefetch_indices = [[] for _ in range(len(self.datasets))]
+ for i in indices:
+ ds_idx, ds_sample_idx = self._get_dataset_and_index(i)
+ prefetch_indices[ds_idx].append(ds_sample_idx)
+ for i in range(len(prefetch_indices)):
+ self.datasets[i].prefetch(prefetch_indices[i])
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ return False
+
+ def set_epoch(self, epoch):
+ super().set_epoch(epoch)
+ if epoch == self._cur_epoch:
+ # re-enter so return
+ return
+ for d in self.datasets:
+ if hasattr(d, "set_epoch"):
+ d.set_epoch(epoch)
+ self._cur_epoch = epoch
+ self._establish_virtual_datasets()
+
+ def _establish_virtual_datasets(self):
+ if self.sample_ratios is None and self._cur_indices is not None:
+ # not a samping dataset, no need to resample if indices are already established
+ return
+ self._reset_cached_properties()
+
+ start_time = time.time()
+ # Generate a weighted sample of indices as a function of the
+ # random seed and the current epoch.
+ rng = np.random.RandomState(
+ [
+ int(
+ hashlib.sha1(
+ str(self.__class__.__name__).encode("utf-8")
+ ).hexdigest(),
+ 16,
+ )
+ % (2**32),
+ self.seed % (2**32), # global seed
+ self._cur_epoch, # epoch index,
+ ]
+ )
+ self._clean_if_not_none(
+ [self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes]
+ )
+ self._sizes = None
+
+ indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices(
+ rng, self.datasets, self.sample_ratios, self.virtual_size
+ )
+ self._cur_indices = indices
+ self.cumulated_sizes = cumulated_sizes
+ self.virtual_size_per_dataset = virtual_size_per_dataset
+
+ raw_sizes = [len(d) for d in self.datasets]
+ sampled_sizes = self.virtual_size_per_dataset
+ logger.info(
+ f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; "
+ f"raw total size: {sum(raw_sizes)}"
+ )
+ logger.info(
+ f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; "
+ f"resampled total size: {sum(sampled_sizes)}"
+ )
+ if self.sample_ratios is not None:
+ logger.info(
+ f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}"
+ )
+ else:
+ logger.info(f"[{self.split}] A concat dataset")
+ logger.info(
+ f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}"
+ )
+
+ def filter_indices_by_size(self, indices, max_sizes):
+ """Filter a list of sample indices. Remove those that are longer
+ than specified in max_sizes.
+
+ Args:
+ indices (np.array): original array of sample indices
+ max_sizes (int or list[int] or tuple[int]): max sample size,
+ can be defined separately for src and tgt (then list or tuple)
+
+ Returns:
+ np.array: filtered sample array
+ list: list of removed indices
+ """
+ sizes = self.sizes
+ tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
+ src_sizes = (
+ sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
+ )
+
+ return data_utils.filter_paired_dataset_indices_by_size(
+ src_sizes, tgt_sizes, indices, max_sizes
+ )
diff --git a/fairseq/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/fairseq/data/multilingual/sampled_multi_epoch_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb187a8dc28c7647fe93cd4ba3d26f5a892ca7fd
--- /dev/null
+++ b/fairseq/fairseq/data/multilingual/sampled_multi_epoch_dataset.py
@@ -0,0 +1,199 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import hashlib
+import logging
+import math
+
+import numpy as np
+
+from fairseq.data import SampledMultiDataset
+
+from .sampled_multi_dataset import CollateFormat, default_virtual_size_func
+
+logger = logging.getLogger(__name__)
+
+
+class SampledMultiEpochDataset(SampledMultiDataset):
+ """Samples from multiple sub-datasets according to sampling ratios
+ using virtual epoch sizes to speed up dataloading.
+ Args:
+ datasets (
+ List[~torch.utils.data.Dataset]
+ or OrderedDict[str, ~torch.utils.data.Dataset]
+ ): datasets
+ sampling_ratios (List[float]): list of probability of each dataset to be sampled
+ (default: None, which corresponds to concating all dataset together).
+ seed (int): RNG seed to use (default: 2).
+ epoch (int): starting epoch number (default: 1).
+ eval_key (str, optional): a key used at evaluation time that causes
+ this instance to pass-through batches from *datasets[eval_key]*.
+ collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or
+ CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
+ the collater to output batches of data mixed from all sub-datasets,
+ and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
+ of sub-datasets.
+ Note that not all sub-datasets will present in a single batch in both formats.
+ virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
+ split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
+ virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by
+ this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering
+ can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded.
+ shared_collater (bool): whether or not to all sub-datasets have the same collater.
+ shard_epoch (int): the real epoch number for shard selection.
+ shuffle (bool): whether or not to shuffle data (default: True).
+ """
+
+ def __init__(
+ self,
+ datasets,
+ sampling_ratios=None,
+ seed=2,
+ epoch=1,
+ eval_key=None,
+ collate_format=CollateFormat.single,
+ virtual_size=default_virtual_size_func,
+ split="",
+ virtual_epoch_size=None,
+ shared_collater=False,
+ shard_epoch=1,
+ shuffle=True,
+ ):
+ self.virtual_epoch_size = virtual_epoch_size
+ self._current_epoch_start_index = None
+ self._random_global_indices = None
+ self.shard_epoch = shard_epoch if shard_epoch is not None else 1
+ self.load_next_shard = None
+ self._epoch_sizes = None
+ super().__init__(
+ datasets=datasets,
+ sampling_ratios=sampling_ratios,
+ seed=seed,
+ epoch=epoch,
+ eval_key=eval_key,
+ collate_format=collate_format,
+ virtual_size=virtual_size,
+ split=split,
+ shared_collater=shared_collater,
+ shuffle=shuffle,
+ )
+
+ def _setup(self, epoch):
+ self.virtual_epoch_size = (
+ self.virtual_epoch_size
+ if self.virtual_epoch_size is not None
+ else self.virtual_size
+ )
+ if self.virtual_epoch_size > self.virtual_size:
+ logger.warning(
+ f"virtual epoch size {self.virtual_epoch_size} "
+ f"is greater than virtual dataset size {self.virtual_size}"
+ )
+ self.virtual_epoch_size = self.virtual_size
+ self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size)
+ self._current_epoch_start_index = self._get_epoch_start_index(epoch)
+ logger.info(
+ f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}"
+ )
+
+ def _map_epoch_index_to_global(self, index):
+ index = self._current_epoch_start_index + index
+ # add randomness
+ return self._random_global_indices[index]
+
+ @property
+ def sizes(self):
+ if self._epoch_sizes is not None:
+ return self._epoch_sizes
+ _sizes = super().sizes
+ indices = self._random_global_indices[
+ self._current_epoch_start_index : self._current_epoch_start_index
+ + len(self)
+ ]
+ self._epoch_sizes = _sizes[indices]
+ # del super()._sizes to save memory
+ del self._sizes
+ self._sizes = None
+ return self._epoch_sizes
+
+ def _get_dataset_and_index(self, index):
+ i = self._map_epoch_index_to_global(index)
+ return super()._get_dataset_and_index(i)
+
+ def __len__(self):
+ return (
+ self.virtual_epoch_size
+ if self._current_epoch_start_index + self.virtual_epoch_size
+ < self.virtual_size
+ else self.virtual_size - self._current_epoch_start_index
+ )
+
+ def set_epoch(self, epoch):
+ if self._current_epoch_start_index is None:
+ # initializing epoch idnices of a virtual dataset
+ self._setup(epoch)
+ self._next_virtual_epoch(epoch)
+ else:
+ # working on already intialized epoch indices
+ if epoch == self._cur_epoch:
+ # re-enter so return
+ return
+ self._next_virtual_epoch(epoch)
+
+ def _get_epoch_start_index(self, epoch):
+ assert epoch >= 1 # fairseq is using 1-based epoch everywhere
+ return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size
+
+ def _next_global_indices(self, epoch):
+ rng = np.random.RandomState(
+ [
+ int(
+ hashlib.sha1(
+ str(self.__class__.__name__).encode("utf-8")
+ ).hexdigest(),
+ 16,
+ )
+ % (2**32),
+ self.seed % (2**32), # global seed
+ epoch, # epoch index,
+ ]
+ )
+ del self._random_global_indices
+ self._random_global_indices = rng.choice(
+ self.virtual_size, self.virtual_size, replace=False
+ )
+ if self.load_next_shard is None:
+ self.load_next_shard = False
+ else:
+ # increase shard epoch for next loading
+ self.shard_epoch += 1
+ self.load_next_shard = True
+ logger.info(
+ "to load next epoch/shard in next load_dataset: "
+ f"epoch={epoch}/shard_epoch={self.shard_epoch}"
+ )
+
+ def _next_virtual_epoch(self, epoch):
+ index = self._get_epoch_start_index(epoch)
+ if index == 0 or self._random_global_indices is None:
+ # need to start from the beginning,
+ # so call super().set_epoch(epoch) to establish the global virtual indices
+ logger.info(
+ "establishing a new set of global virtual indices for "
+ f"epoch={epoch}/shard_epoch={self.shard_epoch}"
+ )
+ super().set_epoch(epoch)
+ self._next_global_indices(epoch)
+ else:
+ self._cur_epoch = epoch
+
+ # reset cache sizes and ordered_indices for the epoch after moving to a new epoch
+ self._clean_if_not_none(
+ [
+ self._epoch_sizes,
+ ]
+ )
+ self._epoch_sizes = None
+ self._current_epoch_start_index = index
diff --git a/fairseq/fairseq/data/multilingual/sampling_method.py b/fairseq/fairseq/data/multilingual/sampling_method.py
new file mode 100644
index 0000000000000000000000000000000000000000..140c68f01d60e902ef88f11f30f8813dc15fc681
--- /dev/null
+++ b/fairseq/fairseq/data/multilingual/sampling_method.py
@@ -0,0 +1,78 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import List
+
+
+logger = logging.getLogger(__name__)
+
+
+def uniform(dataset_sizes: List[int]):
+ return [1.0] * len(dataset_sizes)
+
+
+def temperature_sampling(dataset_sizes, temp):
+ total_size = sum(dataset_sizes)
+ return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes]
+
+
+def make_temperature_sampling(temp=1.0):
+ def sampling_func(dataset_sizes):
+ return temperature_sampling(dataset_sizes, temp)
+
+ return sampling_func
+
+
+def make_ratio_sampling(ratios):
+ def sampling_func(dataset_sizes):
+ return ratios
+
+ return sampling_func
+
+
+class SamplingMethod:
+ @staticmethod
+ def add_arguments(parser):
+ parser.add_argument(
+ "--sampling-method",
+ choices=[
+ "uniform",
+ "temperature",
+ "concat",
+ "RoundRobin",
+ ],
+ type=str,
+ default="concat",
+ help="The method to sample data per language pairs",
+ )
+ parser.add_argument(
+ "--sampling-temperature",
+ default=1.5,
+ type=float,
+ help="only work with --sampling-method temperature",
+ )
+
+ @staticmethod
+ def build_sampler(args, task):
+ return SamplingMethod(args, task)
+
+ def __init__(self, args, task):
+ self.args = args
+ self.task = task
+
+ def is_adaptive(self):
+ return False
+
+ def sampling_method_selector(self):
+ args = self.args
+ logger.info(f"selected sampler: {args.sampling_method}")
+ if args.sampling_method == "uniform":
+ return uniform
+ elif args.sampling_method == "temperature" or self.is_adaptive():
+ return make_temperature_sampling(float(args.sampling_temperature))
+ else:
+ # default to concating all data set together
+ return None
diff --git a/fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 0000000000000000000000000000000000000000..29681fa71a77b40638f051ae8d77bd4a9845853b
--- /dev/null
+++ b/fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf5c5a79b2d76f0e492935b1574f8d570b2b7a8c607415172b0d326478cc1c18
+size 1559392
diff --git a/fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 0000000000000000000000000000000000000000..2e3c85fe3f5701fa4625f60ce719639b12a1e352
--- /dev/null
+++ b/fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd79aa55743c6a6c664461ccc6704f7474fb9f79a43abb5283e57f0083cbb277
+size 1023360