par-meta commited on
Commit
ea1fc75
·
unverified ·
1 Parent(s): 9bd51df

Add approximate state persistence (#73)

Browse files

Summary:

Test Plan:

***
More verbose multiprocess logging, fix get_state_and_recycle

Summary:

Test Plan:

bytelatent/args.py CHANGED
@@ -13,7 +13,10 @@ from bytelatent.data.file_util import get_fs
13
  from bytelatent.data.iterators.abstract_iterator import StatefulIterator
14
  from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
15
  from bytelatent.data.iterators.looping_iterator import LoopingIterator
16
- from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
 
 
 
17
  from bytelatent.data.iterators.packing_iterator import (
18
  PackingArgs,
19
  PackingIterator,
@@ -130,6 +133,7 @@ class DataloaderArgs(BaseModel):
130
  add_bos: bool = True
131
  add_eos: bool = True
132
  load_async: bool = True
 
133
  prefetch_size: int = 64
134
  preprocess_dir: str | None = None
135
  dataset_files: list[str] | None = None
@@ -215,7 +219,9 @@ class DataloaderArgs(BaseModel):
215
  packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
216
  if self.load_async:
217
  mp_iterator = MultiprocessIterator(
218
- packing_iterator, n_batches_to_prefetch=self.prefetch_size
 
 
219
  )
220
  return mp_iterator
221
  else:
 
13
  from bytelatent.data.iterators.abstract_iterator import StatefulIterator
14
  from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
15
  from bytelatent.data.iterators.looping_iterator import LoopingIterator
16
+ from bytelatent.data.iterators.multiprocess_iterator import (
17
+ MultiprocessIterator,
18
+ PersistType,
19
+ )
20
  from bytelatent.data.iterators.packing_iterator import (
21
  PackingArgs,
22
  PackingIterator,
 
133
  add_bos: bool = True
134
  add_eos: bool = True
135
  load_async: bool = True
136
+ async_persist_type: PersistType = PersistType.EXACT
137
  prefetch_size: int = 64
138
  preprocess_dir: str | None = None
139
  dataset_files: list[str] | None = None
 
219
  packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
220
  if self.load_async:
221
  mp_iterator = MultiprocessIterator(
222
+ packing_iterator,
223
+ n_batches_to_prefetch=self.prefetch_size,
224
+ persist_type=self.async_persist_type,
225
  )
226
  return mp_iterator
227
  else:
bytelatent/data/iterators/multiprocess_iterator.py CHANGED
@@ -2,6 +2,7 @@
2
  import json
3
  import logging
4
  import multiprocessing as mp
 
5
  from multiprocessing.synchronize import Event as EventClass
6
  from queue import Empty, Full
7
 
@@ -19,11 +20,17 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
19
  logger = logging.getLogger()
20
 
21
 
 
 
 
 
 
22
  class MultiprocessIteratorState(PydanticIteratorState):
23
  model_config = ConfigDict(extra="forbid")
24
  base_iterator_state: PackingIteratorState
25
  n_batches_to_prefetch: int
26
  serialized_prefetch_buffer: str
 
27
 
28
  def build(self):
29
  base_iterator = self.base_iterator_state.build()
@@ -33,14 +40,19 @@ class MultiprocessIteratorState(PydanticIteratorState):
33
  base_iterator,
34
  n_batches_to_prefetch=self.n_batches_to_prefetch,
35
  prefetch_buffer=prefetch_buffer,
 
36
  )
37
 
38
 
39
  def start_work_from_state(
40
  batch_queue: mp.Queue,
41
  state_queue: mp.Queue,
 
42
  stop_event: EventClass,
43
  state_dumped_event: EventClass,
 
 
 
44
  state: IteratorState,
45
  ):
46
  logging.info("Worker thread: Starting base_iterator work")
@@ -49,6 +61,25 @@ def start_work_from_state(
49
  for item in iterator:
50
  while not stop_event.is_set():
51
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Attempt to put on queue or timeout to try again (maybe main thread is busy)
53
  batch_queue.put(item, timeout=0.1)
54
  # On success, stop trying
@@ -58,10 +89,10 @@ def start_work_from_state(
58
  if stop_event.is_set():
59
  # Signal the end of output, this ensures that even if the queue takes a while to
60
  # buffer, that the main thread receives everything (and tosses this fake batch)
61
- logging.debug(
62
  "Worker thread: Stop event detected, outputting is_final=True batch"
63
  )
64
- logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
65
  batch_queue.put(
66
  Batch(
67
  x=np.zeros((1, 1)),
@@ -72,23 +103,26 @@ def start_work_from_state(
72
  ngram_ids=None,
73
  )
74
  )
75
- logging.debug(
76
  "Worker thread: is_final=True batch put in queue, breaking from loop."
77
  )
78
  break
79
 
80
  try:
81
- logging.debug("Worker thread: outputting state")
82
  state_queue.put(stateful_iterator.get_state(), timeout=1)
83
- logging.debug("Worker thread: state dump complete")
84
  state_dumped_event.set()
85
- logging.debug("Worker thread: set state_dump_event")
86
  except Full:
87
  raise ValueError(
88
  "Attempted to dump state into the state queue, but it was full"
89
  )
90
 
91
 
 
 
 
92
  class MultiprocessIterator(StatefulIterator):
93
  """
94
  Design sketch of the multiprocess iterator:
@@ -124,18 +158,24 @@ class MultiprocessIterator(StatefulIterator):
124
  base_iterator: StatefulIterator,
125
  *,
126
  n_batches_to_prefetch: int,
127
- prefetch_buffer: list | None = None
 
128
  ):
129
  self.base_iterator = base_iterator
130
  self.n_batches_to_prefetch = n_batches_to_prefetch
 
131
  if prefetch_buffer is None:
132
  prefetch_buffer = []
133
  self.prefetch_buffer = prefetch_buffer
134
  self.batch_queue = None
135
  self.state_queue = None
 
136
  self.producer = None
137
  self.stop_iterating_event = None
138
  self.state_dumped_event = None
 
 
 
139
  self.force_shutdown = False
140
 
141
  def shutdown(self):
@@ -144,6 +184,92 @@ class MultiprocessIterator(StatefulIterator):
144
  self.producer.kill()
145
  self.force_shutdown = True
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def get_state(self) -> MultiprocessIteratorState:
148
  """
149
  This is slightly unusual in effectively destroying the current iterator, its necessary
@@ -162,55 +288,15 @@ class MultiprocessIterator(StatefulIterator):
162
  base_iterator_state=self.base_iterator.get_state(),
163
  n_batches_to_prefetch=self.n_batches_to_prefetch,
164
  serialized_prefetch_buffer=serialized_prefetch_buffer,
 
165
  )
166
  else:
167
- logging.debug("Main thread: Sending stop iteration event")
168
- self.stop_iterating_event.set()
169
- logging.debug(
170
- "Main thread: Emptying the batch_queue until batch.is_final=True is found."
171
- )
172
- self.prefetch_buffer = []
173
- final_batch_received = False
174
- while True:
175
- try:
176
- batch = self.batch_queue.get(timeout=1)
177
- if batch.is_final:
178
- logging.debug(
179
- "Main thread: is_final=True batch found, stopping fetch from batch_queue"
180
- )
181
- final_batch_received = True
182
- break
183
- self.prefetch_buffer.append(batch)
184
- except Empty:
185
- logging.warning("Main thread: batch_queue is abnormally empty")
186
- assert final_batch_received
187
-
188
- logging.debug("Main thread: Waiting for state_dumped event")
189
- self.state_dumped_event.wait()
190
-
191
- try:
192
- base_iterator_state = self.state_queue.get(timeout=1)
193
- assert isinstance(base_iterator_state, IteratorState)
194
- except Empty:
195
- raise ValueError(
196
- "Attempted to get the state, but it was unexpectantly missing"
197
- )
198
-
199
- self.base_iterator = base_iterator_state.build()
200
- self.producer.close()
201
- self.producer = None
202
- self.batch_queue = None
203
- self.state_queue = None
204
- self.stop_iterating_event = None
205
- self.state_dumped_event = None
206
-
207
- return MultiprocessIteratorState(
208
- base_iterator_state=self.base_iterator.get_state(),
209
- n_batches_to_prefetch=self.n_batches_to_prefetch,
210
- serialized_prefetch_buffer=json.dumps(
211
- [b.to_python_dict() for b in self.prefetch_buffer]
212
- ),
213
- )
214
 
215
  def create_iter(self):
216
  if self.force_shutdown:
@@ -236,8 +322,14 @@ class MultiprocessIterator(StatefulIterator):
236
  # We should only ever one state, which is output at the detection of a stop event
237
  self.state_queue = ctx.Manager().Queue(maxsize=1)
238
 
 
 
 
239
  self.stop_iterating_event = ctx.Event()
240
  self.state_dumped_event = ctx.Event()
 
 
 
241
 
242
  self.producer = mp.Process(
243
  name="blt_data_loader",
@@ -245,8 +337,12 @@ class MultiprocessIterator(StatefulIterator):
245
  args=(
246
  self.batch_queue,
247
  self.state_queue,
 
248
  self.stop_iterating_event,
249
  self.state_dumped_event,
 
 
 
250
  self.base_iterator.get_state(),
251
  ),
252
  )
 
2
  import json
3
  import logging
4
  import multiprocessing as mp
5
+ from enum import Enum
6
  from multiprocessing.synchronize import Event as EventClass
7
  from queue import Empty, Full
8
 
 
20
  logger = logging.getLogger()
21
 
22
 
23
+ class PersistType(str, Enum):
24
+ EXACT = "exact"
25
+ APPROXIMATE = "approximate"
26
+
27
+
28
  class MultiprocessIteratorState(PydanticIteratorState):
29
  model_config = ConfigDict(extra="forbid")
30
  base_iterator_state: PackingIteratorState
31
  n_batches_to_prefetch: int
32
  serialized_prefetch_buffer: str
33
+ persist_type: PersistType
34
 
35
  def build(self):
36
  base_iterator = self.base_iterator_state.build()
 
40
  base_iterator,
41
  n_batches_to_prefetch=self.n_batches_to_prefetch,
42
  prefetch_buffer=prefetch_buffer,
43
+ persist_type=self.persist_type,
44
  )
45
 
46
 
47
  def start_work_from_state(
48
  batch_queue: mp.Queue,
49
  state_queue: mp.Queue,
50
+ approximate_state_queue: mp.Queue,
51
  stop_event: EventClass,
52
  state_dumped_event: EventClass,
53
+ trigger_approximate_send_state_event: EventClass,
54
+ sent_approximate_state_event: EventClass,
55
+ received_approximate_state_event: EventClass,
56
  state: IteratorState,
57
  ):
58
  logging.info("Worker thread: Starting base_iterator work")
 
61
  for item in iterator:
62
  while not stop_event.is_set():
63
  try:
64
+ if trigger_approximate_send_state_event.is_set():
65
+ logger.info("WT: trigger_approximate_send ack")
66
+ # Since this can be triggered again (but only after the state is received on mp),
67
+ # we should cleanup as soon as possible.
68
+ trigger_approximate_send_state_event.clear()
69
+ logging.info("WT: Computing approximate state")
70
+ approximate_state = stateful_iterator.get_state()
71
+ # At this state, there should always be exactly 1 slot.
72
+ # Blocking here would be a bug.
73
+ logger.info("WT: Attempting to send approximate state")
74
+ approximate_state_queue.put(
75
+ approximate_state, block=True, timeout=None
76
+ )
77
+ sent_approximate_state_event.set()
78
+ logger.info("WT: Approximate state sent")
79
+ # Same here, clear events as we no longer need them.
80
+ received_approximate_state_event.wait()
81
+ received_approximate_state_event.clear()
82
+ logger.info("WT: State received by MT, resuming batch iteration")
83
  # Attempt to put on queue or timeout to try again (maybe main thread is busy)
84
  batch_queue.put(item, timeout=0.1)
85
  # On success, stop trying
 
89
  if stop_event.is_set():
90
  # Signal the end of output, this ensures that even if the queue takes a while to
91
  # buffer, that the main thread receives everything (and tosses this fake batch)
92
+ logging.info(
93
  "Worker thread: Stop event detected, outputting is_final=True batch"
94
  )
95
+ logging.info("Worker thread: batch_queue full=%s", batch_queue.full())
96
  batch_queue.put(
97
  Batch(
98
  x=np.zeros((1, 1)),
 
103
  ngram_ids=None,
104
  )
105
  )
106
+ logging.info(
107
  "Worker thread: is_final=True batch put in queue, breaking from loop."
108
  )
109
  break
110
 
111
  try:
112
+ logging.info("Worker thread: outputting state")
113
  state_queue.put(stateful_iterator.get_state(), timeout=1)
114
+ logging.info("Worker thread: state dump complete")
115
  state_dumped_event.set()
116
+ logging.info("Worker thread: set state_dump_event")
117
  except Full:
118
  raise ValueError(
119
  "Attempted to dump state into the state queue, but it was full"
120
  )
121
 
122
 
123
+ FETCH_STATE_TIMEOUT = 120
124
+
125
+
126
  class MultiprocessIterator(StatefulIterator):
127
  """
128
  Design sketch of the multiprocess iterator:
 
158
  base_iterator: StatefulIterator,
159
  *,
160
  n_batches_to_prefetch: int,
161
+ prefetch_buffer: list | None = None,
162
+ persist_type: PersistType = PersistType.EXACT,
163
  ):
164
  self.base_iterator = base_iterator
165
  self.n_batches_to_prefetch = n_batches_to_prefetch
166
+ self.persist_type = persist_type
167
  if prefetch_buffer is None:
168
  prefetch_buffer = []
169
  self.prefetch_buffer = prefetch_buffer
170
  self.batch_queue = None
171
  self.state_queue = None
172
+ self.approximate_state_queue = None
173
  self.producer = None
174
  self.stop_iterating_event = None
175
  self.state_dumped_event = None
176
+ self.trigger_approximate_send_state_event = None
177
+ self.sent_approximate_state_event = None
178
+ self.received_approximate_state_event = None
179
  self.force_shutdown = False
180
 
181
  def shutdown(self):
 
184
  self.producer.kill()
185
  self.force_shutdown = True
186
 
187
+ def _get_state_exact(self):
188
+ logging.info("Main thread: Sending stop iteration event")
189
+ self.stop_iterating_event.set()
190
+ logging.info(
191
+ "Main thread: Emptying the batch_queue until batch.is_final=True is found."
192
+ )
193
+ self.prefetch_buffer = []
194
+ final_batch_received = False
195
+ while True:
196
+ try:
197
+ batch = self.batch_queue.get(timeout=1)
198
+ if batch.is_final:
199
+ logging.info(
200
+ "Main thread: is_final=True batch found, stopping fetch from batch_queue"
201
+ )
202
+ final_batch_received = True
203
+ break
204
+ self.prefetch_buffer.append(batch)
205
+ except Empty:
206
+ logging.warning("Main thread: batch_queue is abnormally empty")
207
+ assert final_batch_received
208
+
209
+ logging.info("Main thread: Waiting for state_dumped event")
210
+ self.state_dumped_event.wait()
211
+
212
+ try:
213
+ logging.info(
214
+ "Main thread: state_dumped_event received, waiting for state from queue"
215
+ )
216
+ base_iterator_state = self.state_queue.get(timeout=FETCH_STATE_TIMEOUT)
217
+ logging.info("Main thread: received state from queue")
218
+ assert isinstance(base_iterator_state, IteratorState)
219
+ except Empty:
220
+ raise ValueError(
221
+ "Attempted to get the state, but it was unexpectantly missing"
222
+ )
223
+
224
+ self.base_iterator = base_iterator_state.build()
225
+ self.producer.close()
226
+ self.producer = None
227
+ self.batch_queue = None
228
+ self.state_queue = None
229
+ self.approximate_state_queue = None
230
+ self.stop_iterating_event = None
231
+ self.state_dumped_event = None
232
+ self.trigger_approximate_send_state_event = None
233
+ self.sent_approximate_state_event = None
234
+ self.received_approximate_state_event = None
235
+
236
+ return MultiprocessIteratorState(
237
+ base_iterator_state=self.base_iterator.get_state(),
238
+ n_batches_to_prefetch=self.n_batches_to_prefetch,
239
+ serialized_prefetch_buffer=json.dumps(
240
+ [b.to_python_dict() for b in self.prefetch_buffer]
241
+ ),
242
+ persist_type=self.persist_type,
243
+ )
244
+
245
+ def _get_state_approximate(self):
246
+ logging.info("MT: Sending approximate get_state request")
247
+ self.trigger_approximate_send_state_event.set()
248
+ logging.info("MT: Waiting for sent_approximate_state_event")
249
+ self.sent_approximate_state_event.wait()
250
+ logging.info("MT: sent_approximate_state_event ack")
251
+ try:
252
+ logging.info("MT: waiting for approximate state in queue")
253
+ base_iterator_state = self.approximate_state_queue.get(
254
+ timeout=FETCH_STATE_TIMEOUT
255
+ )
256
+ logging.info("MT: approximate state received")
257
+ assert isinstance(base_iterator_state, IteratorState)
258
+ assert self.approximate_state_queue.empty()
259
+ except Empty:
260
+ raise ValueError(
261
+ "Attempted to get approximate state, but queue was erroniously empty."
262
+ )
263
+ self.received_approximate_state_event.set()
264
+ return MultiprocessIteratorState(
265
+ base_iterator_state=base_iterator_state,
266
+ n_batches_to_prefetch=self.n_batches_to_prefetch,
267
+ serialized_prefetch_buffer=json.dumps(
268
+ [b.to_python_dict() for b in self.prefetch_buffer]
269
+ ),
270
+ persist_type=self.persist_type,
271
+ )
272
+
273
  def get_state(self) -> MultiprocessIteratorState:
274
  """
275
  This is slightly unusual in effectively destroying the current iterator, its necessary
 
288
  base_iterator_state=self.base_iterator.get_state(),
289
  n_batches_to_prefetch=self.n_batches_to_prefetch,
290
  serialized_prefetch_buffer=serialized_prefetch_buffer,
291
+ persist_type=self.persist_type,
292
  )
293
  else:
294
+ if self.persist_type == PersistType.EXACT:
295
+ return self._get_state_exact()
296
+ elif self.persist_type == PersistType.APPROXIMATE:
297
+ return self._get_state_approximate()
298
+ else:
299
+ raise ValueError("invalid persist_type")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  def create_iter(self):
302
  if self.force_shutdown:
 
322
  # We should only ever one state, which is output at the detection of a stop event
323
  self.state_queue = ctx.Manager().Queue(maxsize=1)
324
 
325
+ # Similarly, there should only ever be one state in flight due to event signals
326
+ self.approximate_state_queue = ctx.Manager().Queue(maxsize=1)
327
+
328
  self.stop_iterating_event = ctx.Event()
329
  self.state_dumped_event = ctx.Event()
330
+ self.trigger_approximate_send_state_event = ctx.Event()
331
+ self.sent_approximate_state_event = ctx.Event()
332
+ self.received_approximate_state_event = ctx.Event()
333
 
334
  self.producer = mp.Process(
335
  name="blt_data_loader",
 
337
  args=(
338
  self.batch_queue,
339
  self.state_queue,
340
+ self.approximate_state_queue,
341
  self.stop_iterating_event,
342
  self.state_dumped_event,
343
+ self.trigger_approximate_send_state_event,
344
+ self.sent_approximate_state_event,
345
+ self.received_approximate_state_event,
346
  self.base_iterator.get_state(),
347
  ),
348
  )
bytelatent/train.py CHANGED
@@ -31,6 +31,7 @@ from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
31
  from bytelatent.data.iterators.multiprocess_iterator import (
32
  MultiprocessIterator,
33
  MultiprocessIteratorState,
 
34
  )
35
  from bytelatent.data.iterators.packing_iterator import PackingIteratorState
36
  from bytelatent.distributed import (
@@ -712,9 +713,15 @@ def train(args: TrainArgs):
712
  if every_n_steps(
713
  train_state, args.checkpoint.dump.every, acc_step=0
714
  ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
715
- train_state.data_loader_state, data_loader, batch_iterator = (
716
- get_state_and_refresh(data_loader)
717
- )
 
 
 
 
 
 
718
  saved = checkpoint.save(
719
  model,
720
  optimizer,
@@ -756,9 +763,16 @@ def train(args: TrainArgs):
756
 
757
  if preemption_flag["flag"]:
758
  if not saved:
759
- train_state.data_loader_state, data_loader, batch_iterator = (
760
- get_state_and_refresh(data_loader)
761
- )
 
 
 
 
 
 
 
762
  checkpoint.save(
763
  model,
764
  optimizer,
@@ -769,21 +783,27 @@ def train(args: TrainArgs):
769
  requeue_slurm_job()
770
  sys.exit(0)
771
 
772
- if not saved:
773
- train_state.data_loader_state, data_loader, batch_iterator = (
774
- get_state_and_refresh(data_loader)
775
- )
776
- checkpoint.save(
777
- model,
778
- optimizer,
779
- train_state,
780
- args,
781
- device_mesh=world_mesh,
782
- )
783
- if isinstance(data_loader, MultiprocessIterator):
784
- logger.info("Closing MP iterator before exiting")
785
- data_loader.shutdown()
786
- gc.collect()
 
 
 
 
 
 
787
 
788
 
789
  def main():
 
31
  from bytelatent.data.iterators.multiprocess_iterator import (
32
  MultiprocessIterator,
33
  MultiprocessIteratorState,
34
+ PersistType,
35
  )
36
  from bytelatent.data.iterators.packing_iterator import PackingIteratorState
37
  from bytelatent.distributed import (
 
713
  if every_n_steps(
714
  train_state, args.checkpoint.dump.every, acc_step=0
715
  ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
716
+ if (
717
+ args.data.load_async
718
+ and args.data.async_persist_type == PersistType.EXACT
719
+ ):
720
+ train_state.data_loader_state, data_loader, batch_iterator = (
721
+ get_state_and_refresh(data_loader)
722
+ )
723
+ else:
724
+ train_state.data_loader_state = data_loader.get_state()
725
  saved = checkpoint.save(
726
  model,
727
  optimizer,
 
763
 
764
  if preemption_flag["flag"]:
765
  if not saved:
766
+ if (
767
+ args.data.load_async
768
+ and args.data.async_persist_type == PersistType.EXACT
769
+ ):
770
+ train_state.data_loader_state, data_loader, batch_iterator = (
771
+ get_state_and_refresh(data_loader)
772
+ )
773
+ else:
774
+ train_state.data_loader_state = data_loader.get_state()
775
+
776
  checkpoint.save(
777
  model,
778
  optimizer,
 
783
  requeue_slurm_job()
784
  sys.exit(0)
785
 
786
+ if not saved:
787
+ if (
788
+ args.data.load_async
789
+ and args.data.async_persist_type == PersistType.EXACT
790
+ ):
791
+ train_state.data_loader_state, data_loader, batch_iterator = (
792
+ get_state_and_refresh(data_loader)
793
+ )
794
+ else:
795
+ train_state.data_loader_state = data_loader.get_state()
796
+ checkpoint.save(
797
+ model,
798
+ optimizer,
799
+ train_state,
800
+ args,
801
+ device_mesh=world_mesh,
802
+ )
803
+ if isinstance(data_loader, MultiprocessIterator):
804
+ logger.info("Closing MP iterator before exiting")
805
+ data_loader.shutdown()
806
+ gc.collect()
807
 
808
 
809
  def main():