|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for constructing PyTrees of PartitionSpecs.""" |
|
|
|
|
|
|
|
import re |
|
|
|
from flax.core.frozen_dict import freeze |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from jax.experimental import PartitionSpec as P |
|
|
|
|
|
|
|
_unmatched = object() |
|
|
|
|
|
empty_dict = object() |
|
|
|
|
|
def _match(qs, ks): |
|
"""Return True if regexes in qs match any window of strings in tuple ks.""" |
|
|
|
qts = tuple(map(lambda x: re.compile(x + "$"), qs)) |
|
for i in range(len(ks) - len(qs) + 1): |
|
matches = [x.match(y) for x, y in zip(qts, ks[i:])] |
|
if matches and all(matches): |
|
return True |
|
return False |
|
|
|
|
|
def _replacement_rules(rules): |
|
def replace(key, val): |
|
for rule, replacement in rules: |
|
if _match(rule, key): |
|
return replacement |
|
return val |
|
|
|
return replace |
|
|
|
|
|
|
|
|
|
def _get_partition_rules(): |
|
return [ |
|
|
|
(("transformer", "wpe", "embedding"), P("mp", None)), |
|
(("transformer", "wte", "embedding"), P("mp", None)), |
|
|
|
(("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")), |
|
(("attention", "out_proj", "kernel"), P("mp", None)), |
|
(("attention", "out_proj", "bias"), None), |
|
|
|
(("mlp", "c_fc", "kernel"), P(None, "mp")), |
|
(("mlp", "c_fc", "bias"), P("mp")), |
|
(("mlp", "c_proj", "kernel"), P("mp", None)), |
|
(("mlp", "c_proj", "bias"), None), |
|
|
|
((r"ln_\d+", "bias"), None), |
|
((r"\d+", r"ln_\d+", "scale"), None), |
|
(("ln_f", "bias"), None), |
|
(("ln_f", "scale"), None), |
|
] |
|
|
|
|
|
def set_partitions(in_dict): |
|
rules = _get_partition_rules() |
|
replace = _replacement_rules(rules) |
|
initd = {k: _unmatched for k in flatten_dict(in_dict)} |
|
result = {k: replace(k, v) for k, v in initd.items()} |
|
assert _unmatched not in result.values(), "Incomplete partition spec." |
|
return freeze(unflatten_dict(result)) |
|
|