omitakahiro commited on
Commit
e72c788
1 Parent(s): b4a2ab7

Upload LoRA.ipynb

Browse files
Files changed (1) hide show
  1. notebooks/LoRA.ipynb +253 -0
notebooks/LoRA.ipynb ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "このnotebookは`stockmark/gpt-neox-japanese-1.4b`のモデルを`kunishou/databricks-dolly-15k-ja`のデータセットを用いてLoRA tuningするためのコードの例です。以下の例では、学習を1 epochを行います。T4 GPUで実行すると30分ほどかかります。\n",
23
+ "\n",
24
+ "- モデル:https://huggingface.co/stockmark/gpt-neox-japanese-1.4b\n",
25
+ "- データ:https://github.com/kunishou/databricks-dolly-15k-ja\n",
26
+ "\n",
27
+ "\n",
28
+ "また、ここで用いている設定は暫定的なもので、必要に応じて調整してください。"
29
+ ],
30
+ "metadata": {
31
+ "id": "BPGgCZtMdMsv"
32
+ }
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "source": [
37
+ "# ライブラリのインストール"
38
+ ],
39
+ "metadata": {
40
+ "id": "hCZH9e6EcZyj"
41
+ }
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {
47
+ "id": "cmn52bx3v5Ha"
48
+ },
49
+ "outputs": [],
50
+ "source": [
51
+ "!python3 -m pip install -U pip\n",
52
+ "!python3 -m pip install transformers accelerate datasets peft"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "source": [
58
+ "# 準備"
59
+ ],
60
+ "metadata": {
61
+ "id": "4t3Cqs9_ce3J"
62
+ }
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "source": [
67
+ "import torch\n",
68
+ "import datasets\n",
69
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments\n",
70
+ "from peft import get_peft_model, LoraConfig, TaskType, PeftModel, PeftConfig\n",
71
+ "\n",
72
+ "model_name = \"stockmark/gpt-neox-japanese-1.4b\"\n",
73
+ "peft_model_name = \"peft_model\"\n",
74
+ "\n",
75
+ "prompt_template = \"\"\"### Instruction:\n",
76
+ "{instruction}\n",
77
+ "\n",
78
+ "### Input:\n",
79
+ "{input}\n",
80
+ "\n",
81
+ "### Response:\n",
82
+ "\"\"\"\n",
83
+ "\n",
84
+ "def encode(sample):\n",
85
+ " prompt = prompt_template.format(instruction=sample[\"instruction\"], input=sample[\"input\"])\n",
86
+ " target = sample[\"output\"] + tokenizer.eos_token\n",
87
+ " input_ids_prompt, input_ids_target = tokenizer([prompt, target]).input_ids\n",
88
+ " input_ids = input_ids_prompt + input_ids_target\n",
89
+ " labels = input_ids.copy()\n",
90
+ " labels[:len(input_ids_prompt)] = [-100] * len(input_ids_prompt)\n",
91
+ " return {\"input_ids\": input_ids, \"labels\": labels}\n",
92
+ "\n",
93
+ "def get_collator(tokenizer, max_length):\n",
94
+ " def collator(batch):\n",
95
+ " batch = [{ key: value[:max_length] for key, value in sample.items() } for sample in batch ]\n",
96
+ " batch = tokenizer.pad(batch, padding=True)\n",
97
+ " batch[\"labels\"] = [ e + [-100] * (len(batch[\"input_ids\"][0]) - len(e)) for e in batch[\"labels\"] ]\n",
98
+ " batch = { key: torch.tensor(value) for key, value in batch.items() }\n",
99
+ " return batch\n",
100
+ "\n",
101
+ " return collator\n"
102
+ ],
103
+ "metadata": {
104
+ "id": "hNdYMGMRzAVn"
105
+ },
106
+ "execution_count": null,
107
+ "outputs": []
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "source": [
112
+ "# データセットとモデルの準備\n"
113
+ ],
114
+ "metadata": {
115
+ "id": "UqXxPjJ_cliu"
116
+ }
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "source": [
121
+ "# prepare dataset\n",
122
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
123
+ "\n",
124
+ "dataset_name = \"kunishou/databricks-dolly-15k-ja\"\n",
125
+ "dataset = datasets.load_dataset(dataset_name)\n",
126
+ "dataset = dataset.map(encode)\n",
127
+ "dataset = dataset[\"train\"].train_test_split(0.2)\n",
128
+ "train_dataset = dataset[\"train\"]\n",
129
+ "val_dataset = dataset[\"test\"]\n",
130
+ "\n",
131
+ "# load model\n",
132
+ "model = AutoModelForCausalLM.from_pretrained(model_name, device_map={\"\": 0}, torch_dtype=torch.float16)\n",
133
+ "\n",
134
+ "peft_config = LoraConfig(\n",
135
+ " task_type=TaskType.CAUSAL_LM,\n",
136
+ " inference_mode=False,\n",
137
+ " target_modules=[\"query_key_value\"],\n",
138
+ " r=16,\n",
139
+ " lora_alpha=32,\n",
140
+ " lora_dropout=0.05\n",
141
+ ")\n",
142
+ "\n",
143
+ "model = get_peft_model(model, peft_config)\n",
144
+ "model.print_trainable_parameters()"
145
+ ],
146
+ "metadata": {
147
+ "id": "ZWdN-p7t0Grk"
148
+ },
149
+ "execution_count": null,
150
+ "outputs": []
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "source": [
155
+ "# LoRA tuning"
156
+ ],
157
+ "metadata": {
158
+ "id": "XCrdVAJYc88c"
159
+ }
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "source": [
164
+ "training_args = TrainingArguments(\n",
165
+ " output_dir=\"./train_results\",\n",
166
+ " learning_rate=2e-4,\n",
167
+ " per_device_train_batch_size=4,\n",
168
+ " gradient_accumulation_steps=4,\n",
169
+ " per_device_eval_batch_size=16,\n",
170
+ " num_train_epochs=1,\n",
171
+ " logging_strategy='steps',\n",
172
+ " logging_steps=10,\n",
173
+ " save_strategy='epoch',\n",
174
+ " evaluation_strategy='epoch',\n",
175
+ " load_best_model_at_end=True,\n",
176
+ " metric_for_best_model=\"eval_loss\",\n",
177
+ " greater_is_better=False,\n",
178
+ " save_total_limit=2\n",
179
+ ")\n",
180
+ "\n",
181
+ "trainer = Trainer(\n",
182
+ " model=model,\n",
183
+ " args=training_args,\n",
184
+ " train_dataset=train_dataset,\n",
185
+ " eval_dataset=val_dataset,\n",
186
+ " data_collator=get_collator(tokenizer, 512)\n",
187
+ ")\n",
188
+ "\n",
189
+ "trainer.train()\n",
190
+ "model = trainer.model\n",
191
+ "model.save_pretrained(peft_model_name)"
192
+ ],
193
+ "metadata": {
194
+ "id": "4LH9tOCTJVk1"
195
+ },
196
+ "execution_count": null,
197
+ "outputs": []
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "source": [
202
+ "# 学習したモデルのロード"
203
+ ],
204
+ "metadata": {
205
+ "id": "ORgzOPAqdEZR"
206
+ }
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "source": [
211
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
212
+ "model = AutoModelForCausalLM.from_pretrained(model_name, device_map={\"\": 0}, torch_dtype=torch.float16)\n",
213
+ "model = PeftModel.from_pretrained(model, peft_model_name)"
214
+ ],
215
+ "metadata": {
216
+ "id": "yrExyO9EOvzR"
217
+ },
218
+ "execution_count": null,
219
+ "outputs": []
220
+ },
221
+ {
222
+ "cell_type": "markdown",
223
+ "source": [
224
+ "# 推論"
225
+ ],
226
+ "metadata": {
227
+ "id": "-dttR6tkdG0k"
228
+ }
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "source": [
233
+ "prompt = prompt_template.format(instruction=\"日本で人気のスポーツは?\", input=\"\")\n",
234
+ "\n",
235
+ "inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
236
+ "with torch.no_grad():\n",
237
+ " tokens = model.generate(\n",
238
+ " **inputs,\n",
239
+ " max_new_tokens=128,\n",
240
+ " repetition_penalty=1.1\n",
241
+ " )\n",
242
+ "\n",
243
+ "output = tokenizer.decode(tokens[0], skip_special_tokens=True)\n",
244
+ "print(output)"
245
+ ],
246
+ "metadata": {
247
+ "id": "pC5t9F1GJuFN"
248
+ },
249
+ "execution_count": null,
250
+ "outputs": []
251
+ }
252
+ ]
253
+ }