from typing import Optional import time import torch import torch.distributed as dist from torch.utils.data import DataLoader from tqdm import tqdm from swift.llm import to_device from swift.utils import get_logger from swift.utils.torch_utils import time_synchronize logger = get_logger() class BatchSamplerShard: def __init__(self, total_samples: int, batch_size: int, shuffle: bool, drop_last: bool, data_seed: Optional[int], tp_size: int = 1): self.tp_size = tp_size self.total_samples = total_samples // self.world_size self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last self.base_seed = data_seed or 0 self.curr_seed = self.base_seed @property def rank(self): return (dist.get_rank() // self.tp_size) if dist.is_initialized() else 0 @property def world_size(self): return (dist.get_world_size() // self.tp_size) if dist.is_initialized() else 1 def __iter__(self): start_idx = self.rank * self.total_samples if self.shuffle: generator = torch.Generator() generator.manual_seed(self.curr_seed) total_idx = torch.randperm(self.total_samples * self.world_size, generator=generator).tolist() total_idx = total_idx[start_idx:start_idx + self.total_samples] else: total_idx = list(range(start_idx, start_idx + self.total_samples)) batch = [] # Last batch if not complete will be dropped. for idx in total_idx: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if not self.drop_last and len(batch) > 0: yield batch return def set_epoch(self, epoch: int): self.curr_seed = self.base_seed + epoch def __len__(self) -> int: if self.drop_last: return self.total_samples // self.batch_size else: return (self.total_samples + self.batch_size - 1) // self.batch_size class DataLoaderShard(DataLoader): def __init__(self, dataset, device=None, **dataloader_params): self.device = device super().__init__(dataset, **dataloader_params) def set_epoch(self, epoch: int): if self.batch_sampler is not None and hasattr(self.batch_sampler, 'set_epoch'): self.batch_sampler.set_epoch(epoch) elif self.sampler is not None and hasattr(self.sampler, 'set_epoch'): self.sampler.set_epoch(epoch) def __iter__(self): # batch_count = 0 for item in super().__iter__(): # batch_count += 1 # # 测量数据加载时间(从磁盘/内存获取数据) # data_load_start = time_synchronize() # # 测量设备传输时间(CPU -> GPU) # device_transfer_start = time_synchronize() if self.device: item = to_device(item, self.device) # device_transfer_end = time_synchronize() # data_load_end = time_synchronize() # # 计算时间 # device_transfer_time = device_transfer_end - device_transfer_start # total_data_time = data_load_end - data_load_start # # 每批次都打印数据加载时间信息 # logger.warning(f"[DATA_IO] Batch {batch_count} - Device Transfer: {device_transfer_time:.4f}s, " # f"Total Data Loading: {total_data_time:.4f}s") # logger.info(f"[DATA_IO] Batch {batch_count} - Device Transfer: {device_transfer_time:.4f}s, " # f"Total Data Loading: {total_data_time:.4f}s") yield item class DataLoaderDispatcher: def __init__(self, base_dataloader, device=None, skip_batches: int = 0): self.base_dataloader = base_dataloader self.device = device self.skip_batches = skip_batches @property def rank(self): return dist.get_rank(self.group) if dist.is_initialized() else 0 @property def world_size(self): return dist.get_world_size(self.group) if dist.is_initialized() else 1 @property def group(self): return dist.group.WORLD if dist.is_initialized() else 1 def _scatter_object_list(self, inputs): if not dist.is_initialized(): return inputs[0] outputs = [None] global_src_rank = dist.get_global_rank(self.group, 0) dist.scatter_object_list(outputs, inputs, global_src_rank, group=self.group) return outputs[0] def _skip_batches(self, base_iter): if self.rank == 0 and self.skip_batches > 0: for _ in tqdm(range(self.skip_batches), dynamic_ncols=True, desc='Skip Batches: '): [next(base_iter) for _ in range(self.world_size)] def __iter__(self): base_iter = iter(self.base_dataloader) self._skip_batches(base_iter) # batch_count = 0 while True: # batch_count += 1 # # 测量数据获取时间 # data_fetch_start = time_synchronize() if self.rank == 0: try: # # 测量从基础dataloader获取数据的时间 # iter_start = time_synchronize() data = [next(base_iter) for _ in range(self.world_size)] # iter_end = time_synchronize() # iter_time = iter_end - iter_start except StopIteration: data = [None] * self.world_size # iter_time = 0 # # 测量分散数据的时间 # scatter_start = time_synchronize() data = self._scatter_object_list(data) # scatter_end = time_synchronize() # scatter_time = scatter_end - scatter_start # logger.warning(f"[DATA_IO] Batch {batch_count} Rank {self.rank} - " # f"Iterator: {iter_time:.4f}s, Scatter: {scatter_time:.4f}s") # logger.info(f"[DATA_IO] Batch {batch_count} Rank {self.rank} - " # f"Iterator: {iter_time:.4f}s, Scatter: {scatter_time:.4f}s") else: # scatter_start = time_synchronize() data = self._scatter_object_list(None) # scatter_end = time_synchronize() # scatter_time = scatter_end - scatter_start # logger.warning(f"[DATA_IO] Batch {batch_count} Rank {self.rank} - " # f"Scatter Wait: {scatter_time:.4f}s") # logger.info(f"[DATA_IO] Batch {batch_count} Rank {self.rank} - " # f"Scatter Wait: {scatter_time:.4f}s") if data is None: break # # 测量设备传输时间 # device_transfer_start = time_synchronize() if self.device: data = to_device(data, self.device) # device_transfer_end = time_synchronize() # data_fetch_end = time_synchronize() # # 计算总时间 # device_transfer_time = device_transfer_end - device_transfer_start # total_data_time = data_fetch_end - data_fetch_start # logger.warning(f"[DATA_IO] Batch {batch_count} Rank {self.rank} - " # f"Device Transfer: {device_transfer_time:.4f}s, " # f"Total Data Time: {total_data_time:.4f}s") # logger.info(f"[DATA_IO] Batch {batch_count} Rank {self.rank} - " # f"Device Transfer: {device_transfer_time:.4f}s, " # f"Total Data Time: {total_data_time:.4f}s") yield data