boris commited on
Commit
531cd78
1 Parent(s): e669c1b

style: isort

Browse files
tools/train/distributed_shampoo.py CHANGED
@@ -36,13 +36,13 @@ import itertools
36
  from typing import Any, List, NamedTuple
37
 
38
  import chex
39
- from flax import struct
40
  import jax
41
- from jax import lax
42
  import jax.experimental.pjit as pjit
43
  import jax.numpy as jnp
44
  import numpy as np
45
  import optax
 
 
46
 
47
 
48
  # pylint:disable=no-value-for-parameter
 
36
  from typing import Any, List, NamedTuple
37
 
38
  import chex
 
39
  import jax
 
40
  import jax.experimental.pjit as pjit
41
  import jax.numpy as jnp
42
  import numpy as np
43
  import optax
44
+ from flax import struct
45
+ from jax import lax
46
 
47
 
48
  # pylint:disable=no-value-for-parameter
tools/train/train.py CHANGED
@@ -34,6 +34,7 @@ import optax
34
  import transformers
35
  import wandb
36
  from datasets import Dataset
 
37
  from flax import jax_utils, traverse_util
38
  from flax.jax_utils import unreplicate
39
  from flax.serialization import from_bytes, to_bytes
@@ -45,8 +46,6 @@ from transformers import AutoTokenizer, HfArgumentParser
45
  from dalle_mini.data import Dataset
46
  from dalle_mini.model import DalleBart, DalleBartConfig
47
 
48
- from distributed_shampoo import distributed_shampoo, GraftingType
49
-
50
  logger = logging.getLogger(__name__)
51
 
52
 
 
34
  import transformers
35
  import wandb
36
  from datasets import Dataset
37
+ from distributed_shampoo import GraftingType, distributed_shampoo
38
  from flax import jax_utils, traverse_util
39
  from flax.jax_utils import unreplicate
40
  from flax.serialization import from_bytes, to_bytes
 
46
  from dalle_mini.data import Dataset
47
  from dalle_mini.model import DalleBart, DalleBartConfig
48
 
 
 
49
  logger = logging.getLogger(__name__)
50
 
51