|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import tempfile |
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from nemo.collections.asr.models import EncDecCTCModel |
|
|
|
try: |
|
from eff.cookbooks import NeMoCookbook |
|
|
|
_EFF_PRESENT_ = True |
|
except ImportError: |
|
_EFF_PRESENT_ = False |
|
|
|
|
|
requires_eff = pytest.mark.skipif(not _EFF_PRESENT_, reason="Export File Format library required to run test") |
|
|
|
|
|
@pytest.fixture() |
|
def asr_model(): |
|
preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} |
|
encoder = { |
|
'cls': 'nemo.collections.asr.modules.ConvASREncoder', |
|
'params': { |
|
'feat_in': 64, |
|
'activation': 'relu', |
|
'conv_mask': True, |
|
'jasper': [ |
|
{ |
|
'filters': 1024, |
|
'repeat': 1, |
|
'kernel': [1], |
|
'stride': [1], |
|
'dilation': [1], |
|
'dropout': 0.0, |
|
'residual': False, |
|
'separable': True, |
|
'se': True, |
|
'se_context_size': -1, |
|
} |
|
], |
|
}, |
|
} |
|
|
|
decoder = { |
|
'cls': 'nemo.collections.asr.modules.ConvASRDecoder', |
|
'params': { |
|
'feat_in': 1024, |
|
'num_classes': 28, |
|
'vocabulary': [ |
|
' ', |
|
'a', |
|
'b', |
|
'c', |
|
'd', |
|
'e', |
|
'f', |
|
'g', |
|
'h', |
|
'i', |
|
'j', |
|
'k', |
|
'l', |
|
'm', |
|
'n', |
|
'o', |
|
'p', |
|
'q', |
|
'r', |
|
's', |
|
't', |
|
'u', |
|
'v', |
|
'w', |
|
'x', |
|
'y', |
|
'z', |
|
"'", |
|
], |
|
}, |
|
} |
|
modelConfig = DictConfig( |
|
{'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder)} |
|
) |
|
|
|
model_instance = EncDecCTCModel(cfg=modelConfig) |
|
return model_instance |
|
|
|
|
|
class TestFileIO: |
|
@pytest.mark.unit |
|
def test_to_from_config_file(self, asr_model): |
|
"""" Test makes sure that the second instance created with the same configuration (BUT NOT checkpoint) |
|
has different weights. """ |
|
|
|
with tempfile.NamedTemporaryFile() as fp: |
|
yaml_filename = fp.name |
|
asr_model.to_config_file(path2yaml_file=yaml_filename) |
|
next_instance = EncDecCTCModel.from_config_file(path2yaml_file=yaml_filename) |
|
|
|
assert isinstance(next_instance, EncDecCTCModel) |
|
|
|
assert len(next_instance.decoder.vocabulary) == 28 |
|
assert asr_model.num_weights == next_instance.num_weights |
|
|
|
w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
w2 = next_instance.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
|
|
assert not np.array_equal(w1, w2) |
|
|
|
@pytest.mark.unit |
|
def test_save_restore_from_nemo_file(self, asr_model): |
|
"""" Test makes sure that the second instance created from the same configuration AND checkpoint |
|
has the same weights. """ |
|
|
|
with tempfile.NamedTemporaryFile() as fp: |
|
filename = fp.name |
|
|
|
|
|
with tempfile.NamedTemporaryFile() as artifact: |
|
asr_model.register_artifact(config_path="abc", src=artifact.name) |
|
asr_model.save_to(save_path=filename) |
|
|
|
|
|
asr_model2 = EncDecCTCModel.restore_from(restore_path=filename) |
|
|
|
assert len(asr_model.decoder.vocabulary) == len(asr_model2.decoder.vocabulary) |
|
assert asr_model.num_weights == asr_model2.num_weights |
|
|
|
w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
|
|
assert np.array_equal(w1, w2) |
|
|
|
@requires_eff |
|
@pytest.mark.unit |
|
def test_eff_save_restore_from_nemo_file_encrypted(self, asr_model): |
|
"""" Test makes sure that after encrypted save-restore the model has the same weights. """ |
|
|
|
with tempfile.NamedTemporaryFile() as fp: |
|
filename = fp.name |
|
|
|
|
|
NeMoCookbook.set_encryption_key("test_key") |
|
|
|
|
|
with tempfile.NamedTemporaryFile() as artifact: |
|
asr_model.register_artifact(config_path="abc", src=artifact.name) |
|
asr_model.save_to(save_path=filename) |
|
|
|
|
|
NeMoCookbook.set_encryption_key(None) |
|
with pytest.raises(PermissionError): |
|
|
|
asr_model2 = EncDecCTCModel.restore_from(restore_path=filename) |
|
|
|
|
|
NeMoCookbook.set_encryption_key("test_key") |
|
asr_model3 = EncDecCTCModel.restore_from(restore_path=filename) |
|
|
|
NeMoCookbook.set_encryption_key(None) |
|
|
|
assert asr_model.num_weights == asr_model3.num_weights |
|
|
|
@pytest.mark.unit |
|
def test_save_restore_from_nemo_file_with_override(self, asr_model, tmpdir): |
|
"""" Test makes sure that the second instance created from the same configuration AND checkpoint |
|
has the same weights. |
|
|
|
Args: |
|
tmpdir: fixture providing a temporary directory unique to the test invocation. |
|
""" |
|
|
|
filename = os.path.join(tmpdir, "eff.nemo") |
|
|
|
|
|
|
|
cwd = os.getcwd() |
|
|
|
with tempfile.NamedTemporaryFile(mode='a+') as conf_fp: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False) as artifact: |
|
artifact.write("magic content 42") |
|
|
|
_, artifact_filename = os.path.split(artifact.name) |
|
|
|
asr_model.register_artifact(config_path="abc", src=artifact.name) |
|
|
|
asr_model.save_to(save_path=filename) |
|
|
|
|
|
cfg = asr_model.cfg |
|
cfg.encoder.activation = 'swish' |
|
yaml_cfg = OmegaConf.to_yaml(cfg) |
|
conf_fp.write(yaml_cfg) |
|
conf_fp.seek(0) |
|
|
|
|
|
asr_model2 = EncDecCTCModel.restore_from(restore_path=filename, override_config_path=conf_fp.name) |
|
|
|
assert len(asr_model.decoder.vocabulary) == len(asr_model2.decoder.vocabulary) |
|
assert asr_model.num_weights == asr_model2.num_weights |
|
|
|
w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
|
|
assert np.array_equal(w1, w2) |
|
|
|
assert asr_model2.cfg.encoder.activation == 'swish' |
|
|
|
@pytest.mark.unit |
|
def test_save_model_level_pt_ckpt(self, asr_model): |
|
with tempfile.TemporaryDirectory() as ckpt_dir: |
|
nemo_file = os.path.join(ckpt_dir, 'asr.nemo') |
|
asr_model.save_to(nemo_file) |
|
|
|
|
|
asr_model.extract_state_dict_from(nemo_file, ckpt_dir) |
|
ckpt_path = os.path.join(ckpt_dir, 'model_weights.ckpt') |
|
|
|
assert os.path.exists(ckpt_path) |
|
|
|
|
|
asr_model2 = EncDecCTCModel.restore_from(restore_path=nemo_file) |
|
|
|
assert len(asr_model.decoder.vocabulary) == len(asr_model2.decoder.vocabulary) |
|
assert asr_model.num_weights == asr_model2.num_weights |
|
|
|
|
|
asr_model2.encoder.encoder[0].mconv[0].conv.weight.data += 1.0 |
|
|
|
w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
|
|
assert not np.array_equal(w1, w2) |
|
|
|
|
|
asr_model2.load_state_dict(torch.load(ckpt_path)) |
|
|
|
w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
|
|
assert np.array_equal(w1, w2) |
|
|
|
@pytest.mark.unit |
|
def test_save_module_level_pt_ckpt(self, asr_model): |
|
with tempfile.TemporaryDirectory() as ckpt_dir: |
|
nemo_file = os.path.join(ckpt_dir, 'asr.nemo') |
|
asr_model.save_to(nemo_file) |
|
|
|
|
|
asr_model.extract_state_dict_from(nemo_file, ckpt_dir, split_by_module=True) |
|
encoder_path = os.path.join(ckpt_dir, 'encoder.ckpt') |
|
decoder_path = os.path.join(ckpt_dir, 'decoder.ckpt') |
|
preprocessor_path = os.path.join(ckpt_dir, 'preprocessor.ckpt') |
|
|
|
assert os.path.exists(encoder_path) |
|
assert os.path.exists(decoder_path) |
|
assert os.path.exists(preprocessor_path) |
|
|
|
|
|
asr_model2 = EncDecCTCModel.restore_from(restore_path=nemo_file) |
|
|
|
assert len(asr_model.decoder.vocabulary) == len(asr_model2.decoder.vocabulary) |
|
assert asr_model.num_weights == asr_model2.num_weights |
|
|
|
|
|
asr_model2.encoder.encoder[0].mconv[0].conv.weight.data += 1.0 |
|
|
|
w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
|
|
assert not np.array_equal(w1, w2) |
|
|
|
|
|
asr_model2.encoder.load_state_dict(torch.load(encoder_path)) |
|
|
|
w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy() |
|
|
|
assert np.array_equal(w1, w2) |
|
|