{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The primary codes below are based on [akpe12/JP-KR-ocr-translator-for-travel](https://github.com/akpe12/JP-KR-ocr-translator-for-travel)." ] }, { "cell_type": "markdown", "metadata": { "id": "TrHlPFqwFAgj" }, "source": [ "## Import" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "t-jXeSJKE1WM" }, "outputs": [], "source": [ "from typing import Dict, List\n", "import csv\n", "\n", "import datasets\n", "import torch\n", "from transformers import (\n", " PreTrainedTokenizerFast,\n", " DataCollatorForSeq2Seq,\n", " Seq2SeqTrainingArguments,\n", " BertJapaneseTokenizer,\n", " Trainer\n", ")\n", "from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel\n", "\n", "from datasets import load_dataset\n", "\n", "# encoder_model_name = \"xlm-roberta-base\"\n", "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n", "decoder_model_name = \"skt/kogpt2-base-v2\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "nEW5trBtbykK" }, "outputs": [ { "data": { "text/plain": [ "(device(type='cpu'), 0)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "# device = torch.device(\"cpu\")\n", "device, torch.cuda.device_count()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "5ic7pUUBFU_v" }, "outputs": [], "source": [ "class GPT2Tokenizer(PreTrainedTokenizerFast):\n", " def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:\n", " return token_ids + [self.eos_token_id] \n", "\n", "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n", "trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, bos_token='', eos_token='', unk_token='',\n", " pad_token='', mask_token='')" ] }, { "cell_type": "markdown", "metadata": { "id": "DTf4U1fmFQFh" }, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dataset = load_dataset(\"sappho192/Tatoeba-Challenge-jpn-kor\")\n", "# dataset = load_dataset(\"D:\\\\REPO\\\\Tatoeba-Challenge-jpn-kor\")\n", "\n", "train_dataset = dataset['train']\n", "test_dataset = dataset['test']\n", "\n", "train_first_row = train_dataset[0]\n", "test_first_row = test_dataset[0]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "65L4O1c5FLKt" }, "outputs": [], "source": [ "class PairedDataset:\n", " def __init__(self, \n", " source_tokenizer: PreTrainedTokenizerFast, target_tokenizer: PreTrainedTokenizerFast,\n", " file_path: str = None,\n", " dataset_raw: datasets.Dataset = None\n", " ):\n", " self.src_tokenizer = source_tokenizer\n", " self.trg_tokenizer = target_tokenizer\n", " \n", " if file_path is not None:\n", " with open(file_path, 'r') as fd:\n", " reader = csv.reader(fd)\n", " next(reader)\n", " self.data = [row for row in reader]\n", " elif dataset_raw is not None:\n", " self.data = dataset_raw\n", " else:\n", " raise ValueError('file_path or dataset_raw must be specified')\n", "\n", " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n", "# with open('train_log.txt', 'a+') as log_file:\n", "# log_file.write(f'reading data[{index}] {self.data[index]}\\n')\n", " if isinstance(self.data, datasets.Dataset):\n", " src, trg = self.data[index]['sourceString'], self.data[index]['targetString']\n", " else:\n", " src, trg = self.data[index]\n", " embeddings = self.src_tokenizer(src, return_attention_mask=False, return_token_type_ids=False)\n", " embeddings['labels'] = self.trg_tokenizer.build_inputs_with_special_tokens(self.trg_tokenizer(trg, return_attention_mask=False)['input_ids'])\n", "\n", " return embeddings\n", "\n", " def __len__(self):\n", " return len(self.data)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "DATA_ROOT = './output'\n", "FILE_FFAC_FULL = 'ffac_full.csv'\n", "FILE_FFAC_TEST = 'ffac_test.csv'\n", "FILE_JA_KO_TRAIN = 'ja_ko_train.csv'\n", "FILE_JA_KO_TEST = 'ja_ko_test.csv'\n", "\n", "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_FFAC_FULL}')\n", "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_FFAC_TEST}') \n", "\n", "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_JA_KO_TRAIN}')\n", "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_JA_KO_TEST}')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{'input_ids': [2, 33, 2181, 1402, 893, 15200, 893, 13507, 881, 933, 882, 829, 3], 'labels': [9085, 10936, 10993, 23363, 9134, 18368, 8006, 389, 1]}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, dataset_raw=train_dataset)\n", "eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, dataset_raw=test_dataset)\n", "eval_dataset[0]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# be sure to check the column count of each dataset if you encounter \"ValueError: too many values to unpack (expected 2)\"\n", "# at the `src, trg = self.data[index]`\n", "# The `cat ffac_full.csv tteb_train.csv > ja_ko_train.csv` command may be the reason.\n", "# the last row of first csv and first row of second csv is merged and that's why 3rd column is created (which arouse ValueError)\n", "# debug_data = train_dataset.data\n" ] }, { "cell_type": "markdown", "metadata": { "id": "uCBiLouSFiZY" }, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "I7uFbFYJFje8" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at skt/kogpt2-base-v2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.10.crossattention.c_attn.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.10.crossattention.c_proj.bias', 'transformer.h.10.crossattention.c_proj.weight', 'transformer.h.10.crossattention.q_attn.bias', 'transformer.h.10.crossattention.q_attn.weight', 'transformer.h.10.ln_cross_attn.bias', 'transformer.h.10.ln_cross_attn.weight', 'transformer.h.11.crossattention.c_attn.bias', 'transformer.h.11.crossattention.c_attn.weight', 'transformer.h.11.crossattention.c_proj.bias', 'transformer.h.11.crossattention.c_proj.weight', 'transformer.h.11.crossattention.q_attn.bias', 'transformer.h.11.crossattention.q_attn.weight', 'transformer.h.11.ln_cross_attn.bias', 'transformer.h.11.ln_cross_attn.weight', 'transformer.h.2.crossattention.c_attn.bias', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.2.crossattention.c_proj.bias', 'transformer.h.2.crossattention.c_proj.weight', 'transformer.h.2.crossattention.q_attn.bias', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.2.ln_cross_attn.bias', 'transformer.h.2.ln_cross_attn.weight', 'transformer.h.3.crossattention.c_attn.bias', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.3.crossattention.c_proj.bias', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.3.crossattention.q_attn.bias', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.3.ln_cross_attn.bias', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.4.crossattention.c_attn.bias', 'transformer.h.4.crossattention.c_attn.weight', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.4.crossattention.q_attn.bias', 'transformer.h.4.crossattention.q_attn.weight', 'transformer.h.4.ln_cross_attn.bias', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_attn.bias', 'transformer.h.5.crossattention.c_attn.weight', 'transformer.h.5.crossattention.c_proj.bias', 'transformer.h.5.crossattention.c_proj.weight', 'transformer.h.5.crossattention.q_attn.bias', 'transformer.h.5.crossattention.q_attn.weight', 'transformer.h.5.ln_cross_attn.bias', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.6.crossattention.c_attn.bias', 'transformer.h.6.crossattention.c_attn.weight', 'transformer.h.6.crossattention.c_proj.bias', 'transformer.h.6.crossattention.c_proj.weight', 'transformer.h.6.crossattention.q_attn.bias', 'transformer.h.6.crossattention.q_attn.weight', 'transformer.h.6.ln_cross_attn.bias', 'transformer.h.6.ln_cross_attn.weight', 'transformer.h.7.crossattention.c_attn.bias', 'transformer.h.7.crossattention.c_attn.weight', 'transformer.h.7.crossattention.c_proj.bias', 'transformer.h.7.crossattention.c_proj.weight', 'transformer.h.7.crossattention.q_attn.bias', 'transformer.h.7.crossattention.q_attn.weight', 'transformer.h.7.ln_cross_attn.bias', 'transformer.h.7.ln_cross_attn.weight', 'transformer.h.8.crossattention.c_attn.bias', 'transformer.h.8.crossattention.c_attn.weight', 'transformer.h.8.crossattention.c_proj.bias', 'transformer.h.8.crossattention.c_proj.weight', 'transformer.h.8.crossattention.q_attn.bias', 'transformer.h.8.crossattention.q_attn.weight', 'transformer.h.8.ln_cross_attn.bias', 'transformer.h.8.ln_cross_attn.weight', 'transformer.h.9.crossattention.c_attn.bias', 'transformer.h.9.crossattention.c_attn.weight', 'transformer.h.9.crossattention.c_proj.bias', 'transformer.h.9.crossattention.c_proj.weight', 'transformer.h.9.crossattention.q_attn.bias', 'transformer.h.9.crossattention.q_attn.weight', 'transformer.h.9.ln_cross_attn.bias', 'transformer.h.9.ln_cross_attn.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\n", " encoder_model_name,\n", " decoder_model_name,\n", " pad_token_id=trg_tokenizer.bos_token_id,\n", ")\n", "model.config.decoder_start_token_id = trg_tokenizer.bos_token_id" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "YFq2GyOAUV0W" }, "outputs": [ { "data": { "text/html": [ "Finishing last run (ID:1vwqqxps) before initializing another..." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a82aa19a250b43f28d7ecc72eeebc88d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Label(value='0.001 MB of 0.010 MB uploaded\\r'), FloatProgress(value=0.10972568578553615, max=1.…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run jbert+kogpt2 at: https://wandb.ai/sappho192/fftr-poc1/runs/1vwqqxps
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: .\\wandb\\run-20240131_135356-1vwqqxps\\logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Successfully finished last run (ID:1vwqqxps). Initializing new run:
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c2cd7f6fb5b1428b98b80a3cc82ec303", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011288888888884685, max=1.0…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in d:\\REPO\\ffxiv-ja-ko-translator\\wandb\\run-20240131_135421-etxsdxw2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run jbert+kogpt2 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/sappho192/fftr-poc1" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/sappho192/fftr-poc1/runs/etxsdxw2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# for Trainer\n", "import wandb\n", "\n", "collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)\n", "wandb.init(project=\"fftr-poc1\", name='jbert+kogpt2')\n", "\n", "arguments = Seq2SeqTrainingArguments(\n", " output_dir='dump',\n", " do_train=True,\n", " do_eval=True,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " num_train_epochs=3,\n", " # num_train_epochs=25,\n", " per_device_train_batch_size=1,\n", " # per_device_train_batch_size=30, # takes 40GB\n", " # per_device_train_batch_size=64,\n", " per_device_eval_batch_size=1,\n", " # per_device_eval_batch_size=30,\n", " # per_device_eval_batch_size=64,\n", " warmup_ratio=0.1,\n", " gradient_accumulation_steps=4,\n", " save_total_limit=5,\n", " dataloader_num_workers=1,\n", " # fp16=True, # ENABLE if CUDA is enabled\n", " load_best_model_at_end=True,\n", " report_to='wandb'\n", ")\n", "\n", "trainer = Trainer(\n", " model,\n", " arguments,\n", " data_collator=collate_fn,\n", " train_dataset=train_dataset,\n", " eval_dataset=eval_dataset\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "pPsjDHO5Vc3y" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_T4P4XunmK-C" }, "outputs": [], "source": [ "# model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"xlm-roberta-base\", \"skt/kogpt2-base-v2\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "7vTqAgW6Ve3J" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0afe460e9f614d9a90379cf99fcf8af3", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/9671328 [00:00