fix: style
Browse files
tools/train/distributed_shampoo.py
CHANGED
|
@@ -36,13 +36,13 @@ import itertools
|
|
| 36 |
from typing import Any, List, NamedTuple
|
| 37 |
|
| 38 |
import chex
|
| 39 |
-
from flax import struct
|
| 40 |
import jax
|
| 41 |
-
from jax import lax
|
| 42 |
import jax.experimental.pjit as pjit
|
| 43 |
import jax.numpy as jnp
|
| 44 |
import numpy as np
|
| 45 |
import optax
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
# pylint:disable=no-value-for-parameter
|
|
|
|
| 36 |
from typing import Any, List, NamedTuple
|
| 37 |
|
| 38 |
import chex
|
|
|
|
| 39 |
import jax
|
|
|
|
| 40 |
import jax.experimental.pjit as pjit
|
| 41 |
import jax.numpy as jnp
|
| 42 |
import numpy as np
|
| 43 |
import optax
|
| 44 |
+
from flax import struct
|
| 45 |
+
from jax import lax
|
| 46 |
|
| 47 |
|
| 48 |
# pylint:disable=no-value-for-parameter
|