par-meta commited on
Commit
63913e4
·
unverified ·
1 Parent(s): 8f2cf88

Reduce per file resources arrow uses (#77)

Browse files
bytelatent/args.py CHANGED
@@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel):
138
  preprocess_dir: str | None = None
139
  dataset_files: list[str] | None = None
140
  entropy_model_name: str | None = "transformer_100m"
141
- arrow_batch_size: int = 100
 
142
  buffer_size: int = 64
143
  file_format: str = "arrow"
144
 
 
138
  preprocess_dir: str | None = None
139
  dataset_files: list[str] | None = None
140
  entropy_model_name: str | None = "transformer_100m"
141
+ # Be very careful with increasing, increases memory usage by that factor per rank, per data source
142
+ arrow_batch_size: int = 20
143
  buffer_size: int = 64
144
  file_format: str = "arrow"
145
 
bytelatent/data/iterators/arrow_iterator.py CHANGED
@@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator):
226
  if (self.row_num - 1) % self.num_workers == self.worker_id:
227
  yield out
228
 
229
- self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
 
 
 
 
 
 
230
  for batch in self.batch_iterator:
231
  batch_columns = batch.to_pydict()
232
  if self.file_format == "arrow":
 
226
  if (self.row_num - 1) % self.num_workers == self.worker_id:
227
  yield out
228
 
229
+ self.batch_iterator = self.dataset.to_batches(
230
+ batch_size=self.arrow_batch_size,
231
+ # We have large files in GBs, no need to readahead
232
+ fragment_readahead=1,
233
+ # Don't readahead in case batches are huge (e.g., books)
234
+ batch_readahead=1,
235
+ )
236
  for batch in self.batch_iterator:
237
  batch_columns = batch.to_pydict()
238
  if self.file_format == "arrow":
bytelatent/data/iterators/sequence_iterator.py CHANGED
@@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import (
10
  PydanticIteratorState,
11
  StatefulIterator,
12
  )
 
 
 
13
  from bytelatent.data.iterators.preprocess_iterator import (
14
  PreprocessIterator,
15
  PreprocessIteratorState,
@@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState):
40
  )
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class SequenceIterator(StatefulIterator):
44
  def __init__(
45
  self,
@@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator):
74
  tokens: list[int] = []
75
  mask: list[bool] = []
76
  first = True
 
 
 
 
77
  for example in example_iter:
78
  assert example.tokens is not None
79
  assert example.mask is not None
@@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator):
97
  while len(patch_lengths) >= n_buffer_patches:
98
  if first:
99
  first = False
100
- logger.info("First buffer complete")
 
 
 
101
 
102
  x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
103
  self.buffer_size, self.output_seq_len
 
10
  PydanticIteratorState,
11
  StatefulIterator,
12
  )
13
+ from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
14
+ from bytelatent.data.iterators.limit_iterator import LimitIterator
15
+ from bytelatent.data.iterators.looping_iterator import LoopingIterator
16
  from bytelatent.data.iterators.preprocess_iterator import (
17
  PreprocessIterator,
18
  PreprocessIteratorState,
 
43
  )
44
 
45
 
46
+ def get_datafile(
47
+ iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator,
48
+ ):
49
+ if isinstance(iterator, ArrowFileIterator):
50
+ return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}"
51
+ elif isinstance(iterator, PreprocessIterator):
52
+ return get_datafile(iterator.arrow_iterator)
53
+ elif isinstance(iterator, LoopingIterator):
54
+ return get_datafile(iterator.file_iterator)
55
+ elif isinstance(iterator, LimitIterator):
56
+ return get_datafile(iterator.base_iterator)
57
+ else:
58
+ raise NotImplementedError()
59
+
60
+
61
  class SequenceIterator(StatefulIterator):
62
  def __init__(
63
  self,
 
92
  tokens: list[int] = []
93
  mask: list[bool] = []
94
  first = True
95
+ logger.info(
96
+ "Starting first buffer for: %s",
97
+ get_datafile(self.preprocess_iterator),
98
+ )
99
  for example in example_iter:
100
  assert example.tokens is not None
101
  assert example.mask is not None
 
119
  while len(patch_lengths) >= n_buffer_patches:
120
  if first:
121
  first = False
122
+ logger.info(
123
+ "First buffer complete for: %s",
124
+ get_datafile(self.preprocess_iterator),
125
+ )
126
 
127
  x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
128
  self.buffer_size, self.output_seq_len
bytelatent/iterate_data.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import pyarrow
4
+ import typer
5
+ from rich.progress import track
6
+
7
+ from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
8
+ from bytelatent.logger import init_logger
9
+
10
+
11
+ def main(state_file: str):
12
+ init_logger()
13
+ pyarrow.set_io_thread_count(4)
14
+ pyarrow.set_cpu_count(4)
15
+ with open(state_file) as f:
16
+ train_state = json.load(f)
17
+ dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
18
+ packing_iterator_state = dl_state.base_iterator_state
19
+ print("building")
20
+ packing_iterator = packing_iterator_state.build()
21
+ print("iter")
22
+ batch_iter = packing_iterator.create_iter()
23
+ batch = None
24
+ print("looping")
25
+ for i in track(range(1_000)):
26
+ batch = next(batch_iter)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ typer.run(main)
bytelatent/train.py CHANGED
@@ -13,6 +13,7 @@ from timeit import default_timer as timer
13
  from typing import Any, TypeVar
14
 
15
  import numpy as np
 
16
  import torch
17
  import torch.distributed
18
  import torch.nn.functional
@@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale):
266
 
267
  def train(args: TrainArgs):
268
  with ExitStack() as context_stack:
 
 
269
  tokenizer = args.data.tokenizer_args.build()
270
  validate_train_args(
271
  args,
 
13
  from typing import Any, TypeVar
14
 
15
  import numpy as np
16
+ import pyarrow
17
  import torch
18
  import torch.distributed
19
  import torch.nn.functional
 
267
 
268
  def train(args: TrainArgs):
269
  with ExitStack() as context_stack:
270
+ pyarrow.set_io_thread_count(4)
271
+ pyarrow.set_cpu_count(4)
272
  tokenizer = args.data.tokenizer_args.build()
273
  validate_train_args(
274
  args,