# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import sys from pathlib import Path from unittest.mock import patch from parameterized import parameterized from run_eval import run_generate from run_eval_search import run_search from transformers.testing_utils import CaptureStdout, TestCasePlus, slow from utils import ROUGE_KEYS logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() def _dump_articles(path: Path, articles: list): content = "\n".join(articles) Path(path).open("w").writelines(content) T5_TINY = "patrickvonplaten/t5-tiny-random" BART_TINY = "sshleifer/bart-tiny-random" MBART_TINY = "sshleifer/tiny-mbart" stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks class TestTheRest(TestCasePlus): def run_eval_tester(self, model): input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source" output_file_name = input_file_name.parent / "utest_output.txt" assert not output_file_name.exists() articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] _dump_articles(input_file_name, articles) score_path = str(Path(self.get_auto_remove_tmp_dir()) / "scores.json") task = "translation_en_to_de" if model == T5_TINY else "summarization" testargs = f""" run_eval_search.py {model} {input_file_name} {output_file_name} --score_path {score_path} --task {task} --num_beams 2 --length_penalty 2.0 """.split() with patch.object(sys, "argv", testargs): run_generate() assert Path(output_file_name).exists() # os.remove(Path(output_file_name)) # test one model to quickly (no-@slow) catch simple problems and do an # extensive testing of functionality with multiple models as @slow separately def test_run_eval(self): self.run_eval_tester(T5_TINY) # any extra models should go into the list here - can be slow @parameterized.expand([BART_TINY, MBART_TINY]) @slow def test_run_eval_slow(self, model): self.run_eval_tester(model) # testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart) @parameterized.expand([T5_TINY, MBART_TINY]) @slow def test_run_eval_search(self, model): input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source" output_file_name = input_file_name.parent / "utest_output.txt" assert not output_file_name.exists() text = { "en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"], "de": [ "Maschinelles Lernen ist großartig, oder?", "Ich esse gerne Bananen", "Morgen ist wieder ein toller Tag!", ], } tmp_dir = Path(self.get_auto_remove_tmp_dir()) score_path = str(tmp_dir / "scores.json") reference_path = str(tmp_dir / "val.target") _dump_articles(input_file_name, text["en"]) _dump_articles(reference_path, text["de"]) task = "translation_en_to_de" if model == T5_TINY else "summarization" testargs = f""" run_eval_search.py {model} {str(input_file_name)} {str(output_file_name)} --score_path {score_path} --reference_path {reference_path} --task {task} """.split() testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"]) with patch.object(sys, "argv", testargs): with CaptureStdout() as cs: run_search() expected_strings = [" num_beams | length_penalty", model, "Best score args"] un_expected_strings = ["Info"] if "translation" in task: expected_strings.append("bleu") else: expected_strings.extend(ROUGE_KEYS) for w in expected_strings: assert w in cs.out for w in un_expected_strings: assert w not in cs.out assert Path(output_file_name).exists() os.remove(Path(output_file_name))