aapot
commited on
Commit
•
a71634a
1
Parent(s):
646ac2a
Fix distributed evall metric
Browse files- EasyLM/jax_utils.py +5 -4
EasyLM/jax_utils.py
CHANGED
@@ -283,10 +283,11 @@ def global_norm(tree):
|
|
283 |
|
284 |
|
285 |
def average_metrics(metrics):
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
290 |
|
291 |
|
292 |
def get_float_dtype_by_name(dtype):
|
|
|
283 |
|
284 |
|
285 |
def average_metrics(metrics):
|
286 |
+
with jax.spmd_mode("allow_all"):
|
287 |
+
return jax.tree_map(
|
288 |
+
lambda *args: jnp.mean(jnp.stack(args)),
|
289 |
+
*metrics
|
290 |
+
)
|
291 |
|
292 |
|
293 |
def get_float_dtype_by_name(dtype):
|