# -*- coding: utf-8 -*- # File : unittest.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import unittest import numpy as np from torch.autograd import Variable def as_numpy(v): if isinstance(v, Variable): v = v.data return v.cpu().numpy() class TorchTestCase(unittest.TestCase): def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): npa, npb = as_numpy(a), as_numpy(b) self.assertTrue( np.allclose(npa, npb, atol=atol), 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) )