spelling-correction / dataset /autocorrect_dataset.py
hoang1007's picture
Upload 69 files
44db343
raw
history blame
No virus
414 Bytes
import torch
import numpy as np
class SpellCorrectDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset[idx]
def __len__(self):
return len(self.dataset)
def take(self, n = 1):
indies = np.random.choice(len(self.dataset), n)
return [self.dataset[idx] for idx in indies]