Update bin2safetensors/convert.py
#1
by
Thireus
- opened
- bin2safetensors/convert.py +56 -10
bin2safetensors/convert.py
CHANGED
@@ -2,6 +2,7 @@ import argparse
|
|
2 |
import json
|
3 |
import os
|
4 |
import shutil
|
|
|
5 |
from collections import defaultdict
|
6 |
from inspect import signature
|
7 |
from tempfile import TemporaryDirectory
|
@@ -311,14 +312,59 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
|
|
311 |
return new_pr, errors
|
312 |
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
convert_file(input_filename, output_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import json
|
3 |
import os
|
4 |
import shutil
|
5 |
+
import re
|
6 |
from collections import defaultdict
|
7 |
from inspect import signature
|
8 |
from tempfile import TemporaryDirectory
|
|
|
312 |
return new_pr, errors
|
313 |
|
314 |
|
315 |
+
def main(input_directory, output_directory):
|
316 |
+
# Get a list of all files in the input directory
|
317 |
+
files = os.listdir(input_directory)
|
318 |
+
|
319 |
+
# Filter the list to get only the relevant files
|
320 |
+
model_files = [file for file in files if re.match(r'pytorch_model-\d{5}-of-\d{5}\.bin', file)]
|
321 |
+
|
322 |
+
# Determine the range for the loop based on the number of model files
|
323 |
+
num_models = len(model_files)
|
324 |
+
|
325 |
+
if num_models == 0:
|
326 |
+
print("No model files found in the input directory.")
|
327 |
+
return
|
328 |
+
|
329 |
+
# Extract yyyyy from the first model filename
|
330 |
+
match = re.search(r'pytorch_model-\d{5}-of-(\d{5})\.bin', model_files[0])
|
331 |
+
if match:
|
332 |
+
yyyyy = int(match.group(1))
|
333 |
+
else:
|
334 |
+
print("Unable to determine the number of shards from the filename.")
|
335 |
+
return
|
336 |
+
|
337 |
+
if num_models != yyyyy:
|
338 |
+
print("Error: Number of shards mismatch.")
|
339 |
+
return
|
340 |
+
|
341 |
+
# Copy *.json files (except pytorch_model.bin.index.json) from input to output directory
|
342 |
+
for file in files:
|
343 |
+
if file.endswith('.json') and not file == 'pytorch_model.bin.index.json':
|
344 |
+
src = os.path.join(input_directory, file)
|
345 |
+
dest = os.path.join(output_directory, file)
|
346 |
+
shutil.copy2(src, dest)
|
347 |
+
print(f"Copied {src} to {dest}")
|
348 |
+
|
349 |
+
# Copy *.model files from input to output directory
|
350 |
+
for file in files:
|
351 |
+
if file.endswith('.model'):
|
352 |
+
src = os.path.join(input_directory, file)
|
353 |
+
dest = os.path.join(output_directory, file)
|
354 |
+
shutil.copy2(src, dest)
|
355 |
+
print(f"Copied {src} to {dest}")
|
356 |
+
|
357 |
+
# Convert and rename model files
|
358 |
+
for i in range(1, num_models + 1):
|
359 |
+
input_filename = os.path.join(input_directory, f"pytorch_model-{i:05d}-of-{yyyyy:05d}.bin")
|
360 |
+
output_filename = os.path.join(output_directory, f"model-{i:05d}-of-{yyyyy:05d}.safetensors")
|
361 |
convert_file(input_filename, output_filename)
|
362 |
+
print(f"Converted {input_filename} to {output_filename}")
|
363 |
+
|
364 |
+
if __name__ == "__main__":
|
365 |
+
parser = argparse.ArgumentParser(description="Convert pytorch_model model to safetensor and copy JSON and .model files.")
|
366 |
+
parser.add_argument("input_directory", help="Path to the input directory containing pytorch_model files")
|
367 |
+
parser.add_argument("output_directory", help="Path to the output directory for converted safetensor files")
|
368 |
+
args = parser.parse_args()
|
369 |
+
|
370 |
+
main(args.input_directory, args.output_directory)
|