LoCoNet_ASD / videoloaders /collate_functions.py
xiziwang
push files
2e36228
raw
history blame
No virus
620 Bytes
import torch
from torch.utils.data.dataloader import default_collate
def collate_video(batch):
'''
Our video is (temporal_crops, C, T, H, W) where temporal_crops differes from clip to clip
We can't use standard collate function.
Instead of stacking, let's do cat
Keep in mind that this will also need list of frame length in order to restore each videos later.
'''
elem = batch[0]
assert isinstance(elem,dict)
output = {key: default_collate([d[key] for d in batch]) for key in elem if key!='input'}
output["input"] = torch.cat([d["input"] for d in batch])
return output