EasyAnimate / easyanimate /utils /discrete_sampler.py
bubbliiiing
Update V5
f62c8b9
raw
history blame
2.02 kB
"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
"""
import torch
class DiscreteSampling:
def __init__(self, num_idx, uniform_sampling=False):
self.num_idx = num_idx
self.uniform_sampling = uniform_sampling
self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
if self.is_distributed and self.uniform_sampling:
world_size = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()
i = 1
while True:
if world_size % i != 0 or num_idx % (world_size // i) != 0:
i += 1
else:
self.group_num = world_size // i
break
assert self.group_num > 0
assert world_size % self.group_num == 0
# the number of rank in one group
self.group_width = world_size // self.group_num
self.sigma_interval = self.num_idx // self.group_num
print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
self.rank, world_size, self.group_num,
self.group_width, self.sigma_interval))
def __call__(self, n_samples, generator=None, device=None):
if self.is_distributed and self.uniform_sampling:
group_index = self.rank // self.group_width
idx = torch.randint(
group_index * self.sigma_interval,
(group_index + 1) * self.sigma_interval,
(n_samples,),
generator=generator, device=device,
)
print('proc[%d] idx=%s' % (self.rank, idx))
else:
idx = torch.randint(
0, self.num_idx, (n_samples,),
generator=generator, device=device,
)
return idx