File size: 514 Bytes
7730b56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import numpy as np
import mlx.core as mx
from glob import glob
from safetensors.numpy import save_file
patch_weights = mx.load("49-20240112-184735.npz")
for file in glob("model*.safetensors"):
print(f"{file=}")
weights = mx.load(file)
for k, v in weights.items():
if k in patch_weights:
print(f"patching {k}")
weights[k] = np.array(patch_weights[k], copy=False)
else:
weights[k] = np.array(v, copy=False)
save_file(weights, "patched_" + file)
|