Spaces:
Running
Running
style: isort
Browse files
tools/train/distributed_shampoo.py
CHANGED
@@ -36,13 +36,13 @@ import itertools
|
|
36 |
from typing import Any, List, NamedTuple
|
37 |
|
38 |
import chex
|
39 |
-
from flax import struct
|
40 |
import jax
|
41 |
-
from jax import lax
|
42 |
import jax.experimental.pjit as pjit
|
43 |
import jax.numpy as jnp
|
44 |
import numpy as np
|
45 |
import optax
|
|
|
|
|
46 |
|
47 |
|
48 |
# pylint:disable=no-value-for-parameter
|
|
|
36 |
from typing import Any, List, NamedTuple
|
37 |
|
38 |
import chex
|
|
|
39 |
import jax
|
|
|
40 |
import jax.experimental.pjit as pjit
|
41 |
import jax.numpy as jnp
|
42 |
import numpy as np
|
43 |
import optax
|
44 |
+
from flax import struct
|
45 |
+
from jax import lax
|
46 |
|
47 |
|
48 |
# pylint:disable=no-value-for-parameter
|
tools/train/train.py
CHANGED
@@ -34,6 +34,7 @@ import optax
|
|
34 |
import transformers
|
35 |
import wandb
|
36 |
from datasets import Dataset
|
|
|
37 |
from flax import jax_utils, traverse_util
|
38 |
from flax.jax_utils import unreplicate
|
39 |
from flax.serialization import from_bytes, to_bytes
|
@@ -45,8 +46,6 @@ from transformers import AutoTokenizer, HfArgumentParser
|
|
45 |
from dalle_mini.data import Dataset
|
46 |
from dalle_mini.model import DalleBart, DalleBartConfig
|
47 |
|
48 |
-
from distributed_shampoo import distributed_shampoo, GraftingType
|
49 |
-
|
50 |
logger = logging.getLogger(__name__)
|
51 |
|
52 |
|
|
|
34 |
import transformers
|
35 |
import wandb
|
36 |
from datasets import Dataset
|
37 |
+
from distributed_shampoo import GraftingType, distributed_shampoo
|
38 |
from flax import jax_utils, traverse_util
|
39 |
from flax.jax_utils import unreplicate
|
40 |
from flax.serialization import from_bytes, to_bytes
|
|
|
46 |
from dalle_mini.data import Dataset
|
47 |
from dalle_mini.model import DalleBart, DalleBartConfig
|
48 |
|
|
|
|
|
49 |
logger = logging.getLogger(__name__)
|
50 |
|
51 |
|