MartialTerran commited on
Commit
6b7bcb2
·
verified ·
1 Parent(s): 2d8700a

Create Print_and_save_list_of_contents_within_model.safetensors.py

Browse files
Print_and_save_list_of_contents_within_model.safetensors.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Run from the command line:
2
+ # python Print_and_save_list_of_contents_within_model.safetensors.py model.safetensors
3
+
4
+ # The script will create a text file named safetensors_contents.txt (or the name you specify) containing the information about the tensors within your model.safetensors file. This will include the shape (dimensionality), data type (e.g., float32, int64), and potentially even a preview of the data itself if it is an object type, providing insights into the model's structure and parameters. It will print a confirmation message indicating where the file is saved. You will be able to search this text file for lm_head.weight. If you can find lm_head.weight in safetensors_contents.txt, but your original script cannot find it, that is a strong indicator that something may be wrong with your original script's load_weights function.
5
+
6
+ import numpy as np
7
+ print(f"NumPy version = {np.__version__}")
8
+ print(np.bfloat16) # If no error is printed and the version is 1.20.0 or higher then you should be good to go
9
+ print("If no error is printed and the version is 1.20.0 or higher then you should be good to go")
10
+
11
+ # The error "module 'numpy' has no attribute 'bfloat16'" indicates that your NumPy library version doesn't support the bfloat16 data type. bfloat16 (Brain Floating Point) is a relatively newer data type, so older NumPy versions won't have it. Using an older version of numpy may also cause other compatibility problems. The bfloat16 conversion is a last resort if you cannot update numpy.
12
+
13
+ # The most straightforward solution is to update NumPy to a version that includes bfloat16 support (NumPy 1.20.0 and later). 1. Update NumPy: pip install --upgrade numpy or python3 -m pip install --upgrade numpy to specify the correct Python interpreter for which you want to upgrade NumPy. Or py -m pip install --upgrade numpy
14
+
15
+ 4. Uninstall and Reinstall: Completely remove NumPy and then reinstall it:
16
+
17
+ # pip uninstall numpy
18
+ # pip install numpy
19
+
20
+ # Verify: After the upgrade completes, open a new command prompt or terminal (or restart your current one if you use the same one for Python ) and run the Python interpreter:
21
+
22
+ # 2. Force Reinstallation: Sometimes, a simple upgrade might not replace files correctly. Try a forced reinstallation:
23
+
24
+ # pip install --force-reinstall numpy
25
+
26
+ # Successfully installed pip-24.3.1
27
+
28
+ # The error message "AttributeError: module 'numpy' has no attribute 'bfloat16'. Did you mean: 'float16'?" indicates that standard NumPy installations do not include bfloat16 as a built-in data type. While NumPy does support float16, bfloat16 (Brain Floating Point 16) is a different 16-bit floating-point format commonly used in machine learning for training and inference. It's designed to offer better performance in specific hardware, often sacrificing some precision compared to float32.
29
+
30
+ # To use bfloat16 with NumPy, you need to install the bfloat16 package. You can do this using pip:
31
+ # pip install bfloat16
32
+
33
+ # This package extends NumPy, adding the bfloat16 dtype and enabling its use in most standard NumPy operations.[1][2] After installation, you should be able to use bfloat16 similar to other NumPy dtypes. For example:
34
+
35
+ import numpy as np
36
+ import bfloat16
37
+
38
+ # Create a NumPy array with bfloat16 dtype
39
+ x = np.array([1.2, 3.4, 5.6], dtype=bfloat16.bfloat16)
40
+
41
+ # Perform operations with the bfloat16 array
42
+ y = x * 2
43
+ print(f"bfloat16: {y}")
44
+
45
+ # Convert to other dtypes
46
+ z = y.astype(np.float32)
47
+ print(f"z = y.astype(np.float32) {z}")
48
+
49
+
50
+ import os
51
+ import sys
52
+ import torch
53
+ from safetensors.numpy import load_file
54
+
55
+ def print_and_save_safetensors_contents(weights_path: str, output_file: str = "safetensors_contents.txt"):
56
+ """
57
+ Prints and saves the contents of a safetensors file to a text file.
58
+
59
+ Args:
60
+ weights_path: The path to the safetensors file.
61
+ output_file: The name of the output text file.
62
+ """
63
+
64
+ # File Existence and Absolute Path Verification
65
+ weights_path = os.path.abspath(weights_path) # Get absolute path
66
+ if not os.path.exists(weights_path):
67
+ print(f"Error: Weights file not found at: {weights_path}")
68
+ sys.exit(1)
69
+
70
+ try:
71
+ tensors = load_file(weights_path)
72
+
73
+ with open(output_file, "w", encoding="utf-8") as f:
74
+ for key, tensor in tensors.items():
75
+ f.write(f"Tensor: {key}\n")
76
+ f.write(f" Shape: {tensor.shape}\n")
77
+ f.write(f" Dtype: {tensor.dtype}\n")
78
+
79
+ # describe data characteristics in more detail
80
+ if tensor.dtype == 'object':
81
+ try:
82
+ # Attempt to decode as string, handling potential errors
83
+ decoded_data = tensor.astype(str).item().decode('UTF-8', errors='replace')
84
+ f.write(f"Decoded data: {decoded_data} \n")
85
+ except:
86
+ f.write(" Data: Cannot be displayed (likely binary or complex object)\n")
87
+ elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
88
+ f.write(f" Data Type: Integer \n")
89
+ elif tensor.dtype in [torch.float16, torch.float32, torch.bfloat16]:
90
+ f.write(f" Data Type: Float \n")
91
+ elif tensor.dtype == torch.bool:
92
+ f.write(f" Data Type: Boolean \n")
93
+ else:
94
+ f.write("Data: Cannot determine characteristics. \n")
95
+
96
+ # if key == "lm_head.weight":
97
+ # f.write(f"lm_head.weight exists")
98
+
99
+ f.write("\n") # Add a separator between tensors
100
+
101
+
102
+ print(f"Safetensors contents saved to: {output_file}")
103
+
104
+ except Exception as e:
105
+ print(f"An error occurred: {e}")
106
+ sys.exit(1)
107
+
108
+
109
+
110
+ if __name__ == "__main__":
111
+ if len(sys.argv) != 2:
112
+ print("Usage: python Print_and_save_list_of_contents_within_model.safetensors.py <path_to_model.safetensors>")
113
+ sys.exit(1)
114
+
115
+ weights_file = sys.argv[1]
116
+ print_and_save_safetensors_contents(weights_file)