taim-gan / src /data /collate.py
Dmmc's picture
three-model version
c8ddb9b
raw
history blame
1.34 kB
"""Custom collate function for the data loader."""
from typing import Any, List
import torch
from torch.nn.utils.rnn import pad_sequence
def custom_collate(batch: List[Any], device: Any) -> Any:
"""
Custom collate function to be used in the data loader.
:param batch: list, with length equal to number of batches.
:return: processed batch of data [add padding to text, stack tensors in batch]
"""
img, correct_capt, curr_class, word_labels = zip(*batch)
batched_img = torch.stack(img, dim=0).to(
device
) # shape: (batch_size, 3, height, width)
correct_capt_len = torch.tensor(
[len(capt) for capt in correct_capt], dtype=torch.int64
).unsqueeze(
1
) # shape: (batch_size, 1)
batched_correct_capt = pad_sequence(
correct_capt, batch_first=True, padding_value=0
).to(
device
) # shape: (batch_size, max_seq_len)
batched_curr_class = torch.stack(curr_class, dim=0).to(
device
) # shape: (batch_size, 1)
batched_word_labels = pad_sequence(
word_labels, batch_first=True, padding_value=0
).to(
device
) # shape: (batch_size, max_seq_len)
return (
batched_img,
batched_correct_capt,
correct_capt_len,
batched_curr_class,
batched_word_labels,
)