|
"""Custom collate function for the data loader.""" |
|
|
|
from typing import Any, List |
|
|
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
def custom_collate(batch: List[Any], device: Any) -> Any: |
|
""" |
|
Custom collate function to be used in the data loader. |
|
:param batch: list, with length equal to number of batches. |
|
:return: processed batch of data [add padding to text, stack tensors in batch] |
|
""" |
|
img, correct_capt, curr_class, word_labels = zip(*batch) |
|
batched_img = torch.stack(img, dim=0).to( |
|
device |
|
) |
|
correct_capt_len = torch.tensor( |
|
[len(capt) for capt in correct_capt], dtype=torch.int64 |
|
).unsqueeze( |
|
1 |
|
) |
|
batched_correct_capt = pad_sequence( |
|
correct_capt, batch_first=True, padding_value=0 |
|
).to( |
|
device |
|
) |
|
batched_curr_class = torch.stack(curr_class, dim=0).to( |
|
device |
|
) |
|
batched_word_labels = pad_sequence( |
|
word_labels, batch_first=True, padding_value=0 |
|
).to( |
|
device |
|
) |
|
return ( |
|
batched_img, |
|
batched_correct_capt, |
|
correct_capt_len, |
|
batched_curr_class, |
|
batched_word_labels, |
|
) |
|
|