No gather single gpu (#523)
Browse files* don't attempt to gather on multi-gpu
* also check distributed status in bench callback
src/axolotl/utils/callbacks.py
CHANGED
@@ -27,6 +27,7 @@ from axolotl.utils.distributed import (
|
|
27 |
barrier,
|
28 |
gather_scalar_from_all_ranks,
|
29 |
get_world_size,
|
|
|
30 |
is_main_process,
|
31 |
zero_first,
|
32 |
)
|
@@ -270,10 +271,13 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|
270 |
lambda: len(data_loader), get_world_size()
|
271 |
)
|
272 |
|
273 |
-
if not is_main_process():
|
274 |
dist.gather_object(local_bench_names, dst=0)
|
275 |
else:
|
276 |
-
|
|
|
|
|
|
|
277 |
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
278 |
results = {f"{bench_split}_bench_loss": bench_loss}
|
279 |
|
|
|
27 |
barrier,
|
28 |
gather_scalar_from_all_ranks,
|
29 |
get_world_size,
|
30 |
+
is_distributed,
|
31 |
is_main_process,
|
32 |
zero_first,
|
33 |
)
|
|
|
271 |
lambda: len(data_loader), get_world_size()
|
272 |
)
|
273 |
|
274 |
+
if is_distributed() and not is_main_process():
|
275 |
dist.gather_object(local_bench_names, dst=0)
|
276 |
else:
|
277 |
+
if is_distributed():
|
278 |
+
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
279 |
+
else:
|
280 |
+
gathered_bench_names = [local_bench_names]
|
281 |
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
282 |
results = {f"{bench_split}_bench_loss": bench_loss}
|
283 |
|
src/axolotl/utils/distributed.py
CHANGED
@@ -74,6 +74,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|
74 |
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
75 |
"""
|
76 |
value_scalar = fn()
|
|
|
|
|
77 |
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
78 |
|
79 |
if not is_main_process():
|
|
|
74 |
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
75 |
"""
|
76 |
value_scalar = fn()
|
77 |
+
if not is_distributed():
|
78 |
+
return [value_scalar]
|
79 |
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
80 |
|
81 |
if not is_main_process():
|