boris commited on
Commit
696422e
1 Parent(s): 5996680

fix: distributed shampoo class

Browse files
Files changed (1) hide show
  1. tools/train/distributed_shampoo.py +3 -1
tools/train/distributed_shampoo.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # coding=utf-8
2
  # Copyright 2022 The Google Research Authors.
3
  #
@@ -235,7 +237,7 @@ class GraftingType(enum.IntEnum):
235
  RMSPROP = 3
236
  RMSPROP_NORMALIZED = 4
237
  SQRT_N = 5
238
- ADAGRAD_NORMALIZED = 5
239
 
240
 
241
  def power_iteration(
 
1
+ # file from: https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py
2
+
3
  # coding=utf-8
4
  # Copyright 2022 The Google Research Authors.
5
  #
 
237
  RMSPROP = 3
238
  RMSPROP_NORMALIZED = 4
239
  SQRT_N = 5
240
+ ADAGRAD_NORMALIZED = 6
241
 
242
 
243
  def power_iteration(