import numpy as np import mlx.core as mx from glob import glob from safetensors.numpy import save_file patch_weights = mx.load("49-20240112-184735.npz") for file in glob("model*.safetensors"): print(f"{file=}") weights = mx.load(file) for k, v in weights.items(): if k in patch_weights: print(f"patching {k}") weights[k] = np.array(patch_weights[k], copy=False) else: weights[k] = np.array(v, copy=False) save_file(weights, "patched_" + file)