{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The primary codes below are based on [akpe12/JP-KR-ocr-translator-for-travel](https://github.com/akpe12/JP-KR-ocr-translator-for-travel)." ] }, { "cell_type": "markdown", "metadata": { "id": "TrHlPFqwFAgj" }, "source": [ "## Import" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t-jXeSJKE1WM" }, "outputs": [], "source": [ "\n", "from typing import Dict, List\n", "import csv\n", "import torch\n", "from transformers import (\n", " EncoderDecoderModel,\n", " GPT2Tokenizer as BaseGPT2Tokenizer,\n", " PreTrainedTokenizer, BertTokenizerFast,\n", " PreTrainedTokenizerFast,\n", " DataCollatorForSeq2Seq,\n", " Seq2SeqTrainingArguments,\n", " AutoTokenizer,\n", " XLMRobertaTokenizerFast,\n", " BertJapaneseTokenizer,\n", " Trainer\n", ")\n", "from torch.utils.data import DataLoader\n", "from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel\n", "\n", "# encoder_model_name = \"xlm-roberta-base\"\n", "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n", "decoder_model_name = \"skt/kogpt2-base-v2\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nEW5trBtbykK" }, "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "# device = torch.device(\"cpu\")\n", "device, torch.cuda.device_count()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5ic7pUUBFU_v" }, "outputs": [], "source": [ "class GPT2Tokenizer(PreTrainedTokenizerFast):\n", " def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:\n", " return token_ids + [self.eos_token_id] \n", "\n", "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n", "trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, bos_token='', eos_token='', unk_token='',\n", " pad_token='', mask_token='')" ] }, { "cell_type": "markdown", "metadata": { "id": "DTf4U1fmFQFh" }, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "65L4O1c5FLKt" }, "outputs": [], "source": [ "class PairedDataset:\n", " def __init__(self, \n", " src_tokenizer: PreTrainedTokenizerFast, tgt_tokenizer: PreTrainedTokenizerFast,\n", " file_path: str\n", " ):\n", " self.src_tokenizer = src_tokenizer\n", " self.trg_tokenizer = tgt_tokenizer\n", " with open(file_path, 'r') as fd:\n", " reader = csv.reader(fd)\n", " next(reader)\n", " self.data = [row for row in reader]\n", "\n", " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n", " src, trg = self.data[index]\n", " embeddings = self.src_tokenizer(src, return_attention_mask=False, return_token_type_ids=False)\n", " embeddings['labels'] = self.trg_tokenizer.build_inputs_with_special_tokens(self.trg_tokenizer(trg, return_attention_mask=False)['input_ids'])\n", "\n", " return embeddings\n", "\n", " def __len__(self):\n", " return len(self.data)\n", " \n", "DATA_ROOT = './output'\n", "FILE_FFAC_FULL = 'ffac_full.csv'\n", "FILE_FFAC_TEST = 'ffac_test.csv'\n", "# FILE_JA_KO_TRAIN = 'ja_ko_train.csv'\n", "# FILE_JA_KO_TEST = 'ja_ko_test.csv'\n", "\n", "train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_FFAC_FULL}')\n", "eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_FFAC_TEST}') \n", "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_JA_KO_TRAIN}')\n", "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_JA_KO_TEST}') " ] }, { "cell_type": "markdown", "metadata": { "id": "uCBiLouSFiZY" }, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I7uFbFYJFje8" }, "outputs": [], "source": [ "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\n", " encoder_model_name,\n", " decoder_model_name,\n", " pad_token_id=trg_tokenizer.bos_token_id,\n", ")\n", "model.config.decoder_start_token_id = trg_tokenizer.bos_token_id" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YFq2GyOAUV0W" }, "outputs": [], "source": [ "# for Trainer\n", "import wandb\n", "\n", "collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)\n", "wandb.init(project=\"fftr-poc1\", name='jbert+kogpt2')\n", "\n", "arguments = Seq2SeqTrainingArguments(\n", " output_dir='dump',\n", " do_train=True,\n", " do_eval=True,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", "# num_train_epochs=5,\n", " num_train_epochs=25,\n", "# per_device_train_batch_size=32,\n", " per_device_train_batch_size=64,\n", "# per_device_eval_batch_size=32,\n", " per_device_eval_batch_size=64,\n", " warmup_ratio=0.1,\n", " gradient_accumulation_steps=4,\n", " save_total_limit=5,\n", " dataloader_num_workers=1,\n", " fp16=True,\n", " load_best_model_at_end=True,\n", " report_to='wandb'\n", ")\n", "\n", "trainer = Trainer(\n", " model,\n", " arguments,\n", " data_collator=collate_fn,\n", " train_dataset=train_dataset,\n", " eval_dataset=eval_dataset\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "pPsjDHO5Vc3y" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_T4P4XunmK-C" }, "outputs": [], "source": [ "# model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"xlm-roberta-base\", \"skt/kogpt2-base-v2\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7vTqAgW6Ve3J" }, "outputs": [], "source": [ "trainer.train()\n", "\n", "model.save_pretrained(\"dump/best_model\")" ] } ], "metadata": { "colab": { "machine_shape": "hm", "provenance": [] }, "gpuClass": "premium", "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 }