Text Generation
Transformers
PyTorch
Safetensors
Finnish
llama
finnish
text-generation-inference
aapot commited on
Commit
edcd2f1
1 Parent(s): a971b09

Fix distributed evall metric

Browse files
Files changed (1) hide show
  1. EasyLM/jax_utils.py +5 -4
EasyLM/jax_utils.py CHANGED
@@ -283,10 +283,11 @@ def global_norm(tree):
283
 
284
 
285
  def average_metrics(metrics):
286
- return jax.tree_map(
287
- lambda *args: jnp.mean(jnp.stack(args)),
288
- *metrics
289
- )
 
290
 
291
 
292
  def get_float_dtype_by_name(dtype):
 
283
 
284
 
285
  def average_metrics(metrics):
286
+ with jax.spmd_mode("allow_all"):
287
+ return jax.tree_map(
288
+ lambda *args: jnp.mean(jnp.stack(args)),
289
+ *metrics
290
+ )
291
 
292
 
293
  def get_float_dtype_by_name(dtype):