File size: 309 Bytes
03f6091
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import torch

def collate_fn(batch):
    if isinstance(batch, tuple) and isinstance(batch[0], list):
        return batch
    elif isinstance(batch, list):
        transposed = list(zip(*batch))
        return [collate_fn(samples) for samples in transposed]
    return torch.utils.data.default_collate(batch)