from torch.utils.data import Dataset, DataLoader import torch class NewsDataset(Dataset): def __init__(self, titles, texts, labels=None): self.titles = titles self.texts = texts self.labels = labels def __len__(self): return len(self.titles) def __getitem__(self, idx): if self.labels is not None: return self.titles[idx], self.texts[idx], self.labels[idx] return self.titles[idx], self.texts[idx] def create_data_loader(titles, texts, labels=None, batch_size=32, shuffle=False, num_workers=6): dataset = NewsDataset(titles, texts, labels) return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, persistent_workers=True)