{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from transformers import(\n", " EncoderDecoderModel,\n", " PreTrainedTokenizerFast,\n", " # XLMRobertaTokenizerFast,\n", " BertJapaneseTokenizer,\n", " BertTokenizerFast,\n", ")\n", "\n", "import torch\n", "import csv" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. \n", "The class this function is called from is 'PreTrainedTokenizerFast'.\n" ] } ], "source": [ "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n", "decoder_model_name = \"skt/kogpt2-base-v2\"\n", "\n", "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n", "trg_tokenizer = PreTrainedTokenizerFast.from_pretrained(decoder_model_name)\n", "model = EncoderDecoderModel.from_pretrained(\"./dump/best_model\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'길가메시 토벌전'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = \"ギルガメッシュ討伐戦\"\n", "# text = \"ギルガメッシュ討伐戦に行ってきます。一緒に行きましょうか?\"\n", "\n", "def translate(text_src):\n", " embeddings = src_tokenizer(text_src, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')\n", " embeddings = {k: v for k, v in embeddings.items()}\n", " output = model.generate(**embeddings)[0, 1:-1]\n", " text_trg = trg_tokenizer.decode(output.cpu())\n", " return text_trg\n", "\n", "print(translate(text))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n", "smoothie = SmoothingFunction().method4" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Testing: 0%| | 0/267 [00:00