File size: 853 Bytes
be0843c
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from datasets import load_dataset
from torch.utils.data import DataLoader

def get_dataloader(config, tokenizer, split='train'):
    dataset = load_dataset("code_search_net", "python", split=split)
    
    def tokenize_function(examples):
        return tokenizer(examples['whole_func_string'], truncation=True, padding='max_length', max_length=config['model']['max_length'])

    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    tokenized_dataset = tokenized_dataset.remove_columns(['repo', 'path', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'])
    tokenized_dataset.set_format("torch")

    return DataLoader(tokenized_dataset, batch_size=config['training']['batch_size'], shuffle=(split == 'train'))