{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# How to use OpenNMT-py as a Library" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The example notebook (available [here](https://github.com/OpenNMT/OpenNMT-py/blob/master/docs/source/examples/Library.ipynb)) should be able to run as a standalone execution, provided `onmt` is in the path (installed via `pip` for instance).\n", "\n", "Some parts may not be 100% 'library-friendly' but it's mostly workable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import a few modules and functions that will be necessary" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import yaml\n", "import torch\n", "import torch.nn as nn\n", "from argparse import Namespace\n", "from collections import defaultdict, Counter" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import onmt\n", "from onmt.inputters.inputter import _load_vocab, _build_fields_vocab, get_fields, IterOnDevice\n", "from onmt.inputters.corpus import ParallelCorpus\n", "from onmt.inputters.dynamic_iterator import DynamicDatasetIter\n", "from onmt.translate import GNMTGlobalScorer, Translator, TranslationBuilder\n", "from onmt.utils.misc import set_random_seed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Enable logging" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# enable logging\n", "from onmt.utils.logging import init_logger, logger\n", "init_logger()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set random seed" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "is_cuda = torch.cuda.is_available()\n", "set_random_seed(1111, is_cuda)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Retrieve data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To make a proper example, we will need some data, as well as some vocabulary(ies).\n", "\n", "Let's take the same data as in the [quickstart](https://opennmt.net/OpenNMT-py/quickstart.html):" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2020-09-25 15:28:05-- https://s3.amazonaws.com/opennmt-trainingdata/toy-ende.tar.gz\n", "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.18.38\n", "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.18.38|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 1662081 (1,6M) [application/x-gzip]\n", "Saving to: ‘toy-ende.tar.gz.5’\n", "\n", "toy-ende.tar.gz.5 100%[===================>] 1,58M 2,33MB/s in 0,7s \n", "\n", "2020-09-25 15:28:07 (2,33 MB/s) - ‘toy-ende.tar.gz.5’ saved [1662081/1662081]\n", "\n" ] } ], "source": [ "!wget https://s3.amazonaws.com/opennmt-trainingdata/toy-ende.tar.gz" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "!tar xf toy-ende.tar.gz" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "config.yaml src-test.txt src-val.txt tgt-train.txt\r\n", "\u001b[0m\u001b[01;34mrun\u001b[0m/ src-train.txt tgt-test.txt tgt-val.txt\r\n" ] } ], "source": [ "ls toy-ende" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare data and vocab" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As for any use case of OpenNMT-py 2.0, we can start by creating a simple YAML configuration with our datasets. This is the easiest way to build the proper `opts` `Namespace` that will be used to create the vocabulary(ies)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "yaml_config = \"\"\"\n", "## Where the vocab(s) will be written\n", "save_data: toy-ende/run/example\n", "src_vocab: toy-ende/run/example.vocab.src\n", "tgt_vocab: toy-ende/run/example.vocab.tgt\n", "# Corpus opts:\n", "data:\n", " corpus:\n", " path_src: toy-ende/src-train.txt\n", " path_tgt: toy-ende/tgt-train.txt\n", " transforms: []\n", " weight: 1\n", " valid:\n", " path_src: toy-ende/src-val.txt\n", " path_tgt: toy-ende/tgt-val.txt\n", " transforms: []\n", "\"\"\"\n", "config = yaml.safe_load(yaml_config)\n", "with open(\"toy-ende/config.yaml\", \"w\") as f:\n", " f.write(yaml_config)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from onmt.utils.parse import ArgumentParser\n", "parser = ArgumentParser(description='build_vocab.py')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from onmt.opts import dynamic_prepare_opts\n", "dynamic_prepare_opts(parser, build_vocab_only=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "base_args = ([\"-config\", \"toy-ende/config.yaml\", \"-n_sample\", \"10000\"])\n", "opts, unknown = parser.parse_known_args(base_args)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Namespace(config='toy-ende/config.yaml', data=\"{'corpus': {'path_src': 'toy-ende/src-train.txt', 'path_tgt': 'toy-ende/tgt-train.txt', 'transforms': [], 'weight': 1}, 'valid': {'path_src': 'toy-ende/src-val.txt', 'path_tgt': 'toy-ende/tgt-val.txt', 'transforms': []}}\", insert_ratio=0.0, mask_length='subword', mask_ratio=0.0, n_sample=10000, src_onmttok_kwargs=\"{'mode': 'none'}\", tgt_onmttok_kwargs=\"{'mode': 'none'}\", overwrite=False, permute_sent_ratio=0.0, poisson_lambda=0.0, random_ratio=0.0, replace_length=-1, rotate_ratio=0.5, save_config=None, save_data='toy-ende/run/example', seed=-1, share_vocab=False, skip_empty_level='warning', src_seq_length=200, src_subword_model=None, src_subword_type='none', src_vocab=None, subword_alpha=0, subword_nbest=1, switchout_temperature=1.0, tgt_seq_length=200, tgt_subword_model=None, tgt_subword_type='none', tgt_vocab=None, tokendrop_temperature=1.0, tokenmask_temperature=1.0, transforms=[])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "opts" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2020-09-25 15:28:08,068 INFO] Parsed 2 corpora from -data.\n", "[2020-09-25 15:28:08,069 INFO] Counter vocab from 10000 samples.\n", "[2020-09-25 15:28:08,070 INFO] Save 10000 transformed example/corpus.\n", "[2020-09-25 15:28:08,070 INFO] corpus's transforms: TransformPipe()\n", "[2020-09-25 15:28:08,101 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:28:08,320 INFO] Just finished the first loop\n", "[2020-09-25 15:28:08,320 INFO] Counters src:24995\n", "[2020-09-25 15:28:08,321 INFO] Counters tgt:35816\n" ] } ], "source": [ "from onmt.bin.build_vocab import build_vocab_main\n", "build_vocab_main(opts)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "example.vocab.src example.vocab.tgt \u001b[0m\u001b[01;34msample\u001b[0m/\r\n" ] } ], "source": [ "ls toy-ende/run" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We just created our source and target vocabularies, respectively `toy-ende/run/example.vocab.src` and `toy-ende/run/example.vocab.tgt`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build fields" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can build the fields from the text files that were just created." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "src_vocab_path = \"toy-ende/run/example.vocab.src\"\n", "tgt_vocab_path = \"toy-ende/run/example.vocab.tgt\"" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2020-09-25 15:28:08,495 INFO] Loading src vocabulary from toy-ende/run/example.vocab.src\n", "[2020-09-25 15:28:08,554 INFO] Loaded src vocab has 24995 tokens.\n", "[2020-09-25 15:28:08,562 INFO] Loading tgt vocabulary from toy-ende/run/example.vocab.tgt\n", "[2020-09-25 15:28:08,617 INFO] Loaded tgt vocab has 35816 tokens.\n" ] } ], "source": [ "# initialize the frequency counter\n", "counters = defaultdict(Counter)\n", "# load source vocab\n", "_src_vocab, _src_vocab_size = _load_vocab(\n", " src_vocab_path,\n", " 'src',\n", " counters)\n", "# load target vocab\n", "_tgt_vocab, _tgt_vocab_size = _load_vocab(\n", " tgt_vocab_path,\n", " 'tgt',\n", " counters)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# initialize fields\n", "src_nfeats, tgt_nfeats = 0, 0 # do not support word features for now\n", "fields = get_fields(\n", " 'text', src_nfeats, tgt_nfeats)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'src': ,\n", " 'tgt': ,\n", " 'indices': }" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fields" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2020-09-25 15:28:08,699 INFO] * tgt vocab size: 30004.\n", "[2020-09-25 15:28:08,749 INFO] * src vocab size: 24997.\n" ] } ], "source": [ "# build fields vocab\n", "share_vocab = False\n", "vocab_size_multiple = 1\n", "src_vocab_size = 30000\n", "tgt_vocab_size = 30000\n", "src_words_min_frequency = 1\n", "tgt_words_min_frequency = 1\n", "vocab_fields = _build_fields_vocab(\n", " fields, counters, 'text', share_vocab,\n", " vocab_size_multiple,\n", " src_vocab_size, src_words_min_frequency,\n", " tgt_vocab_size, tgt_words_min_frequency)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An alternative way of creating these fields is to run `onmt_train` without actually training, to just output the necessary files." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare for training: model and optimizer creation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's get a few fields/vocab related variables to simplify the model creation a bit:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "src_text_field = vocab_fields[\"src\"].base_field\n", "src_vocab = src_text_field.vocab\n", "src_padding = src_vocab.stoi[src_text_field.pad_token]\n", "\n", "tgt_text_field = vocab_fields['tgt'].base_field\n", "tgt_vocab = tgt_text_field.vocab\n", "tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we specify the core model itself. Here we will build a small model with an encoder and an attention based input feeding decoder. Both models will be RNNs and the encoder will be bidirectional" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "emb_size = 100\n", "rnn_size = 500\n", "# Specify the core model.\n", "\n", "encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),\n", " word_padding_idx=src_padding)\n", "\n", "encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,\n", " rnn_type=\"LSTM\", bidirectional=True,\n", " embeddings=encoder_embeddings)\n", "\n", "decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),\n", " word_padding_idx=tgt_padding)\n", "decoder = onmt.decoders.decoder.InputFeedRNNDecoder(\n", " hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True, \n", " rnn_type=\"LSTM\", embeddings=decoder_embeddings)\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "model = onmt.models.model.NMTModel(encoder, decoder)\n", "model.to(device)\n", "\n", "# Specify the tgt word generator and loss computation module\n", "model.generator = nn.Sequential(\n", " nn.Linear(rnn_size, len(tgt_vocab)),\n", " nn.LogSoftmax(dim=-1)).to(device)\n", "\n", "loss = onmt.utils.loss.NMTLossCompute(\n", " criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction=\"sum\"),\n", " generator=model.generator)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we set up the optimizer. This could be a core torch optim class, or our wrapper which handles learning rate updates and gradient normalization automatically." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "lr = 1\n", "torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", "optim = onmt.utils.optimizers.Optimizer(\n", " torch_optimizer, learning_rate=lr, max_grad_norm=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the training and validation data iterators" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we need to create the dynamic dataset iterator.\n", "\n", "This is not very 'library-friendly' for now because of the way the `DynamicDatasetIter` constructor is defined. It may evolve in the future." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "src_train = \"toy-ende/src-train.txt\"\n", "tgt_train = \"toy-ende/tgt-train.txt\"\n", "src_val = \"toy-ende/src-val.txt\"\n", "tgt_val = \"toy-ende/tgt-val.txt\"\n", "\n", "# build the ParallelCorpus\n", "corpus = ParallelCorpus(\"corpus\", src_train, tgt_train)\n", "valid = ParallelCorpus(\"valid\", src_val, tgt_val)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# build the training iterator\n", "train_iter = DynamicDatasetIter(\n", " corpora={\"corpus\": corpus},\n", " corpora_info={\"corpus\": {\"weight\": 1}},\n", " transforms={},\n", " fields=vocab_fields,\n", " is_train=True,\n", " batch_type=\"tokens\",\n", " batch_size=4096,\n", " batch_size_multiple=1,\n", " data_type=\"text\")" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# make sure the iteration happens on GPU 0 (-1 for CPU, N for GPU N)\n", "train_iter = iter(IterOnDevice(train_iter, 0))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# build the validation iterator\n", "valid_iter = DynamicDatasetIter(\n", " corpora={\"valid\": valid},\n", " corpora_info={\"valid\": {\"weight\": 1}},\n", " transforms={},\n", " fields=vocab_fields,\n", " is_train=False,\n", " batch_type=\"sents\",\n", " batch_size=8,\n", " batch_size_multiple=1,\n", " data_type=\"text\")" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "valid_iter = IterOnDevice(valid_iter, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally we train." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2020-09-25 15:28:15,184 INFO] Start training loop and validate every 500 steps...\n", "[2020-09-25 15:28:15,185 INFO] corpus's transforms: TransformPipe()\n", "[2020-09-25 15:28:15,187 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:28:21,140 INFO] Step 50/ 1000; acc: 7.52; ppl: 8832.29; xent: 9.09; lr: 1.00000; 18916/18871 tok/s; 6 sec\n", "[2020-09-25 15:28:24,869 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:28:27,121 INFO] Step 100/ 1000; acc: 9.34; ppl: 1840.06; xent: 7.52; lr: 1.00000; 18911/18785 tok/s; 12 sec\n", "[2020-09-25 15:28:33,048 INFO] Step 150/ 1000; acc: 10.35; ppl: 1419.18; xent: 7.26; lr: 1.00000; 19062/19017 tok/s; 18 sec\n", "[2020-09-25 15:28:37,019 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:28:39,022 INFO] Step 200/ 1000; acc: 11.14; ppl: 1127.44; xent: 7.03; lr: 1.00000; 19084/18911 tok/s; 24 sec\n", "[2020-09-25 15:28:45,073 INFO] Step 250/ 1000; acc: 12.46; ppl: 912.13; xent: 6.82; lr: 1.00000; 18575/18570 tok/s; 30 sec\n", "[2020-09-25 15:28:49,301 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:28:51,151 INFO] Step 300/ 1000; acc: 13.04; ppl: 779.50; xent: 6.66; lr: 1.00000; 18394/18307 tok/s; 36 sec\n", "[2020-09-25 15:28:57,316 INFO] Step 350/ 1000; acc: 14.04; ppl: 685.48; xent: 6.53; lr: 1.00000; 18339/18173 tok/s; 42 sec\n", "[2020-09-25 15:29:02,117 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:29:03,576 INFO] Step 400/ 1000; acc: 14.99; ppl: 590.20; xent: 6.38; lr: 1.00000; 18090/18029 tok/s; 48 sec\n", "[2020-09-25 15:29:09,546 INFO] Step 450/ 1000; acc: 16.00; ppl: 524.51; xent: 6.26; lr: 1.00000; 18726/18536 tok/s; 54 sec\n", "[2020-09-25 15:29:14,585 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:29:15,596 INFO] Step 500/ 1000; acc: 16.78; ppl: 453.38; xent: 6.12; lr: 1.00000; 17877/17980 tok/s; 60 sec\n", "[2020-09-25 15:29:15,597 INFO] valid's transforms: TransformPipe()\n", "[2020-09-25 15:29:15,599 INFO] Loading ParallelCorpus(toy-ende/src-val.txt, toy-ende/tgt-val.txt, align=None)...\n", "[2020-09-25 15:29:24,528 INFO] Validation perplexity: 295.1\n", "[2020-09-25 15:29:24,529 INFO] Validation accuracy: 17.6533\n", "[2020-09-25 15:29:30,592 INFO] Step 550/ 1000; acc: 17.47; ppl: 421.26; xent: 6.04; lr: 1.00000; 7726/7610 tok/s; 75 sec\n", "[2020-09-25 15:29:36,055 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:29:36,695 INFO] Step 600/ 1000; acc: 18.95; ppl: 354.44; xent: 5.87; lr: 1.00000; 17470/17598 tok/s; 82 sec\n", "[2020-09-25 15:29:42,794 INFO] Step 650/ 1000; acc: 19.60; ppl: 328.47; xent: 5.79; lr: 1.00000; 18994/18793 tok/s; 88 sec\n", "[2020-09-25 15:29:48,635 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:29:48,924 INFO] Step 700/ 1000; acc: 20.57; ppl: 285.48; xent: 5.65; lr: 1.00000; 17856/17788 tok/s; 94 sec\n", "[2020-09-25 15:29:54,898 INFO] Step 750/ 1000; acc: 21.97; ppl: 249.06; xent: 5.52; lr: 1.00000; 19030/18924 tok/s; 100 sec\n", "[2020-09-25 15:30:01,233 INFO] Step 800/ 1000; acc: 22.66; ppl: 228.54; xent: 5.43; lr: 1.00000; 17571/17471 tok/s; 106 sec\n", "[2020-09-25 15:30:01,357 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:30:07,345 INFO] Step 850/ 1000; acc: 24.32; ppl: 193.65; xent: 5.27; lr: 1.00000; 18344/18313 tok/s; 112 sec\n", "[2020-09-25 15:30:11,363 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:30:13,487 INFO] Step 900/ 1000; acc: 24.93; ppl: 177.65; xent: 5.18; lr: 1.00000; 18293/18105 tok/s; 118 sec\n", "[2020-09-25 15:30:19,670 INFO] Step 950/ 1000; acc: 26.33; ppl: 157.10; xent: 5.06; lr: 1.00000; 17791/17746 tok/s; 124 sec\n", "[2020-09-25 15:30:24,072 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n", "[2020-09-25 15:30:25,820 INFO] Step 1000/ 1000; acc: 27.47; ppl: 137.64; xent: 4.92; lr: 1.00000; 17942/17962 tok/s; 131 sec\n", "[2020-09-25 15:30:25,822 INFO] Loading ParallelCorpus(toy-ende/src-val.txt, toy-ende/tgt-val.txt, align=None)...\n", "[2020-09-25 15:30:34,665 INFO] Validation perplexity: 241.801\n", "[2020-09-25 15:30:34,666 INFO] Validation accuracy: 20.2837\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "report_manager = onmt.utils.ReportMgr(\n", " report_every=50, start_time=None, tensorboard_writer=None)\n", "\n", "trainer = onmt.Trainer(model=model,\n", " train_loss=loss,\n", " valid_loss=loss,\n", " optim=optim,\n", " report_manager=report_manager,\n", " dropout=[0.1])\n", "\n", "trainer.train(train_iter=train_iter,\n", " train_steps=1000,\n", " valid_iter=valid_iter,\n", " valid_steps=500)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Translate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For translation, we can build a \"traditional\" (as opposed to dynamic) dataset for now." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "src_data = {\"reader\": onmt.inputters.str2reader[\"text\"](), \"data\": src_val}\n", "tgt_data = {\"reader\": onmt.inputters.str2reader[\"text\"](), \"data\": tgt_val}\n", "_readers, _data = onmt.inputters.Dataset.config(\n", " [('src', src_data), ('tgt', tgt_data)])" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "dataset = onmt.inputters.Dataset(\n", " vocab_fields, readers=_readers, data=_data,\n", " sort_key=onmt.inputters.str2sortkey[\"text\"])" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "data_iter = onmt.inputters.OrderedIterator(\n", " dataset=dataset,\n", " device=\"cuda\",\n", " batch_size=10,\n", " train=False,\n", " sort=False,\n", " sort_within_batch=True,\n", " shuffle=False\n", " )" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "src_reader = onmt.inputters.str2reader[\"text\"]\n", "tgt_reader = onmt.inputters.str2reader[\"text\"]\n", "scorer = GNMTGlobalScorer(alpha=0.7, \n", " beta=0., \n", " length_penalty=\"avg\", \n", " coverage_penalty=\"none\")\n", "gpu = 0 if torch.cuda.is_available() else -1\n", "translator = Translator(model=model, \n", " fields=vocab_fields, \n", " src_reader=src_reader, \n", " tgt_reader=tgt_reader, \n", " global_scorer=scorer,\n", " gpu=gpu)\n", "builder = onmt.translate.TranslationBuilder(data=dataset, \n", " fields=vocab_fields)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Note**: translations will be very poor, because of the very low quantity of data, the absence of proper tokenization, and the brevity of the training." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "SENT 0: ['Parliament', 'Does', 'Not', 'Support', 'Amendment', 'Freeing', 'Tymoshenko']\n", "PRED 0: Parlament das Parlament über die Europäische Parlament , die sich in der Lage in der Lage ist , die es in der Lage sind .\n", "PRED SCORE: -1.5935\n", "\n", "\n", "SENT 0: ['Today', ',', 'the', 'Ukraine', 'parliament', 'dismissed', ',', 'within', 'the', 'Code', 'of', 'Criminal', 'Procedure', 'amendment', ',', 'the', 'motion', 'to', 'revoke', 'an', 'article', 'based', 'on', 'which', 'the', 'opposition', 'leader', ',', 'Yulia', 'Tymoshenko', ',', 'was', 'sentenced', '.']\n", "PRED 0: In der Nähe des Hotels , die in der Lage , die sich in der Lage ist , in der Lage , die in der Lage , die in der Lage ist .\n", "PRED SCORE: -1.7173\n", "\n", "\n", "SENT 0: ['The', 'amendment', 'that', 'would', 'lead', 'to', 'freeing', 'the', 'imprisoned', 'former', 'Prime', 'Minister', 'was', 'revoked', 'during', 'second', 'reading', 'of', 'the', 'proposal', 'for', 'mitigation', 'of', 'sentences', 'for', 'economic', 'offences', '.']\n", "PRED 0: Die Tatsache , die sich in der Lage in der Lage ist , die für eine Antwort der Entwicklung für die Entwicklung von Präsident .\n", "PRED SCORE: -1.6834\n", "\n", "\n", "SENT 0: ['In', 'October', ',', 'Tymoshenko', 'was', 'sentenced', 'to', 'seven', 'years', 'in', 'prison', 'for', 'entering', 'into', 'what', 'was', 'reported', 'to', 'be', 'a', 'disadvantageous', 'gas', 'deal', 'with', 'Russia', '.']\n", "PRED 0: In der Nähe wurde die Menschen in der Lage ist , die sich in der Lage .\n", "PRED SCORE: -1.5765\n", "\n", "\n", "SENT 0: ['The', 'verdict', 'is', 'not', 'yet', 'final;', 'the', 'court', 'will', 'hear', 'Tymoshenko', ''s', 'appeal', 'in', 'December', '.']\n", "PRED 0: Es ist nicht der Fall , die in der Lage in der Lage sind .\n", "PRED SCORE: -1.3287\n", "\n", "\n", "SENT 0: ['Tymoshenko', 'claims', 'the', 'verdict', 'is', 'a', 'political', 'revenge', 'of', 'the', 'regime;', 'in', 'the', 'West', ',', 'the', 'trial', 'has', 'also', 'evoked', 'suspicion', 'of', 'being', 'biased', '.']\n", "PRED 0: Um in der Lage ist auch eine Lösung Rolle .\n", "PRED SCORE: -1.3975\n", "\n", "\n", "SENT 0: ['The', 'proposal', 'to', 'remove', 'Article', '365', 'from', 'the', 'Code', 'of', 'Criminal', 'Procedure', ',', 'upon', 'which', 'the', 'former', 'Prime', 'Minister', 'was', 'sentenced', ',', 'was', 'supported', 'by', '147', 'members', 'of', 'parliament', '.']\n", "PRED 0: Der Vorschlag , die in der Lage , die in der Lage , die in der Lage ist , war er von der Fall wurde .\n", "PRED SCORE: -1.6062\n", "\n", "\n", "SENT 0: ['Its', 'ratification', 'would', 'require', '226', 'votes', '.']\n", "PRED 0: Es wäre noch einmal noch einmal .\n", "PRED SCORE: -1.8001\n", "\n", "\n", "SENT 0: ['Libya', ''s', 'Victory']\n", "PRED 0: In der Nähe des Hotels befindet sich in der Nähe des Hotels in der Lage .\n", "PRED SCORE: -1.7097\n", "\n", "\n", "SENT 0: ['The', 'story', 'of', 'Libya', ''s', 'liberation', ',', 'or', 'rebellion', ',', 'already', 'has', 'its', 'defeated', '.']\n", "PRED 0: In der Nähe des Hotels in der Lage ist in der Lage .\n", "PRED SCORE: -1.7885\n", "\n" ] } ], "source": [ "for batch in data_iter:\n", " trans_batch = translator.translate_batch(\n", " batch=batch, src_vocabs=[src_vocab],\n", " attn_debug=False)\n", " translations = builder.from_batch(trans_batch)\n", " for trans in translations:\n", " print(trans.log(0))\n", " break" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }