interactSpeech / swift /llm /data_loader.py
Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
from typing import Optional
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
class BatchSamplerShard:
def __init__(self, total_samples: int, batch_size: int, shuffle: bool, drop_last: bool, data_seed: Optional[int]):
self.total_samples = total_samples // self.world_size
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.base_seed = data_seed or 0
self.curr_seed = self.base_seed
@property
def rank(self):
return dist.get_rank() if dist.is_initialized() else 0
@property
def world_size(self):
return dist.get_world_size() if dist.is_initialized() else 1
def __iter__(self):
start_idx = self.rank * self.total_samples
if self.shuffle:
generator = torch.Generator()
generator.manual_seed(self.curr_seed)
total_idx = torch.randperm(self.total_samples * self.world_size, generator=generator).tolist()
total_idx = total_idx[start_idx:start_idx + self.total_samples]
else:
total_idx = list(range(start_idx, start_idx + self.total_samples))
batch = []
# Last batch if not complete will be dropped.
for idx in total_idx:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if not self.drop_last and len(batch) > 0:
yield batch
return
def set_epoch(self, epoch: int):
self.curr_seed = self.base_seed + epoch
def __len__(self) -> int:
if self.drop_last:
return self.total_samples // self.batch_size
else:
return (self.total_samples + self.batch_size - 1) // self.batch_size
class DataLoaderShard(DataLoader):
def __init__(self, dataset, batch_sampler: BatchSamplerShard, **dataloader_params):
self.batch_sampler = batch_sampler
super().__init__(dataset, batch_sampler=self.batch_sampler, **dataloader_params)
def set_epoch(self, epoch: int):
self.batch_sampler.set_epoch(epoch)
class DataLoaderDispatcher:
def __init__(self, base_dataloader):
self.base_dataloader = base_dataloader
@property
def rank(self):
return dist.get_rank(self.group) if dist.is_initialized() else 0
@property
def world_size(self):
return dist.get_world_size(self.group) if dist.is_initialized() else 1
@property
def group(self):
return dist.group.WORLD if dist.is_initialized() else 1
def _scatter_object_list(self, inputs):
if not dist.is_initialized():
return inputs[0]
outputs = [None]
global_src_rank = dist.get_global_rank(self.group, 0)
dist.scatter_object_list(outputs, inputs, global_src_rank, group=self.group)
return outputs[0]
def __iter__(self):
base_iter = iter(self.base_dataloader)
while True:
if self.rank == 0:
try:
data = [next(base_iter) for _ in range(self.world_size)]
except StopIteration:
data = [None] * self.world_size
data = self._scatter_object_list(data)
else:
data = self._scatter_object_list(None)
if data is None:
break
yield data