OFA-OCR / fairseq /tests /test_reproducibility.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
4.91 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import json
import os
import tempfile
import unittest
from io import StringIO
import torch
from . import test_binaries
class TestReproducibility(unittest.TestCase):
def _test_reproducibility(
self,
name,
extra_flags=None,
delta=0.0001,
resume_checkpoint="checkpoint1.pt",
max_epoch=3,
):
def get_last_log_stats_containing_string(log_records, search_string):
for log_record in logs.records[::-1]:
if isinstance(log_record.msg, str) and search_string in log_record.msg:
return json.loads(log_record.msg)
if extra_flags is None:
extra_flags = []
with tempfile.TemporaryDirectory(name) as data_dir:
with self.assertLogs() as logs:
test_binaries.create_dummy_data(data_dir)
test_binaries.preprocess_translation_data(data_dir)
# train epochs 1 and 2 together
with self.assertLogs() as logs:
test_binaries.train_translation_model(
data_dir,
"fconv_iwslt_de_en",
[
"--dropout",
"0.0",
"--log-format",
"json",
"--log-interval",
"1",
"--max-epoch",
str(max_epoch),
]
+ extra_flags,
)
train_log = get_last_log_stats_containing_string(logs.records, "train_loss")
valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss")
# train epoch 2, resuming from previous checkpoint 1
os.rename(
os.path.join(data_dir, resume_checkpoint),
os.path.join(data_dir, "checkpoint_last.pt"),
)
with self.assertLogs() as logs:
test_binaries.train_translation_model(
data_dir,
"fconv_iwslt_de_en",
[
"--dropout",
"0.0",
"--log-format",
"json",
"--log-interval",
"1",
"--max-epoch",
str(max_epoch),
]
+ extra_flags,
)
train_res_log = get_last_log_stats_containing_string(
logs.records, "train_loss"
)
valid_res_log = get_last_log_stats_containing_string(
logs.records, "valid_loss"
)
for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]:
self.assertAlmostEqual(
float(train_log[k]), float(train_res_log[k]), delta=delta
)
for k in [
"valid_loss",
"valid_ppl",
"valid_num_updates",
"valid_best_loss",
]:
self.assertAlmostEqual(
float(valid_log[k]), float(valid_res_log[k]), delta=delta
)
def test_reproducibility(self):
self._test_reproducibility("test_reproducibility")
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
def test_reproducibility_fp16(self):
self._test_reproducibility(
"test_reproducibility_fp16",
[
"--fp16",
"--fp16-init-scale",
"4096",
],
delta=0.011,
)
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
def test_reproducibility_memory_efficient_fp16(self):
self._test_reproducibility(
"test_reproducibility_memory_efficient_fp16",
[
"--memory-efficient-fp16",
"--fp16-init-scale",
"4096",
],
)
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
def test_reproducibility_amp(self):
self._test_reproducibility(
"test_reproducibility_amp",
[
"--amp",
"--fp16-init-scale",
"4096",
],
delta=0.011,
)
def test_mid_epoch_reproducibility(self):
self._test_reproducibility(
"test_mid_epoch_reproducibility",
["--save-interval-updates", "3"],
resume_checkpoint="checkpoint_1_3.pt",
max_epoch=1,
)
if __name__ == "__main__":
unittest.main()