|
import torch |
|
from torch.utils.data.dataloader import default_collate |
|
def collate_video(batch): |
|
''' |
|
Our video is (temporal_crops, C, T, H, W) where temporal_crops differes from clip to clip |
|
We can't use standard collate function. |
|
Instead of stacking, let's do cat |
|
Keep in mind that this will also need list of frame length in order to restore each videos later. |
|
''' |
|
elem = batch[0] |
|
assert isinstance(elem,dict) |
|
output = {key: default_collate([d[key] for d in batch]) for key in elem if key!='input'} |
|
output["input"] = torch.cat([d["input"] for d in batch]) |
|
return output |
|
|