kimic's picture
Initial commit for LSTM with GloVe embeddings
6f9bfc0
from torch.utils.data import Dataset, DataLoader
import torch
class NewsDataset(Dataset):
def __init__(self, titles, texts, labels=None):
self.titles = titles
self.texts = texts
self.labels = labels
def __len__(self):
return len(self.titles)
def __getitem__(self, idx):
if self.labels is not None:
return self.titles[idx], self.texts[idx], self.labels[idx]
return self.titles[idx], self.texts[idx]
def create_data_loader(titles, texts, labels=None, batch_size=32, shuffle=False, num_workers=6):
dataset = NewsDataset(titles, texts, labels)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, persistent_workers=True)