sappho192 commited on
Commit
893e943
1 Parent(s): fd5fd91

Update training code

Browse files
Files changed (1) hide show
  1. training.ipynb +112 -251
training.ipynb CHANGED
@@ -19,7 +19,7 @@
19
  },
20
  {
21
  "cell_type": "code",
22
- "execution_count": 1,
23
  "metadata": {
24
  "id": "t-jXeSJKE1WM"
25
  },
@@ -32,15 +32,19 @@
32
  "import torch\n",
33
  "from transformers import (\n",
34
  " PreTrainedTokenizerFast,\n",
 
35
  " DataCollatorForSeq2Seq,\n",
36
  " Seq2SeqTrainingArguments,\n",
37
- " BertJapaneseTokenizer,\n",
38
  " Trainer\n",
39
  ")\n",
40
  "from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel\n",
41
  "\n",
42
  "from datasets import load_dataset\n",
43
  "\n",
 
 
 
 
44
  "# encoder_model_name = \"xlm-roberta-base\"\n",
45
  "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n",
46
  "decoder_model_name = \"skt/kogpt2-base-v2\""
@@ -48,31 +52,21 @@
48
  },
49
  {
50
  "cell_type": "code",
51
- "execution_count": 2,
52
  "metadata": {
53
  "id": "nEW5trBtbykK"
54
  },
55
- "outputs": [
56
- {
57
- "data": {
58
- "text/plain": [
59
- "(device(type='cpu'), 0)"
60
- ]
61
- },
62
- "execution_count": 2,
63
- "metadata": {},
64
- "output_type": "execute_result"
65
- }
66
- ],
67
  "source": [
68
- "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
69
- "# device = torch.device(\"cpu\")\n",
70
- "device, torch.cuda.device_count()"
 
71
  ]
72
  },
73
  {
74
  "cell_type": "code",
75
- "execution_count": 3,
76
  "metadata": {
77
  "id": "5ic7pUUBFU_v"
78
  },
@@ -82,9 +76,9 @@
82
  " def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:\n",
83
  " return token_ids + [self.eos_token_id] \n",
84
  "\n",
85
- "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n",
86
- "trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, bos_token='</s>', eos_token='</s>', unk_token='<unk>',\n",
87
- " pad_token='<pad>', mask_token='<mask>')"
88
  ]
89
  },
90
  {
@@ -98,25 +92,7 @@
98
  },
