File size: 375 Bytes
214ea91 f4f441a 214ea91 f4f441a 214ea91 f4f441a 214ea91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
import spaces
class TorchSeedContext:
def __init__(self, seed):
self.seed = seed
self.state = None
@spaces.GPU
def __enter__(self):
self.state = torch.random.get_rng_state()
torch.manual_seed(self.seed)
@spaces.GPU
def __exit__(self, type, value, traceback):
torch.random.set_rng_state(self.state)
|