Spaces:
Restarting
on
Zero
Restarting
on
Zero
Reduce per file resources arrow uses (#77)
Browse files
bytelatent/args.py
CHANGED
@@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel):
|
|
138 |
preprocess_dir: str | None = None
|
139 |
dataset_files: list[str] | None = None
|
140 |
entropy_model_name: str | None = "transformer_100m"
|
141 |
-
|
|
|
142 |
buffer_size: int = 64
|
143 |
file_format: str = "arrow"
|
144 |
|
|
|
138 |
preprocess_dir: str | None = None
|
139 |
dataset_files: list[str] | None = None
|
140 |
entropy_model_name: str | None = "transformer_100m"
|
141 |
+
# Be very careful with increasing, increases memory usage by that factor per rank, per data source
|
142 |
+
arrow_batch_size: int = 20
|
143 |
buffer_size: int = 64
|
144 |
file_format: str = "arrow"
|
145 |
|
bytelatent/data/iterators/arrow_iterator.py
CHANGED
@@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator):
|
|
226 |
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
227 |
yield out
|
228 |
|
229 |
-
self.batch_iterator = self.dataset.to_batches(
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
for batch in self.batch_iterator:
|
231 |
batch_columns = batch.to_pydict()
|
232 |
if self.file_format == "arrow":
|
|
|
226 |
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
227 |
yield out
|
228 |
|
229 |
+
self.batch_iterator = self.dataset.to_batches(
|
230 |
+
batch_size=self.arrow_batch_size,
|
231 |
+
# We have large files in GBs, no need to readahead
|
232 |
+
fragment_readahead=1,
|
233 |
+
# Don't readahead in case batches are huge (e.g., books)
|
234 |
+
batch_readahead=1,
|
235 |
+
)
|
236 |
for batch in self.batch_iterator:
|
237 |
batch_columns = batch.to_pydict()
|
238 |
if self.file_format == "arrow":
|
bytelatent/data/iterators/sequence_iterator.py
CHANGED
@@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import (
|
|
10 |
PydanticIteratorState,
|
11 |
StatefulIterator,
|
12 |
)
|
|
|
|
|
|
|
13 |
from bytelatent.data.iterators.preprocess_iterator import (
|
14 |
PreprocessIterator,
|
15 |
PreprocessIteratorState,
|
@@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState):
|
|
40 |
)
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
class SequenceIterator(StatefulIterator):
|
44 |
def __init__(
|
45 |
self,
|
@@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator):
|
|
74 |
tokens: list[int] = []
|
75 |
mask: list[bool] = []
|
76 |
first = True
|
|
|
|
|
|
|
|
|
77 |
for example in example_iter:
|
78 |
assert example.tokens is not None
|
79 |
assert example.mask is not None
|
@@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator):
|
|
97 |
while len(patch_lengths) >= n_buffer_patches:
|
98 |
if first:
|
99 |
first = False
|
100 |
-
logger.info(
|
|
|
|
|
|
|
101 |
|
102 |
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
103 |
self.buffer_size, self.output_seq_len
|
|
|
10 |
PydanticIteratorState,
|
11 |
StatefulIterator,
|
12 |
)
|
13 |
+
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
14 |
+
from bytelatent.data.iterators.limit_iterator import LimitIterator
|
15 |
+
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
16 |
from bytelatent.data.iterators.preprocess_iterator import (
|
17 |
PreprocessIterator,
|
18 |
PreprocessIteratorState,
|
|
|
43 |
)
|
44 |
|
45 |
|
46 |
+
def get_datafile(
|
47 |
+
iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator,
|
48 |
+
):
|
49 |
+
if isinstance(iterator, ArrowFileIterator):
|
50 |
+
return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}"
|
51 |
+
elif isinstance(iterator, PreprocessIterator):
|
52 |
+
return get_datafile(iterator.arrow_iterator)
|
53 |
+
elif isinstance(iterator, LoopingIterator):
|
54 |
+
return get_datafile(iterator.file_iterator)
|
55 |
+
elif isinstance(iterator, LimitIterator):
|
56 |
+
return get_datafile(iterator.base_iterator)
|
57 |
+
else:
|
58 |
+
raise NotImplementedError()
|
59 |
+
|
60 |
+
|
61 |
class SequenceIterator(StatefulIterator):
|
62 |
def __init__(
|
63 |
self,
|
|
|
92 |
tokens: list[int] = []
|
93 |
mask: list[bool] = []
|
94 |
first = True
|
95 |
+
logger.info(
|
96 |
+
"Starting first buffer for: %s",
|
97 |
+
get_datafile(self.preprocess_iterator),
|
98 |
+
)
|
99 |
for example in example_iter:
|
100 |
assert example.tokens is not None
|
101 |
assert example.mask is not None
|
|
|
119 |
while len(patch_lengths) >= n_buffer_patches:
|
120 |
if first:
|
121 |
first = False
|
122 |
+
logger.info(
|
123 |
+
"First buffer complete for: %s",
|
124 |
+
get_datafile(self.preprocess_iterator),
|
125 |
+
)
|
126 |
|
127 |
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
128 |
self.buffer_size, self.output_seq_len
|
bytelatent/iterate_data.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import pyarrow
|
4 |
+
import typer
|
5 |
+
from rich.progress import track
|
6 |
+
|
7 |
+
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
|
8 |
+
from bytelatent.logger import init_logger
|
9 |
+
|
10 |
+
|
11 |
+
def main(state_file: str):
|
12 |
+
init_logger()
|
13 |
+
pyarrow.set_io_thread_count(4)
|
14 |
+
pyarrow.set_cpu_count(4)
|
15 |
+
with open(state_file) as f:
|
16 |
+
train_state = json.load(f)
|
17 |
+
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
|
18 |
+
packing_iterator_state = dl_state.base_iterator_state
|
19 |
+
print("building")
|
20 |
+
packing_iterator = packing_iterator_state.build()
|
21 |
+
print("iter")
|
22 |
+
batch_iter = packing_iterator.create_iter()
|
23 |
+
batch = None
|
24 |
+
print("looping")
|
25 |
+
for i in track(range(1_000)):
|
26 |
+
batch = next(batch_iter)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
typer.run(main)
|
bytelatent/train.py
CHANGED
@@ -13,6 +13,7 @@ from timeit import default_timer as timer
|
|
13 |
from typing import Any, TypeVar
|
14 |
|
15 |
import numpy as np
|
|
|
16 |
import torch
|
17 |
import torch.distributed
|
18 |
import torch.nn.functional
|
@@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale):
|
|
266 |
|
267 |
def train(args: TrainArgs):
|
268 |
with ExitStack() as context_stack:
|
|
|
|
|
269 |
tokenizer = args.data.tokenizer_args.build()
|
270 |
validate_train_args(
|
271 |
args,
|
|
|
13 |
from typing import Any, TypeVar
|
14 |
|
15 |
import numpy as np
|
16 |
+
import pyarrow
|
17 |
import torch
|
18 |
import torch.distributed
|
19 |
import torch.nn.functional
|
|
|
267 |
|
268 |
def train(args: TrainArgs):
|
269 |
with ExitStack() as context_stack:
|
270 |
+
pyarrow.set_io_thread_count(4)
|
271 |
+
pyarrow.set_cpu_count(4)
|
272 |
tokenizer = args.data.tokenizer_args.build()
|
273 |
validate_train_args(
|
274 |
args,
|