Spaces:
Paused
Paused
voice_clone_v3
/
transformers
/examples
/research_projects
/seq2seq-distillation
/_test_bash_script.py
#!/usr/bin/env python | |
import argparse | |
import os | |
import sys | |
from unittest.mock import patch | |
import pytorch_lightning as pl | |
import timeout_decorator | |
import torch | |
from distillation import SummarizationDistiller, distill_main | |
from finetune import SummarizationModule, main | |
from transformers import MarianMTModel | |
from transformers.file_utils import cached_path | |
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow | |
from utils import load_json | |
MARIAN_MODEL = "sshleifer/mar_enro_6_3_student" | |
class TestMbartCc25Enro(TestCasePlus): | |
def setUp(self): | |
super().setUp() | |
data_cached = cached_path( | |
"https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz", | |
extract_compressed_file=True, | |
) | |
self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k" | |
def test_model_download(self): | |
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" | |
MarianMTModel.from_pretrained(MARIAN_MODEL) | |
# @timeout_decorator.timeout(1200) | |
def test_train_mbart_cc25_enro_script(self): | |
env_vars_to_replace = { | |
"$MAX_LEN": 64, | |
"$BS": 64, | |
"$GAS": 1, | |
"$ENRO_DIR": self.data_dir, | |
"facebook/mbart-large-cc25": MARIAN_MODEL, | |
# "val_check_interval=0.25": "val_check_interval=1.0", | |
"--learning_rate=3e-5": "--learning_rate 3e-4", | |
"--num_train_epochs 6": "--num_train_epochs 1", | |
} | |
# Clean up bash script | |
bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() | |
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") | |
for k, v in env_vars_to_replace.items(): | |
bash_script = bash_script.replace(k, str(v)) | |
output_dir = self.get_auto_remove_tmp_dir() | |
# bash_script = bash_script.replace("--fp16 ", "") | |
args = f""" | |
--output_dir {output_dir} | |
--tokenizer_name Helsinki-NLP/opus-mt-en-ro | |
--sortish_sampler | |
--do_predict | |
--gpus 1 | |
--freeze_encoder | |
--n_train 40000 | |
--n_val 500 | |
--n_test 500 | |
--fp16_opt_level O1 | |
--num_sanity_val_steps 0 | |
--eval_beams 2 | |
""".split() | |
# XXX: args.gpus > 1 : handle multi_gpu in the future | |
testargs = ["finetune.py"] + bash_script.split() + args | |
with patch.object(sys, "argv", testargs): | |
parser = argparse.ArgumentParser() | |
parser = pl.Trainer.add_argparse_args(parser) | |
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) | |
args = parser.parse_args() | |
model = main(args) | |
# Check metrics | |
metrics = load_json(model.metrics_save_path) | |
first_step_stats = metrics["val"][0] | |
last_step_stats = metrics["val"][-1] | |
self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval)) | |
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) | |
self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01) | |
# model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?) | |
self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0) | |
# test learning requirements: | |
# 1. BLEU improves over the course of training by more than 2 pts | |
self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2) | |
# 2. BLEU finishes above 17 | |
self.assertGreater(last_step_stats["val_avg_bleu"], 17) | |
# 3. test BLEU and val BLEU within ~1.1 pt. | |
self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1) | |
# check lightning ckpt can be loaded and has a reasonable statedict | |
contents = os.listdir(output_dir) | |
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] | |
full_path = os.path.join(args.output_dir, ckpt_path) | |
ckpt = torch.load(full_path, map_location="cpu") | |
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" | |
assert expected_key in ckpt["state_dict"] | |
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 | |
# TODO: turn on args.do_predict when PL bug fixed. | |
if args.do_predict: | |
contents = {os.path.basename(p) for p in contents} | |
assert "test_generations.txt" in contents | |
assert "test_results.txt" in contents | |
# assert len(metrics["val"]) == desired_n_evals | |
assert len(metrics["test"]) == 1 | |
class TestDistilMarianNoTeacher(TestCasePlus): | |
def test_opus_mt_distill_script(self): | |
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro" | |
env_vars_to_replace = { | |
"--fp16_opt_level=O1": "", | |
"$MAX_LEN": 128, | |
"$BS": 16, | |
"$GAS": 1, | |
"$ENRO_DIR": data_dir, | |
"$m": "sshleifer/student_marian_en_ro_6_1", | |
"val_check_interval=0.25": "val_check_interval=1.0", | |
} | |
# Clean up bash script | |
bash_script = ( | |
(self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() | |
) | |
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") | |
bash_script = bash_script.replace("--fp16 ", " ") | |
for k, v in env_vars_to_replace.items(): | |
bash_script = bash_script.replace(k, str(v)) | |
output_dir = self.get_auto_remove_tmp_dir() | |
bash_script = bash_script.replace("--fp16", "") | |
epochs = 6 | |
testargs = ( | |
["distillation.py"] | |
+ bash_script.split() | |
+ [ | |
f"--output_dir={output_dir}", | |
"--gpus=1", | |
"--learning_rate=1e-3", | |
f"--num_train_epochs={epochs}", | |
"--warmup_steps=10", | |
"--val_check_interval=1.0", | |
"--do_predict", | |
] | |
) | |
with patch.object(sys, "argv", testargs): | |
parser = argparse.ArgumentParser() | |
parser = pl.Trainer.add_argparse_args(parser) | |
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd()) | |
args = parser.parse_args() | |
# assert args.gpus == gpus THIS BREAKS for multi_gpu | |
model = distill_main(args) | |
# Check metrics | |
metrics = load_json(model.metrics_save_path) | |
first_step_stats = metrics["val"][0] | |
last_step_stats = metrics["val"][-1] | |
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check | |
assert last_step_stats["val_avg_gen_time"] >= 0.01 | |
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing | |
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. | |
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) | |
# check lightning ckpt can be loaded and has a reasonable statedict | |
contents = os.listdir(output_dir) | |
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] | |
full_path = os.path.join(args.output_dir, ckpt_path) | |
ckpt = torch.load(full_path, map_location="cpu") | |
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" | |
assert expected_key in ckpt["state_dict"] | |
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 | |
# TODO: turn on args.do_predict when PL bug fixed. | |
if args.do_predict: | |
contents = {os.path.basename(p) for p in contents} | |
assert "test_generations.txt" in contents | |
assert "test_results.txt" in contents | |
# assert len(metrics["val"]) == desired_n_evals | |
assert len(metrics["test"]) == 1 | |