import numpy as np import torch from glob import glob from safetensors.torch import save_file, load_file patch_weights = np.load("589-20240113-071533.npz") for file in glob("model*.safetensors"): print(f"{file=}") weights = load_file(file) for k, tensor in weights.items(): if k in patch_weights: print(f"patching {k}") weights[k] = torch.from_numpy(patch_weights[k]) save_file(weights, "patched_" + file, metadata={"format": "pt"})