import pytest | |
import torch | |
from ding.torch_utils.backend_helper import enable_tf32 | |
class TestBackendHelper: | |
def test_tf32(self): | |
r""" | |
Overview: | |
Test the tf32. | |
""" | |
enable_tf32() | |
net = torch.nn.Linear(3, 4) | |
x = torch.randn(1, 3) | |
y = torch.sum(net(x)) | |
net.zero_grad() | |
y.backward() | |
assert net.weight.grad is not None | |