p1atdev commited on
Commit
ba679da
1 Parent(s): a349f5c

Upload merge.ipynb

Browse files
Files changed (1) hide show
  1. merge.ipynb +504 -0
merge.ipynb ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "from transformers import AutoModelForCausalLM, AutoTokenizer"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "from safetensors.torch import save_file, load_file"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "CACHE_DIR = \"/huggingface/cache\""
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "NEKOMATA_MODEL = \"rinna/nekomata-14b\"\n",
38
+ "QARASU_MODEL = \"lightblue/qarasu-14B-chat-plus-unleashed\"\n",
39
+ "QWEN_14B_MODEL = \"Qwen/Qwen-14B\""
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "nekomata = AutoModelForCausalLM.from_pretrained(\n",
49
+ " NEKOMATA_MODEL,\n",
50
+ " cache_dir=CACHE_DIR,\n",
51
+ " torch_dtype=torch.bfloat16,\n",
52
+ " device_map=\"cpu\",\n",
53
+ " offload_folder=\"nekomata\",\n",
54
+ " offload_state_dict=True,\n",
55
+ " trust_remote_code=True,\n",
56
+ ")\n",
57
+ "nekomata.eval()\n",
58
+ "nekomata.hf_device_map"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "nekomata_state_dict = nekomata.state_dict().copy()"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "for key in nekomata_state_dict.keys():\n",
77
+ " nekomata_value = nekomata_state_dict[key].clone().to(\"cpu\")\n",
78
+ " print(key, nekomata_value.dtype, nekomata_value.shape, nekomata_value)\n",
79
+ " break"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "save_file(nekomata_state_dict, \"./nekomata_state.safetensors\")"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "**Restart Runtime**\n"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "qarasu = AutoModelForCausalLM.from_pretrained(\n",
105
+ " QARASU_MODEL,\n",
106
+ " cache_dir=CACHE_DIR,\n",
107
+ " torch_dtype=torch.bfloat16,\n",
108
+ " device_map=\"cpu\",\n",
109
+ " offload_folder=\"qarasu\",\n",
110
+ " offload_state_dict=True,\n",
111
+ " trust_remote_code=True,\n",
112
+ ")\n",
113
+ "qarasu.eval()\n",
114
+ "qarasu.hf_device_map"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "markdown",
119
+ "metadata": {},
120
+ "source": [
121
+ "**Restart Runtime**\n"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "qarasu_state_dict = qarasu.state_dict().copy()"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "for key in qarasu_state_dict.keys():\n",
140
+ " qarasu_value = qarasu_state_dict[key].clone().to(\"cpu\")\n",
141
+ " print(key, qarasu_value.dtype, qarasu_value.shape, qarasu_value)\n",
142
+ " break"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "save_file(qarasu_state_dict, \"./qarasu_state.safetensors\")"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {},
157
+ "source": [
158
+ "**Restart Runtime**\n"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "qwen14b = AutoModelForCausalLM.from_pretrained(\n",
168
+ " QWEN_14B_MODEL,\n",
169
+ " cache_dir=CACHE_DIR,\n",
170
+ " torch_dtype=torch.bfloat16,\n",
171
+ " device_map=\"cpu\",\n",
172
+ " offload_folder=\"qwen\",\n",
173
+ " offload_state_dict=True,\n",
174
+ " trust_remote_code=True,\n",
175
+ ")\n",
176
+ "qwen14b.eval()\n",
177
+ "qwen14b.hf_device_map"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "qwen14b_state_dict = qwen14b.state_dict().copy()"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "save_file(qwen14b_state_dict, \"./qwen14b_state.safetensors\")"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "markdown",
200
+ "metadata": {},
201
+ "source": [
202
+ "**Restart Runtime**\n"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "import torch"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "from safetensors.torch import save_file, load_file"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "nekomata_state_dict = load_file(\"./nekomata_state.safetensors\", device=\"cpu\")"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "qarasu_state_dict = load_file(\"./qarasu_state.safetensors\", device=\"cpu\")"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "new_state_dict = nekomata_state_dict\n",
248
+ "\n",
249
+ "with torch.no_grad():\n",
250
+ " for key in nekomata_state_dict.keys():\n",
251
+ " print(key)\n",
252
+ "\n",
253
+ " new_state_dict[key] = (\n",
254
+ " new_state_dict[key].to(\"cuda\") + qarasu_state_dict[key].to(\"cuda\")\n",
255
+ " ).to(\"cpu\")\n",
256
+ "\n",
257
+ "new_state_dict"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "del nekomata_state_dict, qarasu_state_dict"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "torch.cuda.empty_cache()"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "save_file(new_state_dict, \"./nekomata+qarasu_state.safetensors\")"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "markdown",
289
+ "metadata": {},
290
+ "source": [
291
+ "**Restart Runtime**\n"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": null,
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "import torch\n",
301
+ "from safetensors.torch import load_file, save_file"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": null,
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "nekomata_qarasu_state_dict = load_file(\n",
311
+ " \"./nekomata+qarasu_state.safetensors\", device=\"cpu\"\n",
312
+ ")"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": null,
318
+ "metadata": {},
319
+ "outputs": [],
320
+ "source": [
321
+ "qwen14b_state_dict = load_file(\"./qwen14b_state.safetensors\", device=\"cpu\")"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": null,
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "# キー名が同じことを確認\n",
331
+ "for neko_key, qwen14b_key in zip(\n",
332
+ " nekomata_qarasu_state_dict.keys(), qwen14b_state_dict.keys()\n",
333
+ "):\n",
334
+ " assert neko_key == qwen14b_key"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "code",
339
+ "execution_count": null,
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "new_state_dict = nekomata_qarasu_state_dict\n",
344
+ "\n",
345
+ "with torch.no_grad():\n",
346
+ " for key in new_state_dict.keys():\n",
347
+ " print(key)\n",
348
+ "\n",
349
+ " new_state_dict[key] = (\n",
350
+ " new_state_dict[key].to(\"cuda\") - qwen14b_state_dict[key].to(\"cuda\")\n",
351
+ " ).to(\"cpu\")\n",
352
+ "\n",
353
+ "new_state_dict"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "save_file(new_state_dict, \"./nekoqarasu_state.safetensors\")"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "markdown",
367
+ "metadata": {},
368
+ "source": [
369
+ "**Restart Runtime**\n"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": null,
375
+ "metadata": {},
376
+ "outputs": [],
377
+ "source": [
378
+ "import torch\n",
379
+ "from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": null,
385
+ "metadata": {},
386
+ "outputs": [],
387
+ "source": [
388
+ "CACHE_DIR = \"/huggingface/cache\""
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": null,
394
+ "metadata": {},
395
+ "outputs": [],
396
+ "source": [
397
+ "QWEN_14B_CHAT_MODEL = \"Qwen/Qwen-14B-Chat\""
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "qwen_chat_config = AutoConfig.from_pretrained(\n",
407
+ " QWEN_14B_CHAT_MODEL, trust_remote_code=True, cache_dir=CACHE_DIR\n",
408
+ ")"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": null,
414
+ "metadata": {},
415
+ "outputs": [],
416
+ "source": [
417
+ "nekoqarasu = AutoModelForCausalLM.from_config(\n",
418
+ " qwen_chat_config,\n",
419
+ " trust_remote_code=True,\n",
420
+ ")"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": [
429
+ "from safetensors.torch import load_file"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "metadata": {},
436
+ "outputs": [],
437
+ "source": [
438
+ "state_dict = load_file(\"./nekoqarasu_state.safetensors\", device=\"cpu\")"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "nekoqarasu.load_state_dict(\n",
448
+ " state_dict,\n",
449
+ " strict=False,\n",
450
+ ")"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": null,
456
+ "metadata": {},
457
+ "outputs": [],
458
+ "source": [
459
+ "nekoqarasu.push_to_hub(\"nekoqarasu-14b-chat\", private=True)"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": null,
465
+ "metadata": {},
466
+ "outputs": [],
467
+ "source": [
468
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
469
+ " QWEN_14B_CHAT_MODEL, cache_dir=CACHE_DIR, trust_remote_code=True\n",
470
+ ")"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": [
479
+ "tokenizer.push_to_hub(\"nekoqarasu-14b-chat\", private=True)"
480
+ ]
481
+ }
482
+ ],
483
+ "metadata": {
484
+ "kernelspec": {
485
+ "display_name": "py310",
486
+ "language": "python",
487
+ "name": "python3"
488
+ },
489
+ "language_info": {
490
+ "codemirror_mode": {
491
+ "name": "ipython",
492
+ "version": 3
493
+ },
494
+ "file_extension": ".py",
495
+ "mimetype": "text/x-python",
496
+ "name": "python",
497
+ "nbconvert_exporter": "python",
498
+ "pygments_lexer": "ipython3",
499
+ "version": "3.10.10"
500
+ }
501
+ },
502
+ "nbformat": 4,
503
+ "nbformat_minor": 2
504
+ }