File size: 1,394 Bytes
f7a5cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import random
import torch


def set_random_seed(seed: int):
    torch.manual_seed((seed) % (1 << 31))
    torch.cuda.manual_seed((seed) % (1 << 31))
    torch.cuda.manual_seed_all((seed) % (1 << 31))
    np.random.seed((seed) % (1 << 31))
    random.seed((seed) % (1 << 31))
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


class StackedRandomGenerator:
    """
    Wrapper for torch.Generator that allows specifying a different random seed for each
    sample in a minibatch.
    """

    def __init__(self, device, seeds):
        super().__init__()
        self.generators = [
            torch.Generator(device).manual_seed(int(seed) % (1 << 31)) for seed in seeds
        ]

    def randn_rn(self, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack(
            [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]
        )

    def randn_like(self, input):
        return self.randn_rn(
            input.shape, dtype=input.dtype, layout=input.layout, device=input.device
        )

    def randint(self, *args, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack(
            [
                torch.randint(*args, size=size[1:], generator=gen, **kwargs)
                for gen in self.generators
            ]
        )