|
import torchvision.datasets as dset |
|
from torch.utils.data import Dataset |
|
import torch |
|
from torch.utils.data import DataLoader |
|
import glob |
|
import os |
|
from transformers import AutoTokenizer |
|
from torch.utils.data import Dataset, DataLoader, random_split |
|
|
|
|
|
class GithubDataset(Dataset): |
|
def __init__( |
|
self, |
|
root_dir=os.path.expanduser("~/torch_datasets/github-python/corpus"), |
|
train=False, |
|
max_length=512, |
|
): |
|
self.root = root_dir |
|
self.file_list = glob.glob(os.path.join(root_dir, "*.*")) |
|
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.file_list) |
|
|
|
def __getitem__(self, idx): |
|
|
|
path = self.file_list[idx] |
|
|
|
with open(path, "r", encoding="utf-8", errors="ignore") as file: |
|
code = file.read() |
|
|
|
encoding = self.tokenizer( |
|
code, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors="pt", |
|
) |
|
|
|
input_ids = encoding["input_ids"].squeeze(0) |
|
attention_mask = encoding["attention_mask"].squeeze(0) |
|
|
|
|
|
|
|
return input_ids, attention_mask |
|
|
|
|
|
dataset = GithubDataset() |
|
dataset = GithubDataset(root_dir="./test-data/") |
|
train_size = int(0.8 * len(dataset)) |
|
test_size = len(dataset) - train_size |
|
|
|
train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) |
|
|
|
|
|
def get_train_dataset(): |
|
return train_dataset |
|
|
|
|
|
def get_test_dataset(): |
|
|
|
return test_dataset |
|
|
|
|
|
def get_dataloader(dataset, batch_size=64): |
|
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
d = get_train_dataset() |
|
print("Number of samples: ", len(d)) |
|
|
|
a, b = d[4] |
|
t = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
for i in a: |
|
print(t.decode(i.item()), end=" ") |
|
print() |
|
|