gpt2-chitchat / dataset.py
mason0510's picture
Duplicate from xibaozi/gpt2-chitchat
bd5d31b
raw
history blame contribute delete
No virus
479 Bytes
from torch.utils.data import Dataset
import torch
class MyDataset(Dataset):
"""
"""
def __init__(self, input_list, max_len):
self.input_list = input_list
self.max_len = max_len
def __getitem__(self, index):
input_ids = self.input_list[index]
input_ids = input_ids[:self.max_len]
input_ids = torch.tensor(input_ids, dtype=torch.long)
return input_ids
def __len__(self):
return len(self.input_list)