File size: 435 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import pytest
import torch

from ding.torch_utils.backend_helper import enable_tf32


@pytest.mark.cudatest
class TestBackendHelper:

    def test_tf32(self):
        r"""
        Overview:
            Test the tf32.
        """
        enable_tf32()
        net = torch.nn.Linear(3, 4)
        x = torch.randn(1, 3)
        y = torch.sum(net(x))
        net.zero_grad()
        y.backward()
        assert net.weight.grad is not None