bornjre commited on
Commit
422090a
1 Parent(s): 2cffbf2

delete file

Browse files
Files changed (1) hide show
  1. btlm_3b_convert.py +0 -242
btlm_3b_convert.py DELETED
@@ -1,242 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """btlm_3b_convert.ipynb
3
-
4
- Automatically generated by Colaboratory.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1kuWDvfRlc0BplrityIIwpdf404mljhRK
8
- """
9
-
10
- !pip install -U transformers
11
- !pip install -U accelerate #git+https://github.com/huggingface/accelerate.git
12
- !pip install -U bitsandbytes #git+ https://github.com/timdettmers/bitsandbytes.git
13
-
14
- from transformers import AutoTokenizer, AutoModelForCausalLM
15
-
16
- # Load the tokenizer and model
17
- tokenizer = AutoTokenizer.from_pretrained("cerebras/btlm-3b-8k-base")
18
- model = AutoModelForCausalLM.from_pretrained(
19
- "cerebras/btlm-3b-8k-base",
20
- trust_remote_code=True,
21
- torch_dtype="auto",
22
- load_in_8bit=True,
23
- offload_folder="offload",
24
- )
25
-
26
- # Set the prompt for generating text
27
- prompt = "Albert Einstein was known for "
28
-
29
- # Tokenize the prompt and convert to PyTorch tensors
30
- inputs = tokenizer(prompt, return_tensors="pt")
31
-
32
- # Generate text using the model
33
- outputs = model.generate(
34
- **inputs,
35
- num_beams=5,
36
- max_new_tokens=50,
37
- early_stopping=True,
38
- no_repeat_ngram_size=2
39
- )
40
-
41
- # Convert the generated token IDs back to text
42
- generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
43
-
44
- # Print the generated text
45
- print(generated_text[0])
46
-
47
- print(model)
48
-
49
- list_vars = model.state_dict()
50
- for name in list_vars.keys():
51
- print(name, "=>", list_vars[name])
52
-
53
- import sys
54
- import os
55
- import struct
56
- import json
57
-
58
- import torch
59
- from transformers import AutoConfig
60
-
61
- config = AutoConfig.from_pretrained("cerebras/btlm-3b-8k-base", trust_remote_code=True)
62
- hparams = config.to_dict()
63
- fname_out = "btlm-3b.ggml.bin"
64
-
65
- print(json.dumps(hparams, indent=4, sort_keys=True))
66
-
67
- import re
68
- import numpy as np
69
- import tensorflow as tf
70
-
71
- bfloat16 = tf.bfloat16.as_numpy_dtype
72
-
73
-
74
- fout = open(fname_out, "wb")
75
-
76
- fout.write(struct.pack("i", 0x67676D6C))
77
- fout.write(struct.pack("i", hparams["vocab_size"]))
78
- fout.write(struct.pack("i", hparams["n_positions"]))
79
- fout.write(struct.pack("i", hparams["n_embd"]))
80
- fout.write(struct.pack("i", hparams["n_head"]))
81
- fout.write(struct.pack("i", hparams["n_layer"]))
82
- fout.write(struct.pack("i", hparams["n_inner"]))
83
- fout.write(struct.pack("i", 1))
84
-
85
- for i in range(hparams["vocab_size"]):
86
- text = tokenizer.decode([i]).encode('utf-8')
87
- fout.write(struct.pack("i", len(text)))
88
- fout.write(text)
89
-
90
-
91
- # for name in list_vars.keys():
92
- # print(name, "=>", list_vars[name])
93
-
94
-
95
- for name in list_vars.keys():
96
- if name[-14:] == ".weight_format":
97
- print("FOUND " + name)
98
- continue
99
-
100
-
101
-
102
-
103
- print("Processing variable: " + name)
104
- data = list_vars[name].squeeze().cpu().type(dtype=torch.float16).numpy()
105
- print(" with shape: ", data.shape)
106
-
107
- # rename headers to keep compatibility
108
- if name == "transformer.ln_f.weight":
109
- name = "model/ln_f/g"
110
- elif name == "transformer.ln_f.bias":
111
- name = "model/ln_f/b"
112
- elif name == "transformer.wte.weight":
113
- name = "model/wte"
114
- elif name == "transformer.wpe.weight":
115
- name = "model/wpe"
116
- elif name == "lm_head.weight":
117
- name = "model/lm_head"
118
- elif name == "transformer.relative_pe.slopes":
119
- name = "model/relative_pe/slopes"
120
- elif re.match(r"transformer.h\.\d+\.ln_1\.weight", name):
121
- i = re.findall("\d+", name)[0]
122
- name = f"model/h{i}/ln_1/g"
123
- elif re.match(r"transformer.h\.\d+\.ln_1\.bias", name):
124
- i = re.findall("\d+", name)[0]
125
- name = f"model/h{i}/ln_1/b"
126
- elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.weight", name):
127
- i = re.findall("\d+", name)[0]
128
- name = f"model/h{i}/attn/c_attn/w"
129
- elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.bias", name):
130
- i = re.findall("\d+", name)[0]
131
- name = f"model/h{i}/attn/c_attn/b"
132
- elif re.match(r"transformer.h\.\d+\.attn\.c_proj\.weight", name):
133
- i = re.findall("\d+", name)[0]
134
- name = f"model/h{i}/attn/c_proj/w"
135
- elif re.match(r"transformer.h.\d+.attn.c_proj.bias", name):
136
- i = re.findall("\d+", name)[0]
137
- name = f"model/h{i}/attn/c_proj/b"
138
- elif re.match(r"transformer.h.\d+.ln_2.weight", name):
139
- i = re.findall("\d+", name)[0]
140
- name = f"model/h{i}/ln_2/g"
141
- elif re.match(r"transformer.h.\d+.ln_2.bias", name):
142
- i = re.findall("\d+", name)[0]
143
- name = f"model/h{i}/ln_2/b"
144
- elif re.match(r"transformer.h.\d+.mlp.c_fc.weight", name):
145
- i = re.findall("\d+", name)[0]
146
- name = f"model/h{i}/mlp/c_fc/w"
147
- elif re.match(r"transformer.h.\d+.mlp.c_fc.bias", name):
148
- i = re.findall("\d+", name)[0]
149
- name = f"model/h{i}/mlp/c_fc/b"
150
- elif re.match(r"transformer.h.\d+.mlp.c_proj.weight", name):
151
- i = re.findall("\d+", name)[0]
152
- name = f"model/h{i}/mlp/c_proj/w"
153
- elif re.match(r"transformer.h.\d+.mlp.c_proj.bias", name):
154
- i = re.findall("\d+", name)[0]
155
- name = f"model/h{i}/mlp/c_proj/b"
156
- # NEW
157
- elif re.match(r"transformer.h.\d+.attn.c_proj.SCB", name):
158
- i = re.findall("\d+", name)[0]
159
- name = f"model/h{i}/attn/c_proj/scb"
160
- elif re.match(r"transformer.h.\d+.attn.c_attn.SCB", name):
161
- i = re.findall("\d+", name)[0]
162
- name = f"model/h{i}/attn/c_attn/scb"
163
- elif re.match(r"transformer.h.\d+.mlp.c_fc.SCB", name):
164
- i = re.findall("\d+", name)[0]
165
- name = f"model/h{i}/mlp/c_fc/scb"
166
- elif re.match(r"transformer.h.\d+.mlp.c_fc2.weight", name):
167
- i = re.findall("\d+", name)[0]
168
- name = f"model/h{i}/mlp/c_fc2/weight"
169
- elif re.match(r"transformer.h.\d+.mlp.c_fc2.bias", name):
170
- i = re.findall("\d+", name)[0]
171
- name = f"model/h{i}/mlp/c_fc2/bias"
172
- elif re.match(r"transformer.h.\d+.mlp.c_fc2.SCB", name):
173
- i = re.findall("\d+", name)[0]
174
- name = f"model/h{i}/mlp/c_fc2/scb"
175
- elif re.match(r"transformer.h.\d+.mlp.c_proj.SCB", name):
176
- i = re.findall("\d+", name)[0]
177
- name = f"model/h{i}/mlp/c_proj/scb"
178
-
179
- else:
180
- print("Unrecognized variable name. %s", name)
181
-
182
-
183
- # we don't need these
184
- if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
185
- print(" Skipping variable: " + name)
186
- continue
187
-
188
- n_dims = len(data.shape);
189
-
190
- # ftype == 0 -> float32, ftype == 1 -> float16
191
- ftype = 0;
192
- if True: #use_b16
193
- if (name == "model/wte" or name == "model/lm_head" or name[-2:] == "/g" or name[-2:] == "/w") and n_dims == 2:
194
- print(" Converting to float16")
195
- data = data.astype(np.float16)
196
- ftype = 1
197
- else:
198
- print(" Converting to float32")
199
- data = data.astype(np.float32)
200
- ftype = 0
201
-
202
- # for efficiency - transpose the projection matrices
203
- # "model/h.*/attn/c_attn/w"
204
- # "model/h.*/attn/c_proj/w"
205
- # "model/h.*/mlp/c_fc/w"
206
- # "model/h.*/mlp/c_proj/w"
207
- if name[-14:] == "/attn/c_attn/w" or \
208
- name[-14:] == "/attn/c_proj/w" or \
209
- name[-11:] == "/mlp/c_fc/w" or \
210
- name[-13:] == "/mlp/c_proj/w":
211
- print(" Transposing")
212
- data = data.transpose()
213
-
214
- # header
215
- str = name.encode('utf-8')
216
- fout.write(struct.pack("iii", n_dims, len(str), ftype))
217
- for i in range(n_dims):
218
- fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
219
- fout.write(str);
220
-
221
- # data
222
- data.tofile(fout)
223
-
224
- fout.close()
225
-
226
- print("Done. Output file: " + fname_out)
227
- print("")
228
-
229
-
230
- # write_binary()
231
-
232
- from huggingface_hub import login, HfApi
233
-
234
- login()
235
-
236
- api = HfApi()
237
-
238
- api.upload_folder(
239
- folder_path="/content/btlm-3b-ggml",
240
- repo_id="bornjre/btlm-3b-ggml",
241
- repo_type="model",
242
- )