File size: 760 Bytes
c5cd586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torch.utils.data import Dataset, DataLoader
import torch


class NewsDataset(Dataset):
    def __init__(self, titles, texts, labels=None):
        self.titles = titles
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.titles)

    def __getitem__(self, idx):
        if self.labels is not None:
            return self.titles[idx], self.texts[idx], self.labels[idx]
        return self.titles[idx], self.texts[idx]


def create_data_loader(titles, texts, labels=None, batch_size=32, shuffle=False, num_workers=6):
    dataset = NewsDataset(titles, texts, labels)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, persistent_workers=True)