Transformers
Safetensors
Japanese
text-generation-inference
unsloth
llama
trl
Inference Endpoints
poprap commited on
Commit
6530c70
1 Parent(s): aba7058

Upload dakesan0-inference-testcode.ipynb

Browse files
Files changed (1) hide show
  1. dakesan0-inference-testcode.ipynb +284 -0
dakesan0-inference-testcode.ipynb ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "2a3eb6d8",
6
+ "metadata": {},
7
+ "source": [
8
+ "# 推論テストコード\n",
9
+ "\n",
10
+ "運営様より提供されているテストコードをベースにした推論用コードです。unslothを使用しますが、conda環境を作らなければ動作しませんので、ご注意ください。\n",
11
+ "12-16 (2024)\n",
12
+ "\n",
13
+ "## 環境構築例\n",
14
+ "\n",
15
+ "```bash\n",
16
+ "# install conda\n",
17
+ "curl -L -O \"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh\"\n",
18
+ "bash Miniforge3-$(uname)-$(uname -m).sh\n",
19
+ "\n",
20
+ "\n",
21
+ "source ~/miniforge3/etc/profile.d/mamba.sh\n",
22
+ "\n",
23
+ "mamba create --name unsloth_env \\\n",
24
+ " python=3.10 \\\n",
25
+ " pytorch-cuda=12.1 \\\n",
26
+ " pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \\\n",
27
+ " -y\n",
28
+ " \n",
29
+ "mamba activate unsloth_env\n",
30
+ "\n",
31
+ "pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
32
+ "\n",
33
+ "pip install --no-deps \"trl<0.9.0\" peft accelerate bitsandbytes\n",
34
+ "\n",
35
+ "pip install ipykernel\n",
36
+ "\n",
37
+ "ipython kernel install --name=unsloth --display-name=unsloth\n",
38
+ "```\n",
39
+ "\n",
40
+ "上記環境構築後、`unsloth`カーネルで本jupyter notebookを動作させてください。"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 7,
46
+ "id": "1ed5ea31",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "from unsloth import FastLanguageModel\n",
51
+ "from peft import PeftModel\n",
52
+ "import torch\n",
53
+ "import json\n",
54
+ "from tqdm import tqdm\n",
55
+ "import re\n",
56
+ "import datasets"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "id": "8e07b721",
62
+ "metadata": {},
63
+ "source": [
64
+ "## モデル読み込み"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 2,
70
+ "id": "50a5cebd",
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "model_id = \"llm-jp/llm-jp-3-13b\"\n",
75
+ "adapter_id = \"poprap/llm-jp-3-13b-it-2-3\"\n",
76
+ "adapter_dpo_id = \"poprap/llm-jp-3-13b-dpo\""
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 3,
82
+ "id": "e800c15b",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "HF_TOKEN = \"\" "
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 4,
92
+ "id": "a1240544",
93
+ "metadata": {},
94
+ "outputs": [
95
+ {
96
+ "name": "stdout",
97
+ "output_type": "stream",
98
+ "text": [
99
+ "Unsloth: WARNING `trust_remote_code` is True.\n",
100
+ "Are you certain you want to do remote code execution?\n",
101
+ "==((====))== Unsloth 2024.12.4: Fast Llama patching. Transformers:4.46.3.\n",
102
+ " \\\\ /| GPU: NVIDIA L4. Max memory: 21.964 GB. Platform: Linux.\n",
103
+ "O^O/ \\_/ \\ Torch: 2.5.1. CUDA: 8.9. CUDA Toolkit: 12.1. Triton: 3.1.0\n",
104
+ "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.28.post3. FA2 = False]\n",
105
+ " \"-____-\" Free Apache license: http://github.com/unslothai/unsloth\n",
106
+ "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
107
+ ]
108
+ },
109
+ {
110
+ "name": "stderr",
111
+ "output_type": "stream",
112
+ "text": [
113
+ "Downloading shards: 100%|██████████| 6/6 [01:38<00:00, 16.49s/it]\n",
114
+ "Loading checkpoint shards: 100%|██████████| 6/6 [00:09<00:00, 1.61s/it]\n"
115
+ ]
116
+ }
117
+ ],
118
+ "source": [
119
+ "# unslothのFastLanguageModelで元のモデルをロード。\n",
120
+ "dtype = None # Noneにしておけば自動で設定\n",
121
+ "load_in_4bit = True # 今回は13Bモデルを扱うためTrue\n",
122
+ "\n",
123
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
124
+ " model_name=model_id,\n",
125
+ " dtype=dtype,\n",
126
+ " load_in_4bit=load_in_4bit,\n",
127
+ " trust_remote_code=True,\n",
128
+ ")"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 5,
134
+ "id": "e0599d87",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "# 元のモデルにLoRAのアダプタを統合。\n",
139
+ "model = PeftModel.from_pretrained(model, adapter_id, token = HF_TOKEN)\n",
140
+ "model = PeftModel.from_pretrained(model, adapter_dpo_id, token = HF_TOKEN)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "markdown",
145
+ "id": "2a2830ce",
146
+ "metadata": {},
147
+ "source": [
148
+ "## タスクjsonlの読み込み"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 9,
154
+ "id": "3547c974",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "ds = []\n",
159
+ "\n",
160
+ "with open(\"elyza-tasks-100-TV_0.jsonl\", \"r\") as f:\n",
161
+ " item = \"\"\n",
162
+ " for line in f:\n",
163
+ " line = line.strip()\n",
164
+ " item += line\n",
165
+ " if item.endswith(\"}\"):\n",
166
+ " ds.append(json.loads(item))\n",
167
+ " item = \"\""
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 10,
173
+ "id": "0c1a580f",
174
+ "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "data": {
178
+ "text/plain": [
179
+ "{'task_id': 0, 'input': '野球選手が今シーズン活躍するために取り組むべき5つのことを教えてください。'}"
180
+ ]
181
+ },
182
+ "execution_count": 10,
183
+ "metadata": {},
184
+ "output_type": "execute_result"
185
+ }
186
+ ],
187
+ "source": [
188
+ "ds[0]"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "markdown",
193
+ "id": "a18d3ccd",
194
+ "metadata": {},
195
+ "source": [
196
+ "## 推論 \n",
197
+ "\n",
198
+ "何度か試したところ推論に要する時間はまちまちです。サーバーのリソースの問題でしょうか。\n",
199
+ "一時間はかかりません。"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 15,
205
+ "id": "db654962",
206
+ "metadata": {},
207
+ "outputs": [
208
+ {
209
+ "name": "stderr",
210
+ "output_type": "stream",
211
+ "text": [
212
+ "100%|██████████| 100/100 [14:17<00:00, 8.58s/it]\n"
213
+ ]
214
+ }
215
+ ],
216
+ "source": [
217
+ "# 推論するためにモデルのモードを変更\n",
218
+ "FastLanguageModel.for_inference(model)\n",
219
+ "\n",
220
+ "results = []\n",
221
+ "for dt in tqdm(ds):\n",
222
+ " input = dt[\"input\"]\n",
223
+ "\n",
224
+ " prompt = f\"\"\"### 指示\\n{input}\\n### 回答\\n\"\"\"\n",
225
+ "\n",
226
+ " inputs = tokenizer([prompt], return_tensors = \"pt\").to(model.device)\n",
227
+ "\n",
228
+ " outputs = model.generate(\n",
229
+ " **inputs,\n",
230
+ " max_new_tokens=1024,\n",
231
+ " use_cache = True, \n",
232
+ " do_sample=False, \n",
233
+ " repetition_penalty=1.2\n",
234
+ " )\n",
235
+ " prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).split('\\n### 回答')[-1]\n",
236
+ " \n",
237
+ " results.append({\"task_id\": dt['task_id'], \"input\": input, \"output\": prediction})"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": 20,
243
+ "id": "9a18a4f8",
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "json_file_id = re.sub(\".*/\", \"\", adapter_id)\n",
248
+ "with open(f\"{json_file_id}_output.jsonl\", 'w', encoding='utf-8') as f:\n",
249
+ " for result in results:\n",
250
+ " json.dump(result, f, ensure_ascii=False)\n",
251
+ " f.write('\\n')"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "id": "a2ebf493",
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": []
261
+ }
262
+ ],
263
+ "metadata": {
264
+ "kernelspec": {
265
+ "display_name": "unsloth",
266
+ "language": "python",
267
+ "name": "unsloth"
268
+ },
269
+ "language_info": {
270
+ "codemirror_mode": {
271
+ "name": "ipython",
272
+ "version": 3
273
+ },
274
+ "file_extension": ".py",
275
+ "mimetype": "text/x-python",
276
+ "name": "python",
277
+ "nbconvert_exporter": "python",
278
+ "pygments_lexer": "ipython3",
279
+ "version": "3.10.16"
280
+ }
281
+ },
282
+ "nbformat": 4,
283
+ "nbformat_minor": 5
284
+ }