99
  {
100
  "cell_type": "code",
101
- "execution_count": 4,
102
- "metadata": {
103
- "collapsed": false
104
- },
105
- "outputs": [],
106
- "source": [
107
- "dataset = load_dataset(\"sappho192/Tatoeba-Challenge-jpn-kor\")\n",
108
- "# dataset = load_dataset(\"D:\\\\REPO\\\\Tatoeba-Challenge-jpn-kor\")\n",
109
- "\n",
110
- "train_dataset = dataset['train']\n",
111
- "test_dataset = dataset['test']\n",
112
- "\n",
113
- "train_first_row = train_dataset[0]\n",
114
- "test_first_row = test_dataset[0]"
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": 5,
120
  "metadata": {
121
  "id": "65L4O1c5FLKt"
122
  },
@@ -124,7 +100,7 @@
124
  "source": [
125
  "class PairedDataset:\n",
126
  " def __init__(self, \n",
127
- " source_tokenizer: PreTrainedTokenizerFast, target_tokenizer: PreTrainedTokenizerFast,\n",
128
  " file_path: str = None,\n",
129
  " dataset_raw: datasets.Dataset = None\n",
130
  " ):\n",
@@ -132,7 +108,7 @@
132
  " self.trg_tokenizer = target_tokenizer\n",
133
  " \n",
134
  " if file_path is not None:\n",
135
- " with open(file_path, 'r') as fd:\n",
136
  " reader = csv.reader(fd)\n",
137
  " next(reader)\n",
138
  " self.data = [row for row in reader]\n",
@@ -159,52 +135,66 @@
159
  },
160
  {
161
  "cell_type": "code",
162
- "execution_count": 6,
163
  "metadata": {
164
  "collapsed": false
165
  },
166
  "outputs": [],
167
  "source": [
168
- "DATA_ROOT = './output'\n",
169
- "FILE_FFAC_FULL = 'ffac_full.csv'\n",
170
- "FILE_FFAC_TEST = 'ffac_test.csv'\n",
171
- "FILE_JA_KO_TRAIN = 'ja_ko_train.csv'\n",
172
- "FILE_JA_KO_TEST = 'ja_ko_test.csv'\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  "\n",
174
- "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_FFAC_FULL}')\n",
175
- "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_FFAC_TEST}') \n",
176
  "\n",
177
- "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_JA_KO_TRAIN}')\n",
178
- "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_JA_KO_TEST}')"
179
  ]
180
  },
181
  {
182
  "cell_type": "code",
183
- "execution_count": 7,
184
- "metadata": {
185
- "collapsed": false
186
- },
187
- "outputs": [
188
- {
189
- "data": {
190
- "text/plain": [
191
- "{'input_ids': [2, 33, 2181, 1402, 893, 15200, 893, 13507, 881, 933, 882, 829, 3], 'labels': [9085, 10936, 10993, 23363, 9134, 18368, 8006, 389, 1]}"
192
- ]
193
- },
194
- "execution_count": 7,
195
- "metadata": {},
196
- "output_type": "execute_result"
197
- }
198
- ],
199
  "source": [
200
- "train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, dataset_raw=train_dataset)\n",
201
- "eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, dataset_raw=test_dataset)\n",
202
- "eval_dataset[0]"
203
  ]
204
  },
205
  {
206
  "cell_type": "code",
207
- "execution_count": 8,
208
  "metadata": {},
209
  "outputs": [],
210
  "source": [
@@ -226,20 +216,11 @@
226
  },
227
  {
228
  "cell_type": "code",
229
- "execution_count": 9,
230
  "metadata": {
231
  "id": "I7uFbFYJFje8"
232
  },
233
- "outputs": [
234
- {
235
- "name": "stderr",
236
- "output_type": "stream",
237
- "text": [
238
- "Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at skt/kogpt2-base-v2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.10.crossattention.c_attn.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.10.crossattention.c_proj.bias', 'transformer.h.10.crossattention.c_proj.weight', 'transformer.h.10.crossattention.q_attn.bias', 'transformer.h.10.crossattention.q_attn.weight', 'transformer.h.10.ln_cross_attn.bias', 'transformer.h.10.ln_cross_attn.weight', 'transformer.h.11.crossattention.c_attn.bias', 'transformer.h.11.crossattention.c_attn.weight', 'transformer.h.11.crossattention.c_proj.bias', 'transformer.h.11.crossattention.c_proj.weight', 'transformer.h.11.crossattention.q_attn.bias', 'transformer.h.11.crossattention.q_attn.weight', 'transformer.h.11.ln_cross_attn.bias', 'transformer.h.11.ln_cross_attn.weight', 'transformer.h.2.crossattention.c_attn.bias', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.2.crossattention.c_proj.bias', 'transformer.h.2.crossattention.c_proj.weight', 'transformer.h.2.crossattention.q_attn.bias', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.2.ln_cross_attn.bias', 'transformer.h.2.ln_cross_attn.weight', 'transformer.h.3.crossattention.c_attn.bias', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.3.crossattention.c_proj.bias', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.3.crossattention.q_attn.bias', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.3.ln_cross_attn.bias', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.4.crossattention.c_attn.bias', 'transformer.h.4.crossattention.c_attn.weight', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.4.crossattention.q_attn.bias', 'transformer.h.4.crossattention.q_attn.weight', 'transformer.h.4.ln_cross_attn.bias', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_attn.bias', 'transformer.h.5.crossattention.c_attn.weight', 'transformer.h.5.crossattention.c_proj.bias', 'transformer.h.5.crossattention.c_proj.weight', 'transformer.h.5.crossattention.q_attn.bias', 'transformer.h.5.crossattention.q_attn.weight', 'transformer.h.5.ln_cross_attn.bias', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.6.crossattention.c_attn.bias', 'transformer.h.6.crossattention.c_attn.weight', 'transformer.h.6.crossattention.c_proj.bias', 'transformer.h.6.crossattention.c_proj.weight', 'transformer.h.6.crossattention.q_attn.bias', 'transformer.h.6.crossattention.q_attn.weight', 'transformer.h.6.ln_cross_attn.bias', 'transformer.h.6.ln_cross_attn.weight', 'transformer.h.7.crossattention.c_attn.bias', 'transformer.h.7.crossattention.c_attn.weight', 'transformer.h.7.crossattention.c_proj.bias', 'transformer.h.7.crossattention.c_proj.weight', 'transformer.h.7.crossattention.q_attn.bias', 'transformer.h.7.crossattention.q_attn.weight', 'transformer.h.7.ln_cross_attn.bias', 'transformer.h.7.ln_cross_attn.weight', 'transformer.h.8.crossattention.c_attn.bias', 'transformer.h.8.crossattention.c_attn.weight', 'transformer.h.8.crossattention.c_proj.bias', 'transformer.h.8.crossattention.c_proj.weight', 'transformer.h.8.crossattention.q_attn.bias', 'transformer.h.8.crossattention.q_attn.weight', 'transformer.h.8.ln_cross_attn.bias', 'transformer.h.8.ln_cross_attn.weight', 'transformer.h.9.crossattention.c_attn.bias', 'transformer.h.9.crossattention.c_attn.weight', 'transformer.h.9.crossattention.c_proj.bias', 'transformer.h.9.crossattention.c_proj.weight', 'transformer.h.9.crossattention.q_attn.bias', 'transformer.h.9.crossattention.q_attn.weight', 'transformer.h.9.ln_cross_attn.bias', 'transformer.h.9.ln_cross_attn.weight']\n",
239
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
240
- ]
241
- }
242
- ],
243
  "source": [
244
  "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\n",
245
  " encoder_model_name,\n",
@@ -251,174 +232,69 @@
251
  },
252
  {
253
  "cell_type": "code",
254
- "execution_count": 11,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  "metadata": {
256
  "id": "YFq2GyOAUV0W"
257
  },
258
- "outputs": [
259
- {
260
- "data": {
261
- "text/html": [
262
- "Finishing last run (ID:1vwqqxps) before initializing another..."
263
- ],
264
- "text/plain": [
265
- "<IPython.core.display.HTML object>"
266
- ]
267
- },
268
- "metadata": {},
269
- "output_type": "display_data"
270
- },
271
- {
272
- "data": {
273
- "application/vnd.jupyter.widget-view+json": {
274
- "model_id": "a82aa19a250b43f28d7ecc72eeebc88d",
275
- "version_major": 2,
276
- "version_minor": 0
277
- },
278
- "text/plain": [
279
- "VBox(children=(Label(value='0.001 MB of 0.010 MB uploaded\\r'), FloatProgress(value=0.10972568578553615, max=1.…"
280
- ]
281
- },
282
- "metadata": {},
283
- "output_type": "display_data"
284
- },
285
- {
286
- "data": {
287
- "text/html": [
288
- " View run <strong style=\"color:#cdcd00\">jbert+kogpt2</strong> at: <a href='https://wandb.ai/sappho192/fftr-poc1/runs/1vwqqxps' target=\"_blank\">https://wandb.ai/sappho192/fftr-poc1/runs/1vwqqxps</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
289
- ],
290
- "text/plain": [
291
- "<IPython.core.display.HTML object>"
292
- ]
293
- },
294
- "metadata": {},
295
- "output_type": "display_data"
296
- },
297
- {
298
- "data": {
299
- "text/html": [
300
- "Find logs at: <code>.\\wandb\\run-20240131_135356-1vwqqxps\\logs</code>"
301
- ],
302
- "text/plain": [
303
- "<IPython.core.display.HTML object>"
304
- ]
305
- },
306
- "metadata": {},
307
- "output_type": "display_data"
308
- },
309
- {
310
- "data": {
311
- "text/html": [
312
- "Successfully finished last run (ID:1vwqqxps). Initializing new run:<br/>"
313
- ],
314
- "text/plain": [
315
- "<IPython.core.display.HTML object>"
316
- ]
317
- },
318
- "metadata": {},
319
- "output_type": "display_data"
320
- },
321
- {
322
- "data": {
323
- "application/vnd.jupyter.widget-view+json": {
324
- "model_id": "c2cd7f6fb5b1428b98b80a3cc82ec303",
325
- "version_major": 2,
326
- "version_minor": 0
327
- },
328
- "text/plain": [
329
- "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011288888888884685, max=1.0…"
330
- ]
331
- },
332
- "metadata": {},
333
- "output_type": "display_data"
334
- },
335
- {
336
- "data": {
337
- "text/html": [
338
- "Tracking run with wandb version 0.16.2"
339
- ],
340
- "text/plain": [
341
- "<IPython.core.display.HTML object>"
342
- ]
343
- },
344
- "metadata": {},
345
- "output_type": "display_data"
346
- },
347
- {
348
- "data": {
349
- "text/html": [
350
- "Run data is saved locally in <code>d:\\REPO\\ffxiv-ja-ko-translator\\wandb\\run-20240131_135421-etxsdxw2</code>"
351
- ],
352
- "text/plain": [
353
- "<IPython.core.display.HTML object>"
354
- ]
355
- },
356
- "metadata": {},
357
- "output_type": "display_data"
358
- },
359
- {
360
- "data": {
361
- "text/html": [
362
- "Syncing run <strong><a href='https://wandb.ai/sappho192/fftr-poc1/runs/etxsdxw2' target=\"_blank\">jbert+kogpt2</a></strong> to <a href='https://wandb.ai/sappho192/fftr-poc1' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
363
- ],
364
- "text/plain": [
365
- "<IPython.core.display.HTML object>"
366
- ]
367
- },
368
- "metadata": {},
369
- "output_type": "display_data"
370
- },
371
- {
372
- "data": {
373
- "text/html": [
374
- " View project at <a href='https://wandb.ai/sappho192/fftr-poc1' target=\"_blank\">https://wandb.ai/sappho192/fftr-poc1</a>"
375
- ],
376
- "text/plain": [
377
- "<IPython.core.display.HTML object>"
378
- ]
379
- },
380
- "metadata": {},
381
- "output_type": "display_data"
382
- },
383
- {
384
- "data": {
385
- "text/html": [
386
- " View run at <a href='https://wandb.ai/sappho192/fftr-poc1/runs/etxsdxw2' target=\"_blank\">https://wandb.ai/sappho192/fftr-poc1/runs/etxsdxw2</a>"
387
- ],
388
- "text/plain": [
389
- "<IPython.core.display.HTML object>"
390
- ]
391
- },
392
- "metadata": {},
393
- "output_type": "display_data"
394
- }
395
- ],
396
  "source": [
397
  "# for Trainer\n",
398
  "import wandb\n",
399
  "\n",
400
  "collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)\n",
401
- "wandb.init(project=\"fftr-poc1\", name='jbert+kogpt2')\n",
402
  "\n",
403
  "arguments = Seq2SeqTrainingArguments(\n",
 
404
  " output_dir='dump',\n",
405
  " do_train=True,\n",
406
  " do_eval=True,\n",
407
  " evaluation_strategy=\"epoch\",\n",
408
  " save_strategy=\"epoch\",\n",
409
- " num_train_epochs=3,\n",
410
  " # num_train_epochs=25,\n",
411
- " per_device_train_batch_size=1,\n",
412
- " # per_device_train_batch_size=30, # takes 40GB\n",
413
- " # per_device_train_batch_size=64,\n",
414
- " per_device_eval_batch_size=1,\n",
415
- " # per_device_eval_batch_size=30,\n",
416
- " # per_device_eval_batch_size=64,\n",
417
  " warmup_ratio=0.1,\n",
418
  " gradient_accumulation_steps=4,\n",
419
  " save_total_limit=5,\n",
420
  " dataloader_num_workers=1,\n",
421
- " # fp16=True, # ENABLE if CUDA is enabled\n",
422
  " load_best_model_at_end=True,\n",
423
  " report_to='wandb'\n",
424
  ")\n",
@@ -454,26 +330,11 @@
454
  },
455
  {
456
  "cell_type": "code",
457
- "execution_count": 12,
458
  "metadata": {
459
  "id": "7vTqAgW6Ve3J"
460
  },
461
- "outputs": [
462
- {
463
- "data": {
464
- "application/vnd.jupyter.widget-view+json": {
465
- "model_id": "0afe460e9f614d9a90379cf99fcf8af3",
466
- "version_major": 2,
467
- "version_minor": 0
468
- },
469
- "text/plain": [
470
- " 0%| | 0/9671328 [00:00<?, ?it/s]"
471
- ]
472
- },
473
- "metadata": {},
474
- "output_type": "display_data"
475
- }
476
- ],
477
  "source": [
478
  "trainer.train()\n",
479
  "\n",
@@ -484,12 +345,12 @@
484
  },
485
  {
486
  "cell_type": "code",
487
- "execution_count": 2,
488
  "metadata": {},
489
  "outputs": [],
490
  "source": [
491
  "# import wandb\n",
492
- "# wandb.finish()"
493
  ]
494
  }
495
  ],
 
19
  },
20
  {
21
  "cell_type": "code",
22
+ "execution_count": null,
23
  "metadata": {
24
  "id": "t-jXeSJKE1WM"
25
  },
 
32
  "import torch\n",
33
  "from transformers import (\n",
34
  " PreTrainedTokenizerFast,\n",
35
+ " AutoTokenizer,\n",
36
  " DataCollatorForSeq2Seq,\n",
37
  " Seq2SeqTrainingArguments,\n",
 
38
  " Trainer\n",
39
  ")\n",
40
  "from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel\n",
41
  "\n",
42
  "from datasets import load_dataset\n",
43
  "\n",
44
+ "import os\n",
45
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
46
+ "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
47
+ "\n",
48
  "# encoder_model_name = \"xlm-roberta-base\"\n",
49
  "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n",
50
  "decoder_model_name = \"skt/kogpt2-base-v2\""
 
52
  },
53
  {
54
  "cell_type": "code",
55
+ "execution_count": null,
56
  "metadata": {
57
  "id": "nEW5trBtbykK"
58
  },
59
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
60
  "source": [
61
+ "# device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
62
+ "# # device = torch.device(\"cpu\")\n",
63
+ "# torch.cuda.set_device(device)\n",
64
+ "# device, torch.cuda.device_count()"
65
  ]
66
  },
67
  {
68
  "cell_type": "code",
69
+ "execution_count": null,
70
  "metadata": {
71
  "id": "5ic7pUUBFU_v"
72
  },
 
76
  " def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:\n",
77
  " return token_ids + [self.eos_token_id] \n",
78
  "\n",
79
+ "src_tokenizer = AutoTokenizer.from_pretrained(encoder_model_name, use_fast=False)\n",
80
+ "trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, use_fast=False,\n",
81
+ " bos_token='</s>', eos_token='</s>', unk_token='<unk>', pad_token='<pad>', mask_token='<mask>')"
82
  ]
83
  },
84
  {
 
92
  },
93
  {
94
  "cell_type": "code",
95
+ "execution_count": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  "metadata": {
97
  "id": "65L4O1c5FLKt"
98
  },
 
100
  "source": [
101
  "class PairedDataset:\n",
102
  " def __init__(self, \n",
103
+ " source_tokenizer: AutoTokenizer, target_tokenizer: GPT2Tokenizer,\n",
104
  " file_path: str = None,\n",
105
  " dataset_raw: datasets.Dataset = None\n",
106
  " ):\n",
 
108
  " self.trg_tokenizer = target_tokenizer\n",
109
  " \n",
110
  " if file_path is not None:\n",
111
+ " with open(file_path, 'r', encoding=\"utf-8\") as fd:\n",
112
  " reader = csv.reader(fd)\n",
113
  " next(reader)\n",
114
  " self.data = [row for row in reader]\n",
 
135
  },
136
  {
137
  "cell_type": "code",
138
+ "execution_count": null,
139
  "metadata": {
140
  "collapsed": false
141
  },
142
  "outputs": [],
143
  "source": [
144
+ "# DATASET_TARGET = \"TATOEBA_2023\"\n",
145
+ "# DATASET_TARGET = \"FFAC\"\n",
146
+ "DATASET_TARGET = \"AIHUB\"\n",
147
+ "\n",
148
+ "if (DATASET_TARGET == \"TATOEBA_2023\"):\n",
149
+ " # dataset = load_dataset(\"sappho192/Tatoeba-Challenge-jpn-kor\")\n",
150
+ " dataset = load_dataset(\"/home/akalive/dataset/Tatoeba-Challenge-jpn-kor\")\n",
151
+ "\n",
152
+ " train_dataset = dataset['train']\n",
153
+ " test_dataset = dataset['test']\n",
154
+ "\n",
155
+ " train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, dataset_raw=train_dataset)\n",
156
+ " eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, dataset_raw=test_dataset)\n",
157
+ "elif (DATASET_TARGET == \"FFAC\"):\n",
158
+ " DATA_ROOT = '/home/akalive/dataset/ffac/output'\n",
159
+ " FILE_FFAC_FULL = 'ffac_full.csv'\n",
160
+ " FILE_FFAC_TEST = 'ffac_test.csv'\n",
161
+ " FILE_JA_KO_TRAIN = 'tteb_train.csv'\n",
162
+ " FILE_JA_KO_TEST = 'tteb_test.csv'\n",
163
+ "\n",
164
+ " # train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_FFAC_FULL}')\n",
165
+ " # eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_FFAC_TEST}') \n",
166
+ "\n",
167
+ " train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_JA_KO_TRAIN}')\n",
168
+ " eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_JA_KO_TEST}')\n",
169
+ "elif (DATASET_TARGET == \"AIHUB\"):\n",
170
+ " # AIHUB dataset spent 25~33GB of VRAM with batch_size=30 while training.\n",
171
+ " DATA_ROOT = '/home/akalive/dataset/jkpair/data'\n",
172
+ " FILE_TRAIN = 'train.csv'\n",
173
+ " FILE_VAL = 'validation.csv'\n",
174
+ "\n",
175
+ " train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_TRAIN}')\n",
176
+ " eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_VAL}')\n",
177
  "\n",
178
+ "train_first_row = train_dataset[0]\n",
179
+ "eval_first_row = eval_dataset[0]\n",
180
  "\n",
181
+ "print(train_first_row)\n",
182
+ "print(eval_first_row)"
183
  ]
184
  },
185
  {
186
  "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  "source": [
191
+ "print(train_dataset)\n",
192
+ "train_dataset[0]"
 
193
  ]
194
  },
195
  {
196
  "cell_type": "code",
197
+ "execution_count": null,
198
  "metadata": {},
199
  "outputs": [],
200
  "source": [
 
216
  },
217
  {
218
  "cell_type": "code",
219
+ "execution_count": null,
220
  "metadata": {
221
  "id": "I7uFbFYJFje8"
222
  },
223
+ "outputs": [],
 
 
 
 
 
 
 
 
 
224
  "source": [
225
  "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\n",
226
  " encoder_model_name,\n",
 
232
  },
233
  {
234
  "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "class CustomTrainingArguments(Seq2SeqTrainingArguments):\n",
240
+ " def __init__(self,*args, **kwargs):\n",
241
+ " super(CustomTrainingArguments, self).__init__(*args, **kwargs)\n",
242
+ "\n",
243
+ " @property\n",
244
+ " def device(self) -> \"torch.device\":\n",
245
+ " \"\"\"\n",
246
+ " The device used by this process.\n",
247
+ " Name the device the number you use.\n",
248
+ " \"\"\"\n",
249
+ " return torch.device(\"cuda:0\")\n",
250
+ "\n",
251
+ " @property\n",
252
+ " def n_gpu(self):\n",
253
+ " \"\"\"\n",
254
+ " The number of GPUs used by this process.\n",
255
+ " Note:\n",
256
+ " This will only be greater than one when you have multiple GPUs available but are not using distributed\n",
257
+ " training. For distributed training, it will always be 1.\n",
258
+ " \"\"\"\n",
259
+ " # Make sure `self._n_gpu` is properly setup.\n",
260
+ " # _ = self._setup_devices\n",
261
+ " # I set to one manullay\n",
262
+ " self._n_gpu = 1\n",
263
+ " return self._n_gpu\n"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
  "metadata": {
270
  "id": "YFq2GyOAUV0W"
271
  },
272
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  "source": [
274
  "# for Trainer\n",
275
  "import wandb\n",
276
  "\n",
277
  "collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)\n",
278
+ "wandb.init(project=\"aihub-gt-2023\", name='jbert+kogpt2')\n",
279
  "\n",
280
  "arguments = Seq2SeqTrainingArguments(\n",
281
+ "# arguments = CustomTrainingArguments(\n",
282
  " output_dir='dump',\n",
283
  " do_train=True,\n",
284
  " do_eval=True,\n",
285
  " evaluation_strategy=\"epoch\",\n",
286
  " save_strategy=\"epoch\",\n",
287
+ " num_train_epochs=5, # for 40GB\n",
288
  " # num_train_epochs=25,\n",
289
+ " # per_device_train_batch_size=15,\n",
290
+ " per_device_train_batch_size=30, # takes 40GB\n",
291
+ " # per_device_eval_batch_size=10,\n",
292
+ " per_device_eval_batch_size=10,\n",
 
 
293
  " warmup_ratio=0.1,\n",
294
  " gradient_accumulation_steps=4,\n",
295
  " save_total_limit=5,\n",
296
  " dataloader_num_workers=1,\n",
297
+ " fp16=True, # ENABLE if CUDA is enabled\n",
298
  " load_best_model_at_end=True,\n",
299
  " report_to='wandb'\n",
300
  ")\n",
 
330
  },
331
  {
332
  "cell_type": "code",
333
+ "execution_count": null,
334
  "metadata": {
335
  "id": "7vTqAgW6Ve3J"
336
  },
337
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  "source": [
339
  "trainer.train()\n",
340
  "\n",
 
345
  },
346
  {
347
  "cell_type": "code",
348
+ "execution_count": null,
349
  "metadata": {},
350
  "outputs": [],
351
  "source": [
352
  "# import wandb\n",
353
+ "wandb.finish()"
354
  ]
355
  }
356
  ],