|
|
|
|
|
|
|
|
|
|
|
import torch as th |
|
|
|
|
|
def get_generator(generator, num_samples=0, seed=0): |
|
if generator == "dummy": |
|
return DummyGenerator() |
|
elif generator == "determ": |
|
return DeterministicGenerator(num_samples, seed) |
|
elif generator == "determ-indiv": |
|
return DeterministicIndividualGenerator(num_samples, seed) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
class DummyGenerator: |
|
def randn(self, *args, **kwargs): |
|
return th.randn(*args, **kwargs) |
|
|
|
def randint(self, *args, **kwargs): |
|
return th.randint(*args, **kwargs) |
|
|
|
def randn_like(self, *args, **kwargs): |
|
return th.randn_like(*args, **kwargs) |
|
|
|
|
|
class DeterministicGenerator: |
|
""" |
|
RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines |
|
Uses a single rng and samples num_samples sized randomness and subsamples the current indices |
|
""" |
|
|
|
def __init__(self, num_samples, seed=0): |
|
print("Warning: Distributed not initialised, using single rank") |
|
self.rank = 0 |
|
self.world_size = 1 |
|
self.num_samples = num_samples |
|
self.done_samples = 0 |
|
self.seed = seed |
|
self.rng_cpu = th.Generator() |
|
if th.cuda.is_available(): |
|
self.rng_cuda = th.Generator(dist_util.dev()) |
|
self.set_seed(seed) |
|
|
|
def get_global_size_and_indices(self, size): |
|
global_size = (self.num_samples, *size[1:]) |
|
indices = th.arange( |
|
self.done_samples + self.rank, |
|
self.done_samples + self.world_size * int(size[0]), |
|
self.world_size, |
|
) |
|
indices = th.clamp(indices, 0, self.num_samples - 1) |
|
assert ( |
|
len(indices) == size[0] |
|
), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" |
|
return global_size, indices |
|
|
|
def get_generator(self, device): |
|
return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda |
|
|
|
def randn(self, *size, dtype=th.float, device="cpu"): |
|
global_size, indices = self.get_global_size_and_indices(size) |
|
generator = self.get_generator(device) |
|
return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[ |
|
indices |
|
] |
|
|
|
def randint(self, low, high, size, dtype=th.long, device="cpu"): |
|
global_size, indices = self.get_global_size_and_indices(size) |
|
generator = self.get_generator(device) |
|
return th.randint( |
|
low, high, generator=generator, size=global_size, dtype=dtype, device=device |
|
)[indices] |
|
|
|
def randn_like(self, tensor): |
|
size, dtype, device = tensor.size(), tensor.dtype, tensor.device |
|
return self.randn(*size, dtype=dtype, device=device) |
|
|
|
def set_done_samples(self, done_samples): |
|
self.done_samples = done_samples |
|
self.set_seed(self.seed) |
|
|
|
def get_seed(self): |
|
return self.seed |
|
|
|
def set_seed(self, seed): |
|
self.rng_cpu.manual_seed(seed) |
|
if th.cuda.is_available(): |
|
self.rng_cuda.manual_seed(seed) |
|
|
|
|
|
class DeterministicIndividualGenerator: |
|
""" |
|
RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines |
|
Uses a separate rng for each sample to reduce memoery usage |
|
""" |
|
|
|
def __init__(self, num_samples, seed=0): |
|
print("Warning: Distributed not initialised, using single rank") |
|
self.rank = 0 |
|
self.world_size = 1 |
|
self.num_samples = num_samples |
|
self.done_samples = 0 |
|
self.seed = seed |
|
self.rng_cpu = [th.Generator() for _ in range(num_samples)] |
|
if th.cuda.is_available(): |
|
self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)] |
|
self.set_seed(seed) |
|
|
|
def get_size_and_indices(self, size): |
|
indices = th.arange( |
|
self.done_samples + self.rank, |
|
self.done_samples + self.world_size * int(size[0]), |
|
self.world_size, |
|
) |
|
indices = th.clamp(indices, 0, self.num_samples - 1) |
|
assert ( |
|
len(indices) == size[0] |
|
), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" |
|
return (1, *size[1:]), indices |
|
|
|
def get_generator(self, device): |
|
return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda |
|
|
|
def randn(self, *size, dtype=th.float, device="cpu"): |
|
size, indices = self.get_size_and_indices(size) |
|
generator = self.get_generator(device) |
|
return th.cat( |
|
[ |
|
th.randn(*size, generator=generator[i], dtype=dtype, device=device) |
|
for i in indices |
|
], |
|
dim=0, |
|
) |
|
|
|
def randint(self, low, high, size, dtype=th.long, device="cpu"): |
|
size, indices = self.get_size_and_indices(size) |
|
generator = self.get_generator(device) |
|
return th.cat( |
|
[ |
|
th.randint( |
|
low, |
|
high, |
|
generator=generator[i], |
|
size=size, |
|
dtype=dtype, |
|
device=device, |
|
) |
|
for i in indices |
|
], |
|
dim=0, |
|
) |
|
|
|
def randn_like(self, tensor): |
|
size, dtype, device = tensor.size(), tensor.dtype, tensor.device |
|
return self.randn(*size, dtype=dtype, device=device) |
|
|
|
def set_done_samples(self, done_samples): |
|
self.done_samples = done_samples |
|
|
|
def get_seed(self): |
|
return self.seed |
|
|
|
def set_seed(self, seed): |
|
[ |
|
rng_cpu.manual_seed(i + self.num_samples * seed) |
|
for i, rng_cpu in enumerate(self.rng_cpu) |
|
] |
|
if th.cuda.is_available(): |
|
[ |
|
rng_cuda.manual_seed(i + self.num_samples * seed) |
|
for i, rng_cuda in enumerate(self.rng_cuda) |
|
] |
|
|