Add Safetensor Conversion Script + Safetensor File
Browse files- checkpoints/model.safetensors +3 -0
- convert_to_safetensors.py +47 -0
checkpoints/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
|
| 3 |
+
size 16
|
convert_to_safetensors.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from safetensors.torch import save_file
|
| 4 |
+
import glob
|
| 5 |
+
|
| 6 |
+
def convert_model_to_safetensors(model_path, output_path):
|
| 7 |
+
print(f"Looking for PyTorch model files in {model_path}")
|
| 8 |
+
|
| 9 |
+
# Create the output directory if it doesn't exist
|
| 10 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 11 |
+
|
| 12 |
+
# Load the PyTorch model file
|
| 13 |
+
model_files = glob.glob(os.path.join(model_path, "*.pt")) + \
|
| 14 |
+
glob.glob(os.path.join(model_path, "*.pth")) + \
|
| 15 |
+
glob.glob(os.path.join(model_path, "pytorch_model.bin"))
|
| 16 |
+
|
| 17 |
+
if not model_files:
|
| 18 |
+
raise FileNotFoundError(f"No PyTorch model files found in {model_path}")
|
| 19 |
+
|
| 20 |
+
print(f"Found model file(s): {model_files}")
|
| 21 |
+
model_file = model_files[0] # Use the first found model file
|
| 22 |
+
|
| 23 |
+
# Load the state dict
|
| 24 |
+
print(f"Loading model from {model_file}")
|
| 25 |
+
checkpoint = torch.load(model_file, map_location='cpu')
|
| 26 |
+
|
| 27 |
+
# Extract only the model weights, removing metadata
|
| 28 |
+
model_state_dict = {}
|
| 29 |
+
if isinstance(checkpoint, dict):
|
| 30 |
+
if 'state_dict' in checkpoint:
|
| 31 |
+
checkpoint = checkpoint['state_dict']
|
| 32 |
+
# Only keep tensor values
|
| 33 |
+
for key, value in checkpoint.items():
|
| 34 |
+
if isinstance(value, torch.Tensor):
|
| 35 |
+
model_state_dict[key] = value
|
| 36 |
+
|
| 37 |
+
# Save as safetensors
|
| 38 |
+
print(f"Converting to safetensors and saving to {output_path}")
|
| 39 |
+
save_file(model_state_dict, output_path)
|
| 40 |
+
print("Conversion completed successfully!")
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
# Update these paths according to your model location
|
| 44 |
+
model_path = "./checkpoints" # Path to your checkpoints directory
|
| 45 |
+
output_path = "./checkpoints/model.safetensors"
|
| 46 |
+
|
| 47 |
+
convert_model_to_safetensors(model_path, output_path)
|