File size: 375 Bytes
214ea91
f4f441a
 
214ea91
 
 
 
 
 
f4f441a
214ea91
 
 
 
f4f441a
214ea91
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import spaces


class TorchSeedContext:
    def __init__(self, seed):
        self.seed = seed
        self.state = None

    @spaces.GPU
    def __enter__(self):
        self.state = torch.random.get_rng_state()
        torch.manual_seed(self.seed)

    @spaces.GPU
    def __exit__(self, type, value, traceback):
        torch.random.set_rng_state(self.state)