|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
|
|
|
|
class NewsDataset(Dataset): |
|
def __init__(self, titles, texts, labels=None): |
|
self.titles = titles |
|
self.texts = texts |
|
self.labels = labels |
|
|
|
def __len__(self): |
|
return len(self.titles) |
|
|
|
def __getitem__(self, idx): |
|
if self.labels is not None: |
|
return self.titles[idx], self.texts[idx], self.labels[idx] |
|
return self.titles[idx], self.texts[idx] |
|
|
|
|
|
def create_data_loader(titles, texts, labels=None, batch_size=32, shuffle=False, num_workers=6): |
|
dataset = NewsDataset(titles, texts, labels) |
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, persistent_workers=True) |
|
|