|
import glob |
|
import json |
|
import torch |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
|
|
def main(): |
|
print("densifying...") |
|
files = glob.glob("model-*-of-*.safetensors") |
|
densified = {} |
|
|
|
for mf in sorted(files): |
|
st = safe_open(mf, framework='pt') |
|
print(f"processing {mf}") |
|
|
|
for key in st.keys(): |
|
tensor = st.get_tensor(key) |
|
nk = key.replace("block_sparse_moe.experts.0.w2.weight", "mlp.down_proj.weight") |
|
nk = nk.replace("block_sparse_moe.experts.0.w1.weight", "mlp.gate_proj.weight") |
|
nk = nk.replace("block_sparse_moe.experts.0.w3.weight", "mlp.up_proj.weight") |
|
|
|
if "block_sparse_moe.gate.weight" not in key: |
|
densified[nk] = tensor |
|
print(f"+ {key} -> {nk}") |
|
else: |
|
print(f"- {key}") |
|
save_file(densified, f"{mf}") |
|
densified = {} |
|
|
|
with open("config.json", "r") as read_conf_file: |
|
conf = json.load(read_conf_file) |
|
with open("config.json.old", "w") as conf_file: |
|
json.dump(conf, conf_file, indent=4) |
|
|
|
if 'num_local_experts' in conf: |
|
del conf['num_local_experts'] |
|
if 'num_experts_per_tok' in conf: |
|
del conf['num_experts_per_tok'] |
|
if 'output_router_logits' in conf: |
|
del conf['output_router_logits'] |
|
if 'router_aux_loss_coef' in conf: |
|
del conf['router_aux_loss_coef'] |
|
if 'router_jitter_noise' in conf: |
|
del conf['router_jitter_noise'] |
|
|
|
conf['architectures'] = ['MistralForCausalLM'] |
|
conf['model_type'] = 'mistral' |
|
conf['_name_or_path'] = 'lodrick-the-lafted/Densefin-291-Mistral-22B' |
|
|
|
with open("config.json", "w") as write_conf_file: |
|
json.dump(conf, write_conf_file, indent=4) |
|
|
|
with open("model.safetensors.index.json", "r") as index_file: |
|
index = json.load(index_file) |
|
|
|
new_index = {} |
|
new_index['metadata'] = index['metadata'] |
|
|
|
wm = index['weight_map'] |
|
nwm = {} |
|
|
|
for weight in wm.keys(): |
|
nk = weight.replace("block_sparse_moe.experts.0.w2.weight", "mlp.down_proj.weight") |
|
nk = nk.replace("block_sparse_moe.experts.0.w1.weight", "mlp.gate_proj.weight") |
|
nk = nk.replace("block_sparse_moe.experts.0.w3.weight", "mlp.up_proj.weight") |
|
|
|
|
|
if "block_sparse_moe.gate.weight" not in weight: |
|
nwm[nk] = wm[weight] |
|
|
|
new_index['weight_map'] = nwm |
|
|
|
with open("model.safetensors.index.json.old", "w") as backup: |
|
json.dump(index, backup, indent=4) |
|
with open("model.safetensors.index.json", "w") as new_st_index: |
|
json.dump(new_index, new_st_index, indent=4) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|