File size: 484 Bytes
7730b56
2f1b075
7730b56
2f1b075
7730b56
2f1b075
7730b56
 
 
2f1b075
 
7730b56
 
2f1b075
0e6349c
7730b56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
import torch
from glob import glob
from safetensors.torch import save_file, load_file

patch_weights = np.load("589-20240113-071533.npz")

for file in glob("model*.safetensors"):
    print(f"{file=}")
    weights = load_file(file)
    for k, tensor in weights.items():
        if k in patch_weights:
            print(f"patching {k}")
            weights[k] = torch.from_numpy(patch_weights[k])
    save_file(weights, "patched_" + file, metadata={"format": "pt"})