Nanobit commited on
Commit
28fd429
2 Parent(s): edd6980 45ac7c4

Merge pull request #293 from NanoCode012/fix/tokenize-speed

Browse files
Files changed (1) hide show
  1. src/axolotl/datasets.py +11 -12
src/axolotl/datasets.py CHANGED
@@ -1,12 +1,13 @@
1
  """Module containing Dataset functionality"""
2
 
3
  import logging
 
4
  from typing import List
5
 
6
  import torch
7
  from datasets import IterableDataset
8
 
9
- from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
10
 
11
  # We want this to be a wrapper for an existing dataset that we have loaded
12
  # lets use the concept of middlewares to wrap each dataset, for example
@@ -34,17 +35,15 @@ class TokenizedPromptDataset(IterableDataset):
34
  self.dataset = dataset
35
 
36
  def __iter__(self):
37
- iterator = iter(self.dataset)
38
- count = 0
39
- # Loop through the entire dataset
40
- for example in iterator:
41
- try:
42
- yield self.prompt_tokenizer.tokenize_prompt(example)
43
- count += 1
44
- except InvalidDataException:
45
- pass
46
- if count == 0:
47
- raise RuntimeError("Expected at least one datapoint in dataset.")
48
 
49
 
50
  # TODO this isn't the best since it can't interleave datasets
 
1
  """Module containing Dataset functionality"""
2
 
3
  import logging
4
+ import os
5
  from typing import List
6
 
7
  import torch
8
  from datasets import IterableDataset
9
 
10
+ from .prompt_tokenizers import PromptTokenizingStrategy
11
 
12
  # We want this to be a wrapper for an existing dataset that we have loaded
13
  # lets use the concept of middlewares to wrap each dataset, for example
 
35
  self.dataset = dataset
36
 
37
  def __iter__(self):
38
+ features = self.dataset.features.keys()
39
+ num_proc = os.cpu_count()
40
+ return iter(
41
+ self.dataset.map(
42
+ self.prompt_tokenizer.tokenize_prompt,
43
+ num_proc=num_proc,
44
+ remove_columns=features,
45
+ )
46
+ )
 
 
47
 
48
 
49
  # TODO this isn't the best since it can't interleave datasets