Panchovix Thireus commited on
Commit
7a60a8c
1 Parent(s): eac3f1d

Update bin2safetensors/convert.py (#1)

Browse files

- Update bin2safetensors/convert.py (2d879e05c129c0640abefd5718d17ae5ab5b5b4f)


Co-authored-by: None <Thireus@users.noreply.huggingface.co>

Files changed (1) hide show
  1. bin2safetensors/convert.py +56 -10
bin2safetensors/convert.py CHANGED
@@ -2,6 +2,7 @@ import argparse
2
  import json
3
  import os
4
  import shutil
 
5
  from collections import defaultdict
6
  from inspect import signature
7
  from tempfile import TemporaryDirectory
@@ -311,14 +312,59 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
311
  return new_pr, errors
312
 
313
 
314
- if __name__ == "__main__":
315
- DESCRIPTION = """
316
- Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
317
- It is PyTorch exclusive for now.
318
- It works by downloading the weights (PT), converting them locally, and uploading them back
319
- as a PR on the hub.
320
- """
321
- for i in range(1, 16): # Range starts at 1 and ends at 15
322
- input_filename = f"jondurbin_airoboros-l2-70b-gpt4-1.4.1/pytorch_model-{i:05d}-of-00015.bin"
323
- output_filename = f"pytorch_model-{i:05d}-of-00015.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  convert_file(input_filename, output_filename)
 
 
 
 
 
 
 
 
 
 
2
  import json
3
  import os
4
  import shutil
5
+ import re
6
  from collections import defaultdict
7
  from inspect import signature
8
  from tempfile import TemporaryDirectory
 
312
  return new_pr, errors
313
 
314
 
315
+ def main(input_directory, output_directory):
316
+ # Get a list of all files in the input directory
317
+ files = os.listdir(input_directory)
318
+
319
+ # Filter the list to get only the relevant files
320
+ model_files = [file for file in files if re.match(r'pytorch_model-\d{5}-of-\d{5}\.bin', file)]
321
+
322
+ # Determine the range for the loop based on the number of model files
323
+ num_models = len(model_files)
324
+
325
+ if num_models == 0:
326
+ print("No model files found in the input directory.")
327
+ return
328
+
329
+ # Extract yyyyy from the first model filename
330
+ match = re.search(r'pytorch_model-\d{5}-of-(\d{5})\.bin', model_files[0])
331
+ if match:
332
+ yyyyy = int(match.group(1))
333
+ else:
334
+ print("Unable to determine the number of shards from the filename.")
335
+ return
336
+
337
+ if num_models != yyyyy:
338
+ print("Error: Number of shards mismatch.")
339
+ return
340
+
341
+ # Copy *.json files (except pytorch_model.bin.index.json) from input to output directory
342
+ for file in files:
343
+ if file.endswith('.json') and not file == 'pytorch_model.bin.index.json':
344
+ src = os.path.join(input_directory, file)
345
+ dest = os.path.join(output_directory, file)
346
+ shutil.copy2(src, dest)
347
+ print(f"Copied {src} to {dest}")
348
+
349
+ # Copy *.model files from input to output directory
350
+ for file in files:
351
+ if file.endswith('.model'):
352
+ src = os.path.join(input_directory, file)
353
+ dest = os.path.join(output_directory, file)
354
+ shutil.copy2(src, dest)
355
+ print(f"Copied {src} to {dest}")
356
+
357
+ # Convert and rename model files
358
+ for i in range(1, num_models + 1):
359
+ input_filename = os.path.join(input_directory, f"pytorch_model-{i:05d}-of-{yyyyy:05d}.bin")
360
+ output_filename = os.path.join(output_directory, f"model-{i:05d}-of-{yyyyy:05d}.safetensors")
361
  convert_file(input_filename, output_filename)
362
+ print(f"Converted {input_filename} to {output_filename}")
363
+
364
+ if __name__ == "__main__":
365
+ parser = argparse.ArgumentParser(description="Convert pytorch_model model to safetensor and copy JSON and .model files.")
366
+ parser.add_argument("input_directory", help="Path to the input directory containing pytorch_model files")
367
+ parser.add_argument("output_directory", help="Path to the output directory for converted safetensor files")
368
+ args = parser.parse_args()
369
+
370
+ main(args.input_directory, args.output_directory)