lodrick-the-lafted's picture
Update densify.py
1f3503c verified
import glob
import json
import torch
from safetensors import safe_open
from safetensors.torch import save_file
def main():
print("densifying...")
files = glob.glob("model-*-of-*.safetensors")
densified = {}
for mf in sorted(files):
st = safe_open(mf, framework='pt')
print(f"processing {mf}")
for key in st.keys():
tensor = st.get_tensor(key)
nk = key.replace("block_sparse_moe.experts.0.w2.weight", "mlp.down_proj.weight")
nk = nk.replace("block_sparse_moe.experts.0.w1.weight", "mlp.gate_proj.weight")
nk = nk.replace("block_sparse_moe.experts.0.w3.weight", "mlp.up_proj.weight")
if "block_sparse_moe.gate.weight" not in key:
densified[nk] = tensor
print(f"+ {key} -> {nk}")
else:
print(f"- {key}")
save_file(densified, f"{mf}")
densified = {}
with open("config.json", "r") as read_conf_file:
conf = json.load(read_conf_file)
with open("config.json.old", "w") as conf_file:
json.dump(conf, conf_file, indent=4)
if 'num_local_experts' in conf:
del conf['num_local_experts']
if 'num_experts_per_tok' in conf:
del conf['num_experts_per_tok']
if 'output_router_logits' in conf:
del conf['output_router_logits']
if 'router_aux_loss_coef' in conf:
del conf['router_aux_loss_coef']
if 'router_jitter_noise' in conf:
del conf['router_jitter_noise']
conf['architectures'] = ['MistralForCausalLM']
conf['model_type'] = 'mistral'
conf['_name_or_path'] = 'lodrick-the-lafted/Densefin-291-Mistral-22B'
with open("config.json", "w") as write_conf_file:
json.dump(conf, write_conf_file, indent=4)
with open("model.safetensors.index.json", "r") as index_file:
index = json.load(index_file)
new_index = {}
new_index['metadata'] = index['metadata']
wm = index['weight_map']
nwm = {}
for weight in wm.keys():
nk = weight.replace("block_sparse_moe.experts.0.w2.weight", "mlp.down_proj.weight")
nk = nk.replace("block_sparse_moe.experts.0.w1.weight", "mlp.gate_proj.weight")
nk = nk.replace("block_sparse_moe.experts.0.w3.weight", "mlp.up_proj.weight")
#print(f"{weight} -> {nk}")
if "block_sparse_moe.gate.weight" not in weight:
nwm[nk] = wm[weight]
new_index['weight_map'] = nwm
with open("model.safetensors.index.json.old", "w") as backup:
json.dump(index, backup, indent=4)
with open("model.safetensors.index.json", "w") as new_st_index:
json.dump(new_index, new_st_index, indent=4)
if __name__ == "__main__":
main()