Mentors4EDU commited on
Commit
e9af1a0
·
verified ·
1 Parent(s): b73e538

Add Safetensor Conversion Script + Safetensor File

Browse files
checkpoints/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
3
+ size 16
convert_to_safetensors.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from safetensors.torch import save_file
4
+ import glob
5
+
6
+ def convert_model_to_safetensors(model_path, output_path):
7
+ print(f"Looking for PyTorch model files in {model_path}")
8
+
9
+ # Create the output directory if it doesn't exist
10
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
11
+
12
+ # Load the PyTorch model file
13
+ model_files = glob.glob(os.path.join(model_path, "*.pt")) + \
14
+ glob.glob(os.path.join(model_path, "*.pth")) + \
15
+ glob.glob(os.path.join(model_path, "pytorch_model.bin"))
16
+
17
+ if not model_files:
18
+ raise FileNotFoundError(f"No PyTorch model files found in {model_path}")
19
+
20
+ print(f"Found model file(s): {model_files}")
21
+ model_file = model_files[0] # Use the first found model file
22
+
23
+ # Load the state dict
24
+ print(f"Loading model from {model_file}")
25
+ checkpoint = torch.load(model_file, map_location='cpu')
26
+
27
+ # Extract only the model weights, removing metadata
28
+ model_state_dict = {}
29
+ if isinstance(checkpoint, dict):
30
+ if 'state_dict' in checkpoint:
31
+ checkpoint = checkpoint['state_dict']
32
+ # Only keep tensor values
33
+ for key, value in checkpoint.items():
34
+ if isinstance(value, torch.Tensor):
35
+ model_state_dict[key] = value
36
+
37
+ # Save as safetensors
38
+ print(f"Converting to safetensors and saving to {output_path}")
39
+ save_file(model_state_dict, output_path)
40
+ print("Conversion completed successfully!")
41
+
42
+ if __name__ == "__main__":
43
+ # Update these paths according to your model location
44
+ model_path = "./checkpoints" # Path to your checkpoints directory
45
+ output_path = "./checkpoints/model.safetensors"
46
+
47
+ convert_model_to_safetensors(model_path, output_path)