Text Generation
Transformers
Safetensors
Finnish
llama
finnish
conversational
text-generation-inference
RASMUS commited on
Commit
607981f
1 Parent(s): 63ea4c7

Add finetuning example notebooks

Browse files
Files changed (2) hide show
  1. Finetune_Ahma_3B_example.ipynb +1039 -0
  2. setup_steps.txt +10 -0
Finetune_Ahma_3B_example.ipynb ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "9cf22489-4421-49b1-b5f8-f61093ce2fe6",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "text/plain": [
12
+ "'2.1.0+cu121'"
13
+ ]
14
+ },
15
+ "execution_count": 1,
16
+ "metadata": {},
17
+ "output_type": "execute_result"
18
+ }
19
+ ],
20
+ "source": [
21
+ "import torch\n",
22
+ "torch.__version__"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "id": "c70f64f5-1d44-4139-adf8-f1b0b149831f",
29
+ "metadata": {},
30
+ "outputs": [
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "nvcc: NVIDIA (R) Cuda compiler driver\n",
36
+ "Copyright (c) 2005-2023 NVIDIA Corporation\n",
37
+ "Built on Mon_Apr__3_17:16:06_PDT_2023\n",
38
+ "Cuda compilation tools, release 12.1, V12.1.105\n",
39
+ "Build cuda_12.1.r12.1/compiler.32688072_0\n"
40
+ ]
41
+ }
42
+ ],
43
+ "source": [
44
+ "!nvcc --version"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "f3c2cb90-7d57-4bcb-84f3-ce30f8bbd200",
51
+ "metadata": {},
52
+ "outputs": [
53
+ {
54
+ "data": {
55
+ "text/plain": [
56
+ "'0.0.22.post7'"
57
+ ]
58
+ },
59
+ "execution_count": 3,
60
+ "metadata": {},
61
+ "output_type": "execute_result"
62
+ }
63
+ ],
64
+ "source": [
65
+ "import xformers\n",
66
+ "xformers.__version__"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 4,
72
+ "id": "6a217b07-a6f5-43dc-a32f-38b3f055707b",
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
80
+ "++++++++++++++++++ BUG REPORT INFORMATION ++++++++++++++++++\n",
81
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
82
+ "++++++++++++++++++++++++++ OTHER +++++++++++++++++++++++++++\n",
83
+ "CUDA specs: CUDASpecs(highest_compute_capability=(8, 9), cuda_version_string='121', cuda_version_tuple=(12, 1))\n",
84
+ "PyTorch settings found: CUDA_VERSION=121, Highest Compute Capability: (8, 9).\n",
85
+ "To manually override the PyTorch CUDA version please see: https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx\n",
86
+ "The directory listed in your path is found to be non-existent: /usr/local/nvidia/lib\n",
87
+ "The directory listed in your path is found to be non-existent: /usr/local/nvidia/lib64\n",
88
+ "The directory listed in your path is found to be non-existent: /workspace/Untitled.ipynb\n",
89
+ "The directory listed in your path is found to be non-existent: //matplotlib_inline.backend_inline\n",
90
+ "CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.\n",
91
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
92
+ "++++++++++++++++++++++ DEBUG INFO END ++++++++++++++++++++++\n",
93
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
94
+ "Checking that the library is importable and CUDA is callable...\n",
95
+ "SUCCESS!\n",
96
+ "Installation was successful!\n"
97
+ ]
98
+ }
99
+ ],
100
+ "source": [
101
+ "!python -m bitsandbytes"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 5,
107
+ "id": "4ea1d820-a1e1-401d-b2ee-76bda33e7de9",
108
+ "metadata": {},
109
+ "outputs": [
110
+ {
111
+ "name": "stdout",
112
+ "output_type": "stream",
113
+ "text": [
114
+ "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n"
115
+ ]
116
+ },
117
+ {
118
+ "data": {
119
+ "application/vnd.jupyter.widget-view+json": {
120
+ "model_id": "a14499d3ba274ea0be10db40f11b830e",
121
+ "version_major": 2,
122
+ "version_minor": 0
123
+ },
124
+ "text/plain": [
125
+ "config.json: 0%| | 0.00/662 [00:00<?, ?B/s]"
126
+ ]
127
+ },
128
+ "metadata": {},
129
+ "output_type": "display_data"
130
+ },
131
+ {
132
+ "name": "stdout",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "==((====))== Unsloth: Fast Llama patching release 2024.6\n",
136
+ " \\\\ /| GPU: NVIDIA GeForce RTX 4090. Max memory: 23.65 GB. Platform = Linux.\n",
137
+ "O^O/ \\_/ \\ Pytorch: 2.1.0+cu121. CUDA = 8.9. CUDA Toolkit = 12.1.\n",
138
+ "\\ / Bfloat16 = TRUE. Xformers = 0.0.22.post7. FA = True.\n",
139
+ " \"-____-\" Free Apache license: http://github.com/unslothai/unsloth\n"
140
+ ]
141
+ },
142
+ {
143
+ "data": {
144
+ "application/vnd.jupyter.widget-view+json": {
145
+ "model_id": "4f0c5cd22e124bd7a40b8bcfe2146fa9",
146
+ "version_major": 2,
147
+ "version_minor": 0
148
+ },
149
+ "text/plain": [
150
+ "model.safetensors.index.json: 0%| | 0.00/19.5k [00:00<?, ?B/s]"
151
+ ]
152
+ },
153
+ "metadata": {},
154
+ "output_type": "display_data"
155
+ },
156
+ {
157
+ "data": {
158
+ "application/vnd.jupyter.widget-view+json": {
159
+ "model_id": "e4fc84e4a41343219fc6f257491d1e93",
160
+ "version_major": 2,
161
+ "version_minor": 0
162
+ },
163
+ "text/plain": [
164
+ "Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
165
+ ]
166
+ },
167
+ "metadata": {},
168
+ "output_type": "display_data"
169
+ },
170
+ {
171
+ "data": {
172
+ "application/vnd.jupyter.widget-view+json": {
173
+ "model_id": "1f919f19324c44f7917a4120588e8a23",
174
+ "version_major": 2,
175
+ "version_minor": 0
176
+ },
177
+ "text/plain": [
178
+ "model-00001-of-00002.safetensors: 0%| | 0.00/4.95G [00:00<?, ?B/s]"
179
+ ]
180
+ },
181
+ "metadata": {},
182
+ "output_type": "display_data"
183
+ },
184
+ {
185
+ "data": {
186
+ "application/vnd.jupyter.widget-view+json": {
187
+ "model_id": "b156005b83854886af6b30eac5fd1c25",
188
+ "version_major": 2,
189
+ "version_minor": 0
190
+ },
191
+ "text/plain": [
192
+ "model-00002-of-00002.safetensors: 0%| | 0.00/2.31G [00:00<?, ?B/s]"
193
+ ]
194
+ },
195
+ "metadata": {},
196
+ "output_type": "display_data"
197
+ },
198
+ {
199
+ "data": {
200
+ "application/vnd.jupyter.widget-view+json": {
201
+ "model_id": "9baf738f18d042339c89bef5a1ec5c02",
202
+ "version_major": 2,
203
+ "version_minor": 0
204
+ },
205
+ "text/plain": [
206
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
207
+ ]
208
+ },
209
+ "metadata": {},
210
+ "output_type": "display_data"
211
+ },
212
+ {
213
+ "data": {
214
+ "application/vnd.jupyter.widget-view+json": {
215
+ "model_id": "0ca53bdb15864ad7826b89b38bd13613",
216
+ "version_major": 2,
217
+ "version_minor": 0
218
+ },
219
+ "text/plain": [
220
+ "generation_config.json: 0%| | 0.00/116 [00:00<?, ?B/s]"
221
+ ]
222
+ },
223
+ "metadata": {},
224
+ "output_type": "display_data"
225
+ },
226
+ {
227
+ "data": {
228
+ "application/vnd.jupyter.widget-view+json": {
229
+ "model_id": "3004dbbac0db49bcb29ca96a6fb54c9b",
230
+ "version_major": 2,
231
+ "version_minor": 0
232
+ },
233
+ "text/plain": [
234
+ "tokenizer_config.json: 0%| | 0.00/2.90k [00:00<?, ?B/s]"
235
+ ]
236
+ },
237
+ "metadata": {},
238
+ "output_type": "display_data"
239
+ },
240
+ {
241
+ "data": {
242
+ "application/vnd.jupyter.widget-view+json": {
243
+ "model_id": "7c2dc2d55b7c453cb2e8619947b51a68",
244
+ "version_major": 2,
245
+ "version_minor": 0
246
+ },
247
+ "text/plain": [
248
+ "tokenizer.json: 0%| | 0.00/4.84M [00:00<?, ?B/s]"
249
+ ]
250
+ },
251
+ "metadata": {},
252
+ "output_type": "display_data"
253
+ },
254
+ {
255
+ "data": {
256
+ "application/vnd.jupyter.widget-view+json": {
257
+ "model_id": "91955239a02141ff986c5fb34158ab9f",
258
+ "version_major": 2,
259
+ "version_minor": 0
260
+ },
261
+ "text/plain": [
262
+ "special_tokens_map.json: 0%| | 0.00/414 [00:00<?, ?B/s]"
263
+ ]
264
+ },
265
+ "metadata": {},
266
+ "output_type": "display_data"
267
+ },
268
+ {
269
+ "name": "stderr",
270
+ "output_type": "stream",
271
+ "text": [
272
+ "Finnish-NLP/Ahma-3B does not have a padding token! Will use pad_token = <unk>.\n"
273
+ ]
274
+ }
275
+ ],
276
+ "source": [
277
+ "from unsloth import FastLanguageModel\n",
278
+ "from transformers import AutoTokenizer\n",
279
+ "\n",
280
+ "max_seq_length = 2048\n",
281
+ "load_in_4bit = True\n",
282
+ "dtype = None\n",
283
+ "\n",
284
+ "revision = \"63ea4c7c4f4cae078655294c96973d0db75cd656\"\n",
285
+ "pretrained_model_hf = \"Finnish-NLP/Ahma-3B\"\n",
286
+ "\n",
287
+ "model, _ = FastLanguageModel.from_pretrained(\n",
288
+ " model_name = pretrained_model_hf,\n",
289
+ " max_seq_length = max_seq_length,\n",
290
+ " dtype = torch.bfloat16,\n",
291
+ " load_in_4bit = load_in_4bit,\n",
292
+ " revision = revision\n",
293
+ ")"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": 6,
299
+ "id": "00baa980-c4ad-45c5-86aa-93664653b6fd",
300
+ "metadata": {},
301
+ "outputs": [],
302
+ "source": [
303
+ "tokenizer = AutoTokenizer.from_pretrained(pretrained_model_hf, revision=revision)"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 8,
309
+ "id": "0e632d54-b5cd-4cd1-8d3e-742e475a664a",
310
+ "metadata": {},
311
+ "outputs": [],
312
+ "source": [
313
+ "tokenizer.clean_up_tokenization_spaces = True\n",
314
+ "tokenizer.add_tokens([\"<PAD>\"])\n",
315
+ "tokenizer.pad_token = \"<PAD>\"\n",
316
+ "tokenizer.add_eos_token = False\n",
317
+ "\n",
318
+ "model.resize_token_embeddings(new_num_tokens=len(tokenizer))\n",
319
+ "model.config.eos_token_id = tokenizer.eos_token_id"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": 10,
325
+ "id": "5f9e02fa-b89a-453f-adf8-383d7bd11b3c",
326
+ "metadata": {},
327
+ "outputs": [
328
+ {
329
+ "name": "stdout",
330
+ "output_type": "stream",
331
+ "text": [
332
+ "Unsloth: Offloading input_embeddings to disk to save VRAM\n",
333
+ "Unsloth: Offloading output_embeddings to disk to save VRAM\n"
334
+ ]
335
+ },
336
+ {
337
+ "name": "stderr",
338
+ "output_type": "stream",
339
+ "text": [
340
+ "Unsloth 2024.6 patched 26 layers with 26 QKV layers, 26 O layers and 26 MLP layers.\n"
341
+ ]
342
+ },
343
+ {
344
+ "name": "stdout",
345
+ "output_type": "stream",
346
+ "text": [
347
+ "Unsloth: Casting embed_tokens to float32\n",
348
+ "Unsloth: Casting lm_head to float32\n"
349
+ ]
350
+ }
351
+ ],
352
+ "source": [
353
+ "model = FastLanguageModel.get_peft_model(\n",
354
+ " model,\n",
355
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
356
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
357
+ " modules_to_save = [\"lm_head\", \"embed_tokens\"],\n",
358
+ " lora_alpha = 16,\n",
359
+ " r = 32,\n",
360
+ " lora_dropout = 0,\n",
361
+ " bias = \"none\",\n",
362
+ " use_gradient_checkpointing=\"unsloth\",\n",
363
+ " random_state=3407,\n",
364
+ " use_rslora = True,\n",
365
+ " loftq_config = None,\n",
366
+ " max_seq_length = max_seq_length\n",
367
+ ")\n",
368
+ " "
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": 11,
374
+ "id": "93e95410-8f6b-4d4c-951b-5602397f3ce0",
375
+ "metadata": {},
376
+ "outputs": [
377
+ {
378
+ "data": {
379
+ "application/vnd.jupyter.widget-view+json": {
380
+ "model_id": "7a9d00c6b33f423bb9c25b0261676f89",
381
+ "version_major": 2,
382
+ "version_minor": 0
383
+ },
384
+ "text/plain": [
385
+ "Downloading builder script: 0%| | 0.00/25.0k [00:00<?, ?B/s]"
386
+ ]
387
+ },
388
+ "metadata": {},
389
+ "output_type": "display_data"
390
+ },
391
+ {
392
+ "data": {
393
+ "application/vnd.jupyter.widget-view+json": {
394
+ "model_id": "6501843986c740f1b4340f75f6cd8d1e",
395
+ "version_major": 2,
396
+ "version_minor": 0
397
+ },
398
+ "text/plain": [
399
+ "Downloading readme: 0%| | 0.00/1.32k [00:00<?, ?B/s]"
400
+ ]
401
+ },
402
+ "metadata": {},
403
+ "output_type": "display_data"
404
+ },
405
+ {
406
+ "name": "stdin",
407
+ "output_type": "stream",
408
+ "text": [
409
+ "The repository for intfloat/multilingual_cc_news contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/intfloat/multilingual_cc_news.\n",
410
+ "You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n",
411
+ "\n",
412
+ "Do you wish to run the custom code? [y/N] y\n"
413
+ ]
414
+ },
415
+ {
416
+ "data": {
417
+ "application/vnd.jupyter.widget-view+json": {
418
+ "model_id": "cf251dae77124091ac23d7dac518edfc",
419
+ "version_major": 2,
420
+ "version_minor": 0
421
+ },
422
+ "text/plain": [
423
+ "Downloading data: 0%| | 0.00/36.8M [00:00<?, ?B/s]"
424
+ ]
425
+ },
426
+ "metadata": {},
427
+ "output_type": "display_data"
428
+ },
429
+ {
430
+ "data": {
431
+ "application/vnd.jupyter.widget-view+json": {
432
+ "model_id": "5325f85116f14f83a69100a036d1a81b",
433
+ "version_major": 2,
434
+ "version_minor": 0
435
+ },
436
+ "text/plain": [
437
+ "Downloading data: 0%| | 0.00/99.3M [00:00<?, ?B/s]"
438
+ ]
439
+ },
440
+ "metadata": {},
441
+ "output_type": "display_data"
442
+ },
443
+ {
444
+ "data": {
445
+ "application/vnd.jupyter.widget-view+json": {
446
+ "model_id": "8abd4fa848d146ff84e9ff5491c9f128",
447
+ "version_major": 2,
448
+ "version_minor": 0
449
+ },
450
+ "text/plain": [
451
+ "Downloading data: 0%| | 0.00/236M [00:00<?, ?B/s]"
452
+ ]
453
+ },
454
+ "metadata": {},
455
+ "output_type": "display_data"
456
+ },
457
+ {
458
+ "data": {
459
+ "application/vnd.jupyter.widget-view+json": {
460
+ "model_id": "5528f07d7d8c49f5b835f5db63be75ba",
461
+ "version_major": 2,
462
+ "version_minor": 0
463
+ },
464
+ "text/plain": [
465
+ "Downloading data: 0%| | 0.00/296M [00:00<?, ?B/s]"
466
+ ]
467
+ },
468
+ "metadata": {},
469
+ "output_type": "display_data"
470
+ },
471
+ {
472
+ "data": {
473
+ "application/vnd.jupyter.widget-view+json": {
474
+ "model_id": "371dee77de1749728312dea44c4b4f3d",
475
+ "version_major": 2,
476
+ "version_minor": 0
477
+ },
478
+ "text/plain": [
479
+ "Downloading data: 0%| | 0.00/400M [00:00<?, ?B/s]"
480
+ ]
481
+ },
482
+ "metadata": {},
483
+ "output_type": "display_data"
484
+ },
485
+ {
486
+ "data": {
487
+ "application/vnd.jupyter.widget-view+json": {
488
+ "model_id": "0b400224b4504e8c8580c394ba7848c0",
489
+ "version_major": 2,
490
+ "version_minor": 0
491
+ },
492
+ "text/plain": [
493
+ "Downloading data: 0%| | 0.00/367M [00:00<?, ?B/s]"
494
+ ]
495
+ },
496
+ "metadata": {},
497
+ "output_type": "display_data"
498
+ },
499
+ {
500
+ "data": {
501
+ "application/vnd.jupyter.widget-view+json": {
502
+ "model_id": "ae2383cc6ca948c89ce20348bbf20c52",
503
+ "version_major": 2,
504
+ "version_minor": 0
505
+ },
506
+ "text/plain": [
507
+ "Generating train split: 0 examples [00:00, ? examples/s]"
508
+ ]
509
+ },
510
+ "metadata": {},
511
+ "output_type": "display_data"
512
+ },
513
+ {
514
+ "name": "stdout",
515
+ "output_type": "stream",
516
+ "text": [
517
+ "Dataset contains: 1536679 articles\n"
518
+ ]
519
+ }
520
+ ],
521
+ "source": [
522
+ "from datasets import load_dataset\n",
523
+ "\n",
524
+ "dataset = load_dataset(\"intfloat/multilingual_cc_news\", languages=[\"fi\"])[\"train\"]\n",
525
+ "print(f'Dataset contains: {len(dataset)} articles')"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": 12,
531
+ "id": "0802a3bb-1e25-4f57-9de7-60e7cecb1531",
532
+ "metadata": {},
533
+ "outputs": [
534
+ {
535
+ "data": {
536
+ "text/plain": [
537
+ "Dataset({\n",
538
+ " features: ['title', 'maintext', 'url', 'date_publish'],\n",
539
+ " num_rows: 1536679\n",
540
+ "})"
541
+ ]
542
+ },
543
+ "execution_count": 12,
544
+ "metadata": {},
545
+ "output_type": "execute_result"
546
+ }
547
+ ],
548
+ "source": [
549
+ "dataset"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "execution_count": 13,
555
+ "id": "9ceea32b-fa31-4f3e-b03a-a0efdc12e3e3",
556
+ "metadata": {},
557
+ "outputs": [
558
+ {
559
+ "name": "stdout",
560
+ "output_type": "stream",
561
+ "text": [
562
+ "Articles title:\n",
563
+ "Postin tulisi kiinnostua myös tuoreista lehdistä\n",
564
+ "--------------------------------------------------\n",
565
+ "Articles title:\n",
566
+ "Posti kertoi eilen tiedotteessaan, että sen työntekijät siirtyvät tällä viikolla nurmikonleikkuusta lehtien haravoimiseen. Postin verkkokaupasta ja Postin omista myymälöistä voi tilata pihan haravoinnin haluamilleen kahdelle tiistaipäivälle. Syyslehdille voi myös tilata Postilta poisviennin.\n",
567
+ "Samaan aikaan, kun postin varsinaista palvelutehtävää — postilähetysten kuljettamista asiakkailleen — on heikennetty, valtionyhtiö tekee kaikkensa tunkeutuakseen yksityisten yritysten tonteille kiinteistönhoito- ja sosiaalitöihin.\n",
568
+ "Syksyn putoavien lehtien sijasta Postin tulisi keskittää tarmonsa siihen, miten se hoitaa tuoreiden lehtien eli paikallis- , sanoma- ja aikakauslehtien kuljetuksen. Tienoo tosin on hankkinut varhaiskantopalvelunsa yksityiseltä toimijalta, mutta Posti vie tai sen pitäisi viedä ydinaluetta kauempana oleville tilaajille lehdet määräaikoina. Käytännössä lehti menee asiakkaalle vasta seuraavana tai sitä seuraavana päivänä. Tuollainen toiminta on häpeällistä ja asiakasta väheksyvää.\n",
569
+ "Nyt Posti haluaisi rukata postilakia niin, että postia kannettaisiin vain kolmena päivänä viikossa. Asiaan on ottanut jyrkän kielteisen kannan muun muassa Sanomalehtien Liitto. Se katsoo, että viisipäiväisyys on turvattava. Keinoina ovat eri jakelujen ja muiden palvelujen yhdistäminen. Taajamissa esimerkiksi päiväposti olisi mahdollista jakaa aamun sanomalehtijakelun yhteydessä. Haja-asutusalueella yhteistyötä Posti voisi tehdä yksityisten yrittäjien kanssa.\n",
570
+ "Kun jo nyt monin paikoin Posti on omatoimisesti siirtynyt nelipäiväiseen jakeluun ja pantannut joskus loppuviikon postit seuraavalle viikolle, tärkeät kirjeet saattavat jäädä saamatta määräaikoina. Monipuolisen tiedonvälityksen vuoksi on lehdetkin toimitetava perille ilmestymispäivinä. Kun se on ollut ennekin mahdollista, miksei se nyt onnistuisi?\n",
571
+ "Rauli Ala-Karvia\n"
572
+ ]
573
+ }
574
+ ],
575
+ "source": [
576
+ "print(f'Articles title:\\n{dataset[0][\"title\"]}')\n",
577
+ "print('-' * 50)\n",
578
+ "print(f'Articles title:\\n{dataset[0][\"maintext\"]}')"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "code",
583
+ "execution_count": 14,
584
+ "id": "7457d23e-c228-433f-870e-e402f6ca52fd",
585
+ "metadata": {},
586
+ "outputs": [],
587
+ "source": [
588
+ "dataset_filtered = dataset.select([i for i in range(3000)])\n",
589
+ "\n",
590
+ "dataset_filtered = dataset_filtered.rename_column('maintext', 'instruction')\n",
591
+ "dataset_filtered = dataset_filtered.rename_column('title', 'response')"
592
+ ]
593
+ },
594
+ {
595
+ "cell_type": "code",
596
+ "execution_count": 16,
597
+ "id": "bcd33d1f-7530-4785-a03a-0e0e48c85532",
598
+ "metadata": {},
599
+ "outputs": [
600
+ {
601
+ "data": {
602
+ "application/vnd.jupyter.widget-view+json": {
603
+ "model_id": "dfd349b056394d26abd32d0a98295ae6",
604
+ "version_major": 2,
605
+ "version_minor": 0
606
+ },
607
+ "text/plain": [
608
+ "Map: 0%| | 0/3000 [00:00<?, ? examples/s]"
609
+ ]
610
+ },
611
+ "metadata": {},
612
+ "output_type": "display_data"
613
+ }
614
+ ],
615
+ "source": [
616
+ "def form_instruction(row):\n",
617
+ " row[\"instruction\"] = f'Saat seuraavana artikkelin tekstin. Tehtävänäsi on tuottaa otsikko artikkelin perusteella.\\n Artikkeli:\\n{row[\"instruction\"]}\\n Nyt luo otsikko artikkelin perusteella. \\n Otsikko: \\n'\n",
618
+ " return row\n",
619
+ "\n",
620
+ "dataset_filtered = dataset_filtered.map(form_instruction)"
621
+ ]
622
+ },
623
+ {
624
+ "cell_type": "code",
625
+ "execution_count": 19,
626
+ "id": "dfb38fa1-7d85-48a0-9941-f540324f38d9",
627
+ "metadata": {},
628
+ "outputs": [
629
+ {
630
+ "data": {
631
+ "application/vnd.jupyter.widget-view+json": {
632
+ "model_id": "045c6fbbfbdd49bca1f4e9a824985a2d",
633
+ "version_major": 2,
634
+ "version_minor": 0
635
+ },
636
+ "text/plain": [
637
+ "Map: 0%| | 0/3000 [00:00<?, ? examples/s]"
638
+ ]
639
+ },
640
+ "metadata": {},
641
+ "output_type": "display_data"
642
+ }
643
+ ],
644
+ "source": [
645
+ "def form_messages(row):\n",
646
+ " row[\"messages\"] = [{'role': 'user', 'content': row[\"instruction\"]}, {'role': 'assistant', 'content': row[\"response\"]}]\n",
647
+ " return row\n",
648
+ "\n",
649
+ "dataset_filtered = dataset_filtered.map(form_messages)"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": 22,
655
+ "id": "e6cef361-588e-4d25-8d79-dde7a6150387",
656
+ "metadata": {},
657
+ "outputs": [
658
+ {
659
+ "data": {
660
+ "application/vnd.jupyter.widget-view+json": {
661
+ "model_id": "a93736c087ac4c75bd3b130343b36ceb",
662
+ "version_major": 2,
663
+ "version_minor": 0
664
+ },
665
+ "text/plain": [
666
+ "Map: 0%| | 0/3000 [00:00<?, ? examples/s]"
667
+ ]
668
+ },
669
+ "metadata": {},
670
+ "output_type": "display_data"
671
+ }
672
+ ],
673
+ "source": [
674
+ "def form_prompt(row):\n",
675
+ " row[\"text\"] = tokenizer.apply_chat_template(row[\"messages\"], tokenize=False)\n",
676
+ " return row\n",
677
+ "\n",
678
+ "dataset_filtered = dataset_filtered.map(form_prompt)"
679
+ ]
680
+ },
681
+ {
682
+ "cell_type": "code",
683
+ "execution_count": 23,
684
+ "id": "4daf13d2-ef17-48d5-8176-8a47452e9424",
685
+ "metadata": {},
686
+ "outputs": [
687
+ {
688
+ "data": {
689
+ "text/plain": [
690
+ "{'response': 'Postin tulisi kiinnostua myös tuoreista lehdistä',\n",
691
+ " 'instruction': 'Saat seuraavana artikkelin tekstin. Tehtävänäsi on tuottaa otsikko artikkelin perusteella.\\n Artikkeli:\\nPosti kertoi eilen tiedotteessaan, että sen työntekijät siirtyvät tällä viikolla nurmikonleikkuusta lehtien haravoimiseen. Postin verkkokaupasta ja Postin omista myymälöistä voi tilata pihan haravoinnin haluamilleen kahdelle tiistaipäivälle. Syyslehdille voi myös tilata Postilta poisviennin.\\nSamaan aikaan, kun postin varsinaista palvelutehtävää — postilähetysten kuljettamista asiakkailleen — on heikennetty, valtionyhtiö tekee kaikkensa tunkeutuakseen yksityisten yritysten tonteille kiinteistönhoito- ja sosiaalitöihin.\\nSyksyn putoavien lehtien sijasta Postin tulisi keskittää tarmonsa siihen, miten se hoitaa tuoreiden lehtien eli paikallis- , sanoma- ja aikakauslehtien kuljetuksen. Tienoo tosin on hankkinut varhaiskantopalvelunsa yksityiseltä toimijalta, mutta Posti vie tai sen pitäisi viedä ydinaluetta kauempana oleville tilaajille lehdet määräaikoina. Käytännössä lehti menee asiakkaalle vasta seuraavana tai sitä seuraavana päivänä. Tuollainen toiminta on häpeällistä ja asiakasta väheksyvää.\\nNyt Posti haluaisi rukata postilakia niin, että postia kannettaisiin vain kolmena päivänä viikossa. Asiaan on ottanut jyrkän kielteisen kannan muun muassa Sanomalehtien Liitto. Se katsoo, että viisipäiväisyys on turvattava. Keinoina ovat eri jakelujen ja muiden palvelujen yhdistäminen. Taajamissa esimerkiksi päiväposti olisi mahdollista jakaa aamun sanomalehtijakelun yhteydessä. Haja-asutusalueella yhteistyötä Posti voisi tehdä yksityisten yrittäjien kanssa.\\nKun jo nyt monin paikoin Posti on omatoimisesti siirtynyt nelipäiväiseen jakeluun ja pantannut joskus loppuviikon postit seuraavalle viikolle, tärkeät kirjeet saattavat jäädä saamatta määräaikoina. Monipuolisen tiedonvälityksen vuoksi on lehdetkin toimitetava perille ilmestymispäivinä. Kun se on ollut ennekin mahdollista, miksei se nyt onnistuisi?\\nRauli Ala-Karvia\\n Nyt luo otsikko artikkelin perusteella. \\n Otsikko: \\n',\n",
692
+ " 'url': 'http://www.turuntienoo.fi/index.php/2904-',\n",
693
+ " 'date_publish': '2016-09-15 00:00:00',\n",
694
+ " 'messages': [{'content': 'Saat seuraavana artikkelin tekstin. Tehtävänäsi on tuottaa otsikko artikkelin perusteella.\\n Artikkeli:\\nPosti kertoi eilen tiedotteessaan, että sen työntekijät siirtyvät tällä viikolla nurmikonleikkuusta lehtien haravoimiseen. Postin verkkokaupasta ja Postin omista myymälöistä voi tilata pihan haravoinnin haluamilleen kahdelle tiistaipäivälle. Syyslehdille voi myös tilata Postilta poisviennin.\\nSamaan aikaan, kun postin varsinaista palvelutehtävää — postilähetysten kuljettamista asiakkailleen — on heikennetty, valtionyhtiö tekee kaikkensa tunkeutuakseen yksityisten yritysten tonteille kiinteistönhoito- ja sosiaalitöihin.\\nSyksyn putoavien lehtien sijasta Postin tulisi keskittää tarmonsa siihen, miten se hoitaa tuoreiden lehtien eli paikallis- , sanoma- ja aikakauslehtien kuljetuksen. Tienoo tosin on hankkinut varhaiskantopalvelunsa yksityiseltä toimijalta, mutta Posti vie tai sen pitäisi viedä ydinaluetta kauempana oleville tilaajille lehdet määräaikoina. Käytännössä lehti menee asiakkaalle vasta seuraavana tai sitä seuraavana päivänä. Tuollainen toiminta on häpeällistä ja asiakasta väheksyvää.\\nNyt Posti haluaisi rukata postilakia niin, että postia kannettaisiin vain kolmena päivänä viikossa. Asiaan on ottanut jyrkän kielteisen kannan muun muassa Sanomalehtien Liitto. Se katsoo, että viisipäiväisyys on turvattava. Keinoina ovat eri jakelujen ja muiden palvelujen yhdistäminen. Taajamissa esimerkiksi päiväposti olisi mahdollista jakaa aamun sanomalehtijakelun yhteydessä. Haja-asutusalueella yhteistyötä Posti voisi tehdä yksityisten yrittäjien kanssa.\\nKun jo nyt monin paikoin Posti on omatoimisesti siirtynyt nelipäiväiseen jakeluun ja pantannut joskus loppuviikon postit seuraavalle viikolle, tärkeät kirjeet saattavat jäädä saamatta määräaikoina. Monipuolisen tiedonvälityksen vuoksi on lehdetkin toimitetava perille ilmestymispäivinä. Kun se on ollut ennekin mahdollista, miksei se nyt onnistuisi?\\nRauli Ala-Karvia\\n Nyt luo otsikko artikkelin perusteella. \\n Otsikko: \\n',\n",
695
+ " 'role': 'user'},\n",
696
+ " {'content': 'Postin tulisi kiinnostua myös tuoreista lehdistä',\n",
697
+ " 'role': 'assistant'}],\n",
698
+ " 'text': '<s> [INST] <<SYS>>\\nOlet tekoälyavustaja. Vastaat aina mahdollisimman avuliaasti. Vastauksesi eivät saa sisältää mitään haitallista, epäeettistä, rasistista, seksististä, vaarallista tai laitonta sisältöä. Jos kysymyksessä ei ole mitään järkeä tai se ei ole asiasisällöltään johdonmukainen, selitä miksi sen sijaan, että vastaisit jotain väärin. Jos et tiedä vastausta kysymykseen, älä kerro väärää tietoa.\\n<</SYS>>\\n\\nSaat seuraavana artikkelin tekstin. Tehtävänäsi on tuottaa otsikko artikkelin perusteella.\\n Artikkeli:\\nPosti kertoi eilen tiedotteessaan, että sen työntekijät siirtyvät tällä viikolla nurmikonleikkuusta lehtien haravoimiseen. Postin verkkokaupasta ja Postin omista myymälöistä voi tilata pihan haravoinnin haluamilleen kahdelle tiistaipäivälle. Syyslehdille voi myös tilata Postilta poisviennin.\\nSamaan aikaan, kun postin varsinaista palvelutehtävää — postilähetysten kuljettamista asiakkailleen — on heikennetty, valtionyhtiö tekee kaikkensa tunkeutuakseen yksityisten yritysten tonteille kiinteistönhoito- ja sosiaalitöihin.\\nSyksyn putoavien lehtien sijasta Postin tulisi keskittää tarmonsa siihen, miten se hoitaa tuoreiden lehtien eli paikallis- , sanoma- ja aikakauslehtien kuljetuksen. Tienoo tosin on hankkinut varhaiskantopalvelunsa yksityiseltä toimijalta, mutta Posti vie tai sen pitäisi viedä ydinaluetta kauempana oleville tilaajille lehdet määräaikoina. Käytännössä lehti menee asiakkaalle vasta seuraavana tai sitä seuraavana päivänä. Tuollainen toiminta on häpeällistä ja asiakasta väheksyvää.\\nNyt Posti haluaisi rukata postilakia niin, että postia kannettaisiin vain kolmena päivänä viikossa. Asiaan on ottanut jyrkän kielteisen kannan muun muassa Sanomalehtien Liitto. Se katsoo, että viisipäiväisyys on turvattava. Keinoina ovat eri jakelujen ja muiden palvelujen yhdistäminen. Taajamissa esimerkiksi päiväposti olisi mahdollista jakaa aamun sanomalehtijakelun yhteydessä. Haja-asutusalueella yhteistyötä Posti voisi tehdä yksityisten yrittäjien kanssa.\\nKun jo nyt monin paikoin Posti on omatoimisesti siirtynyt nelipäiväiseen jakeluun ja pantannut joskus loppuviikon postit seuraavalle viikolle, tärkeät kirjeet saattavat jäädä saamatta määräaikoina. Monipuolisen tiedonvälityksen vuoksi on lehdetkin toimitetava perille ilmestymispäivinä. Kun se on ollut ennekin mahdollista, miksei se nyt onnistuisi?\\nRauli Ala-Karvia\\n Nyt luo otsikko artikkelin perusteella. \\n Otsikko: [/INST] Postin tulisi kiinnostua myös tuoreista lehdistä</s>'}"
699
+ ]
700
+ },
701
+ "execution_count": 23,
702
+ "metadata": {},
703
+ "output_type": "execute_result"
704
+ }
705
+ ],
706
+ "source": [
707
+ "dataset_filtered[0]"
708
+ ]
709
+ },
710
+ {
711
+ "cell_type": "code",
712
+ "execution_count": 24,
713
+ "id": "a83682b6-ae4e-418c-939d-95bb9ec8b4e5",
714
+ "metadata": {},
715
+ "outputs": [],
716
+ "source": [
717
+ "from unsloth import UnslothTrainer, UnslothTrainingArguments\n",
718
+ "import math\n",
719
+ "\n",
720
+ "batch_size = 4\n",
721
+ "eval_batch_size = 4\n",
722
+ "gradient_accumulation_steps = 4\n",
723
+ "epochs = 1\n",
724
+ "train_steps = math.ceil(len(dataset_filtered) / batch_size / gradient_accumulation_steps * epochs)\n",
725
+ "eval_steps = math.floor(train_steps/epochs/5)\n",
726
+ "warmup_steps = math.ceil(train_steps * 0.1)\n",
727
+ "dataset_split = dataset_filtered.train_test_split(test_size=0.05)\n",
728
+ "output_dir = 'train_checkpoints'"
729
+ ]
730
+ },
731
+ {
732
+ "cell_type": "code",
733
+ "execution_count": 25,
734
+ "id": "a7447baf-4346-4916-a279-7e5a70bd2b06",
735
+ "metadata": {},
736
+ "outputs": [
737
+ {
738
+ "name": "stderr",
739
+ "output_type": "stream",
740
+ "text": [
741
+ "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
742
+ " warnings.warn(\n"
743
+ ]
744
+ },
745
+ {
746
+ "data": {
747
+ "application/vnd.jupyter.widget-view+json": {
748
+ "model_id": "3617098845674e3392bedfc21b9fab4a",
749
+ "version_major": 2,
750
+ "version_minor": 0
751
+ },
752
+ "text/plain": [
753
+ "Map (num_proc=2): 0%| | 0/2850 [00:00<?, ? examples/s]"
754
+ ]
755
+ },
756
+ "metadata": {},
757
+ "output_type": "display_data"
758
+ },
759
+ {
760
+ "data": {
761
+ "application/vnd.jupyter.widget-view+json": {
762
+ "model_id": "2faae87a88924eacbd93437b69042949",
763
+ "version_major": 2,
764
+ "version_minor": 0
765
+ },
766
+ "text/plain": [
767
+ "Map (num_proc=2): 0%| | 0/150 [00:00<?, ? examples/s]"
768
+ ]
769
+ },
770
+ "metadata": {},
771
+ "output_type": "display_data"
772
+ },
773
+ {
774
+ "name": "stderr",
775
+ "output_type": "stream",
776
+ "text": [
777
+ "max_steps is given, it will override any value given in num_train_epochs\n"
778
+ ]
779
+ }
780
+ ],
781
+ "source": [
782
+ "trainer = UnslothTrainer(\n",
783
+ " model = model,\n",
784
+ " tokenizer = tokenizer,\n",
785
+ " train_dataset = dataset_split[\"train\"],\n",
786
+ " eval_dataset = dataset_split[\"test\"],\n",
787
+ " dataset_text_field = 'text',\n",
788
+ " max_seq_length = max_seq_length,\n",
789
+ " dataset_num_proc = 2,\n",
790
+ " packing=False,\n",
791
+ " args = UnslothTrainingArguments(\n",
792
+ " per_device_train_batch_size = batch_size,\n",
793
+ " per_device_eval_batch_size = eval_batch_size,\n",
794
+ " gradient_accumulation_steps = gradient_accumulation_steps,\n",
795
+ " warmup_steps = warmup_steps,\n",
796
+ " max_steps = train_steps,\n",
797
+ " eval_steps = eval_steps,\n",
798
+ " save_steps = eval_steps,\n",
799
+ " evaluation_strategy = 'steps',\n",
800
+ " save_strategy = 'steps',\n",
801
+ " learning_rate = 0.00002,\n",
802
+ " fp16=False,\n",
803
+ " bf16=True,\n",
804
+ " logging_steps=5,\n",
805
+ " optim = \"paged_adamw_8bit\",\n",
806
+ " weight_decay = 0.005,\n",
807
+ " lr_scheduler_type = 'cosine',\n",
808
+ " seed=3407,\n",
809
+ " output_dir = output_dir\n",
810
+ " ),\n",
811
+ ")\n",
812
+ " "
813
+ ]
814
+ },
815
+ {
816
+ "cell_type": "code",
817
+ "execution_count": 26,
818
+ "id": "1dc9f3e2-99b8-441f-b8da-0624e9d56db4",
819
+ "metadata": {},
820
+ "outputs": [
821
+ {
822
+ "name": "stderr",
823
+ "output_type": "stream",
824
+ "text": [
825
+ "==((====))== Unsloth - 2x faster free finetuning | Num GPUs = 1\n",
826
+ " \\\\ /| Num examples = 2,850 | Num Epochs = 2\n",
827
+ "O^O/ \\_/ \\ Batch size per device = 4 | Gradient Accumulation steps = 4\n",
828
+ "\\ / Total batch size = 16 | Total steps = 188\n",
829
+ " \"-____-\" Number of trainable parameters = 462,096,640\n"
830
+ ]
831
+ },
832
+ {
833
+ "data": {
834
+ "text/html": [
835
+ "\n",
836
+ " <div>\n",
837
+ " \n",
838
+ " <progress value='188' max='188' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
839
+ " [188/188 12:30, Epoch 1/2]\n",
840
+ " </div>\n",
841
+ " <table border=\"1\" class=\"dataframe\">\n",
842
+ " <thead>\n",
843
+ " <tr style=\"text-align: left;\">\n",
844
+ " <th>Step</th>\n",
845
+ " <th>Training Loss</th>\n",
846
+ " <th>Validation Loss</th>\n",
847
+ " </tr>\n",
848
+ " </thead>\n",
849
+ " <tbody>\n",
850
+ " <tr>\n",
851
+ " <td>37</td>\n",
852
+ " <td>2.150700</td>\n",
853
+ " <td>2.066879</td>\n",
854
+ " </tr>\n",
855
+ " <tr>\n",
856
+ " <td>74</td>\n",
857
+ " <td>1.930400</td>\n",
858
+ " <td>1.996720</td>\n",
859
+ " </tr>\n",
860
+ " <tr>\n",
861
+ " <td>111</td>\n",
862
+ " <td>1.834200</td>\n",
863
+ " <td>1.981537</td>\n",
864
+ " </tr>\n",
865
+ " <tr>\n",
866
+ " <td>148</td>\n",
867
+ " <td>2.056900</td>\n",
868
+ " <td>1.977489</td>\n",
869
+ " </tr>\n",
870
+ " <tr>\n",
871
+ " <td>185</td>\n",
872
+ " <td>1.643500</td>\n",
873
+ " <td>1.976829</td>\n",
874
+ " </tr>\n",
875
+ " </tbody>\n",
876
+ "</table><p>"
877
+ ],
878
+ "text/plain": [
879
+ "<IPython.core.display.HTML object>"
880
+ ]
881
+ },
882
+ "metadata": {},
883
+ "output_type": "display_data"
884
+ },
885
+ {
886
+ "name": "stderr",
887
+ "output_type": "stream",
888
+ "text": [
889
+ "/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py:209: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
890
+ " warnings.warn(\n",
891
+ "/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py:209: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
892
+ " warnings.warn(\n",
893
+ "/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py:209: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
894
+ " warnings.warn(\n",
895
+ "/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py:209: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
896
+ " warnings.warn(\n",
897
+ "/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py:209: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
898
+ " warnings.warn(\n"
899
+ ]
900
+ },
901
+ {
902
+ "data": {
903
+ "text/plain": [
904
+ "TrainOutput(global_step=188, training_loss=1.9880720529150455, metrics={'train_runtime': 754.5522, 'train_samples_per_second': 3.986, 'train_steps_per_second': 0.249, 'total_flos': 5.843729764618752e+16, 'train_loss': 1.9880720529150455, 'epoch': 1.0546984572230014})"
905
+ ]
906
+ },
907
+ "execution_count": 26,
908
+ "metadata": {},
909
+ "output_type": "execute_result"
910
+ }
911
+ ],
912
+ "source": [
913
+ "trainer.train()"
914
+ ]
915
+ },
916
+ {
917
+ "cell_type": "code",
918
+ "execution_count": null,
919
+ "id": "a06d56e0-f460-4f49-a737-8d74bdb49ca6",
920
+ "metadata": {},
921
+ "outputs": [],
922
+ "source": [
923
+ "from transfomers import GenerationConfig\n",
924
+ "\n",
925
+ "generation_config = GenerationConfig(\n",
926
+ " pad_token_id = tokenizer.pad_token_id,\n",
927
+ " eos_token_id = tokenizer.convert_tokens_to_ids(\"</s>\")\n",
928
+ ")"
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "code",
933
+ "execution_count": 37,
934
+ "id": "df1a3b68-2212-4b95-a2dd-e973084e474a",
935
+ "metadata": {},
936
+ "outputs": [
937
+ {
938
+ "data": {
939
+ "text/plain": [
940
+ "150"
941
+ ]
942
+ },
943
+ "execution_count": 37,
944
+ "metadata": {},
945
+ "output_type": "execute_result"
946
+ }
947
+ ],
948
+ "source": [
949
+ "len(dataset_split[\"test\"])"
950
+ ]
951
+ },
952
+ {
953
+ "cell_type": "code",
954
+ "execution_count": 52,
955
+ "id": "0e53b5a9-b5e4-4453-a275-e13f4fc67223",
956
+ "metadata": {},
957
+ "outputs": [
958
+ {
959
+ "name": "stdout",
960
+ "output_type": "stream",
961
+ "text": [
962
+ "--------\n",
963
+ "<s><s> [INST] <<SYS>>\n",
964
+ "Olet tekoälyavustaja. Vastaat aina mahdollisimman avuliaasti. Vastauksesi eivät saa sisältää mitään haitallista, epäeettistä, rasistista, seksististä, vaarallista tai laitonta sisältöä. Jos kysymyksessä ei ole mitään järkeä tai se ei ole asiasisällöltään johdonmukainen, selitä miksi sen sijaan, että vastaisit jotain väärin. Jos et tiedä vastausta kysymykseen, älä kerro väärää tietoa.\n",
965
+ "<</SYS>>\n",
966
+ "\n",
967
+ "Saat seuraavana artikkelin tekstin. Tehtävänäsi on tuottaa otsikko artikkelin perusteella.\n",
968
+ " Artikkeli:\n",
969
+ "Tiedekeskus Heurekan Pokémon Go -päivä keräsi lauantaina paikalle järjestäjien laskujen mukaan noin 2 000 pelaajaa. Päivään kuului paitsi pelaamista, myös tutkijoiden luentoja pelaamisesta ja lisätystä todellisuudesta.\n",
970
+ "”Lisätty todellisuus on kiinnostava ilmiö, ja nyt kun tällainen hyvin sosiaalinen peli on noussut pinnalle, halusimme tuoda aiheeseen tieteellisen näkökulman”, kertoo Heurekan tapahtumatuottaja Siina Vasama.\n",
971
+ "”Paikalla oli teiniporukoita, paljon perheitä ja myös isovanhempia lapsenlapsineen”, Vasama kertoo.\n",
972
+ "Pelaaminen tapahtui ulkona Heurekan edustalla. Tilaisuudessa jaettiin myös pelivinkkejä.\n",
973
+ "Auditoriossa puolestaan yli 300 ihmistä oli kuuntelemassa pelitutkimukseen erikoistuneen professorin Frans Mäyrän ja Pokémonia tutkineen väitöskirjatutkijan Johannes Kosken puheenvuoroja, joissa he pohtivat muun muassa pelaamisen muuttumista ja lisätyn todellisuuden tulevaisuutta.\n",
974
+ "Puheenvuoroissa pohdittiin myös pelaamisen hyväksyttävyyttä. Pelien on muun muassa pelätty vähentävän ihmisten välistä vuorovaikutusta.\n",
975
+ "Tapahtuman järjestäjien kokemus on päinvastainen.\n",
976
+ "”Ihmiset viettivät hyvin liikunnallista ja sosiaalista päivää ulkona ja juttelivat paljon keskenään. Tämän päivän kokemuksen perusteella kallistun sille kannalle, että peli lisäsi ulkoilua ja kanssakäymistä”, Vasama sanoo.\n",
977
+ " Nyt luo otsikko artikkelin perusteella. \n",
978
+ " Otsikko: [/INST] Heurekassa pelattiin Pokémon Go -peliä</s>\n"
979
+ ]
980
+ }
981
+ ],
982
+ "source": [
983
+ "import random\n",
984
+ "i = random.randint(0, len(dataset_split[\"test\"]))\n",
985
+ "\n",
986
+ "model.eval()\n",
987
+ "\n",
988
+ "inputs = tokenizer([tokenizer.apply_chat_template([{'role': 'user', 'content': dataset_split[\"test\"][i][\"instruction\"]}], tokenize=False)] * 1, return_tensors='pt').to(\"cuda\")\n",
989
+ "\n",
990
+ "with torch.no_grad():\n",
991
+ " generation_ids = model.generate(\n",
992
+ " input_ids = inputs[\"input_ids\"],\n",
993
+ " attention_mask = inputs[\"attention_mask\"],\n",
994
+ " generation_config = generation_config, **{\n",
995
+ " \"temperature\": 0.7,\n",
996
+ " \"penalty_alpha\": 0.6,\n",
997
+ " \"min_p\": 0.5,\n",
998
+ " \"do_sample\": True,\n",
999
+ " \"repetition_penalty\": 1.28,\n",
1000
+ " \"min_length\": 6,\n",
1001
+ " \"max_new_tokens\": 50})\n",
1002
+ "\n",
1003
+ "generated_text = tokenizer.batch_decode(generation_ids, skip_special_tokens = False, cleanup_tokenization_spaces = True)[0]\n",
1004
+ "print('--------')\n",
1005
+ "print(generated_text)\n",
1006
+ " "
1007
+ ]
1008
+ },
1009
+ {
1010
+ "cell_type": "code",
1011
+ "execution_count": null,
1012
+ "id": "aba8b4ce-2914-45fb-ad3e-a23f75b57f28",
1013
+ "metadata": {},
1014
+ "outputs": [],
1015
+ "source": []
1016
+ }
1017
+ ],
1018
+ "metadata": {
1019
+ "kernelspec": {
1020
+ "display_name": "Python 3 (ipykernel)",
1021
+ "language": "python",
1022
+ "name": "python3"
1023
+ },
1024
+ "language_info": {
1025
+ "codemirror_mode": {
1026
+ "name": "ipython",
1027
+ "version": 3
1028
+ },
1029
+ "file_extension": ".py",
1030
+ "mimetype": "text/x-python",
1031
+ "name": "python",
1032
+ "nbconvert_exporter": "python",
1033
+ "pygments_lexer": "ipython3",
1034
+ "version": "3.10.12"
1035
+ }
1036
+ },
1037
+ "nbformat": 4,
1038
+ "nbformat_minor": 5
1039
+ }
setup_steps.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Setup Runpod account and add few dollars.
2
+ Setup RTX 4090 with
3
+ Disk 50GB/50GB
4
+ Select image:
5
+ runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04
6
+
7
+ pip install --upgrade --force-reinstall --no-cache-dir torch==2.1.1 triton --index-url https://download.pytorch.org/whl/cu121
8
+ pip install "unsloth[cu121-ampere] @ git+https://github.com/unslothai/unsloth.git"
9
+
10
+ --> run notebook