Spaces:
Runtime error
Runtime error
import numpy as np | |
from torch.utils.data import DistributedSampler as _DistributedSampler | |
class DistributedVideoSampler(_DistributedSampler): | |
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False): | |
super().__init__(dataset, num_replicas=num_replicas, rank=rank) | |
self.shuffle = shuffle | |
assert not self.shuffle, "Specific for video sequential testing." | |
self.num_samples = len(dataset) | |
first_frame_indices = [] | |
for i, img_info in enumerate(self.dataset.data_infos): | |
if img_info["frame_id"] == 0: | |
first_frame_indices.append(i) | |
chunks = np.array_split(first_frame_indices, num_replicas) | |
split_flags = [c[0] for c in chunks] | |
split_flags.append(self.num_samples) | |
self.indices = [ | |
list(range(split_flags[i], split_flags[i + 1])) | |
for i in range(self.num_replicas) | |
] | |
def __iter__(self): | |
indices = self.indices[self.rank] | |
return iter(indices) | |