ReactSeq / onmt /tests /test_models.py
Oopstom's picture
Upload 313 files
c668e80 verified
import copy
import unittest
import torch
import pyonmttok
from onmt.constants import DefaultTokens
from collections import Counter
import onmt
import onmt.inputters
import onmt.opts
from onmt.model_builder import build_embeddings, build_encoder, build_decoder
from onmt.utils.parse import ArgumentParser
parser = ArgumentParser(description="train.py")
onmt.opts.model_opts(parser)
onmt.opts.distributed_opts(parser)
onmt.opts._add_train_general_opts(parser)
# -data option is required, but not used in this test, so dummy.
opt = parser.parse_known_args(["-data", "dummy"])[0]
class TestModel(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestModel, self).__init__(*args, **kwargs)
self.opt = opt
def get_vocabs(self):
src_vocab = pyonmttok.build_vocab_from_tokens(
Counter(),
maximum_size=0,
minimum_frequency=1,
special_tokens=[
DefaultTokens.UNK,
DefaultTokens.PAD,
DefaultTokens.BOS,
DefaultTokens.EOS,
],
)
tgt_vocab = pyonmttok.build_vocab_from_tokens(
Counter(),
maximum_size=0,
minimum_frequency=1,
special_tokens=[
DefaultTokens.UNK,
DefaultTokens.PAD,
DefaultTokens.BOS,
DefaultTokens.EOS,
],
)
vocabs = {"src": src_vocab, "tgt": tgt_vocab}
return vocabs
def get_batch(self, source_l=3, bsize=1):
# len x batch x nfeat
test_src = torch.ones(bsize, source_l, 1).long()
test_tgt = torch.ones(bsize, source_l, 1).long()
test_length = torch.ones(bsize).fill_(source_l).long()
return test_src, test_tgt, test_length
def embeddings_forward(self, opt, source_l=3, bsize=1):
"""
Tests if the embeddings works as expected
args:
opt: set of options
source_l: Length of generated input sentence
bsize: Batchsize of generated input
"""
vocabs = self.get_vocabs()
emb = build_embeddings(opt, vocabs)
test_src, _, __ = self.get_batch(source_l=source_l, bsize=bsize)
if opt.decoder_type == "transformer":
input = torch.cat([test_src, test_src], 1)
res = emb(input)
compare_to = torch.zeros(bsize, source_l * 2, opt.src_word_vec_size)
else:
res = emb(test_src)
compare_to = torch.zeros(bsize, source_l, opt.src_word_vec_size)
self.assertEqual(res.size(), compare_to.size())
def encoder_forward(self, opt, source_l=3, bsize=1):
"""
Tests if the encoder works as expected
args:
opt: set of options
source_l: Length of generated input sentence
bsize: Batchsize of generated input
"""
if opt.hidden_size > 0:
opt.enc_hid_size = opt.hidden_size
vocabs = self.get_vocabs()
embeddings = build_embeddings(opt, vocabs)
enc = build_encoder(opt, embeddings)
test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize)
enc_out, hidden_t, test_length = enc(test_src, test_length)
# Initialize vectors to compare size with
test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.enc_hid_size)
test_out = torch.zeros(bsize, source_l, opt.dec_hid_size)
# Ensure correct sizes and types
self.assertEqual(test_hid.size(), hidden_t[0].size(), hidden_t[1].size())
self.assertEqual(test_out.size(), enc_out.size())
self.assertEqual(type(enc_out), torch.Tensor)
def nmtmodel_forward(self, opt, source_l=3, bsize=1):
"""
Creates a nmtmodel with a custom opt function.
Forwards a testbatch and checks output size.
Args:
opt: Namespace with options
source_l: length of input sequence
bsize: batchsize
"""
if opt.hidden_size > 0:
opt.enc_hid_size = opt.hidden_size
opt.dec_hid_size = opt.hidden_size
vocabs = self.get_vocabs()
embeddings = build_embeddings(opt, vocabs)
enc = build_encoder(opt, embeddings)
embeddings = build_embeddings(opt, vocabs, for_encoder=False)
dec = build_decoder(opt, embeddings)
model = onmt.models.model.NMTModel(enc, dec)
test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize)
output, attn = model(test_src, test_tgt, test_length)
outputsize = torch.zeros(bsize, source_l - 1, opt.dec_hid_size)
# Make sure that output has the correct size and type
self.assertEqual(output.size(), outputsize.size())
self.assertEqual(type(output), torch.Tensor)
def _add_test(param_setting, methodname):
"""
Adds a Test to TestModel according to settings
Args:
param_setting: list of tuples of (param, setting)
methodname: name of the method that gets called
"""
def test_method(self):
opt = copy.deepcopy(self.opt)
if param_setting:
for param, setting in param_setting:
setattr(opt, param, setting)
ArgumentParser.update_model_opts(opt)
getattr(self, methodname)(opt)
if param_setting:
name = "test_" + methodname + "_" + "_".join(str(param_setting).split())
else:
name = "test_" + methodname + "_standard"
setattr(TestModel, name, test_method)
test_method.__name__ = name
"""
TEST PARAMETERS
"""
opt.brnn = False
test_embeddings = [[], [("decoder_type", "transformer")]]
for p in test_embeddings:
_add_test(p, "embeddings_forward")
tests_encoder = [
[],
[("encoder_type", "mean")],
# [('encoder_type', 'transformer'),
# ('word_vec_size', 16), ('hidden_size', 16)],
[],
]
for p in tests_encoder:
_add_test(p, "encoder_forward")
tests_nmtmodel = [
[("rnn_type", "GRU")],
[("layers", 10)],
[("input_feed", 0)],
[
("decoder_type", "transformer"),
("encoder_type", "transformer"),
("src_word_vec_size", 16),
("tgt_word_vec_size", 16),
("hidden_size", 16),
],
[
("decoder_type", "transformer"),
("encoder_type", "transformer"),
("src_word_vec_size", 16),
("tgt_word_vec_size", 16),
("hidden_size", 16),
("position_encoding", True),
],
[("coverage_attn", True)],
[("copy_attn", True)],
[("global_attention", "mlp")],
[("context_gate", "both")],
[("context_gate", "target")],
[("context_gate", "source")],
[("encoder_type", "brnn"), ("brnn_merge", "sum")],
[("encoder_type", "brnn")],
[("decoder_type", "cnn"), ("encoder_type", "cnn")],
[("encoder_type", "rnn"), ("global_attention", None)],
[
("encoder_type", "rnn"),
("global_attention", None),
("copy_attn", True),
("copy_attn_type", "general"),
],
[
("encoder_type", "rnn"),
("global_attention", "mlp"),
("copy_attn", True),
("copy_attn_type", "general"),
],
[],
]
if onmt.modules.sru.check_sru_requirement():
# """ Only do SRU test if requirment is safisfied. """
# SRU doesn't support input_feed.
tests_nmtmodel.append([("rnn_type", "SRU"), ("input_feed", 0)])
for p in tests_nmtmodel:
_add_test(p, "nmtmodel_forward")