File size: 30,791 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# How to use OpenNMT-py as a Library"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The example notebook (available [here](https://github.com/OpenNMT/OpenNMT-py/blob/master/docs/source/examples/Library.ipynb)) should be able to run as a standalone execution, provided `onmt` is in the path (installed via `pip` for instance).\n",
"\n",
"Some parts may not be 100% 'library-friendly' but it's mostly workable."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import a few modules and functions that will be necessary"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"import torch\n",
"import torch.nn as nn\n",
"from argparse import Namespace\n",
"from collections import defaultdict, Counter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import onmt\n",
"from onmt.inputters.inputter import _load_vocab, _build_fields_vocab, get_fields, IterOnDevice\n",
"from onmt.inputters.corpus import ParallelCorpus\n",
"from onmt.inputters.dynamic_iterator import DynamicDatasetIter\n",
"from onmt.translate import GNMTGlobalScorer, Translator, TranslationBuilder\n",
"from onmt.utils.misc import set_random_seed"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Enable logging"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<RootLogger root (INFO)>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# enable logging\n",
"from onmt.utils.logging import init_logger, logger\n",
"init_logger()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set random seed"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"is_cuda = torch.cuda.is_available()\n",
"set_random_seed(1111, is_cuda)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Retrieve data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To make a proper example, we will need some data, as well as some vocabulary(ies).\n",
"\n",
"Let's take the same data as in the [quickstart](https://opennmt.net/OpenNMT-py/quickstart.html):"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2020-09-25 15:28:05-- https://s3.amazonaws.com/opennmt-trainingdata/toy-ende.tar.gz\n",
"Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.18.38\n",
"Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.18.38|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1662081 (1,6M) [application/x-gzip]\n",
"Saving to: ‘toy-ende.tar.gz.5’\n",
"\n",
"toy-ende.tar.gz.5 100%[===================>] 1,58M 2,33MB/s in 0,7s \n",
"\n",
"2020-09-25 15:28:07 (2,33 MB/s) - ‘toy-ende.tar.gz.5’ saved [1662081/1662081]\n",
"\n"
]
}
],
"source": [
"!wget https://s3.amazonaws.com/opennmt-trainingdata/toy-ende.tar.gz"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"!tar xf toy-ende.tar.gz"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"config.yaml src-test.txt src-val.txt tgt-train.txt\r\n",
"\u001b[0m\u001b[01;34mrun\u001b[0m/ src-train.txt tgt-test.txt tgt-val.txt\r\n"
]
}
],
"source": [
"ls toy-ende"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare data and vocab"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As for any use case of OpenNMT-py 2.0, we can start by creating a simple YAML configuration with our datasets. This is the easiest way to build the proper `opts` `Namespace` that will be used to create the vocabulary(ies)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"yaml_config = \"\"\"\n",
"## Where the vocab(s) will be written\n",
"save_data: toy-ende/run/example\n",
"src_vocab: toy-ende/run/example.vocab.src\n",
"tgt_vocab: toy-ende/run/example.vocab.tgt\n",
"# Corpus opts:\n",
"data:\n",
" corpus:\n",
" path_src: toy-ende/src-train.txt\n",
" path_tgt: toy-ende/tgt-train.txt\n",
" transforms: []\n",
" weight: 1\n",
" valid:\n",
" path_src: toy-ende/src-val.txt\n",
" path_tgt: toy-ende/tgt-val.txt\n",
" transforms: []\n",
"\"\"\"\n",
"config = yaml.safe_load(yaml_config)\n",
"with open(\"toy-ende/config.yaml\", \"w\") as f:\n",
" f.write(yaml_config)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from onmt.utils.parse import ArgumentParser\n",
"parser = ArgumentParser(description='build_vocab.py')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from onmt.opts import dynamic_prepare_opts\n",
"dynamic_prepare_opts(parser, build_vocab_only=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"base_args = ([\"-config\", \"toy-ende/config.yaml\", \"-n_sample\", \"10000\"])\n",
"opts, unknown = parser.parse_known_args(base_args)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Namespace(config='toy-ende/config.yaml', data=\"{'corpus': {'path_src': 'toy-ende/src-train.txt', 'path_tgt': 'toy-ende/tgt-train.txt', 'transforms': [], 'weight': 1}, 'valid': {'path_src': 'toy-ende/src-val.txt', 'path_tgt': 'toy-ende/tgt-val.txt', 'transforms': []}}\", insert_ratio=0.0, mask_length='subword', mask_ratio=0.0, n_sample=10000, src_onmttok_kwargs=\"{'mode': 'none'}\", tgt_onmttok_kwargs=\"{'mode': 'none'}\", overwrite=False, permute_sent_ratio=0.0, poisson_lambda=0.0, random_ratio=0.0, replace_length=-1, rotate_ratio=0.5, save_config=None, save_data='toy-ende/run/example', seed=-1, share_vocab=False, skip_empty_level='warning', src_seq_length=200, src_subword_model=None, src_subword_type='none', src_vocab=None, subword_alpha=0, subword_nbest=1, switchout_temperature=1.0, tgt_seq_length=200, tgt_subword_model=None, tgt_subword_type='none', tgt_vocab=None, tokendrop_temperature=1.0, tokenmask_temperature=1.0, transforms=[])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"opts"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2020-09-25 15:28:08,068 INFO] Parsed 2 corpora from -data.\n",
"[2020-09-25 15:28:08,069 INFO] Counter vocab from 10000 samples.\n",
"[2020-09-25 15:28:08,070 INFO] Save 10000 transformed example/corpus.\n",
"[2020-09-25 15:28:08,070 INFO] corpus's transforms: TransformPipe()\n",
"[2020-09-25 15:28:08,101 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:28:08,320 INFO] Just finished the first loop\n",
"[2020-09-25 15:28:08,320 INFO] Counters src:24995\n",
"[2020-09-25 15:28:08,321 INFO] Counters tgt:35816\n"
]
}
],
"source": [
"from onmt.bin.build_vocab import build_vocab_main\n",
"build_vocab_main(opts)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"example.vocab.src example.vocab.tgt \u001b[0m\u001b[01;34msample\u001b[0m/\r\n"
]
}
],
"source": [
"ls toy-ende/run"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We just created our source and target vocabularies, respectively `toy-ende/run/example.vocab.src` and `toy-ende/run/example.vocab.tgt`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build fields"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can build the fields from the text files that were just created."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"src_vocab_path = \"toy-ende/run/example.vocab.src\"\n",
"tgt_vocab_path = \"toy-ende/run/example.vocab.tgt\""
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2020-09-25 15:28:08,495 INFO] Loading src vocabulary from toy-ende/run/example.vocab.src\n",
"[2020-09-25 15:28:08,554 INFO] Loaded src vocab has 24995 tokens.\n",
"[2020-09-25 15:28:08,562 INFO] Loading tgt vocabulary from toy-ende/run/example.vocab.tgt\n",
"[2020-09-25 15:28:08,617 INFO] Loaded tgt vocab has 35816 tokens.\n"
]
}
],
"source": [
"# initialize the frequency counter\n",
"counters = defaultdict(Counter)\n",
"# load source vocab\n",
"_src_vocab, _src_vocab_size = _load_vocab(\n",
" src_vocab_path,\n",
" 'src',\n",
" counters)\n",
"# load target vocab\n",
"_tgt_vocab, _tgt_vocab_size = _load_vocab(\n",
" tgt_vocab_path,\n",
" 'tgt',\n",
" counters)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# initialize fields\n",
"src_nfeats, tgt_nfeats = 0, 0 # do not support word features for now\n",
"fields = get_fields(\n",
" 'text', src_nfeats, tgt_nfeats)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'src': <onmt.inputters.text_dataset.TextMultiField at 0x7fca93802c50>,\n",
" 'tgt': <onmt.inputters.text_dataset.TextMultiField at 0x7fca93802f60>,\n",
" 'indices': <torchtext.data.field.Field at 0x7fca93802940>}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fields"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2020-09-25 15:28:08,699 INFO] * tgt vocab size: 30004.\n",
"[2020-09-25 15:28:08,749 INFO] * src vocab size: 24997.\n"
]
}
],
"source": [
"# build fields vocab\n",
"share_vocab = False\n",
"vocab_size_multiple = 1\n",
"src_vocab_size = 30000\n",
"tgt_vocab_size = 30000\n",
"src_words_min_frequency = 1\n",
"tgt_words_min_frequency = 1\n",
"vocab_fields = _build_fields_vocab(\n",
" fields, counters, 'text', share_vocab,\n",
" vocab_size_multiple,\n",
" src_vocab_size, src_words_min_frequency,\n",
" tgt_vocab_size, tgt_words_min_frequency)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An alternative way of creating these fields is to run `onmt_train` without actually training, to just output the necessary files."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare for training: model and optimizer creation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's get a few fields/vocab related variables to simplify the model creation a bit:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"src_text_field = vocab_fields[\"src\"].base_field\n",
"src_vocab = src_text_field.vocab\n",
"src_padding = src_vocab.stoi[src_text_field.pad_token]\n",
"\n",
"tgt_text_field = vocab_fields['tgt'].base_field\n",
"tgt_vocab = tgt_text_field.vocab\n",
"tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we specify the core model itself. Here we will build a small model with an encoder and an attention based input feeding decoder. Both models will be RNNs and the encoder will be bidirectional"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"emb_size = 100\n",
"rnn_size = 500\n",
"# Specify the core model.\n",
"\n",
"encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),\n",
" word_padding_idx=src_padding)\n",
"\n",
"encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,\n",
" rnn_type=\"LSTM\", bidirectional=True,\n",
" embeddings=encoder_embeddings)\n",
"\n",
"decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),\n",
" word_padding_idx=tgt_padding)\n",
"decoder = onmt.decoders.decoder.InputFeedRNNDecoder(\n",
" hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True, \n",
" rnn_type=\"LSTM\", embeddings=decoder_embeddings)\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = onmt.models.model.NMTModel(encoder, decoder)\n",
"model.to(device)\n",
"\n",
"# Specify the tgt word generator and loss computation module\n",
"model.generator = nn.Sequential(\n",
" nn.Linear(rnn_size, len(tgt_vocab)),\n",
" nn.LogSoftmax(dim=-1)).to(device)\n",
"\n",
"loss = onmt.utils.loss.NMTLossCompute(\n",
" criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction=\"sum\"),\n",
" generator=model.generator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we set up the optimizer. This could be a core torch optim class, or our wrapper which handles learning rate updates and gradient normalization automatically."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"lr = 1\n",
"torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
"optim = onmt.utils.optimizers.Optimizer(\n",
" torch_optimizer, learning_rate=lr, max_grad_norm=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the training and validation data iterators"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we need to create the dynamic dataset iterator.\n",
"\n",
"This is not very 'library-friendly' for now because of the way the `DynamicDatasetIter` constructor is defined. It may evolve in the future."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"src_train = \"toy-ende/src-train.txt\"\n",
"tgt_train = \"toy-ende/tgt-train.txt\"\n",
"src_val = \"toy-ende/src-val.txt\"\n",
"tgt_val = \"toy-ende/tgt-val.txt\"\n",
"\n",
"# build the ParallelCorpus\n",
"corpus = ParallelCorpus(\"corpus\", src_train, tgt_train)\n",
"valid = ParallelCorpus(\"valid\", src_val, tgt_val)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# build the training iterator\n",
"train_iter = DynamicDatasetIter(\n",
" corpora={\"corpus\": corpus},\n",
" corpora_info={\"corpus\": {\"weight\": 1}},\n",
" transforms={},\n",
" fields=vocab_fields,\n",
" is_train=True,\n",
" batch_type=\"tokens\",\n",
" batch_size=4096,\n",
" batch_size_multiple=1,\n",
" data_type=\"text\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# make sure the iteration happens on GPU 0 (-1 for CPU, N for GPU N)\n",
"train_iter = iter(IterOnDevice(train_iter, 0))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"# build the validation iterator\n",
"valid_iter = DynamicDatasetIter(\n",
" corpora={\"valid\": valid},\n",
" corpora_info={\"valid\": {\"weight\": 1}},\n",
" transforms={},\n",
" fields=vocab_fields,\n",
" is_train=False,\n",
" batch_type=\"sents\",\n",
" batch_size=8,\n",
" batch_size_multiple=1,\n",
" data_type=\"text\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"valid_iter = IterOnDevice(valid_iter, 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we train."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2020-09-25 15:28:15,184 INFO] Start training loop and validate every 500 steps...\n",
"[2020-09-25 15:28:15,185 INFO] corpus's transforms: TransformPipe()\n",
"[2020-09-25 15:28:15,187 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:28:21,140 INFO] Step 50/ 1000; acc: 7.52; ppl: 8832.29; xent: 9.09; lr: 1.00000; 18916/18871 tok/s; 6 sec\n",
"[2020-09-25 15:28:24,869 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:28:27,121 INFO] Step 100/ 1000; acc: 9.34; ppl: 1840.06; xent: 7.52; lr: 1.00000; 18911/18785 tok/s; 12 sec\n",
"[2020-09-25 15:28:33,048 INFO] Step 150/ 1000; acc: 10.35; ppl: 1419.18; xent: 7.26; lr: 1.00000; 19062/19017 tok/s; 18 sec\n",
"[2020-09-25 15:28:37,019 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:28:39,022 INFO] Step 200/ 1000; acc: 11.14; ppl: 1127.44; xent: 7.03; lr: 1.00000; 19084/18911 tok/s; 24 sec\n",
"[2020-09-25 15:28:45,073 INFO] Step 250/ 1000; acc: 12.46; ppl: 912.13; xent: 6.82; lr: 1.00000; 18575/18570 tok/s; 30 sec\n",
"[2020-09-25 15:28:49,301 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:28:51,151 INFO] Step 300/ 1000; acc: 13.04; ppl: 779.50; xent: 6.66; lr: 1.00000; 18394/18307 tok/s; 36 sec\n",
"[2020-09-25 15:28:57,316 INFO] Step 350/ 1000; acc: 14.04; ppl: 685.48; xent: 6.53; lr: 1.00000; 18339/18173 tok/s; 42 sec\n",
"[2020-09-25 15:29:02,117 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:29:03,576 INFO] Step 400/ 1000; acc: 14.99; ppl: 590.20; xent: 6.38; lr: 1.00000; 18090/18029 tok/s; 48 sec\n",
"[2020-09-25 15:29:09,546 INFO] Step 450/ 1000; acc: 16.00; ppl: 524.51; xent: 6.26; lr: 1.00000; 18726/18536 tok/s; 54 sec\n",
"[2020-09-25 15:29:14,585 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:29:15,596 INFO] Step 500/ 1000; acc: 16.78; ppl: 453.38; xent: 6.12; lr: 1.00000; 17877/17980 tok/s; 60 sec\n",
"[2020-09-25 15:29:15,597 INFO] valid's transforms: TransformPipe()\n",
"[2020-09-25 15:29:15,599 INFO] Loading ParallelCorpus(toy-ende/src-val.txt, toy-ende/tgt-val.txt, align=None)...\n",
"[2020-09-25 15:29:24,528 INFO] Validation perplexity: 295.1\n",
"[2020-09-25 15:29:24,529 INFO] Validation accuracy: 17.6533\n",
"[2020-09-25 15:29:30,592 INFO] Step 550/ 1000; acc: 17.47; ppl: 421.26; xent: 6.04; lr: 1.00000; 7726/7610 tok/s; 75 sec\n",
"[2020-09-25 15:29:36,055 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:29:36,695 INFO] Step 600/ 1000; acc: 18.95; ppl: 354.44; xent: 5.87; lr: 1.00000; 17470/17598 tok/s; 82 sec\n",
"[2020-09-25 15:29:42,794 INFO] Step 650/ 1000; acc: 19.60; ppl: 328.47; xent: 5.79; lr: 1.00000; 18994/18793 tok/s; 88 sec\n",
"[2020-09-25 15:29:48,635 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:29:48,924 INFO] Step 700/ 1000; acc: 20.57; ppl: 285.48; xent: 5.65; lr: 1.00000; 17856/17788 tok/s; 94 sec\n",
"[2020-09-25 15:29:54,898 INFO] Step 750/ 1000; acc: 21.97; ppl: 249.06; xent: 5.52; lr: 1.00000; 19030/18924 tok/s; 100 sec\n",
"[2020-09-25 15:30:01,233 INFO] Step 800/ 1000; acc: 22.66; ppl: 228.54; xent: 5.43; lr: 1.00000; 17571/17471 tok/s; 106 sec\n",
"[2020-09-25 15:30:01,357 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:30:07,345 INFO] Step 850/ 1000; acc: 24.32; ppl: 193.65; xent: 5.27; lr: 1.00000; 18344/18313 tok/s; 112 sec\n",
"[2020-09-25 15:30:11,363 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:30:13,487 INFO] Step 900/ 1000; acc: 24.93; ppl: 177.65; xent: 5.18; lr: 1.00000; 18293/18105 tok/s; 118 sec\n",
"[2020-09-25 15:30:19,670 INFO] Step 950/ 1000; acc: 26.33; ppl: 157.10; xent: 5.06; lr: 1.00000; 17791/17746 tok/s; 124 sec\n",
"[2020-09-25 15:30:24,072 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
"[2020-09-25 15:30:25,820 INFO] Step 1000/ 1000; acc: 27.47; ppl: 137.64; xent: 4.92; lr: 1.00000; 17942/17962 tok/s; 131 sec\n",
"[2020-09-25 15:30:25,822 INFO] Loading ParallelCorpus(toy-ende/src-val.txt, toy-ende/tgt-val.txt, align=None)...\n",
"[2020-09-25 15:30:34,665 INFO] Validation perplexity: 241.801\n",
"[2020-09-25 15:30:34,666 INFO] Validation accuracy: 20.2837\n"
]
},
{
"data": {
"text/plain": [
"<onmt.utils.statistics.Statistics at 0x7fca934e8e80>"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"report_manager = onmt.utils.ReportMgr(\n",
" report_every=50, start_time=None, tensorboard_writer=None)\n",
"\n",
"trainer = onmt.Trainer(model=model,\n",
" train_loss=loss,\n",
" valid_loss=loss,\n",
" optim=optim,\n",
" report_manager=report_manager,\n",
" dropout=[0.1])\n",
"\n",
"trainer.train(train_iter=train_iter,\n",
" train_steps=1000,\n",
" valid_iter=valid_iter,\n",
" valid_steps=500)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Translate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For translation, we can build a \"traditional\" (as opposed to dynamic) dataset for now."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"src_data = {\"reader\": onmt.inputters.str2reader[\"text\"](), \"data\": src_val}\n",
"tgt_data = {\"reader\": onmt.inputters.str2reader[\"text\"](), \"data\": tgt_val}\n",
"_readers, _data = onmt.inputters.Dataset.config(\n",
" [('src', src_data), ('tgt', tgt_data)])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"dataset = onmt.inputters.Dataset(\n",
" vocab_fields, readers=_readers, data=_data,\n",
" sort_key=onmt.inputters.str2sortkey[\"text\"])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"data_iter = onmt.inputters.OrderedIterator(\n",
" dataset=dataset,\n",
" device=\"cuda\",\n",
" batch_size=10,\n",
" train=False,\n",
" sort=False,\n",
" sort_within_batch=True,\n",
" shuffle=False\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"src_reader = onmt.inputters.str2reader[\"text\"]\n",
"tgt_reader = onmt.inputters.str2reader[\"text\"]\n",
"scorer = GNMTGlobalScorer(alpha=0.7, \n",
" beta=0., \n",
" length_penalty=\"avg\", \n",
" coverage_penalty=\"none\")\n",
"gpu = 0 if torch.cuda.is_available() else -1\n",
"translator = Translator(model=model, \n",
" fields=vocab_fields, \n",
" src_reader=src_reader, \n",
" tgt_reader=tgt_reader, \n",
" global_scorer=scorer,\n",
" gpu=gpu)\n",
"builder = onmt.translate.TranslationBuilder(data=dataset, \n",
" fields=vocab_fields)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: translations will be very poor, because of the very low quantity of data, the absence of proper tokenization, and the brevity of the training."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"SENT 0: ['Parliament', 'Does', 'Not', 'Support', 'Amendment', 'Freeing', 'Tymoshenko']\n",
"PRED 0: Parlament das Parlament über die Europäische Parlament , die sich in der Lage in der Lage ist , die es in der Lage sind .\n",
"PRED SCORE: -1.5935\n",
"\n",
"\n",
"SENT 0: ['Today', ',', 'the', 'Ukraine', 'parliament', 'dismissed', ',', 'within', 'the', 'Code', 'of', 'Criminal', 'Procedure', 'amendment', ',', 'the', 'motion', 'to', 'revoke', 'an', 'article', 'based', 'on', 'which', 'the', 'opposition', 'leader', ',', 'Yulia', 'Tymoshenko', ',', 'was', 'sentenced', '.']\n",
"PRED 0: In der Nähe des Hotels , die in der Lage , die sich in der Lage ist , in der Lage , die in der Lage , die in der Lage ist .\n",
"PRED SCORE: -1.7173\n",
"\n",
"\n",
"SENT 0: ['The', 'amendment', 'that', 'would', 'lead', 'to', 'freeing', 'the', 'imprisoned', 'former', 'Prime', 'Minister', 'was', 'revoked', 'during', 'second', 'reading', 'of', 'the', 'proposal', 'for', 'mitigation', 'of', 'sentences', 'for', 'economic', 'offences', '.']\n",
"PRED 0: Die Tatsache , die sich in der Lage in der Lage ist , die für eine Antwort der Entwicklung für die Entwicklung von Präsident .\n",
"PRED SCORE: -1.6834\n",
"\n",
"\n",
"SENT 0: ['In', 'October', ',', 'Tymoshenko', 'was', 'sentenced', 'to', 'seven', 'years', 'in', 'prison', 'for', 'entering', 'into', 'what', 'was', 'reported', 'to', 'be', 'a', 'disadvantageous', 'gas', 'deal', 'with', 'Russia', '.']\n",
"PRED 0: In der Nähe wurde die Menschen in der Lage ist , die sich in der Lage <unk> .\n",
"PRED SCORE: -1.5765\n",
"\n",
"\n",
"SENT 0: ['The', 'verdict', 'is', 'not', 'yet', 'final;', 'the', 'court', 'will', 'hear', 'Tymoshenko', ''s', 'appeal', 'in', 'December', '.']\n",
"PRED 0: Es ist nicht der Fall , die in der Lage in der Lage sind .\n",
"PRED SCORE: -1.3287\n",
"\n",
"\n",
"SENT 0: ['Tymoshenko', 'claims', 'the', 'verdict', 'is', 'a', 'political', 'revenge', 'of', 'the', 'regime;', 'in', 'the', 'West', ',', 'the', 'trial', 'has', 'also', 'evoked', 'suspicion', 'of', 'being', 'biased', '.']\n",
"PRED 0: Um in der Lage ist auch eine Lösung Rolle .\n",
"PRED SCORE: -1.3975\n",
"\n",
"\n",
"SENT 0: ['The', 'proposal', 'to', 'remove', 'Article', '365', 'from', 'the', 'Code', 'of', 'Criminal', 'Procedure', ',', 'upon', 'which', 'the', 'former', 'Prime', 'Minister', 'was', 'sentenced', ',', 'was', 'supported', 'by', '147', 'members', 'of', 'parliament', '.']\n",
"PRED 0: Der Vorschlag , die in der Lage , die in der Lage , die in der Lage ist , war er von der Fall <unk> wurde .\n",
"PRED SCORE: -1.6062\n",
"\n",
"\n",
"SENT 0: ['Its', 'ratification', 'would', 'require', '226', 'votes', '.']\n",
"PRED 0: Es wäre noch einmal noch einmal <unk> .\n",
"PRED SCORE: -1.8001\n",
"\n",
"\n",
"SENT 0: ['Libya', ''s', 'Victory']\n",
"PRED 0: In der Nähe des Hotels befindet sich in der Nähe des Hotels in der Lage .\n",
"PRED SCORE: -1.7097\n",
"\n",
"\n",
"SENT 0: ['The', 'story', 'of', 'Libya', ''s', 'liberation', ',', 'or', 'rebellion', ',', 'already', 'has', 'its', 'defeated', '.']\n",
"PRED 0: In der Nähe des Hotels in der Lage ist in der Lage .\n",
"PRED SCORE: -1.7885\n",
"\n"
]
}
],
"source": [
"for batch in data_iter:\n",
" trans_batch = translator.translate_batch(\n",
" batch=batch, src_vocabs=[src_vocab],\n",
" attn_debug=False)\n",
" translations = builder.from_batch(trans_batch)\n",
" for trans in translations:\n",
" print(trans.log(0))\n",
" break"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|