# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from pathlib import Path import numpy as np import pytest from pack_dataset import pack_data_dir from parameterized import parameterized from save_len_file import save_len_file from torch.utils.data import DataLoader from transformers import AutoTokenizer from transformers.models.mbart.modeling_mbart import shift_tokens_right from transformers.testing_utils import TestCasePlus, slow from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset BERT_BASE_CASED = "bert-base-cased" PEGASUS_XSUM = "google/pegasus-xsum" ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."] SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] T5_TINY = "patrickvonplaten/t5-tiny-random" BART_TINY = "sshleifer/bart-tiny-random" MBART_TINY = "sshleifer/tiny-mbart" MARIAN_TINY = "sshleifer/tiny-marian-en-de" def _dump_articles(path: Path, articles: list): content = "\n".join(articles) Path(path).open("w").writelines(content) def make_test_data_dir(tmp_dir): for split in ["train", "val", "test"]: _dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES) _dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES) return tmp_dir class TestAll(TestCasePlus): @parameterized.expand( [ MBART_TINY, MARIAN_TINY, T5_TINY, BART_TINY, PEGASUS_XSUM, ], ) @slow def test_seq2seq_dataset_truncation(self, tok_name): tokenizer = AutoTokenizer.from_pretrained(tok_name) tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_src_len = 4 max_tgt_len = 8 assert max_len_target > max_src_len # Will be truncated assert max_len_source > max_src_len # Will be truncated src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error. train_dataset = Seq2SeqDataset( tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=max_src_len, max_target_length=max_tgt_len, # ignored src_lang=src_lang, tgt_lang=tgt_lang, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: assert isinstance(batch, dict) assert batch["attention_mask"].shape == batch["input_ids"].shape # show that articles were trimmed. assert batch["input_ids"].shape[1] == max_src_len # show that targets are the same len assert batch["labels"].shape[1] == max_tgt_len if tok_name != MBART_TINY: continue # check language codes in correct place batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id) assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang] break # No need to test every batch @parameterized.expand([BART_TINY, BERT_BASE_CASED]) def test_legacy_dataset_truncation(self, tok): tokenizer = AutoTokenizer.from_pretrained(tok) tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) trunc_target = 4 train_dataset = LegacySeq2SeqDataset( tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: assert batch["attention_mask"].shape == batch["input_ids"].shape # show that articles were trimmed. assert batch["input_ids"].shape[1] == max_len_source assert 20 >= batch["input_ids"].shape[1] # trimmed significantly # show that targets were truncated assert batch["labels"].shape[1] == trunc_target # Truncated assert max_len_target > trunc_target # Truncated break # No need to test every batch def test_pack_dataset(self): tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())) orig_examples = tmp_dir.joinpath("train.source").open().readlines() save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())) pack_data_dir(tokenizer, tmp_dir, 128, save_dir) orig_paths = {x.name for x in tmp_dir.iterdir()} new_paths = {x.name for x in save_dir.iterdir()} packed_examples = save_dir.joinpath("train.source").open().readlines() # orig: [' Sam ate lunch today.\n', 'Sams lunch ingredients.'] # desired_packed: [' Sam ate lunch today.\n Sams lunch ingredients.'] assert len(packed_examples) < len(orig_examples) assert len(packed_examples) == 1 assert len(packed_examples[0]) == sum(len(x) for x in orig_examples) assert orig_paths == new_paths @pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq") def test_dynamic_batch_size(self): if not FAIRSEQ_AVAILABLE: return ds, max_tokens, tokenizer = self._get_dataset(max_len=64) required_batch_size_multiple = 64 batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple) batch_sizes = [len(x) for x in batch_sampler] assert len(set(batch_sizes)) > 1 # it's not dynamic batch size if every batch is the same length assert sum(batch_sizes) == len(ds) # no dropped or added examples data_loader = DataLoader(ds, batch_sampler=batch_sampler, collate_fn=ds.collate_fn, num_workers=2) failures = [] num_src_per_batch = [] for batch in data_loader: src_shape = batch["input_ids"].shape bs = src_shape[0] assert bs % required_batch_size_multiple == 0 or bs < required_batch_size_multiple num_src_tokens = np.product(batch["input_ids"].shape) num_src_per_batch.append(num_src_tokens) if num_src_tokens > (max_tokens * 1.1): failures.append(num_src_tokens) assert num_src_per_batch[0] == max(num_src_per_batch) if failures: raise AssertionError(f"too many tokens in {len(failures)} batches") def test_sortish_sampler_reduces_padding(self): ds, _, tokenizer = self._get_dataset(max_len=512) bs = 2 sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False) naive_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2) sortish_dl = DataLoader(ds, batch_size=bs, collate_fn=ds.collate_fn, num_workers=2, sampler=sortish_sampler) pad = tokenizer.pad_token_id def count_pad_tokens(data_loader, k="input_ids"): return [batch[k].eq(pad).sum().item() for batch in data_loader] assert sum(count_pad_tokens(sortish_dl, k="labels")) < sum(count_pad_tokens(naive_dl, k="labels")) assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl)) assert len(sortish_dl) == len(naive_dl) def _get_dataset(self, n_obs=1000, max_len=128): if os.getenv("USE_REAL_DATA", False): data_dir = "examples/seq2seq/wmt_en_ro" max_tokens = max_len * 2 * 64 if not Path(data_dir).joinpath("train.len").exists(): save_len_file(MARIAN_TINY, data_dir) else: data_dir = "examples/seq2seq/test_data/wmt_en_ro" max_tokens = max_len * 4 save_len_file(MARIAN_TINY, data_dir) tokenizer = AutoTokenizer.from_pretrained(MARIAN_TINY) ds = Seq2SeqDataset( tokenizer, data_dir=data_dir, type_path="train", max_source_length=max_len, max_target_length=max_len, n_obs=n_obs, ) return ds, max_tokens, tokenizer def test_distributed_sortish_sampler_splits_indices_between_procs(self): ds, max_tokens, tokenizer = self._get_dataset() ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False)) ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False)) assert ids1.intersection(ids2) == set() @parameterized.expand( [ MBART_TINY, MARIAN_TINY, T5_TINY, BART_TINY, PEGASUS_XSUM, ], ) def test_dataset_kwargs(self, tok_name): tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=False) if tok_name == MBART_TINY: train_dataset = Seq2SeqDataset( tokenizer, data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()), type_path="train", max_source_length=4, max_target_length=8, src_lang="EN", tgt_lang="FR", ) kwargs = train_dataset.dataset_kwargs assert "src_lang" in kwargs and "tgt_lang" in kwargs else: train_dataset = Seq2SeqDataset( tokenizer, data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()), type_path="train", max_source_length=4, max_target_length=8, ) kwargs = train_dataset.dataset_kwargs assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0