clecho52 commited on
Commit
7d59799
1 Parent(s): a1d0bac

Upload RWKV_v4_RNN_Pile_Fine_Tuning.ipynb

Browse files
Files changed (1) hide show
  1. RWKV_v4_RNN_Pile_Fine_Tuning.ipynb +299 -0
RWKV_v4_RNN_Pile_Fine_Tuning.ipynb ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "Vx7KFfeieD7z"
7
+ },
8
+ "source": [
9
+ "# RWKV-v4-RNN-Pile Fine-Tuning\n",
10
+ "\n",
11
+ "[RWKV](https://github.com/BlinkDL/RWKV-LM) is an RNN with transformer-level performance\n",
12
+ "\n",
13
+ "\n",
14
+ "This notebook aims to streamline fine-tuning RWKV-v4 models"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {
20
+ "id": "7JFIiAsrfvJy"
21
+ },
22
+ "source": [
23
+ "\n",
24
+ "## Setup"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {
31
+ "id": "g_qFjgYmtSfK"
32
+ },
33
+ "outputs": [],
34
+ "source": [
35
+ "#@title Google Drive Options { display-mode: \"form\" }\n",
36
+ "save_models_to_drive = True #@param {type:\"boolean\"}\n",
37
+ "drive_mount = '/content/drive' #@param {type:\"string\"}\n",
38
+ "output_dir = 'rwkv-v4-rnn-pile-tuning' #@param {type:\"string\"}\n",
39
+ "tuned_model_name = 'tuned' #@param {type:\"string\"}\n",
40
+ "\n",
41
+ "import os\n",
42
+ "from google.colab import drive\n",
43
+ "if save_models_to_drive:\n",
44
+ " from google.colab import drive\n",
45
+ " drive.mount(drive_mount)\n",
46
+ " \n",
47
+ "output_path = f\"{drive_mount}/MyDrive/{output_dir}\" if save_models_to_drive else f\"/content/{output_dir}\"\n",
48
+ "os.makedirs(f\"{output_path}/{tuned_model_name}\", exist_ok=True)\n",
49
+ "os.makedirs(f\"{output_path}/base_models/\", exist_ok=True)\n",
50
+ "\n",
51
+ "print(f\"Saving models to {output_path}\")"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {
58
+ "id": "eivKJ6FP1_9z",
59
+ "outputId": "a687e3ad-8158-492a-da86-4f4ed8804699",
60
+ "colab": {
61
+ "base_uri": "https://localhost:8080/"
62
+ }
63
+ },
64
+ "outputs": [
65
+ {
66
+ "output_type": "stream",
67
+ "name": "stdout",
68
+ "text": [
69
+ "Fri Sep 2 16:11:37 2022 \n",
70
+ "+-----------------------------------------------------------------------------+\n",
71
+ "| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
72
+ "|-------------------------------+----------------------+----------------------+\n",
73
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
74
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
75
+ "| | | MIG M. |\n",
76
+ "|===============================+======================+======================|\n",
77
+ "| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
78
+ "| N/A 35C P0 28W / 250W | 0MiB / 16280MiB | 0% Default |\n",
79
+ "| | | N/A |\n",
80
+ "+-------------------------------+----------------------+----------------------+\n",
81
+ " \n",
82
+ "+-----------------------------------------------------------------------------+\n",
83
+ "| Processes: |\n",
84
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
85
+ "| ID ID Usage |\n",
86
+ "|=============================================================================|\n",
87
+ "| No running processes found |\n",
88
+ "+-----------------------------------------------------------------------------+\n"
89
+ ]
90
+ }
91
+ ],
92
+ "source": [
93
+ "!nvidia-smi"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "metadata": {
100
+ "id": "R4lt0FTegJw9"
101
+ },
102
+ "outputs": [],
103
+ "source": [
104
+ "!git clone https://github.com/blinkdl/RWKV-LM\n",
105
+ "repo_dir = \"/content/RWKV-LM/RWKV-v4\"\n",
106
+ "%cd $repo_dir"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {
113
+ "id": "RDavUrBsgKIV"
114
+ },
115
+ "outputs": [],
116
+ "source": [
117
+ "!pip install transformers pytorch-lightning==1.9 deepspeed wandb ninja"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "metadata": {
123
+ "id": "Wt7y7vR6e6U3"
124
+ },
125
+ "source": [
126
+ "## Load Base Model\n",
127
+ "\n",
128
+ "\n"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {
135
+ "id": "KIgagN-Se3wi"
136
+ },
137
+ "outputs": [],
138
+ "source": [
139
+ "#@title Base Model Options\n",
140
+ "#@markdown Using any of the listed options will download the checkpoint from huggingface\n",
141
+ "\n",
142
+ "base_model_name = \"RWKV-4-Pile-169M\" #@param [\"RWKV-4-Pile-1B5\", \"RWKV-4-Pile-430M\", \"RWKV-4-Pile-169M\"]\n",
143
+ "base_model_url = f\"https://huggingface.co/BlinkDL/{base_model_name.lower()}\"\n",
144
+ "\n",
145
+ "# This may take a while\n",
146
+ "!git lfs clone $base_model_url\n",
147
+ "\n",
148
+ "from glob import glob\n",
149
+ "base_model_path = glob(f\"{base_model_name.lower()}/{base_model_name}*.pth\")[0]\n",
150
+ "\n",
151
+ "print(f\"Using {base_model_path} as base\")"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {
157
+ "id": "hCOPnLelfJgP"
158
+ },
159
+ "source": [
160
+ "## Generate Training Data"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {
167
+ "id": "wW5OmlXmvaIU",
168
+ "cellView": "form"
169
+ },
170
+ "outputs": [],
171
+ "source": [
172
+ "#@title Training Data Options\n",
173
+ "#@markdown `input_file` should be the path to a single file that contains the text you want to fine-tune with.\n",
174
+ "#@markdown Either upload a file to this notebook instance or reference a file in your Google drive.\n",
175
+ "\n",
176
+ "import numpy as np\n",
177
+ "from transformers import PreTrainedTokenizerFast\n",
178
+ "\n",
179
+ "tokenizer = PreTrainedTokenizerFast(tokenizer_file=f'{repo_dir}/20B_tokenizer.json')\n",
180
+ "\n",
181
+ "input_file = \"/content/drive/MyDrive/training.txt\" #@param {type:\"string\"}\n",
182
+ "output_file = 'train.npy'\n",
183
+ "\n",
184
+ "print(f'Tokenizing {input_file} (VERY slow. please wait)')\n",
185
+ "\n",
186
+ "data_raw = open(input_file, encoding=\"utf-8\").read()\n",
187
+ "print(f'Raw length = {len(data_raw)}')\n",
188
+ "\n",
189
+ "data_code = tokenizer.encode(data_raw)\n",
190
+ "print(f'Tokenized length = {len(data_code)}')\n",
191
+ "\n",
192
+ "out = np.array(data_code, dtype='uint16')\n",
193
+ "np.save(output_file, out, allow_pickle=False)"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "markdown",
198
+ "metadata": {
199
+ "id": "I4lz-3maeIwY"
200
+ },
201
+ "source": [
202
+ "## Training"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "metadata": {
209
+ "id": "fuCw5_ASwMud"
210
+ },
211
+ "outputs": [],
212
+ "source": [
213
+ "#@title Training Options { display-mode: \"form\" }\n",
214
+ "from shutil import copy\n",
215
+ "import os\n",
216
+ "\n",
217
+ "def training_options():\n",
218
+ " EXPRESS_PILE_MODE = True\n",
219
+ " EXPRESS_PILE_MODEL_NAME = base_model_path.split(\".\")[0]\n",
220
+ " EXPRESS_PILE_MODEL_TYPE = base_model_name\n",
221
+ " n_epoch = 100 #@param {type:\"integer\"}\n",
222
+ " epoch_save_frequency = 25 #@param {type:\"integer\"}\n",
223
+ " batch_size = 11#@param {type:\"integer\"} \n",
224
+ " ctx_len = 384 #@param {type:\"integer\"}\n",
225
+ " epoch_save_path = f\"{output_path}/{tuned_model_name}\"\n",
226
+ " return locals()\n",
227
+ "\n",
228
+ "def model_options():\n",
229
+ " T_MAX = 384 #@param {type:\"integer\"}\n",
230
+ " return locals()\n",
231
+ "\n",
232
+ "def env_vars():\n",
233
+ " RWKV_FLOAT_MODE = 'fp16' #@param ['fp16', 'bf16', 'bf32'] {type:\"string\"}\n",
234
+ " RWKV_DEEPSPEED = '0' #@param ['0', '1'] {type:\"string\"}\n",
235
+ " return {f\"os.environ['{key}']\": value for key, value in locals().items()}\n",
236
+ "\n",
237
+ "def replace_lines(file_name, to_replace):\n",
238
+ " with open(file_name, 'r') as f:\n",
239
+ " lines = f.readlines()\n",
240
+ " with open(f'{file_name}.tmp', 'w') as f:\n",
241
+ " for line in lines:\n",
242
+ " key = line.split(\" =\")[0]\n",
243
+ " if key.strip() in to_replace:\n",
244
+ " value = to_replace[key.strip()]\n",
245
+ " if isinstance(value, str):\n",
246
+ " f.write(f'{key} = \"{value}\"\\n')\n",
247
+ " else:\n",
248
+ " f.write(f'{key} = {value}\\n')\n",
249
+ " else:\n",
250
+ " f.write(line)\n",
251
+ " copy(f'{file_name}.tmp', file_name)\n",
252
+ " os.remove(f'{file_name}.tmp')\n",
253
+ "\n",
254
+ "values = training_options()\n",
255
+ "values.update(env_vars())\n",
256
+ "replace_lines('train.py', values)\n",
257
+ "replace_lines('src/model.py', model_options())"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "source": [
263
+ "!python train.py "
264
+ ],
265
+ "metadata": {
266
+ "id": "0ZSF8U-nzylI"
267
+ },
268
+ "execution_count": null,
269
+ "outputs": []
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "source": [],
274
+ "metadata": {
275
+ "id": "pcDci4O7xJiZ"
276
+ },
277
+ "execution_count": null,
278
+ "outputs": []
279
+ }
280
+ ],
281
+ "metadata": {
282
+ "accelerator": "GPU",
283
+ "colab": {
284
+ "name": "RWKV-v4-RNN-Pile Fine-Tuning",
285
+ "provenance": [],
286
+ "toc_visible": true
287
+ },
288
+ "gpuClass": "standard",
289
+ "kernelspec": {
290
+ "display_name": "Python 3",
291
+ "name": "python3"
292
+ },
293
+ "language_info": {
294
+ "name": "python"
295
+ }
296
+ },
297
+ "nbformat": 4,
298
+ "nbformat_minor": 0
299
+ }