|
import copy |
|
import unittest |
|
|
|
import torch |
|
|
|
import onmt |
|
import onmt.inputters |
|
import onmt.opts |
|
from onmt.model_builder import build_embeddings, \ |
|
build_encoder, build_decoder |
|
from onmt.utils.parse import ArgumentParser |
|
|
|
parser = ArgumentParser(description='train.py') |
|
onmt.opts.model_opts(parser) |
|
onmt.opts._add_train_general_opts(parser) |
|
|
|
|
|
opt = parser.parse_known_args(['-data', 'dummy'])[0] |
|
|
|
|
|
class TestModel(unittest.TestCase): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super(TestModel, self).__init__(*args, **kwargs) |
|
self.opt = opt |
|
|
|
def get_field(self): |
|
src = onmt.inputters.get_fields("text", 0, 0)["src"] |
|
src.base_field.build_vocab([]) |
|
return src |
|
|
|
def get_batch(self, source_l=3, bsize=1): |
|
|
|
test_src = torch.ones(source_l, bsize, 1).long() |
|
test_tgt = torch.ones(source_l, bsize, 1).long() |
|
test_length = torch.ones(bsize).fill_(source_l).long() |
|
return test_src, test_tgt, test_length |
|
|
|
def embeddings_forward(self, opt, source_l=3, bsize=1): |
|
''' |
|
Tests if the embeddings works as expected |
|
|
|
args: |
|
opt: set of options |
|
source_l: Length of generated input sentence |
|
bsize: Batchsize of generated input |
|
''' |
|
word_field = self.get_field() |
|
emb = build_embeddings(opt, word_field) |
|
test_src, _, __ = self.get_batch(source_l=source_l, bsize=bsize) |
|
if opt.decoder_type == 'transformer': |
|
input = torch.cat([test_src, test_src], 0) |
|
res = emb(input) |
|
compare_to = torch.zeros(source_l * 2, bsize, |
|
opt.src_word_vec_size) |
|
else: |
|
res = emb(test_src) |
|
compare_to = torch.zeros(source_l, bsize, opt.src_word_vec_size) |
|
|
|
self.assertEqual(res.size(), compare_to.size()) |
|
|
|
def encoder_forward(self, opt, source_l=3, bsize=1): |
|
''' |
|
Tests if the encoder works as expected |
|
|
|
args: |
|
opt: set of options |
|
source_l: Length of generated input sentence |
|
bsize: Batchsize of generated input |
|
''' |
|
if opt.rnn_size > 0: |
|
opt.enc_rnn_size = opt.rnn_size |
|
word_field = self.get_field() |
|
embeddings = build_embeddings(opt, word_field) |
|
enc = build_encoder(opt, embeddings) |
|
|
|
test_src, test_tgt, test_length = self.get_batch(source_l=source_l, |
|
bsize=bsize) |
|
|
|
hidden_t, outputs, test_length = enc(test_src, test_length) |
|
|
|
|
|
test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.enc_rnn_size) |
|
test_out = torch.zeros(source_l, bsize, opt.dec_rnn_size) |
|
|
|
|
|
self.assertEqual(test_hid.size(), |
|
hidden_t[0].size(), |
|
hidden_t[1].size()) |
|
self.assertEqual(test_out.size(), outputs.size()) |
|
self.assertEqual(type(outputs), torch.Tensor) |
|
|
|
def nmtmodel_forward(self, opt, source_l=3, bsize=1): |
|
""" |
|
Creates a nmtmodel with a custom opt function. |
|
Forwards a testbatch and checks output size. |
|
|
|
Args: |
|
opt: Namespace with options |
|
source_l: length of input sequence |
|
bsize: batchsize |
|
""" |
|
if opt.rnn_size > 0: |
|
opt.enc_rnn_size = opt.rnn_size |
|
opt.dec_rnn_size = opt.rnn_size |
|
word_field = self.get_field() |
|
|
|
embeddings = build_embeddings(opt, word_field) |
|
enc = build_encoder(opt, embeddings) |
|
|
|
embeddings = build_embeddings(opt, word_field, for_encoder=False) |
|
dec = build_decoder(opt, embeddings) |
|
|
|
model = onmt.models.model.NMTModel(enc, dec) |
|
|
|
test_src, test_tgt, test_length = self.get_batch(source_l=source_l, |
|
bsize=bsize) |
|
outputs, attn = model(test_src, test_tgt, test_length) |
|
outputsize = torch.zeros(source_l - 1, bsize, opt.dec_rnn_size) |
|
|
|
self.assertEqual(outputs.size(), outputsize.size()) |
|
self.assertEqual(type(outputs), torch.Tensor) |
|
|
|
|
|
def _add_test(param_setting, methodname): |
|
""" |
|
Adds a Test to TestModel according to settings |
|
|
|
Args: |
|
param_setting: list of tuples of (param, setting) |
|
methodname: name of the method that gets called |
|
""" |
|
|
|
def test_method(self): |
|
opt = copy.deepcopy(self.opt) |
|
if param_setting: |
|
for param, setting in param_setting: |
|
setattr(opt, param, setting) |
|
ArgumentParser.update_model_opts(opt) |
|
getattr(self, methodname)(opt) |
|
if param_setting: |
|
name = 'test_' + methodname + "_" + "_".join( |
|
str(param_setting).split()) |
|
else: |
|
name = 'test_' + methodname + '_standard' |
|
setattr(TestModel, name, test_method) |
|
test_method.__name__ = name |
|
|
|
|
|
''' |
|
TEST PARAMETERS |
|
''' |
|
opt.brnn = False |
|
|
|
test_embeddings = [[], |
|
[('decoder_type', 'transformer')] |
|
] |
|
|
|
for p in test_embeddings: |
|
_add_test(p, 'embeddings_forward') |
|
|
|
tests_encoder = [[], |
|
[('encoder_type', 'mean')], |
|
|
|
|
|
[] |
|
] |
|
|
|
for p in tests_encoder: |
|
_add_test(p, 'encoder_forward') |
|
|
|
tests_nmtmodel = [[('rnn_type', 'GRU')], |
|
[('layers', 10)], |
|
[('input_feed', 0)], |
|
[('decoder_type', 'transformer'), |
|
('encoder_type', 'transformer'), |
|
('src_word_vec_size', 16), |
|
('tgt_word_vec_size', 16), |
|
('rnn_size', 16)], |
|
[('decoder_type', 'transformer'), |
|
('encoder_type', 'transformer'), |
|
('src_word_vec_size', 16), |
|
('tgt_word_vec_size', 16), |
|
('rnn_size', 16), |
|
('position_encoding', True)], |
|
[('coverage_attn', True)], |
|
[('copy_attn', True)], |
|
[('global_attention', 'mlp')], |
|
[('context_gate', 'both')], |
|
[('context_gate', 'target')], |
|
[('context_gate', 'source')], |
|
[('encoder_type', "brnn"), |
|
('brnn_merge', 'sum')], |
|
[('encoder_type', "brnn")], |
|
[('decoder_type', 'cnn'), |
|
('encoder_type', 'cnn')], |
|
[('encoder_type', 'rnn'), |
|
('global_attention', None)], |
|
[('encoder_type', 'rnn'), |
|
('global_attention', None), |
|
('copy_attn', True), |
|
('copy_attn_type', 'general')], |
|
[('encoder_type', 'rnn'), |
|
('global_attention', 'mlp'), |
|
('copy_attn', True), |
|
('copy_attn_type', 'general')], |
|
[], |
|
] |
|
|
|
if onmt.models.sru.check_sru_requirement(): |
|
|
|
|
|
tests_nmtmodel.append([('rnn_type', 'SRU'), ('input_feed', 0)]) |
|
|
|
for p in tests_nmtmodel: |
|
_add_test(p, 'nmtmodel_forward') |
|
|