sakharamg's picture
Uploading all files
158b61b
import copy
import unittest
import torch
import onmt
import onmt.inputters
import onmt.opts
from onmt.model_builder import build_embeddings, \
build_encoder, build_decoder
from onmt.utils.parse import ArgumentParser
parser = ArgumentParser(description='train.py')
onmt.opts.model_opts(parser)
onmt.opts._add_train_general_opts(parser)
# -data option is required, but not used in this test, so dummy.
opt = parser.parse_known_args(['-data', 'dummy'])[0]
class TestModel(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestModel, self).__init__(*args, **kwargs)
self.opt = opt
def get_field(self):
src = onmt.inputters.get_fields("text", 0, 0)["src"]
src.base_field.build_vocab([])
return src
def get_batch(self, source_l=3, bsize=1):
# len x batch x nfeat
test_src = torch.ones(source_l, bsize, 1).long()
test_tgt = torch.ones(source_l, bsize, 1).long()
test_length = torch.ones(bsize).fill_(source_l).long()
return test_src, test_tgt, test_length
def embeddings_forward(self, opt, source_l=3, bsize=1):
'''
Tests if the embeddings works as expected
args:
opt: set of options
source_l: Length of generated input sentence
bsize: Batchsize of generated input
'''
word_field = self.get_field()
emb = build_embeddings(opt, word_field)
test_src, _, __ = self.get_batch(source_l=source_l, bsize=bsize)
if opt.decoder_type == 'transformer':
input = torch.cat([test_src, test_src], 0)
res = emb(input)
compare_to = torch.zeros(source_l * 2, bsize,
opt.src_word_vec_size)
else:
res = emb(test_src)
compare_to = torch.zeros(source_l, bsize, opt.src_word_vec_size)
self.assertEqual(res.size(), compare_to.size())
def encoder_forward(self, opt, source_l=3, bsize=1):
'''
Tests if the encoder works as expected
args:
opt: set of options
source_l: Length of generated input sentence
bsize: Batchsize of generated input
'''
if opt.rnn_size > 0:
opt.enc_rnn_size = opt.rnn_size
word_field = self.get_field()
embeddings = build_embeddings(opt, word_field)
enc = build_encoder(opt, embeddings)
test_src, test_tgt, test_length = self.get_batch(source_l=source_l,
bsize=bsize)
hidden_t, outputs, test_length = enc(test_src, test_length)
# Initialize vectors to compare size with
test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.enc_rnn_size)
test_out = torch.zeros(source_l, bsize, opt.dec_rnn_size)
# Ensure correct sizes and types
self.assertEqual(test_hid.size(),
hidden_t[0].size(),
hidden_t[1].size())
self.assertEqual(test_out.size(), outputs.size())
self.assertEqual(type(outputs), torch.Tensor)
def nmtmodel_forward(self, opt, source_l=3, bsize=1):
"""
Creates a nmtmodel with a custom opt function.
Forwards a testbatch and checks output size.
Args:
opt: Namespace with options
source_l: length of input sequence
bsize: batchsize
"""
if opt.rnn_size > 0:
opt.enc_rnn_size = opt.rnn_size
opt.dec_rnn_size = opt.rnn_size
word_field = self.get_field()
embeddings = build_embeddings(opt, word_field)
enc = build_encoder(opt, embeddings)
embeddings = build_embeddings(opt, word_field, for_encoder=False)
dec = build_decoder(opt, embeddings)
model = onmt.models.model.NMTModel(enc, dec)
test_src, test_tgt, test_length = self.get_batch(source_l=source_l,
bsize=bsize)
outputs, attn = model(test_src, test_tgt, test_length)
outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size)
# Make sure that output has the correct size and type
self.assertEqual(outputs.size(), outputsize.size())
self.assertEqual(type(outputs), torch.Tensor)
def _add_test(param_setting, methodname):
"""
Adds a Test to TestModel according to settings
Args:
param_setting: list of tuples of (param, setting)
methodname: name of the method that gets called
"""
def test_method(self):
opt = copy.deepcopy(self.opt)
if param_setting:
for param, setting in param_setting:
setattr(opt, param, setting)
ArgumentParser.update_model_opts(opt)
getattr(self, methodname)(opt)
if param_setting:
name = 'test_' + methodname + "_" + "_".join(
str(param_setting).split())
else:
name = 'test_' + methodname + '_standard'
setattr(TestModel, name, test_method)
test_method.__name__ = name
'''
TEST PARAMETERS
'''
opt.brnn = False
test_embeddings = [[],
[('decoder_type', 'transformer')]
]
for p in test_embeddings:
_add_test(p, 'embeddings_forward')
tests_encoder = [[],
[('encoder_type', 'mean')],
# [('encoder_type', 'transformer'),
# ('word_vec_size', 16), ('rnn_size', 16)],
[]
]
for p in tests_encoder:
_add_test(p, 'encoder_forward')
tests_nmtmodel = [[('rnn_type', 'GRU')],
[('layers', 10)],
[('input_feed', 0)],
[('decoder_type', 'transformer'),
('encoder_type', 'transformer'),
('src_word_vec_size', 16),
('tgt_word_vec_size', 16),
('rnn_size', 16)],
[('decoder_type', 'transformer'),
('encoder_type', 'transformer'),
('src_word_vec_size', 16),
('tgt_word_vec_size', 16),
('rnn_size', 16),
('position_encoding', True)],
[('coverage_attn', True)],
[('copy_attn', True)],
[('global_attention', 'mlp')],
[('context_gate', 'both')],
[('context_gate', 'target')],
[('context_gate', 'source')],
[('encoder_type', "brnn"),
('brnn_merge', 'sum')],
[('encoder_type', "brnn")],
[('decoder_type', 'cnn'),
('encoder_type', 'cnn')],
[('encoder_type', 'rnn'),
('global_attention', None)],
[('encoder_type', 'rnn'),
('global_attention', None),
('copy_attn', True),
('copy_attn_type', 'general')],
[('encoder_type', 'rnn'),
('global_attention', 'mlp'),
('copy_attn', True),
('copy_attn_type', 'general')],
[],
]
if onmt.models.sru.check_sru_requirement():
# """ Only do SRU test if requirment is safisfied. """
# SRU doesn't support input_feed.
tests_nmtmodel.append([('rnn_type', 'SRU'), ('input_feed', 0)])
for p in tests_nmtmodel:
_add_test(p, 'nmtmodel_forward')