| import torch
|
| import torch.distributed as dist
|
| from collections.abc import Iterator
|
|
|
| def _is_batch_valid(batch):
|
| """
|
| Check if a batch is valid for training/evaluation.
|
| A valid batch must have input_ids and at least one image.
|
| """
|
| if not batch:
|
| return False
|
|
|
| if len(batch['input_ids']) == 0:
|
| return False
|
|
|
| if len(batch['images']) == 0:
|
| return False
|
|
|
|
|
| if len([img for sublist in batch['images'] for img in sublist]) == 0:
|
|
|
|
|
| return False
|
|
|
| return True
|
|
|
|
|
| def synchronized_dataloader_step(train_loader, is_dist):
|
| """
|
| Create a synchronized iterator that handles uneven data distribution in DDP.
|
| All ranks will stop when the first rank runs out of data.
|
| This happens because when packing a presharded dataset, a rank might have less groups than the others.
|
| It also handles cases where a collator returns an empty/invalid batch on some ranks,
|
| by ensuring all ranks skip the invalid batch and attempt to fetch a new one.
|
| """
|
| if not is_dist:
|
|
|
| for batch in train_loader:
|
| if _is_batch_valid(batch):
|
| yield batch
|
| return
|
|
|
|
|
| if isinstance(train_loader, Iterator):
|
| train_iter = train_loader
|
| else:
|
| train_iter = iter(train_loader)
|
|
|
| while True:
|
| is_valid = False
|
| try:
|
| while not is_valid:
|
| batch = next(train_iter)
|
| is_valid = _is_batch_valid(batch)
|
| has_data = torch.tensor(1, device=torch.cuda.current_device())
|
| except StopIteration:
|
| batch = None
|
| has_data = torch.tensor(0, device=torch.cuda.current_device())
|
|
|
|
|
| dist.all_reduce(has_data, op=dist.ReduceOp.MIN)
|
|
|
| if has_data.item() == 0:
|
|
|
| break
|
| yield batch
|
| return None |