OFA-OCR-dedao-demo001 / fairseq /tests /test_inference_dropout.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
3.31 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import unittest
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models.transformer import TransformerModel
from tests.test_sequence_generator import get_dummy_task_and_parser
class TestInferenceDropout(unittest.TestCase):
def setUp(self):
self.task, self.parser = get_dummy_task_and_parser()
TransformerModel.add_args(self.parser)
self.args = self.parser.parse_args([])
self.args.encoder_layers = 2
self.args.decoder_layers = 1
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def test_sets_inference_dropout_to_true(self):
self.args.retain_dropout = True
self.transformer_model = TransformerModel.build_model(self.args, self.task)
cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert self.transformer_model.encoder.dropout_module.apply_during_inference
assert self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.encoder.layers:
assert layer.dropout_module.apply_during_inference
def test_inference_dropout_false_by_default(self):
self.transformer_model = TransformerModel.build_model(self.args, self.task)
cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert not self.transformer_model.encoder.dropout_module.apply_during_inference
assert not self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.encoder.layers:
assert not layer.dropout_module.apply_during_inference
for layer in self.transformer_model.decoder.layers:
assert not layer.dropout_module.apply_during_inference
def test_applies_training_mode(self):
self.transformer_model = TransformerModel.build_model(self.args, self.task)
assert self.transformer_model.encoder.dropout_module.training
for layer in self.transformer_model.encoder.layers:
assert layer.dropout_module.training
self.transformer_model.eval()
assert not self.transformer_model.decoder.dropout_module.training
for layer in self.transformer_model.encoder.layers:
assert not layer.dropout_module.training
def test_retain_modules(self):
self.args.retain_dropout = True
self.args.retain_dropout_modules = [
"TransformerEncoder",
"TransformerEncoderLayer",
]
self.transformer_model = TransformerModel.build_model(self.args, self.task)
cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert self.transformer_model.encoder.dropout_module.apply_during_inference
assert not self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.decoder.layers:
assert not layer.dropout_module.apply_during_inference