from torch.utils.data import Dataset, DataLoader | |
import torch | |
class NewsDataset(Dataset): | |
def __init__(self, titles, texts, labels=None): | |
self.titles = titles | |
self.texts = texts | |
self.labels = labels | |
def __len__(self): | |
return len(self.titles) | |
def __getitem__(self, idx): | |
if self.labels is not None: | |
return self.titles[idx], self.texts[idx], self.labels[idx] | |
return self.titles[idx], self.texts[idx] | |
def create_data_loader(titles, texts, labels=None, batch_size=32, shuffle=False, num_workers=6): | |
dataset = NewsDataset(titles, texts, labels) | |
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, persistent_workers=True) | |