Marcel Bischoff commited on
Commit
2f1b075
1 Parent(s): b3c4e71

new checkpoint, safetensor torch

Browse files
Files changed (2) hide show
  1. 589-20240113-071533.npz +3 -0
  2. patch_weights.py +6 -8
589-20240113-071533.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d0bcec2b0181b164a34572c12fecf766321cfdd63b96b551e9345532ebae3da
3
+ size 663426
patch_weights.py CHANGED
@@ -1,18 +1,16 @@
1
  import numpy as np
2
- import mlx.core as mx
3
  from glob import glob
4
- from safetensors.numpy import save_file
5
 
6
- patch_weights = mx.load("49-20240112-184735.npz")
7
 
8
  for file in glob("model*.safetensors"):
9
  print(f"{file=}")
10
- weights = mx.load(file)
11
- for k, v in weights.items():
12
  if k in patch_weights:
13
  print(f"patching {k}")
14
- weights[k] = np.array(patch_weights[k], copy=False)
15
- else:
16
- weights[k] = np.array(v, copy=False)
17
  save_file(weights, "patched_" + file)
18
 
 
1
  import numpy as np
2
+ import torch
3
  from glob import glob
4
+ from safetensors.torch import save_file, load_file
5
 
6
+ patch_weights = np.load("589-20240113-071533.npz")
7
 
8
  for file in glob("model*.safetensors"):
9
  print(f"{file=}")
10
+ weights = load_file(file)
11
+ for k, tensor in weights.items():
12
  if k in patch_weights:
13
  print(f"patching {k}")
14
+ weights[k] = torch.from_numpy(patch_weights[k])
 
 
15
  save_file(weights, "patched_" + file)
16