# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # All contributions by Andy Brock: # Copyright (c) 2019 Andy Brock # # -*- coding: utf-8 -*- # File : unittest.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import unittest import torch class TorchTestCase(unittest.TestCase): def assertTensorClose(self, x, y): adiff = float((x - y).abs().max()) if (y == 0).all(): rdiff = "NaN" else: rdiff = float((adiff / y).abs().max()) message = ("Tensor close check failed\n" "adiff={}\n" "rdiff={}\n").format( adiff, rdiff ) self.assertTrue(torch.allclose(x, y), message)