boris commited on
Commit
df1fe19
·
1 Parent(s): 49597a2

feat(train): no batch dimension with pjit

Browse files
src/dalle_mini/model/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
 
3
  from .tokenizer import DalleBartTokenizer
 
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
3
+ from .partitions import set_partitions
4
  from .tokenizer import DalleBartTokenizer
tools/train/train.py CHANGED
@@ -38,7 +38,7 @@ from distributed_shampoo import GraftingType, distributed_shampoo
38
  from flax.core.frozen_dict import freeze
39
  from flax.serialization import from_bytes, to_bytes
40
  from flax.training import train_state
41
- from flax.training.common_utils import get_metrics, onehot
42
  from jax.experimental import PartitionSpec, maps
43
  from jax.experimental.pjit import pjit
44
  from tqdm import tqdm
@@ -764,7 +764,6 @@ def main():
764
  ),
765
  )
766
 
767
- grads = jax.lax.pmean(grads, "batch")
768
  state = state.apply_gradients(
769
  grads=grads,
770
  dropout_rng=new_dropout_rng,
@@ -776,7 +775,6 @@ def main():
776
  "loss": loss,
777
  "learning_rate": learning_rate_fn(state.step),
778
  }
779
- metrics = jax.lax.pmean(metrics, axis_name="batch")
780
 
781
  return state, metrics
782
 
@@ -788,7 +786,6 @@ def main():
788
 
789
  # summarize metrics
790
  metrics = {"loss": loss}
791
- metrics = jax.lax.pmean(metrics, axis_name="batch")
792
  return metrics
793
 
794
  # Create parallel version of the train and eval step
@@ -861,7 +858,7 @@ def main():
861
  eval_metrics.append(metrics)
862
 
863
  # normalize eval metrics
864
- eval_metrics = get_metrics(eval_metrics)
865
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
866
 
867
  # log metrics
 
38
  from flax.core.frozen_dict import freeze
39
  from flax.serialization import from_bytes, to_bytes
40
  from flax.training import train_state
41
+ from flax.training.common_utils import onehot, stack_forest
42
  from jax.experimental import PartitionSpec, maps
43
  from jax.experimental.pjit import pjit
44
  from tqdm import tqdm
 
764
  ),
765
  )
766
 
 
767
  state = state.apply_gradients(
768
  grads=grads,
769
  dropout_rng=new_dropout_rng,
 
775
  "loss": loss,
776
  "learning_rate": learning_rate_fn(state.step),
777
  }
 
778
 
779
  return state, metrics
780
 
 
786
 
787
  # summarize metrics
788
  metrics = {"loss": loss}
 
789
  return metrics
790
 
791
  # Create parallel version of the train and eval step
 
858
  eval_metrics.append(metrics)
859
 
860
  # normalize eval metrics
861
+ eval_metrics = stack_forest(eval_metrics)
862
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
863
 
864
  # log metrics