osanseviero HF staff commited on
Commit
cbe1813
1 Parent(s): fc67275
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CODE_OF_CONDUCT.md +0 -77
  2. CONTRIBUTING.md +0 -28
  3. LICENSE +0 -21
  4. gradiodemo.py → app.py +1 -1
  5. docs/Makefile +0 -20
  6. docs/_static/theme_overrides.css +0 -9
  7. docs/command_line_tools.rst +0 -85
  8. docs/conf.py +0 -134
  9. docs/criterions.rst +0 -31
  10. docs/data.rst +0 -58
  11. docs/docutils.conf +0 -2
  12. docs/fairseq.gif +0 -0
  13. docs/fairseq_logo.png +0 -0
  14. docs/getting_started.rst +0 -216
  15. docs/hydra_integration.md +0 -284
  16. docs/index.rst +0 -49
  17. docs/lr_scheduler.rst +0 -34
  18. docs/make.bat +0 -36
  19. docs/models.rst +0 -104
  20. docs/modules.rst +0 -9
  21. docs/optim.rst +0 -38
  22. docs/overview.rst +0 -74
  23. docs/requirements.txt +0 -2
  24. docs/tasks.rst +0 -61
  25. docs/tutorial_classifying_names.rst +0 -415
  26. docs/tutorial_simple_lstm.rst +0 -518
  27. examples/.gitignore +0 -2
  28. examples/__init__.py +0 -9
  29. examples/adaptive_span/README.md +0 -90
  30. examples/adaptive_span/__init__.py +0 -19
  31. examples/adaptive_span/adagrad_with_grad_clip.py +0 -128
  32. examples/adaptive_span/adaptive_span_attention.py +0 -160
  33. examples/adaptive_span/adaptive_span_loss.py +0 -106
  34. examples/adaptive_span/adaptive_span_model.py +0 -263
  35. examples/adaptive_span/adaptive_span_model_wrapper.py +0 -145
  36. examples/adaptive_span/truncated_bptt_lm_task.py +0 -1
  37. examples/backtranslation/README.md +0 -297
  38. examples/backtranslation/deduplicate_lines.py +0 -41
  39. examples/backtranslation/extract_bt_data.py +0 -72
  40. examples/backtranslation/prepare-de-monolingual.sh +0 -98
  41. examples/backtranslation/prepare-wmt18en2de.sh +0 -135
  42. examples/backtranslation/sacrebleu.sh +0 -37
  43. examples/backtranslation/tokenized_bleu.sh +0 -46
  44. examples/bart/README.glue.md +0 -99
  45. examples/bart/README.md +0 -228
  46. examples/bart/README.summarization.md +0 -102
  47. examples/bart/summarize.py +0 -100
  48. examples/byte_level_bpe/README.md +0 -88
  49. examples/byte_level_bpe/get_bitext.py +0 -254
  50. examples/byte_level_bpe/get_data.sh +0 -47
CODE_OF_CONDUCT.md DELETED
@@ -1,77 +0,0 @@
1
- # Code of Conduct
2
-
3
- ## Our Pledge
4
-
5
- In the interest of fostering an open and welcoming environment, we as
6
- contributors and maintainers pledge to make participation in our project and
7
- our community a harassment-free experience for everyone, regardless of age, body
8
- size, disability, ethnicity, sex characteristics, gender identity and expression,
9
- level of experience, education, socio-economic status, nationality, personal
10
- appearance, race, religion, or sexual identity and orientation.
11
-
12
- ## Our Standards
13
-
14
- Examples of behavior that contributes to creating a positive environment
15
- include:
16
-
17
- * Using welcoming and inclusive language
18
- * Being respectful of differing viewpoints and experiences
19
- * Gracefully accepting constructive criticism
20
- * Focusing on what is best for the community
21
- * Showing empathy towards other community members
22
-
23
- Examples of unacceptable behavior by participants include:
24
-
25
- * The use of sexualized language or imagery and unwelcome sexual attention or
26
- advances
27
- * Trolling, insulting/derogatory comments, and personal or political attacks
28
- * Public or private harassment
29
- * Publishing others' private information, such as a physical or electronic
30
- address, without explicit permission
31
- * Other conduct which could reasonably be considered inappropriate in a
32
- professional setting
33
-
34
- ## Our Responsibilities
35
-
36
- Project maintainers are responsible for clarifying the standards of acceptable
37
- behavior and are expected to take appropriate and fair corrective action in
38
- response to any instances of unacceptable behavior.
39
-
40
- Project maintainers have the right and responsibility to remove, edit, or
41
- reject comments, commits, code, wiki edits, issues, and other contributions
42
- that are not aligned to this Code of Conduct, or to ban temporarily or
43
- permanently any contributor for other behaviors that they deem inappropriate,
44
- threatening, offensive, or harmful.
45
-
46
- ## Scope
47
-
48
- This Code of Conduct applies within all project spaces, and it also applies when
49
- an individual is representing the project or its community in public spaces.
50
- Examples of representing a project or community include using an official
51
- project e-mail address, posting via an official social media account, or acting
52
- as an appointed representative at an online or offline event. Representation of
53
- a project may be further defined and clarified by project maintainers.
54
-
55
- ## Enforcement
56
-
57
- Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
- reported by contacting the project team at <conduct@pytorch.org>. All
59
- complaints will be reviewed and investigated and will result in a response that
60
- is deemed necessary and appropriate to the circumstances. The project team is
61
- obligated to maintain confidentiality with regard to the reporter of an incident.
62
- Further details of specific enforcement policies may be posted separately.
63
-
64
- Project maintainers who do not follow or enforce the Code of Conduct in good
65
- faith may face temporary or permanent repercussions as determined by other
66
- members of the project's leadership.
67
-
68
- ## Attribution
69
-
70
- This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
- available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
-
73
- [homepage]: https://www.contributor-covenant.org
74
-
75
- For answers to common questions about this code of conduct, see
76
- https://www.contributor-covenant.org/faq
77
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CONTRIBUTING.md DELETED
@@ -1,28 +0,0 @@
1
- # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2
- We want to make contributing to this project as easy and transparent as
3
- possible.
4
-
5
- ## Pull Requests
6
- We actively welcome your pull requests.
7
-
8
- 1. Fork the repo and create your branch from `master`.
9
- 2. If you've added code that should be tested, add tests.
10
- 3. If you've changed APIs, update the documentation.
11
- 4. Ensure the test suite passes.
12
- 5. Make sure your code lints.
13
- 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
-
15
- ## Contributor License Agreement ("CLA")
16
- In order to accept your pull request, we need you to submit a CLA. You only need
17
- to do this once to work on any of Facebook's open source projects.
18
-
19
- Complete your CLA here: <https://code.facebook.com/cla>
20
-
21
- ## Issues
22
- We use GitHub issues to track public bugs. Please ensure your description is
23
- clear and has sufficient instructions to be able to reproduce the issue.
24
-
25
- ## License
26
- By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27
- you agree that your contributions will be licensed under the LICENSE file in
28
- the root directory of this source tree.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) Facebook, Inc. and its affiliates.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradiodemo.py → app.py RENAMED
@@ -2,7 +2,7 @@ import gradio as gr
2
 
3
 
4
 
5
- description = "demo for HuBERT. To use it, simply add your audio or click one of the examples to load them. Read more at the links below."
6
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.07447'>HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units</a> | <a href='https://github.com/pytorch/fairseq/tree/master/examples/hubert'>Github Repo</a></p>"
7
 
8
  gr.Interface.load("huggingface/facebook/hubert-large-ls960-ft",
 
2
 
3
 
4
 
5
+ description = "Demo for HuBERT. Add your audio or click one of the examples to load them. Read more at the links below."
6
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.07447'>HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units</a> | <a href='https://github.com/pytorch/fairseq/tree/master/examples/hubert'>Github Repo</a></p>"
7
 
8
  gr.Interface.load("huggingface/facebook/hubert-large-ls960-ft",
docs/Makefile DELETED
@@ -1,20 +0,0 @@
1
- # Minimal makefile for Sphinx documentation
2
- #
3
-
4
- # You can set these variables from the command line.
5
- SPHINXOPTS =
6
- SPHINXBUILD = python -msphinx
7
- SPHINXPROJ = fairseq
8
- SOURCEDIR = .
9
- BUILDDIR = _build
10
-
11
- # Put it first so that "make" without argument is like "make help".
12
- help:
13
- @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
-
15
- .PHONY: help Makefile
16
-
17
- # Catch-all target: route all unknown targets to Sphinx using the new
18
- # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
- %: Makefile
20
- @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/_static/theme_overrides.css DELETED
@@ -1,9 +0,0 @@
1
- .wy-table-responsive table td kbd {
2
- white-space: nowrap;
3
- }
4
- .wy-table-responsive table td {
5
- white-space: normal !important;
6
- }
7
- .wy-table-responsive {
8
- overflow: visible !important;
9
- }
 
 
 
 
 
 
 
 
 
 
docs/command_line_tools.rst DELETED
@@ -1,85 +0,0 @@
1
- .. _Command-line Tools:
2
-
3
- Command-line Tools
4
- ==================
5
-
6
- Fairseq provides several command-line tools for training and evaluating models:
7
-
8
- - :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
9
- - :ref:`fairseq-train`: Train a new model on one or multiple GPUs
10
- - :ref:`fairseq-generate`: Translate pre-processed data with a trained model
11
- - :ref:`fairseq-interactive`: Translate raw text with a trained model
12
- - :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
13
- - :ref:`fairseq-eval-lm`: Language model evaluation
14
-
15
-
16
- .. _fairseq-preprocess:
17
-
18
- fairseq-preprocess
19
- ~~~~~~~~~~~~~~~~~~
20
- .. automodule:: fairseq_cli.preprocess
21
-
22
- .. argparse::
23
- :module: fairseq.options
24
- :func: get_preprocessing_parser
25
- :prog: fairseq-preprocess
26
-
27
-
28
- .. _fairseq-train:
29
-
30
- fairseq-train
31
- ~~~~~~~~~~~~~
32
- .. automodule:: fairseq_cli.train
33
-
34
- .. argparse::
35
- :module: fairseq.options
36
- :func: get_training_parser
37
- :prog: fairseq-train
38
-
39
-
40
- .. _fairseq-generate:
41
-
42
- fairseq-generate
43
- ~~~~~~~~~~~~~~~~
44
- .. automodule:: fairseq_cli.generate
45
-
46
- .. argparse::
47
- :module: fairseq.options
48
- :func: get_generation_parser
49
- :prog: fairseq-generate
50
-
51
-
52
- .. _fairseq-interactive:
53
-
54
- fairseq-interactive
55
- ~~~~~~~~~~~~~~~~~~~
56
- .. automodule:: fairseq_cli.interactive
57
-
58
- .. argparse::
59
- :module: fairseq.options
60
- :func: get_interactive_generation_parser
61
- :prog: fairseq-interactive
62
-
63
-
64
- .. _fairseq-score:
65
-
66
- fairseq-score
67
- ~~~~~~~~~~~~~
68
- .. automodule:: fairseq_cli.score
69
-
70
- .. argparse::
71
- :module: fairseq_cli.score
72
- :func: get_parser
73
- :prog: fairseq-score
74
-
75
-
76
- .. _fairseq-eval-lm:
77
-
78
- fairseq-eval-lm
79
- ~~~~~~~~~~~~~~~
80
- .. automodule:: fairseq_cli.eval_lm
81
-
82
- .. argparse::
83
- :module: fairseq.options
84
- :func: get_eval_lm_parser
85
- :prog: fairseq-eval-lm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/conf.py DELETED
@@ -1,134 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- #
4
- # fairseq documentation build configuration file, created by
5
- # sphinx-quickstart on Fri Aug 17 21:45:30 2018.
6
- #
7
- # This file is execfile()d with the current directory set to its
8
- # containing dir.
9
- #
10
- # Note that not all possible configuration values are present in this
11
- # autogenerated file.
12
- #
13
- # All configuration values have a default; values that are commented out
14
- # serve to show the default.
15
-
16
- # If extensions (or modules to document with autodoc) are in another directory,
17
- # add these directories to sys.path here. If the directory is relative to the
18
- # documentation root, use os.path.abspath to make it absolute, like shown here.
19
-
20
- import os
21
- import sys
22
- from fairseq import __version__
23
-
24
-
25
- # source code directory, relative to this file, for sphinx-autobuild
26
- sys.path.insert(0, os.path.abspath(".."))
27
-
28
- source_suffix = [".rst"]
29
-
30
- # -- General configuration ------------------------------------------------
31
-
32
- # If your documentation needs a minimal Sphinx version, state it here.
33
- #
34
- # needs_sphinx = '1.0'
35
-
36
- # Add any Sphinx extension module names here, as strings. They can be
37
- # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
38
- # ones.
39
- extensions = [
40
- "sphinx.ext.autodoc",
41
- "sphinx.ext.intersphinx",
42
- "sphinx.ext.viewcode",
43
- "sphinx.ext.napoleon",
44
- "sphinxarg.ext",
45
- ]
46
-
47
- # Add any paths that contain templates here, relative to this directory.
48
- templates_path = ["_templates"]
49
-
50
- # The master toctree document.
51
- master_doc = "index"
52
-
53
- # General information about the project.
54
- project = "fairseq"
55
- copyright = "Facebook AI Research (FAIR)"
56
- author = "Facebook AI Research (FAIR)"
57
-
58
- github_doc_root = "https://github.com/pytorch/fairseq/tree/master/docs/"
59
-
60
- # The version info for the project you're documenting, acts as replacement for
61
- # |version| and |release|, also used in various other places throughout the
62
- # built documents.
63
- #
64
- # The short X.Y version.
65
- version = __version__
66
- # The full version, including alpha/beta/rc tags.
67
- release = __version__
68
-
69
- # The language for content autogenerated by Sphinx. Refer to documentation
70
- # for a list of supported languages.
71
- #
72
- # This is also used if you do content translation via gettext catalogs.
73
- # Usually you set "language" from the command line for these cases.
74
- language = None
75
-
76
- # List of patterns, relative to source directory, that match files and
77
- # directories to ignore when looking for source files.
78
- # This patterns also effect to html_static_path and html_extra_path
79
- exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
80
-
81
- # The name of the Pygments (syntax highlighting) style to use.
82
- pygments_style = "sphinx"
83
- highlight_language = "python"
84
-
85
- # If true, `todo` and `todoList` produce output, else they produce nothing.
86
- todo_include_todos = False
87
-
88
-
89
- # -- Options for HTML output ----------------------------------------------
90
-
91
- # The theme to use for HTML and HTML Help pages. See the documentation for
92
- # a list of builtin themes.
93
- #
94
- html_theme = "sphinx_rtd_theme"
95
-
96
- # Theme options are theme-specific and customize the look and feel of a theme
97
- # further. For a list of options available for each theme, see the
98
- # documentation.
99
- #
100
- # html_theme_options = {}
101
-
102
- # Add any paths that contain custom static files (such as style sheets) here,
103
- # relative to this directory. They are copied after the builtin static files,
104
- # so a file named "default.css" will overwrite the builtin "default.css".
105
- html_static_path = ["_static"]
106
-
107
- html_context = {
108
- "css_files": [
109
- "_static/theme_overrides.css", # override wide tables in RTD theme
110
- ],
111
- }
112
-
113
- # Custom sidebar templates, must be a dictionary that maps document names
114
- # to template names.
115
- #
116
- # This is required for the alabaster theme
117
- # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
118
- # html_sidebars = {
119
- # '**': [
120
- # 'about.html',
121
- # 'navigation.html',
122
- # 'relations.html', # needs 'show_related': True theme option to display
123
- # 'searchbox.html',
124
- # 'donate.html',
125
- # ]
126
- # }
127
-
128
-
129
- # Example configuration for intersphinx: refer to the Python standard library.
130
- intersphinx_mapping = {
131
- "numpy": ("http://docs.scipy.org/doc/numpy/", None),
132
- "python": ("https://docs.python.org/", None),
133
- "torch": ("https://pytorch.org/docs/master/", None),
134
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/criterions.rst DELETED
@@ -1,31 +0,0 @@
1
- .. role:: hidden
2
- :class: hidden-section
3
-
4
- .. _Criterions:
5
-
6
- Criterions
7
- ==========
8
-
9
- Criterions compute the loss function given the model and batch, roughly::
10
-
11
- loss = criterion(model, batch)
12
-
13
- .. automodule:: fairseq.criterions
14
- :members:
15
-
16
- .. autoclass:: fairseq.criterions.FairseqCriterion
17
- :members:
18
- :undoc-members:
19
-
20
- .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
21
- :members:
22
- :undoc-members:
23
- .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
24
- :members:
25
- :undoc-members:
26
- .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
27
- :members:
28
- :undoc-members:
29
- .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
30
- :members:
31
- :undoc-members:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/data.rst DELETED
@@ -1,58 +0,0 @@
1
- .. role:: hidden
2
- :class: hidden-section
3
-
4
- .. module:: fairseq.data
5
-
6
- Data Loading and Utilities
7
- ==========================
8
-
9
- .. _datasets:
10
-
11
- Datasets
12
- --------
13
-
14
- **Datasets** define the data format and provide helpers for creating
15
- mini-batches.
16
-
17
- .. autoclass:: fairseq.data.FairseqDataset
18
- :members:
19
- .. autoclass:: fairseq.data.LanguagePairDataset
20
- :members:
21
- .. autoclass:: fairseq.data.MonolingualDataset
22
- :members:
23
-
24
- **Helper Datasets**
25
-
26
- These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
27
- provide additional functionality:
28
-
29
- .. autoclass:: fairseq.data.BacktranslationDataset
30
- :members:
31
- .. autoclass:: fairseq.data.ConcatDataset
32
- :members:
33
- .. autoclass:: fairseq.data.ResamplingDataset
34
- :members:
35
- .. autoclass:: fairseq.data.RoundRobinZipDatasets
36
- :members:
37
- .. autoclass:: fairseq.data.TransformEosDataset
38
- :members:
39
-
40
-
41
- Dictionary
42
- ----------
43
-
44
- .. autoclass:: fairseq.data.Dictionary
45
- :members:
46
-
47
-
48
- Iterators
49
- ---------
50
-
51
- .. autoclass:: fairseq.data.CountingIterator
52
- :members:
53
- .. autoclass:: fairseq.data.EpochBatchIterator
54
- :members:
55
- .. autoclass:: fairseq.data.GroupedIterator
56
- :members:
57
- .. autoclass:: fairseq.data.ShardedIterator
58
- :members:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/docutils.conf DELETED
@@ -1,2 +0,0 @@
1
- [writers]
2
- option-limit=0
 
 
 
docs/fairseq.gif DELETED
Binary file (2.66 MB)
 
docs/fairseq_logo.png DELETED
Binary file (73 kB)
 
docs/getting_started.rst DELETED
@@ -1,216 +0,0 @@
1
- Evaluating Pre-trained Models
2
- =============================
3
-
4
- First, download a pre-trained model along with its vocabularies:
5
-
6
- .. code-block:: console
7
-
8
- > curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
9
-
10
- This model uses a `Byte Pair Encoding (BPE)
11
- vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply
12
- the encoding to the source text before it can be translated. This can be
13
- done with the
14
- `apply\_bpe.py <https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/apply_bpe.py>`__
15
- script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
16
- used as a continuation marker and the original text can be easily
17
- recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
18
- flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
19
- using ``tokenizer.perl`` from
20
- `mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
21
-
22
- Let's use :ref:`fairseq-interactive` to generate translations interactively.
23
- Here, we use a beam size of 5 and preprocess the input with the Moses
24
- tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
25
- remove the BPE continuation markers and detokenize the output.
26
-
27
- .. code-block:: console
28
-
29
- > MODEL_DIR=wmt14.en-fr.fconv-py
30
- > fairseq-interactive \
31
- --path $MODEL_DIR/model.pt $MODEL_DIR \
32
- --beam 5 --source-lang en --target-lang fr \
33
- --tokenizer moses \
34
- --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
35
- | loading model(s) from wmt14.en-fr.fconv-py/model.pt
36
- | [en] dictionary: 44206 types
37
- | [fr] dictionary: 44463 types
38
- | Type the input sentence and press return:
39
- Why is it rare to discover new marine mammal species?
40
- S-0 Why is it rare to discover new marine mam@@ mal species ?
41
- H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
42
- P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
43
-
44
- This generation script produces three types of outputs: a line prefixed
45
- with *O* is a copy of the original source sentence; *H* is the
46
- hypothesis along with an average log-likelihood; and *P* is the
47
- positional score per token position, including the
48
- end-of-sentence marker which is omitted from the text.
49
-
50
- Other types of output lines you might see are *D*, the detokenized hypothesis,
51
- *T*, the reference target, *A*, alignment info, *E* the history of generation steps.
52
-
53
- See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
54
- full list of pre-trained models available.
55
-
56
- Training a New Model
57
- ====================
58
-
59
- The following tutorial is for machine translation. For an example of how
60
- to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
61
- ``examples/`` directory.
62
-
63
- Data Pre-processing
64
- -------------------
65
-
66
- Fairseq contains example pre-processing scripts for several translation
67
- datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
68
- 2014 (English-German). To pre-process and binarize the IWSLT dataset:
69
-
70
- .. code-block:: console
71
-
72
- > cd examples/translation/
73
- > bash prepare-iwslt14.sh
74
- > cd ../..
75
- > TEXT=examples/translation/iwslt14.tokenized.de-en
76
- > fairseq-preprocess --source-lang de --target-lang en \
77
- --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
78
- --destdir data-bin/iwslt14.tokenized.de-en
79
-
80
- This will write binarized data that can be used for model training to
81
- ``data-bin/iwslt14.tokenized.de-en``.
82
-
83
- Training
84
- --------
85
-
86
- Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
87
- well for the IWSLT 2014 dataset:
88
-
89
- .. code-block:: console
90
-
91
- > mkdir -p checkpoints/fconv
92
- > CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
93
- --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
94
- --arch fconv_iwslt_de_en --save-dir checkpoints/fconv
95
-
96
- By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
97
- ``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
98
- change the number of GPU devices that will be used.
99
-
100
- Also note that the batch size is specified in terms of the maximum
101
- number of tokens per batch (``--max-tokens``). You may need to use a
102
- smaller value depending on the available GPU memory on your system.
103
-
104
- Generation
105
- ----------
106
-
107
- Once your model is trained, you can generate translations using
108
- :ref:`fairseq-generate` **(for binarized data)** or
109
- :ref:`fairseq-interactive` **(for raw text)**:
110
-
111
- .. code-block:: console
112
-
113
- > fairseq-generate data-bin/iwslt14.tokenized.de-en \
114
- --path checkpoints/fconv/checkpoint_best.pt \
115
- --batch-size 128 --beam 5
116
- | [de] dictionary: 35475 types
117
- | [en] dictionary: 24739 types
118
- | data-bin/iwslt14.tokenized.de-en test 6750 examples
119
- | model fconv
120
- | loaded checkpoint trainings/fconv/checkpoint_best.pt
121
- S-721 danke .
122
- T-721 thank you .
123
- ...
124
-
125
- To generate translations with only a CPU, use the ``--cpu`` flag. BPE
126
- continuation markers can be removed with the ``--remove-bpe`` flag.
127
-
128
- Advanced Training Options
129
- =========================
130
-
131
- Large mini-batch training with delayed updates
132
- ----------------------------------------------
133
-
134
- The ``--update-freq`` option can be used to accumulate gradients from
135
- multiple mini-batches and delay updating, creating a larger effective
136
- batch size. Delayed updates can also improve training speed by reducing
137
- inter-GPU communication costs and by saving idle time caused by variance
138
- in workload across GPUs. See `Ott et al.
139
- (2018) <https://arxiv.org/abs/1806.00187>`__ for more details.
140
-
141
- To train on a single GPU with an effective batch size that is equivalent
142
- to training on 8 GPUs:
143
-
144
- .. code-block:: console
145
-
146
- > CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
147
-
148
- Training with half precision floating point (FP16)
149
- --------------------------------------------------
150
-
151
- .. note::
152
-
153
- FP16 training requires a Volta GPU and CUDA 9.1 or greater
154
-
155
- Recent GPUs enable efficient half precision floating point computation,
156
- e.g., using `Nvidia Tensor Cores
157
- <https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html>`__.
158
- Fairseq supports FP16 training with the ``--fp16`` flag:
159
-
160
- .. code-block:: console
161
-
162
- > fairseq-train --fp16 (...)
163
-
164
- Distributed training
165
- --------------------
166
-
167
- Distributed training in fairseq is implemented on top of ``torch.distributed``.
168
- The easiest way to launch jobs is with the `torch.distributed.launch
169
- <https://pytorch.org/docs/stable/distributed.html#launch-utility>`__ tool.
170
-
171
- For example, to train a large English-German Transformer model on 2 nodes each
172
- with 8 GPUs (in total 16 GPUs), run the following command on each node,
173
- replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
174
- sure to update ``--master_addr`` to the IP address of the first node:
175
-
176
- .. code-block:: console
177
-
178
- > python -m torch.distributed.launch --nproc_per_node=8 \
179
- --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
180
- --master_port=12345 \
181
- $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
182
- --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
183
- --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
184
- --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
185
- --lr 0.0005 \
186
- --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
187
- --max-tokens 3584 \
188
- --max-epoch 70 \
189
- --fp16
190
-
191
- On SLURM clusters, fairseq will automatically detect the number of nodes and
192
- GPUs, but a port number must be provided:
193
-
194
- .. code-block:: console
195
-
196
- > salloc --gpus=16 --nodes 2 (...)
197
- > srun fairseq-train --distributed-port 12345 (...).
198
-
199
- Sharding very large datasets
200
- ----------------------------
201
-
202
- It can be challenging to train over very large datasets, particularly if your
203
- machine does not have much system RAM. Most tasks in fairseq support training
204
- over "sharded" datasets, in which the original dataset has been preprocessed
205
- into non-overlapping chunks (or "shards").
206
-
207
- For example, instead of preprocessing all your data into a single "data-bin"
208
- directory, you can split the data and create "data-bin1", "data-bin2", etc.
209
- Then you can adapt your training command like so:
210
-
211
- .. code-block:: console
212
-
213
- > fairseq-train data-bin1:data-bin2:data-bin3 (...)
214
-
215
- Training will now iterate over each shard, one by one, with each shard
216
- corresponding to an "epoch", thus reducing system memory usage.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/hydra_integration.md DELETED
@@ -1,284 +0,0 @@
1
- ## Hydra
2
-
3
- [Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
4
- framework that simplifies the development of research and other complex
5
- applications. The key feature is the ability to dynamically create a
6
- hierarchical configuration by composition and override it through config files
7
- and the command line. The name Hydra comes from its ability to run multiple
8
- similar jobs - much like a Hydra with multiple heads.
9
-
10
- ## Motivation
11
-
12
- Until recently, all components in fairseq were configured through a shared
13
- `args` namespace that was created at application startup. Components declared
14
- their own `add_args` method to update the argparse parser, hoping that the names
15
- would not clash with arguments from other components. While this model works for
16
- smaller applications, as fairseq grew and became integrated into other
17
- applications, this became problematic. In order to determine how to configure
18
- each component, one needed to a) examine what args were added by this component,
19
- and b) read the code to figure out what shared arguments it is using that were
20
- added in other places. Reproducing models involved sharing commands that often
21
- contained dozens of command line switches.
22
-
23
- The model described above is still supported by fairseq for backward
24
- compatibility, but will be deprecated some time in the future.
25
-
26
- New components in fairseq should now create a dataclass that encapsulates all
27
- parameters required to configure this component. The dataclass is registered
28
- along with the component, and fairseq takes care of constructing and providing
29
- this configuration object to the component's constructor. Note that sharing
30
- parameters can optionally still work, but one has to explicitly point to the
31
- "source of truth" (see inheritance example below). These changes make components
32
- in fairseq more independent and re-usable by other applications: all that is
33
- needed to create a component is to initialize its dataclass and overwrite some
34
- of the defaults.
35
-
36
- While configuring fairseq through command line (using either the legacy argparse
37
- based or the new Hydra based entry points) is still fully supported, you can now
38
- take advantage of configuring fairseq completely or piece-by-piece through
39
- hierarchical YAML configuration files. These files can also be shipped as
40
- examples that others can use to run an identically configured job.
41
-
42
- Additionally, Hydra has a rich and growing [library of
43
- plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
44
- provide functionality such as hyperparameter sweeping (including using bayesian
45
- optimization through the [Ax](https://github.com/facebook/Ax) library), job
46
- launching across various platforms, and more.
47
-
48
- ## Creating or migrating components
49
-
50
- In general, each new (or updated) component should provide a companion
51
- [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
52
- typically located in the same file as the component and are passed as arguments
53
- to the `register_*()` functions. Top-level configs that should be present in
54
- every fairseq application are placed in the
55
- [global](fairseq/dataclass/configs.py) config file and added to the
56
- `FairseqConfig` object.
57
-
58
- Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
59
- classes are decorated with a `@dataclass` decorator, and typically inherit from
60
- `FairseqDataclass` (which adds some functionality for backward compatibility).
61
- Each field must have a type, and generally has metadata (such as a help string)
62
- and a default value. Only primitive types or other config objects are allowed as
63
- data types for each field.
64
-
65
- #### Example:
66
-
67
- ```python
68
- from dataclasses import dataclass, field
69
- from fairseq.dataclass import FairseqDataclass
70
-
71
- @dataclass
72
- class InteractiveConfig(FairseqDataclass):
73
- buffer_size: int = field(
74
- default=0,
75
- metadata={
76
- "help": "read this many sentences into a buffer before processing them"
77
- },
78
- )
79
- input: str = field(
80
- default="-",
81
- metadata={"help": "file to read from; use - for stdin"},
82
- )
83
- ```
84
-
85
- ### Inherting values
86
-
87
- Some components require sharing a value. For example, a learning rate scheduler
88
- and an optimizer may both need to know the initial learning rate value. One can
89
- declare a field that, by default, will inherit its value from another config
90
- node in the same hierarchy:
91
-
92
- ```python
93
- @dataclass
94
- FairseqAdamConfig(FairseqDataclass):
95
- ...
96
- lr: List[float] = II("optimization.lr")
97
- ...
98
- ```
99
-
100
- `II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
101
- the value one can use in a YAML config file or through command line to achieve
102
- the same effect. Note that this assumes that there is an "optimization" config
103
- object in the root config and it has a field called "lr".
104
-
105
- ### Tasks and Models
106
-
107
- Creating Tasks and Models works same as before, except that legacy
108
- implementations now inherit from `LegacyFairseq*` base classes, while new
109
- components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
110
- to the `register_*()` functions.
111
-
112
- #### Task example:
113
-
114
- ```python
115
- @dataclass
116
- class LanguageModelingConfig(FairseqDataclass):
117
- data: Optional[str] = field(
118
- default=None, metadata={"help": "path to data directory"}
119
- )
120
- ...
121
-
122
- @register_task("language_modeling", dataclass=LanguageModelingConfig)
123
- class LanguageModelingTask(FairseqTask):
124
- ...
125
- @classmethod
126
- def setup_task(cls, cfg: LanguageModelingConfig):
127
- ...
128
- ```
129
-
130
- #### Model example:
131
-
132
- ```python
133
- @dataclass
134
- class TransformerLanguageModelConfig(FairseqDataclass):
135
- activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
136
- default="relu", metadata={"help": "activation function to use"}
137
- )
138
- dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
139
- ...
140
-
141
- @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
142
- class TransformerLanguageModel(FairseqLanguageModel):
143
- ...
144
- @classmethod
145
- def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
146
- ...
147
- ```
148
-
149
- ### Other components
150
-
151
- Other components work as before, but they now take their configuration dataclass
152
- as the only constructor argument:
153
-
154
- ```python
155
- @dataclass
156
- class MosesTokenizerConfig(FairseqDataclass):
157
- source_lang: str = field(default="en", metadata={"help": "source language"})
158
- ...
159
-
160
- @register_tokenizer("moses", dataclass=MosesTokenizerConfig)
161
- class MosesTokenizer(object):
162
- def __init__(self, cfg: MosesTokenizerConfig):
163
- ...
164
- ```
165
-
166
- Note that if you are adding a new registry for a new set of components, you need
167
- to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
168
-
169
- ```python
170
- @dataclass
171
- class FairseqConfig(object):
172
- ...
173
- my_new_registry: Any = None
174
- ```
175
-
176
- ## Training with `fairseq-hydra-train`
177
-
178
- To fully take advantage of configuration flexibility offered by Hydra, you may
179
- want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
180
- tools such as `fairseq-train` will remain supported for the foreseeable future
181
- but will be deprecated eventually.
182
-
183
- On startup, Hydra will create a configuration object that contains a hierarchy
184
- of all the necessary dataclasses populated with their default values in the
185
- code. The default values are overwritten by values found in YAML files in
186
- `fairseq/config` directory (which currently sets minimal defaults) and then
187
- further overwritten by values provided through command line arguments.
188
-
189
- Some of the most common use cases are shown below:
190
-
191
- ### 1. Override default values through command line:
192
-
193
- ```shell script
194
- $ fairseq-hydra-train \
195
- distributed_training.distributed_world_size=1 \
196
- dataset.batch_size=2 \
197
- task.data=data-bin \
198
- model=transformer_lm/transformer_lm_gpt \
199
- task=language_modeling \
200
- optimization.max_update=5000
201
- ```
202
-
203
- Note that along with explicitly providing values for parameters such as
204
- `dataset.batch_size`, this also tells Hydra to overlay configuration found in
205
- `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
206
- values in the dataclass. If you want to train a model without specifying a
207
- particular architecture you can simply specify `model=transformer_lm`. This only
208
- works for migrated tasks and models.
209
-
210
- ### 2. Replace bundled configs with an external config:
211
-
212
- ```shell script
213
- $ fairseq-hydra-train \
214
- --config-dir /path/to/external/configs \
215
- --config-name wiki103
216
- ```
217
-
218
- where `/path/to/external/configs/wiki103.yaml` contains:
219
-
220
- ```yaml
221
- # @package _group_
222
-
223
- model:
224
- _name: transformer_lm
225
- distributed_training:
226
- distributed_world_size: 1
227
- dataset:
228
- batch_size: 2
229
- task:
230
- _name: language_modeling
231
- data: /path/to/data
232
- add_bos_token: false
233
- max_target_positions: 1024
234
- optimization:
235
- max_update: 50000
236
- lr: [ 0.25 ]
237
- criterion: cross_entropy
238
- optimizer: adam
239
- lr_scheduler:
240
- _name: cosine
241
- ```
242
-
243
- Note that here bundled configs from `fairseq/config` directory are not used,
244
- however the defaults from each dataclass will still be used (unless overwritten
245
- by your external config).
246
-
247
- Additionally you can choose to break up your configs by creating a directory
248
- structure in the same location as your main config file, with the names of the
249
- top-level fields (such as "model", "dataset", etc), and placing config files
250
- with meaningful names that would populate that specific section of your
251
- top-level config file (for example, you might have
252
- `model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
253
- can then specify the correct configuration via command line, defaults in the
254
- main config, or even launch all of them as a sweep (see Hydra documentation on
255
- how to do this).
256
-
257
- ### 3. Add an external config directory to Hydra search path:
258
-
259
- This allows combining default configuration (including using any bundled config
260
- files), while specifying your own config files for some parts of the
261
- configuration.
262
-
263
- ```shell script
264
- $ fairseq-hydra-train \
265
- distributed_training.distributed_world_size=1 \
266
- dataset.batch_size=2 \
267
- task.data=/path/to/data/ \
268
- model=transformer_lm/2_layers \
269
- task=language_modeling \
270
- optimization.max_update=5000 \
271
- --config-dir /path/to/external/configs
272
- ```
273
-
274
- where `/path/to/external/configs` has the following structure:
275
- ```
276
- .
277
- +-- model
278
- | +-- transformer_lm
279
- | | +-- 2_layers.yaml
280
- ```
281
-
282
- and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
283
- `decoder_layers` set to 2. You can add other configs to configure other
284
- components as well.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/index.rst DELETED
@@ -1,49 +0,0 @@
1
- .. fairseq documentation master file, created by
2
- sphinx-quickstart on Fri Aug 17 21:45:30 2018.
3
- You can adapt this file completely to your liking, but it should at least
4
- contain the root `toctree` directive.
5
-
6
- :github_url: https://github.com/pytorch/fairseq
7
-
8
-
9
- fairseq documentation
10
- =====================
11
-
12
- Fairseq is a sequence modeling toolkit written in `PyTorch
13
- <http://pytorch.org/>`_ that allows researchers and developers to
14
- train custom models for translation, summarization, language modeling and other
15
- text generation tasks.
16
-
17
- .. toctree::
18
- :maxdepth: 1
19
- :caption: Getting Started
20
-
21
- getting_started
22
- command_line_tools
23
-
24
- .. toctree::
25
- :maxdepth: 1
26
- :caption: Extending Fairseq
27
-
28
- overview
29
- tutorial_simple_lstm
30
- tutorial_classifying_names
31
-
32
- .. toctree::
33
- :maxdepth: 2
34
- :caption: Library Reference
35
-
36
- tasks
37
- models
38
- criterions
39
- optim
40
- lr_scheduler
41
- data
42
- modules
43
-
44
-
45
- Indices and tables
46
- ==================
47
-
48
- * :ref:`genindex`
49
- * :ref:`search`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/lr_scheduler.rst DELETED
@@ -1,34 +0,0 @@
1
- .. role:: hidden
2
- :class: hidden-section
3
-
4
- .. _Learning Rate Schedulers:
5
-
6
- Learning Rate Schedulers
7
- ========================
8
-
9
- Learning Rate Schedulers update the learning rate over the course of training.
10
- Learning rates can be updated after each update via :func:`step_update` or at
11
- epoch boundaries via :func:`step`.
12
-
13
- .. automodule:: fairseq.optim.lr_scheduler
14
- :members:
15
-
16
- .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
17
- :members:
18
- :undoc-members:
19
-
20
- .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
21
- :members:
22
- :undoc-members:
23
- .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
24
- :members:
25
- :undoc-members:
26
- .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
27
- :members:
28
- :undoc-members:
29
- .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
30
- :members:
31
- :undoc-members:
32
- .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
33
- :members:
34
- :undoc-members:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/make.bat DELETED
@@ -1,36 +0,0 @@
1
- @ECHO OFF
2
-
3
- pushd %~dp0
4
-
5
- REM Command file for Sphinx documentation
6
-
7
- if "%SPHINXBUILD%" == "" (
8
- set SPHINXBUILD=python -msphinx
9
- )
10
- set SOURCEDIR=.
11
- set BUILDDIR=_build
12
- set SPHINXPROJ=fairseq
13
-
14
- if "%1" == "" goto help
15
-
16
- %SPHINXBUILD% >NUL 2>NUL
17
- if errorlevel 9009 (
18
- echo.
19
- echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20
- echo.then set the SPHINXBUILD environment variable to point to the full
21
- echo.path of the 'sphinx-build' executable. Alternatively you may add the
22
- echo.Sphinx directory to PATH.
23
- echo.
24
- echo.If you don't have Sphinx installed, grab it from
25
- echo.http://sphinx-doc.org/
26
- exit /b 1
27
- )
28
-
29
- %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30
- goto end
31
-
32
- :help
33
- %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34
-
35
- :end
36
- popd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/models.rst DELETED
@@ -1,104 +0,0 @@
1
- .. role:: hidden
2
- :class: hidden-section
3
-
4
- .. module:: fairseq.models
5
-
6
- .. _Models:
7
-
8
- Models
9
- ======
10
-
11
- A Model defines the neural network's ``forward()`` method and encapsulates all
12
- of the learnable parameters in the network. Each model also provides a set of
13
- named *architectures* that define the precise network configuration (e.g.,
14
- embedding dimension, number of layers, etc.).
15
-
16
- Both the model type and architecture are selected via the ``--arch``
17
- command-line argument. Once selected, a model may expose additional command-line
18
- arguments for further configuration.
19
-
20
- .. note::
21
-
22
- All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
23
- :class:`torch.nn.Module`. Thus any fairseq Model can be used as a
24
- stand-alone Module in other PyTorch code.
25
-
26
-
27
- Convolutional Neural Networks (CNN)
28
- -----------------------------------
29
-
30
- .. module:: fairseq.models.fconv
31
- .. autoclass:: fairseq.models.fconv.FConvModel
32
- :members:
33
- .. autoclass:: fairseq.models.fconv.FConvEncoder
34
- :members:
35
- :undoc-members:
36
- .. autoclass:: fairseq.models.fconv.FConvDecoder
37
- :members:
38
-
39
-
40
- Long Short-Term Memory (LSTM) networks
41
- --------------------------------------
42
-
43
- .. module:: fairseq.models.lstm
44
- .. autoclass:: fairseq.models.lstm.LSTMModel
45
- :members:
46
- .. autoclass:: fairseq.models.lstm.LSTMEncoder
47
- :members:
48
- .. autoclass:: fairseq.models.lstm.LSTMDecoder
49
- :members:
50
-
51
-
52
- Transformer (self-attention) networks
53
- -------------------------------------
54
-
55
- .. module:: fairseq.models.transformer
56
- .. autoclass:: fairseq.models.transformer.TransformerModel
57
- :members:
58
- .. autoclass:: fairseq.models.transformer.TransformerEncoder
59
- :members:
60
- .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
61
- :members:
62
- .. autoclass:: fairseq.models.transformer.TransformerDecoder
63
- :members:
64
- .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
65
- :members:
66
-
67
-
68
- Adding new models
69
- -----------------
70
-
71
- .. currentmodule:: fairseq.models
72
- .. autofunction:: fairseq.models.register_model
73
- .. autofunction:: fairseq.models.register_model_architecture
74
- .. autoclass:: fairseq.models.BaseFairseqModel
75
- :members:
76
- :undoc-members:
77
- .. autoclass:: fairseq.models.FairseqEncoderDecoderModel
78
- :members:
79
- :undoc-members:
80
- .. autoclass:: fairseq.models.FairseqEncoderModel
81
- :members:
82
- :undoc-members:
83
- .. autoclass:: fairseq.models.FairseqLanguageModel
84
- :members:
85
- :undoc-members:
86
- .. autoclass:: fairseq.models.FairseqMultiModel
87
- :members:
88
- :undoc-members:
89
- .. autoclass:: fairseq.models.FairseqEncoder
90
- :members:
91
- .. autoclass:: fairseq.models.CompositeEncoder
92
- :members:
93
- .. autoclass:: fairseq.models.FairseqDecoder
94
- :members:
95
-
96
-
97
- .. _Incremental decoding:
98
-
99
- Incremental decoding
100
- --------------------
101
-
102
- .. autoclass:: fairseq.models.FairseqIncrementalDecoder
103
- :members:
104
- :undoc-members:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/modules.rst DELETED
@@ -1,9 +0,0 @@
1
- Modules
2
- =======
3
-
4
- Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
5
- be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
6
-
7
- .. automodule:: fairseq.modules
8
- :members:
9
- :undoc-members:
 
 
 
 
 
 
 
 
 
 
docs/optim.rst DELETED
@@ -1,38 +0,0 @@
1
- .. role:: hidden
2
- :class: hidden-section
3
-
4
- .. _optimizers:
5
-
6
- Optimizers
7
- ==========
8
-
9
- Optimizers update the Model parameters based on the gradients.
10
-
11
- .. automodule:: fairseq.optim
12
- :members:
13
-
14
- .. autoclass:: fairseq.optim.FairseqOptimizer
15
- :members:
16
- :undoc-members:
17
-
18
- .. autoclass:: fairseq.optim.adadelta.Adadelta
19
- :members:
20
- :undoc-members:
21
- .. autoclass:: fairseq.optim.adagrad.Adagrad
22
- :members:
23
- :undoc-members:
24
- .. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
25
- :members:
26
- :undoc-members:
27
- .. autoclass:: fairseq.optim.adam.FairseqAdam
28
- :members:
29
- :undoc-members:
30
- .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
31
- :members:
32
- :undoc-members:
33
- .. autoclass:: fairseq.optim.nag.FairseqNAG
34
- :members:
35
- :undoc-members:
36
- .. autoclass:: fairseq.optim.sgd.SGD
37
- :members:
38
- :undoc-members:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/overview.rst DELETED
@@ -1,74 +0,0 @@
1
- Overview
2
- ========
3
-
4
- Fairseq can be extended through user-supplied `plug-ins
5
- <https://en.wikipedia.org/wiki/Plug-in_(computing)>`_. We support five kinds of
6
- plug-ins:
7
-
8
- - :ref:`Models` define the neural network architecture and encapsulate all of the
9
- learnable parameters.
10
- - :ref:`Criterions` compute the loss function given the model outputs and targets.
11
- - :ref:`Tasks` store dictionaries and provide helpers for loading/iterating over
12
- Datasets, initializing the Model/Criterion and calculating the loss.
13
- - :ref:`Optimizers` update the Model parameters based on the gradients.
14
- - :ref:`Learning Rate Schedulers` update the learning rate over the course of
15
- training.
16
-
17
- **Training Flow**
18
-
19
- Given a ``model``, ``criterion``, ``task``, ``optimizer`` and ``lr_scheduler``,
20
- fairseq implements the following high-level training flow::
21
-
22
- for epoch in range(num_epochs):
23
- itr = task.get_batch_iterator(task.dataset('train'))
24
- for num_updates, batch in enumerate(itr):
25
- task.train_step(batch, model, criterion, optimizer)
26
- average_and_clip_gradients()
27
- optimizer.step()
28
- lr_scheduler.step_update(num_updates)
29
- lr_scheduler.step(epoch)
30
-
31
- where the default implementation for ``task.train_step`` is roughly::
32
-
33
- def train_step(self, batch, model, criterion, optimizer, **unused):
34
- loss = criterion(model, batch)
35
- optimizer.backward(loss)
36
- return loss
37
-
38
- **Registering new plug-ins**
39
-
40
- New plug-ins are *registered* through a set of ``@register`` function
41
- decorators, for example::
42
-
43
- @register_model('my_lstm')
44
- class MyLSTM(FairseqEncoderDecoderModel):
45
- (...)
46
-
47
- Once registered, new plug-ins can be used with the existing :ref:`Command-line
48
- Tools`. See the Tutorial sections for more detailed walkthroughs of how to add
49
- new plug-ins.
50
-
51
- **Loading plug-ins from another directory**
52
-
53
- New plug-ins can be defined in a custom module stored in the user system. In
54
- order to import the module, and make the plugin available to *fairseq*, the
55
- command line supports the ``--user-dir`` flag that can be used to specify a
56
- custom location for additional modules to load into *fairseq*.
57
-
58
- For example, assuming this directory tree::
59
-
60
- /home/user/my-module/
61
- └── __init__.py
62
-
63
- with ``__init__.py``::
64
-
65
- from fairseq.models import register_model_architecture
66
- from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
67
-
68
- @register_model_architecture('transformer', 'my_transformer')
69
- def transformer_mmt_big(args):
70
- transformer_vaswani_wmt_en_de_big(args)
71
-
72
- it is possible to invoke the :ref:`fairseq-train` script with the new architecture with::
73
-
74
- fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/requirements.txt DELETED
@@ -1,2 +0,0 @@
1
- sphinx<2.0
2
- sphinx-argparse
 
 
 
docs/tasks.rst DELETED
@@ -1,61 +0,0 @@
1
- .. role:: hidden
2
- :class: hidden-section
3
-
4
- .. module:: fairseq.tasks
5
-
6
- .. _Tasks:
7
-
8
- Tasks
9
- =====
10
-
11
- Tasks store dictionaries and provide helpers for loading/iterating over
12
- Datasets, initializing the Model/Criterion and calculating the loss.
13
-
14
- Tasks can be selected via the ``--task`` command-line argument. Once selected, a
15
- task may expose additional command-line arguments for further configuration.
16
-
17
- Example usage::
18
-
19
- # setup the task (e.g., load dictionaries)
20
- task = fairseq.tasks.setup_task(args)
21
-
22
- # build model and criterion
23
- model = task.build_model(args)
24
- criterion = task.build_criterion(args)
25
-
26
- # load datasets
27
- task.load_dataset('train')
28
- task.load_dataset('valid')
29
-
30
- # iterate over mini-batches of data
31
- batch_itr = task.get_batch_iterator(
32
- task.dataset('train'), max_tokens=4096,
33
- )
34
- for batch in batch_itr:
35
- # compute the loss
36
- loss, sample_size, logging_output = task.get_loss(
37
- model, criterion, batch,
38
- )
39
- loss.backward()
40
-
41
-
42
- Translation
43
- -----------
44
-
45
- .. autoclass:: fairseq.tasks.translation.TranslationTask
46
-
47
- .. _language modeling:
48
-
49
- Language Modeling
50
- -----------------
51
-
52
- .. autoclass:: fairseq.tasks.language_modeling.LanguageModelingTask
53
-
54
-
55
- Adding new tasks
56
- ----------------
57
-
58
- .. autofunction:: fairseq.tasks.register_task
59
- .. autoclass:: fairseq.tasks.FairseqTask
60
- :members:
61
- :undoc-members:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/tutorial_classifying_names.rst DELETED
@@ -1,415 +0,0 @@
1
- Tutorial: Classifying Names with a Character-Level RNN
2
- ======================================================
3
-
4
- In this tutorial we will extend fairseq to support *classification* tasks. In
5
- particular we will re-implement the PyTorch tutorial for `Classifying Names with
6
- a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
7
- in fairseq. It is recommended to quickly skim that tutorial before beginning
8
- this one.
9
-
10
- This tutorial covers:
11
-
12
- 1. **Preprocessing the data** to create dictionaries.
13
- 2. **Registering a new Model** that encodes an input sentence with a simple RNN
14
- and predicts the output label.
15
- 3. **Registering a new Task** that loads our dictionaries and dataset.
16
- 4. **Training the Model** using the existing command-line tools.
17
- 5. **Writing an evaluation script** that imports fairseq and allows us to
18
- interactively evaluate our model on new inputs.
19
-
20
-
21
- 1. Preprocessing the data
22
- -------------------------
23
-
24
- The original tutorial provides raw data, but we'll work with a modified version
25
- of the data that is already tokenized into characters and split into separate
26
- train, valid and test sets.
27
-
28
- Download and extract the data from here:
29
- `tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
30
-
31
- Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
32
- command-line tool to create the dictionaries. While this tool is primarily
33
- intended for sequence-to-sequence problems, we're able to reuse it here by
34
- treating the label as a "target" sequence of length 1. We'll also output the
35
- preprocessed files in "raw" format using the ``--dataset-impl`` option to
36
- enhance readability:
37
-
38
- .. code-block:: console
39
-
40
- > fairseq-preprocess \
41
- --trainpref names/train --validpref names/valid --testpref names/test \
42
- --source-lang input --target-lang label \
43
- --destdir names-bin --dataset-impl raw
44
-
45
- After running the above command you should see a new directory,
46
- :file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
47
-
48
-
49
- 2. Registering a new Model
50
- --------------------------
51
-
52
- Next we'll register a new model in fairseq that will encode an input sentence
53
- with a simple RNN and predict the output label. Compared to the original PyTorch
54
- tutorial, our version will also work with batches of data and GPU Tensors.
55
-
56
- First let's copy the simple RNN module implemented in the `PyTorch tutorial
57
- <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
58
- Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
59
- following contents::
60
-
61
- import torch
62
- import torch.nn as nn
63
-
64
- class RNN(nn.Module):
65
-
66
- def __init__(self, input_size, hidden_size, output_size):
67
- super(RNN, self).__init__()
68
-
69
- self.hidden_size = hidden_size
70
-
71
- self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
72
- self.i2o = nn.Linear(input_size + hidden_size, output_size)
73
- self.softmax = nn.LogSoftmax(dim=1)
74
-
75
- def forward(self, input, hidden):
76
- combined = torch.cat((input, hidden), 1)
77
- hidden = self.i2h(combined)
78
- output = self.i2o(combined)
79
- output = self.softmax(output)
80
- return output, hidden
81
-
82
- def initHidden(self):
83
- return torch.zeros(1, self.hidden_size)
84
-
85
- We must also *register* this model with fairseq using the
86
- :func:`~fairseq.models.register_model` function decorator. Once the model is
87
- registered we'll be able to use it with the existing :ref:`Command-line Tools`.
88
-
89
- All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
90
- interface, so we'll create a small wrapper class in the same file and register
91
- it in fairseq with the name ``'rnn_classifier'``::
92
-
93
- from fairseq.models import BaseFairseqModel, register_model
94
-
95
- # Note: the register_model "decorator" should immediately precede the
96
- # definition of the Model class.
97
-
98
- @register_model('rnn_classifier')
99
- class FairseqRNNClassifier(BaseFairseqModel):
100
-
101
- @staticmethod
102
- def add_args(parser):
103
- # Models can override this method to add new command-line arguments.
104
- # Here we'll add a new command-line argument to configure the
105
- # dimensionality of the hidden state.
106
- parser.add_argument(
107
- '--hidden-dim', type=int, metavar='N',
108
- help='dimensionality of the hidden state',
109
- )
110
-
111
- @classmethod
112
- def build_model(cls, args, task):
113
- # Fairseq initializes models by calling the ``build_model()``
114
- # function. This provides more flexibility, since the returned model
115
- # instance can be of a different type than the one that was called.
116
- # In this case we'll just return a FairseqRNNClassifier instance.
117
-
118
- # Initialize our RNN module
119
- rnn = RNN(
120
- # We'll define the Task in the next section, but for now just
121
- # notice that the task holds the dictionaries for the "source"
122
- # (i.e., the input sentence) and "target" (i.e., the label).
123
- input_size=len(task.source_dictionary),
124
- hidden_size=args.hidden_dim,
125
- output_size=len(task.target_dictionary),
126
- )
127
-
128
- # Return the wrapped version of the module
129
- return FairseqRNNClassifier(
130
- rnn=rnn,
131
- input_vocab=task.source_dictionary,
132
- )
133
-
134
- def __init__(self, rnn, input_vocab):
135
- super(FairseqRNNClassifier, self).__init__()
136
-
137
- self.rnn = rnn
138
- self.input_vocab = input_vocab
139
-
140
- # The RNN module in the tutorial expects one-hot inputs, so we can
141
- # precompute the identity matrix to help convert from indices to
142
- # one-hot vectors. We register it as a buffer so that it is moved to
143
- # the GPU when ``cuda()`` is called.
144
- self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
145
-
146
- def forward(self, src_tokens, src_lengths):
147
- # The inputs to the ``forward()`` function are determined by the
148
- # Task, and in particular the ``'net_input'`` key in each
149
- # mini-batch. We'll define the Task in the next section, but for
150
- # now just know that *src_tokens* has shape `(batch, src_len)` and
151
- # *src_lengths* has shape `(batch)`.
152
- bsz, max_src_len = src_tokens.size()
153
-
154
- # Initialize the RNN hidden state. Compared to the original PyTorch
155
- # tutorial we'll also handle batched inputs and work on the GPU.
156
- hidden = self.rnn.initHidden()
157
- hidden = hidden.repeat(bsz, 1) # expand for batched inputs
158
- hidden = hidden.to(src_tokens.device) # move to GPU
159
-
160
- for i in range(max_src_len):
161
- # WARNING: The inputs have padding, so we should mask those
162
- # elements here so that padding doesn't affect the results.
163
- # This is left as an exercise for the reader. The padding symbol
164
- # is given by ``self.input_vocab.pad()`` and the unpadded length
165
- # of each input is given by *src_lengths*.
166
-
167
- # One-hot encode a batch of input characters.
168
- input = self.one_hot_inputs[src_tokens[:, i].long()]
169
-
170
- # Feed the input to our RNN.
171
- output, hidden = self.rnn(input, hidden)
172
-
173
- # Return the final output state for making a prediction
174
- return output
175
-
176
- Finally let's define a *named architecture* with the configuration for our
177
- model. This is done with the :func:`~fairseq.models.register_model_architecture`
178
- function decorator. Thereafter this named architecture can be used with the
179
- ``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
180
-
181
- from fairseq.models import register_model_architecture
182
-
183
- # The first argument to ``register_model_architecture()`` should be the name
184
- # of the model we registered above (i.e., 'rnn_classifier'). The function we
185
- # register here should take a single argument *args* and modify it in-place
186
- # to match the desired architecture.
187
-
188
- @register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
189
- def pytorch_tutorial_rnn(args):
190
- # We use ``getattr()`` to prioritize arguments that are explicitly given
191
- # on the command-line, so that the defaults defined below are only used
192
- # when no other value has been specified.
193
- args.hidden_dim = getattr(args, 'hidden_dim', 128)
194
-
195
-
196
- 3. Registering a new Task
197
- -------------------------
198
-
199
- Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
200
- dictionaries and dataset. Tasks can also control how the data is batched into
201
- mini-batches, but in this tutorial we'll reuse the batching provided by
202
- :class:`fairseq.data.LanguagePairDataset`.
203
-
204
- Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
205
- following contents::
206
-
207
- import os
208
- import torch
209
-
210
- from fairseq.data import Dictionary, LanguagePairDataset
211
- from fairseq.tasks import FairseqTask, register_task
212
-
213
-
214
- @register_task('simple_classification')
215
- class SimpleClassificationTask(LegacyFairseqTask):
216
-
217
- @staticmethod
218
- def add_args(parser):
219
- # Add some command-line arguments for specifying where the data is
220
- # located and the maximum supported input length.
221
- parser.add_argument('data', metavar='FILE',
222
- help='file prefix for data')
223
- parser.add_argument('--max-positions', default=1024, type=int,
224
- help='max input length')
225
-
226
- @classmethod
227
- def setup_task(cls, args, **kwargs):
228
- # Here we can perform any setup required for the task. This may include
229
- # loading Dictionaries, initializing shared Embedding layers, etc.
230
- # In this case we'll just load the Dictionaries.
231
- input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
232
- label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
233
- print('| [input] dictionary: {} types'.format(len(input_vocab)))
234
- print('| [label] dictionary: {} types'.format(len(label_vocab)))
235
-
236
- return SimpleClassificationTask(args, input_vocab, label_vocab)
237
-
238
- def __init__(self, args, input_vocab, label_vocab):
239
- super().__init__(args)
240
- self.input_vocab = input_vocab
241
- self.label_vocab = label_vocab
242
-
243
- def load_dataset(self, split, **kwargs):
244
- """Load a given dataset split (e.g., train, valid, test)."""
245
-
246
- prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
247
-
248
- # Read input sentences.
249
- sentences, lengths = [], []
250
- with open(prefix + '.input', encoding='utf-8') as file:
251
- for line in file:
252
- sentence = line.strip()
253
-
254
- # Tokenize the sentence, splitting on spaces
255
- tokens = self.input_vocab.encode_line(
256
- sentence, add_if_not_exist=False,
257
- )
258
-
259
- sentences.append(tokens)
260
- lengths.append(tokens.numel())
261
-
262
- # Read labels.
263
- labels = []
264
- with open(prefix + '.label', encoding='utf-8') as file:
265
- for line in file:
266
- label = line.strip()
267
- labels.append(
268
- # Convert label to a numeric ID.
269
- torch.LongTensor([self.label_vocab.add_symbol(label)])
270
- )
271
-
272
- assert len(sentences) == len(labels)
273
- print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
274
-
275
- # We reuse LanguagePairDataset since classification can be modeled as a
276
- # sequence-to-sequence task where the target sequence has length 1.
277
- self.datasets[split] = LanguagePairDataset(
278
- src=sentences,
279
- src_sizes=lengths,
280
- src_dict=self.input_vocab,
281
- tgt=labels,
282
- tgt_sizes=torch.ones(len(labels)), # targets have length 1
283
- tgt_dict=self.label_vocab,
284
- left_pad_source=False,
285
- # Since our target is a single class label, there's no need for
286
- # teacher forcing. If we set this to ``True`` then our Model's
287
- # ``forward()`` method would receive an additional argument called
288
- # *prev_output_tokens* that would contain a shifted version of the
289
- # target sequence.
290
- input_feeding=False,
291
- )
292
-
293
- def max_positions(self):
294
- """Return the max input length allowed by the task."""
295
- # The source should be less than *args.max_positions* and the "target"
296
- # has max length 1.
297
- return (self.args.max_positions, 1)
298
-
299
- @property
300
- def source_dictionary(self):
301
- """Return the source :class:`~fairseq.data.Dictionary`."""
302
- return self.input_vocab
303
-
304
- @property
305
- def target_dictionary(self):
306
- """Return the target :class:`~fairseq.data.Dictionary`."""
307
- return self.label_vocab
308
-
309
- # We could override this method if we wanted more control over how batches
310
- # are constructed, but it's not necessary for this tutorial since we can
311
- # reuse the batching provided by LanguagePairDataset.
312
- #
313
- # def get_batch_iterator(
314
- # self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
315
- # ignore_invalid_inputs=False, required_batch_size_multiple=1,
316
- # seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
317
- # data_buffer_size=0, disable_iterator_cache=False,
318
- # ):
319
- # (...)
320
-
321
-
322
- 4. Training the Model
323
- ---------------------
324
-
325
- Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
326
- command-line tool for this, making sure to specify our new Task (``--task
327
- simple_classification``) and Model architecture (``--arch
328
- pytorch_tutorial_rnn``):
329
-
330
- .. note::
331
-
332
- You can also configure the dimensionality of the hidden state by passing the
333
- ``--hidden-dim`` argument to :ref:`fairseq-train`.
334
-
335
- .. code-block:: console
336
-
337
- > fairseq-train names-bin \
338
- --task simple_classification \
339
- --arch pytorch_tutorial_rnn \
340
- --optimizer adam --lr 0.001 --lr-shrink 0.5 \
341
- --max-tokens 1000
342
- (...)
343
- | epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
344
- | epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
345
- | done training in 31.6 seconds
346
-
347
- The model files should appear in the :file:`checkpoints/` directory.
348
-
349
-
350
- 5. Writing an evaluation script
351
- -------------------------------
352
-
353
- Finally we can write a short script to evaluate our model on new inputs. Create
354
- a new file named :file:`eval_classifier.py` with the following contents::
355
-
356
- from fairseq import checkpoint_utils, data, options, tasks
357
-
358
- # Parse command-line arguments for generation
359
- parser = options.get_generation_parser(default_task='simple_classification')
360
- args = options.parse_args_and_arch(parser)
361
-
362
- # Setup task
363
- task = tasks.setup_task(args)
364
-
365
- # Load model
366
- print('| loading model from {}'.format(args.path))
367
- models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
368
- model = models[0]
369
-
370
- while True:
371
- sentence = input('\nInput: ')
372
-
373
- # Tokenize into characters
374
- chars = ' '.join(list(sentence.strip()))
375
- tokens = task.source_dictionary.encode_line(
376
- chars, add_if_not_exist=False,
377
- )
378
-
379
- # Build mini-batch to feed to the model
380
- batch = data.language_pair_dataset.collate(
381
- samples=[{'id': -1, 'source': tokens}], # bsz = 1
382
- pad_idx=task.source_dictionary.pad(),
383
- eos_idx=task.source_dictionary.eos(),
384
- left_pad_source=False,
385
- input_feeding=False,
386
- )
387
-
388
- # Feed batch to the model and get predictions
389
- preds = model(**batch['net_input'])
390
-
391
- # Print top 3 predictions and their log-probabilities
392
- top_scores, top_labels = preds[0].topk(k=3)
393
- for score, label_idx in zip(top_scores, top_labels):
394
- label_name = task.target_dictionary.string([label_idx])
395
- print('({:.2f})\t{}'.format(score, label_name))
396
-
397
- Now we can evaluate our model interactively. Note that we have included the
398
- original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
399
-
400
- .. code-block:: console
401
-
402
- > python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
403
- | [input] dictionary: 64 types
404
- | [label] dictionary: 24 types
405
- | loading model from checkpoints/checkpoint_best.pt
406
-
407
- Input: Satoshi
408
- (-0.61) Japanese
409
- (-1.20) Arabic
410
- (-2.86) Italian
411
-
412
- Input: Sinbad
413
- (-0.30) Arabic
414
- (-1.76) English
415
- (-4.08) Russian
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/tutorial_simple_lstm.rst DELETED
@@ -1,518 +0,0 @@
1
- Tutorial: Simple LSTM
2
- =====================
3
-
4
- In this tutorial we will extend fairseq by adding a new
5
- :class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
6
- sentence with an LSTM and then passes the final hidden state to a second LSTM
7
- that decodes the target sentence (without attention).
8
-
9
- This tutorial covers:
10
-
11
- 1. **Writing an Encoder and Decoder** to encode/decode the source/target
12
- sentence, respectively.
13
- 2. **Registering a new Model** so that it can be used with the existing
14
- :ref:`Command-line tools`.
15
- 3. **Training the Model** using the existing command-line tools.
16
- 4. **Making generation faster** by modifying the Decoder to use
17
- :ref:`Incremental decoding`.
18
-
19
-
20
- 1. Building an Encoder and Decoder
21
- ----------------------------------
22
-
23
- In this section we'll define a simple LSTM Encoder and Decoder. All Encoders
24
- should implement the :class:`~fairseq.models.FairseqEncoder` interface and
25
- Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface.
26
- These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders
27
- and FairseqDecoders can be written and used in the same ways as ordinary PyTorch
28
- Modules.
29
-
30
-
31
- Encoder
32
- ~~~~~~~
33
-
34
- Our Encoder will embed the tokens in the source sentence, feed them to a
35
- :class:`torch.nn.LSTM` and return the final hidden state. To create our encoder
36
- save the following in a new file named :file:`fairseq/models/simple_lstm.py`::
37
-
38
- import torch.nn as nn
39
- from fairseq import utils
40
- from fairseq.models import FairseqEncoder
41
-
42
- class SimpleLSTMEncoder(FairseqEncoder):
43
-
44
- def __init__(
45
- self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
46
- ):
47
- super().__init__(dictionary)
48
- self.args = args
49
-
50
- # Our encoder will embed the inputs before feeding them to the LSTM.
51
- self.embed_tokens = nn.Embedding(
52
- num_embeddings=len(dictionary),
53
- embedding_dim=embed_dim,
54
- padding_idx=dictionary.pad(),
55
- )
56
- self.dropout = nn.Dropout(p=dropout)
57
-
58
- # We'll use a single-layer, unidirectional LSTM for simplicity.
59
- self.lstm = nn.LSTM(
60
- input_size=embed_dim,
61
- hidden_size=hidden_dim,
62
- num_layers=1,
63
- bidirectional=False,
64
- batch_first=True,
65
- )
66
-
67
- def forward(self, src_tokens, src_lengths):
68
- # The inputs to the ``forward()`` function are determined by the
69
- # Task, and in particular the ``'net_input'`` key in each
70
- # mini-batch. We discuss Tasks in the next tutorial, but for now just
71
- # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
72
- # has shape `(batch)`.
73
-
74
- # Note that the source is typically padded on the left. This can be
75
- # configured by adding the `--left-pad-source "False"` command-line
76
- # argument, but here we'll make the Encoder handle either kind of
77
- # padding by converting everything to be right-padded.
78
- if self.args.left_pad_source:
79
- # Convert left-padding to right-padding.
80
- src_tokens = utils.convert_padding_direction(
81
- src_tokens,
82
- padding_idx=self.dictionary.pad(),
83
- left_to_right=True
84
- )
85
-
86
- # Embed the source.
87
- x = self.embed_tokens(src_tokens)
88
-
89
- # Apply dropout.
90
- x = self.dropout(x)
91
-
92
- # Pack the sequence into a PackedSequence object to feed to the LSTM.
93
- x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
94
-
95
- # Get the output from the LSTM.
96
- _outputs, (final_hidden, _final_cell) = self.lstm(x)
97
-
98
- # Return the Encoder's output. This can be any object and will be
99
- # passed directly to the Decoder.
100
- return {
101
- # this will have shape `(bsz, hidden_dim)`
102
- 'final_hidden': final_hidden.squeeze(0),
103
- }
104
-
105
- # Encoders are required to implement this method so that we can rearrange
106
- # the order of the batch elements during inference (e.g., beam search).
107
- def reorder_encoder_out(self, encoder_out, new_order):
108
- """
109
- Reorder encoder output according to `new_order`.
110
-
111
- Args:
112
- encoder_out: output from the ``forward()`` method
113
- new_order (LongTensor): desired order
114
-
115
- Returns:
116
- `encoder_out` rearranged according to `new_order`
117
- """
118
- final_hidden = encoder_out['final_hidden']
119
- return {
120
- 'final_hidden': final_hidden.index_select(0, new_order),
121
- }
122
-
123
-
124
- Decoder
125
- ~~~~~~~
126
-
127
- Our Decoder will predict the next word, conditioned on the Encoder's final
128
- hidden state and an embedded representation of the previous target word -- which
129
- is sometimes called *teacher forcing*. More specifically, we'll use a
130
- :class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project
131
- to the size of the output vocabulary to predict each target word.
132
-
133
- ::
134
-
135
- import torch
136
- from fairseq.models import FairseqDecoder
137
-
138
- class SimpleLSTMDecoder(FairseqDecoder):
139
-
140
- def __init__(
141
- self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
142
- dropout=0.1,
143
- ):
144
- super().__init__(dictionary)
145
-
146
- # Our decoder will embed the inputs before feeding them to the LSTM.
147
- self.embed_tokens = nn.Embedding(
148
- num_embeddings=len(dictionary),
149
- embedding_dim=embed_dim,
150
- padding_idx=dictionary.pad(),
151
- )
152
- self.dropout = nn.Dropout(p=dropout)
153
-
154
- # We'll use a single-layer, unidirectional LSTM for simplicity.
155
- self.lstm = nn.LSTM(
156
- # For the first layer we'll concatenate the Encoder's final hidden
157
- # state with the embedded target tokens.
158
- input_size=encoder_hidden_dim + embed_dim,
159
- hidden_size=hidden_dim,
160
- num_layers=1,
161
- bidirectional=False,
162
- )
163
-
164
- # Define the output projection.
165
- self.output_projection = nn.Linear(hidden_dim, len(dictionary))
166
-
167
- # During training Decoders are expected to take the entire target sequence
168
- # (shifted right by one position) and produce logits over the vocabulary.
169
- # The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
170
- # ``dictionary.eos()``, followed by the target sequence.
171
- def forward(self, prev_output_tokens, encoder_out):
172
- """
173
- Args:
174
- prev_output_tokens (LongTensor): previous decoder outputs of shape
175
- `(batch, tgt_len)`, for teacher forcing
176
- encoder_out (Tensor, optional): output from the encoder, used for
177
- encoder-side attention
178
-
179
- Returns:
180
- tuple:
181
- - the last decoder layer's output of shape
182
- `(batch, tgt_len, vocab)`
183
- - the last decoder layer's attention weights of shape
184
- `(batch, tgt_len, src_len)`
185
- """
186
- bsz, tgt_len = prev_output_tokens.size()
187
-
188
- # Extract the final hidden state from the Encoder.
189
- final_encoder_hidden = encoder_out['final_hidden']
190
-
191
- # Embed the target sequence, which has been shifted right by one
192
- # position and now starts with the end-of-sentence symbol.
193
- x = self.embed_tokens(prev_output_tokens)
194
-
195
- # Apply dropout.
196
- x = self.dropout(x)
197
-
198
- # Concatenate the Encoder's final hidden state to *every* embedded
199
- # target token.
200
- x = torch.cat(
201
- [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
202
- dim=2,
203
- )
204
-
205
- # Using PackedSequence objects in the Decoder is harder than in the
206
- # Encoder, since the targets are not sorted in descending length order,
207
- # which is a requirement of ``pack_padded_sequence()``. Instead we'll
208
- # feed nn.LSTM directly.
209
- initial_state = (
210
- final_encoder_hidden.unsqueeze(0), # hidden
211
- torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
212
- )
213
- output, _ = self.lstm(
214
- x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
215
- initial_state,
216
- )
217
- x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
218
-
219
- # Project the outputs to the size of the vocabulary.
220
- x = self.output_projection(x)
221
-
222
- # Return the logits and ``None`` for the attention weights
223
- return x, None
224
-
225
-
226
- 2. Registering the Model
227
- ------------------------
228
-
229
- Now that we've defined our Encoder and Decoder we must *register* our model with
230
- fairseq using the :func:`~fairseq.models.register_model` function decorator.
231
- Once the model is registered we'll be able to use it with the existing
232
- :ref:`Command-line Tools`.
233
-
234
- All registered models must implement the
235
- :class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
236
- models (i.e., any model with a single Encoder and Decoder), we can instead
237
- implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
238
-
239
- Create a small wrapper class in the same file and register it in fairseq with
240
- the name ``'simple_lstm'``::
241
-
242
- from fairseq.models import FairseqEncoderDecoderModel, register_model
243
-
244
- # Note: the register_model "decorator" should immediately precede the
245
- # definition of the Model class.
246
-
247
- @register_model('simple_lstm')
248
- class SimpleLSTMModel(FairseqEncoderDecoderModel):
249
-
250
- @staticmethod
251
- def add_args(parser):
252
- # Models can override this method to add new command-line arguments.
253
- # Here we'll add some new command-line arguments to configure dropout
254
- # and the dimensionality of the embeddings and hidden states.
255
- parser.add_argument(
256
- '--encoder-embed-dim', type=int, metavar='N',
257
- help='dimensionality of the encoder embeddings',
258
- )
259
- parser.add_argument(
260
- '--encoder-hidden-dim', type=int, metavar='N',
261
- help='dimensionality of the encoder hidden state',
262
- )
263
- parser.add_argument(
264
- '--encoder-dropout', type=float, default=0.1,
265
- help='encoder dropout probability',
266
- )
267
- parser.add_argument(
268
- '--decoder-embed-dim', type=int, metavar='N',
269
- help='dimensionality of the decoder embeddings',
270
- )
271
- parser.add_argument(
272
- '--decoder-hidden-dim', type=int, metavar='N',
273
- help='dimensionality of the decoder hidden state',
274
- )
275
- parser.add_argument(
276
- '--decoder-dropout', type=float, default=0.1,
277
- help='decoder dropout probability',
278
- )
279
-
280
- @classmethod
281
- def build_model(cls, args, task):
282
- # Fairseq initializes models by calling the ``build_model()``
283
- # function. This provides more flexibility, since the returned model
284
- # instance can be of a different type than the one that was called.
285
- # In this case we'll just return a SimpleLSTMModel instance.
286
-
287
- # Initialize our Encoder and Decoder.
288
- encoder = SimpleLSTMEncoder(
289
- args=args,
290
- dictionary=task.source_dictionary,
291
- embed_dim=args.encoder_embed_dim,
292
- hidden_dim=args.encoder_hidden_dim,
293
- dropout=args.encoder_dropout,
294
- )
295
- decoder = SimpleLSTMDecoder(
296
- dictionary=task.target_dictionary,
297
- encoder_hidden_dim=args.encoder_hidden_dim,
298
- embed_dim=args.decoder_embed_dim,
299
- hidden_dim=args.decoder_hidden_dim,
300
- dropout=args.decoder_dropout,
301
- )
302
- model = SimpleLSTMModel(encoder, decoder)
303
-
304
- # Print the model architecture.
305
- print(model)
306
-
307
- return model
308
-
309
- # We could override the ``forward()`` if we wanted more control over how
310
- # the encoder and decoder interact, but it's not necessary for this
311
- # tutorial since we can inherit the default implementation provided by
312
- # the FairseqEncoderDecoderModel base class, which looks like:
313
- #
314
- # def forward(self, src_tokens, src_lengths, prev_output_tokens):
315
- # encoder_out = self.encoder(src_tokens, src_lengths)
316
- # decoder_out = self.decoder(prev_output_tokens, encoder_out)
317
- # return decoder_out
318
-
319
- Finally let's define a *named architecture* with the configuration for our
320
- model. This is done with the :func:`~fairseq.models.register_model_architecture`
321
- function decorator. Thereafter this named architecture can be used with the
322
- ``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``::
323
-
324
- from fairseq.models import register_model_architecture
325
-
326
- # The first argument to ``register_model_architecture()`` should be the name
327
- # of the model we registered above (i.e., 'simple_lstm'). The function we
328
- # register here should take a single argument *args* and modify it in-place
329
- # to match the desired architecture.
330
-
331
- @register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
332
- def tutorial_simple_lstm(args):
333
- # We use ``getattr()`` to prioritize arguments that are explicitly given
334
- # on the command-line, so that the defaults defined below are only used
335
- # when no other value has been specified.
336
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
337
- args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
338
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
339
- args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
340
-
341
-
342
- 3. Training the Model
343
- ---------------------
344
-
345
- Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
346
- command-line tool for this, making sure to specify our new Model architecture
347
- (``--arch tutorial_simple_lstm``).
348
-
349
- .. note::
350
-
351
- Make sure you've already preprocessed the data from the IWSLT example in the
352
- :file:`examples/translation/` directory.
353
-
354
- .. code-block:: console
355
-
356
- > fairseq-train data-bin/iwslt14.tokenized.de-en \
357
- --arch tutorial_simple_lstm \
358
- --encoder-dropout 0.2 --decoder-dropout 0.2 \
359
- --optimizer adam --lr 0.005 --lr-shrink 0.5 \
360
- --max-tokens 12000
361
- (...)
362
- | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
363
- | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
364
-
365
- The model files should appear in the :file:`checkpoints/` directory. While this
366
- model architecture is not very good, we can use the :ref:`fairseq-generate` script to
367
- generate translations and compute our BLEU score over the test set:
368
-
369
- .. code-block:: console
370
-
371
- > fairseq-generate data-bin/iwslt14.tokenized.de-en \
372
- --path checkpoints/checkpoint_best.pt \
373
- --beam 5 \
374
- --remove-bpe
375
- (...)
376
- | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
377
- | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
378
-
379
-
380
- 4. Making generation faster
381
- ---------------------------
382
-
383
- While autoregressive generation from sequence-to-sequence models is inherently
384
- slow, our implementation above is especially slow because it recomputes the
385
- entire sequence of Decoder hidden states for every output token (i.e., it is
386
- ``O(n^2)``). We can make this significantly faster by instead caching the
387
- previous hidden states.
388
-
389
- In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a
390
- special mode at inference time where the Model only receives a single timestep
391
- of input corresponding to the immediately previous output token (for teacher
392
- forcing) and must produce the next output incrementally. Thus the model must
393
- cache any long-term state that is needed about the sequence, e.g., hidden
394
- states, convolutional states, etc.
395
-
396
- To implement incremental decoding we will modify our model to implement the
397
- :class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the
398
- standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental
399
- decoder interface allows ``forward()`` methods to take an extra keyword argument
400
- (*incremental_state*) that can be used to cache state across time-steps.
401
-
402
- Let's replace our ``SimpleLSTMDecoder`` with an incremental one::
403
-
404
- import torch
405
- from fairseq.models import FairseqIncrementalDecoder
406
-
407
- class SimpleLSTMDecoder(FairseqIncrementalDecoder):
408
-
409
- def __init__(
410
- self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
411
- dropout=0.1,
412
- ):
413
- # This remains the same as before.
414
- super().__init__(dictionary)
415
- self.embed_tokens = nn.Embedding(
416
- num_embeddings=len(dictionary),
417
- embedding_dim=embed_dim,
418
- padding_idx=dictionary.pad(),
419
- )
420
- self.dropout = nn.Dropout(p=dropout)
421
- self.lstm = nn.LSTM(
422
- input_size=encoder_hidden_dim + embed_dim,
423
- hidden_size=hidden_dim,
424
- num_layers=1,
425
- bidirectional=False,
426
- )
427
- self.output_projection = nn.Linear(hidden_dim, len(dictionary))
428
-
429
- # We now take an additional kwarg (*incremental_state*) for caching the
430
- # previous hidden and cell states.
431
- def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
432
- if incremental_state is not None:
433
- # If the *incremental_state* argument is not ``None`` then we are
434
- # in incremental inference mode. While *prev_output_tokens* will
435
- # still contain the entire decoded prefix, we will only use the
436
- # last step and assume that the rest of the state is cached.
437
- prev_output_tokens = prev_output_tokens[:, -1:]
438
-
439
- # This remains the same as before.
440
- bsz, tgt_len = prev_output_tokens.size()
441
- final_encoder_hidden = encoder_out['final_hidden']
442
- x = self.embed_tokens(prev_output_tokens)
443
- x = self.dropout(x)
444
- x = torch.cat(
445
- [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
446
- dim=2,
447
- )
448
-
449
- # We will now check the cache and load the cached previous hidden and
450
- # cell states, if they exist, otherwise we will initialize them to
451
- # zeros (as before). We will use the ``utils.get_incremental_state()``
452
- # and ``utils.set_incremental_state()`` helpers.
453
- initial_state = utils.get_incremental_state(
454
- self, incremental_state, 'prev_state',
455
- )
456
- if initial_state is None:
457
- # first time initialization, same as the original version
458
- initial_state = (
459
- final_encoder_hidden.unsqueeze(0), # hidden
460
- torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
461
- )
462
-
463
- # Run one step of our LSTM.
464
- output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
465
-
466
- # Update the cache with the latest hidden and cell states.
467
- utils.set_incremental_state(
468
- self, incremental_state, 'prev_state', latest_state,
469
- )
470
-
471
- # This remains the same as before
472
- x = output.transpose(0, 1)
473
- x = self.output_projection(x)
474
- return x, None
475
-
476
- # The ``FairseqIncrementalDecoder`` interface also requires implementing a
477
- # ``reorder_incremental_state()`` method, which is used during beam search
478
- # to select and reorder the incremental state.
479
- def reorder_incremental_state(self, incremental_state, new_order):
480
- # Load the cached state.
481
- prev_state = utils.get_incremental_state(
482
- self, incremental_state, 'prev_state',
483
- )
484
-
485
- # Reorder batches according to *new_order*.
486
- reordered_state = (
487
- prev_state[0].index_select(1, new_order), # hidden
488
- prev_state[1].index_select(1, new_order), # cell
489
- )
490
-
491
- # Update the cached state.
492
- utils.set_incremental_state(
493
- self, incremental_state, 'prev_state', reordered_state,
494
- )
495
-
496
- Finally, we can rerun generation and observe the speedup:
497
-
498
- .. code-block:: console
499
-
500
- # Before
501
-
502
- > fairseq-generate data-bin/iwslt14.tokenized.de-en \
503
- --path checkpoints/checkpoint_best.pt \
504
- --beam 5 \
505
- --remove-bpe
506
- (...)
507
- | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
508
- | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
509
-
510
- # After
511
-
512
- > fairseq-generate data-bin/iwslt14.tokenized.de-en \
513
- --path checkpoints/checkpoint_best.pt \
514
- --beam 5 \
515
- --remove-bpe
516
- (...)
517
- | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
518
- | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/.gitignore DELETED
@@ -1,2 +0,0 @@
1
- !*/*.sh
2
- !*/*.md
 
 
 
examples/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- try:
7
- from fairseq.version import __version__ # noqa
8
- except ImportError:
9
- pass
 
 
 
 
 
 
 
 
 
 
examples/adaptive_span/README.md DELETED
@@ -1,90 +0,0 @@
1
- # Adaptive Span
2
-
3
- Adaptive Span is a novel self-attention mechanism that can learn its optimal
4
- attention span. This allows us to extend significantly the maximum context size
5
- used in Transformer, while maintaining control over their memory footprint
6
- and computational time. It uses the Truncated BPTT technique for training,
7
- as in [transformerXL](https://github.com/pytorch/fairseq/blob/master/examples/truncated_bptt/README.md).
8
-
9
- Adaptive Span was introduced by paper:
10
- [Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
11
- which achieved state-of-the-art language modeling results at the time of publication.
12
-
13
- We manage to reproduce their result in fairseq and keep most of the
14
- [original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
15
- You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
16
-
17
- ##### 0. Setup
18
-
19
- First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
20
- from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
21
- You can download the dataset, and then run:
22
- ```bash
23
- fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
24
- --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
25
- --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
26
- ```
27
-
28
- ##### 1. Train a Adaptive Span model on Enwik8
29
-
30
- We will train a 12-layer Adaptive Span model following the [hyperparameters
31
- used in the original
32
- paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
33
-
34
- The following command assumes 4 GPUs, so that the total batch size is 64
35
- sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
36
- ```bash
37
- CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
38
- --user-dir examples/adaptive_span \
39
- --data ~/data/enwik8/data-bin/ \
40
- --fp16 --fp16-no-flatten-grads --max-update 600000 \
41
- --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
42
- --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
43
- --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
44
- --validate-interval-updates 1000 \
45
- --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
46
- --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
47
- --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
48
- ```
49
- This should land around 1.05 on validation, 1.03 on test. You can lower the
50
- --aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
51
- improvement to the transformerXL baseline here.
52
- If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
53
- and simulate training on 4 GPUs.
54
- You can also reproduce the transformerXL result on enwik8 using this code base.
55
- It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
56
- You can try by
57
- ```bash
58
- CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
59
- --user-dir examples/truncated_bptt \
60
- ~/data/enwik8/data-bin/ \
61
- --task truncated_bptt_lm --fp16 --max-update 400000 \
62
- --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
63
- --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
64
- --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
65
- --lr-scheduler cosine --warmup-updates 0 \
66
- --lr 0.0 --lr 0.00025 --batch-size 15 \
67
- --update-freq 1 --seed 2 --log-format json --log-interval 25 \
68
- --fp16
69
- ```
70
-
71
- ##### 2. Evaluate
72
- For Adaptive Span:
73
- ```bash
74
- fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
75
- --user-dir examples/adaptive_span \
76
- --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
77
- ```
78
- For Transformer-XL evaluation:
79
- ```bash
80
- fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
81
- --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
82
- --tokens-per-sample 80 \
83
- --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
84
- --gen-subset valid
85
- ```
86
-
87
- *Note:* During training the model saw 512 tokens of context
88
- (``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
89
- settings from [the original
90
- paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/adaptive_span/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import importlib
7
- import os
8
-
9
- # automatically import any Python files in the current directory
10
- cur_dir = os.path.dirname(__file__)
11
- for file in os.listdir(cur_dir):
12
- path = os.path.join(cur_dir, file)
13
- if (
14
- not file.startswith("_")
15
- and not file.startswith(".")
16
- and (file.endswith(".py") or os.path.isdir(path))
17
- ):
18
- mod_name = file[: file.find(".py")] if file.endswith(".py") else file
19
- module = importlib.import_module(__name__ + "." + mod_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/adaptive_span/adagrad_with_grad_clip.py DELETED
@@ -1,128 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- from torch.optim import Adagrad
7
-
8
- from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
9
-
10
-
11
- @register_optimizer("adagrad_with_grad_clip")
12
- class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
13
- def __init__(self, args, params):
14
- super().__init__(args)
15
- self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
16
-
17
- @staticmethod
18
- def add_args(parser):
19
- """Add optimizer-specific arguments to the parser."""
20
- # fmt: off
21
- parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22
- help='weight decay')
23
- parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
24
- help='internal grad clip')
25
- # fmt: on
26
-
27
- @property
28
- def optimizer_config(self):
29
- """
30
- Return a kwarg dictionary that will be used to override optimizer
31
- args stored in checkpoints. This allows us to load a checkpoint and
32
- resume training using a different set of optimizer args, e.g., with a
33
- different learning rate.
34
- """
35
- return {
36
- "lr": self.args.lr[0],
37
- "weight_decay": self.args.weight_decay,
38
- "grad_clip": self.args.adagrad_clip,
39
- }
40
-
41
- @property
42
- def supports_flat_params(self):
43
- return False
44
-
45
-
46
- def _clip_grad(clr, grad, group_grad_clip):
47
- if group_grad_clip > 0:
48
- norm = grad.norm(2).item()
49
- if norm > group_grad_clip:
50
- clr *= group_grad_clip / (norm + 1e-10)
51
- return clr
52
-
53
-
54
- class AdagradWithGradClip(Adagrad):
55
- """Adagrad algorithm with custom gradient clipping"""
56
-
57
- def __init__(
58
- self,
59
- params,
60
- lr=1e-2,
61
- lr_decay=0,
62
- weight_decay=0,
63
- initial_accumulator_value=0,
64
- grad_clip=0,
65
- ):
66
- Adagrad.__init__(
67
- self,
68
- params,
69
- lr=lr,
70
- lr_decay=lr_decay,
71
- weight_decay=weight_decay,
72
- initial_accumulator_value=initial_accumulator_value,
73
- )
74
- self.defaults["grad_clip"] = grad_clip
75
- self.param_groups[0].setdefault("grad_clip", grad_clip)
76
-
77
- def step(self, closure=None):
78
- loss = None
79
- if closure is not None:
80
- loss = closure()
81
-
82
- for group in self.param_groups:
83
- for p in group["params"]:
84
- if p.grad is None:
85
- continue
86
-
87
- grad = p.grad.data
88
- state = self.state[p]
89
-
90
- state["step"] += 1
91
-
92
- if group["weight_decay"] != 0:
93
- if p.grad.data.is_sparse:
94
- raise RuntimeError(
95
- "weight_decay option is "
96
- "not compatible with sparse "
97
- "gradients"
98
- )
99
- grad = grad.add(group["weight_decay"], p.data)
100
-
101
- clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
102
-
103
- # clip
104
- clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
105
-
106
- if grad.is_sparse:
107
- # the update is non-linear so indices must be unique
108
- grad = grad.coalesce()
109
- grad_indices = grad._indices()
110
- grad_values = grad._values()
111
- size = grad.size()
112
-
113
- def make_sparse(values):
114
- constructor = grad.new
115
- if grad_indices.dim() == 0 or values.dim() == 0:
116
- return constructor().resize_as_(grad)
117
- return constructor(grad_indices, values, size)
118
-
119
- state["sum"].add_(make_sparse(grad_values.pow(2)))
120
- std = state["sum"]._sparse_mask(grad)
121
- std_values = std._values().sqrt_().add_(1e-10)
122
- p.data.add_(-clr, make_sparse(grad_values / std_values))
123
- else:
124
- state["sum"].addcmul_(1, grad, grad)
125
- std = state["sum"].sqrt().add_(1e-10)
126
- p.data.addcdiv_(-clr, grad, std)
127
-
128
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/adaptive_span/adaptive_span_attention.py DELETED
@@ -1,160 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- import math
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
-
12
- class AdaptiveMask(nn.Module):
13
- """Soft masking function for adaptive size.
14
- It masks out the last K values of an input. The masking value
15
- goes from 1 to 0 gradually, so K can be learned with
16
- back-propagation.
17
- Args:
18
- max_size: maximum size (i.e. input dimension)
19
- ramp_size: size of the ramp going from 0 to 1
20
- init_val: initial size proportion not to be masked out
21
- shape: learn multiple sizes independent of each other
22
- """
23
-
24
- def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
25
- nn.Module.__init__(self)
26
- self._max_size = max_size
27
- self._ramp_size = ramp_size
28
- self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
29
- mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
30
- self.register_buffer("mask_template", mask_template)
31
-
32
- def forward(self, x):
33
- mask = self.mask_template.float() + self.current_val.float() * self._max_size
34
- mask = mask / self._ramp_size + 1
35
- mask = mask.clamp(0, 1)
36
- if x.size(-1) < self._max_size:
37
- # the input could have been trimmed beforehand to save computation
38
- mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
39
- x = (x * mask).type_as(x)
40
- return x
41
-
42
- def get_current_max_size(self, include_ramp=True):
43
- current_size = math.ceil(self.current_val.max().item() * self._max_size)
44
- if include_ramp:
45
- current_size += self._ramp_size
46
- current_size = max(0, min(self._max_size, current_size))
47
- return current_size
48
-
49
- def get_current_avg_size(self, include_ramp=True):
50
- current_size = math.ceil(
51
- self.current_val.float().mean().item() * self._max_size
52
- )
53
- if include_ramp:
54
- current_size += self._ramp_size
55
- current_size = max(0, min(self._max_size, current_size))
56
- return current_size
57
-
58
- def clamp_param(self):
59
- """this need to be called after each update"""
60
- self.current_val.data.clamp_(0, 1)
61
-
62
-
63
- class AdaptiveSpan(nn.Module):
64
- """Adaptive attention span for Transformerself.
65
- This module learns an attention span length from data for each
66
- self-attention head.
67
- Args:
68
- attn_span: maximum attention span
69
- adapt_span_loss: loss coefficient for the span length
70
- adapt_span_ramp: length of the masking ramp
71
- adapt_span_init: initial size ratio
72
- adapt_span_cache: adapt cache size to reduce memory usage
73
- """
74
-
75
- def __init__(
76
- self,
77
- attn_span,
78
- adapt_span_ramp,
79
- adapt_span_init,
80
- n_head,
81
- adapt_span_layer,
82
- **kargs
83
- ):
84
- nn.Module.__init__(self)
85
- self._max_span = attn_span
86
- self._n_head = n_head
87
- self._adapt_span_layer = adapt_span_layer
88
- if self._adapt_span_layer:
89
- self._mask = AdaptiveMask(
90
- max_size=self._max_span,
91
- ramp_size=adapt_span_ramp,
92
- init_val=adapt_span_init,
93
- )
94
- else:
95
- self._mask = AdaptiveMask(
96
- max_size=self._max_span,
97
- ramp_size=adapt_span_ramp,
98
- init_val=adapt_span_init,
99
- shape=(n_head, 1, 1),
100
- )
101
-
102
- def forward(self, attn, normalize=True):
103
- """mask attention with the right span"""
104
- # batch and head dimensions are merged together, so separate them first
105
- self.clamp_param()
106
- if self._adapt_span_layer:
107
- attn = self._mask(attn)
108
- else:
109
- B = attn.size(0) # batch size
110
- M = attn.size(1) # block size
111
- attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
112
- attn = self._mask(attn)
113
- attn = attn.view(B, M, -1)
114
- return attn
115
-
116
- def get_trim_len(self):
117
- """how much of memory can be trimmed to reduce computation"""
118
- L = self._max_span
119
- trim_len = min(L - 1, L - self._mask.get_current_max_size())
120
- # too fine granularity might be bad for the memory management
121
- trim_len = math.floor(trim_len / 64) * 64
122
- return trim_len
123
-
124
- def trim_memory(self, query, key, value, key_pe):
125
- """trim out unnecessary memory beforehand to reduce computation"""
126
- trim_len = self.get_trim_len()
127
- cache_size = key.size(1) - query.size(1)
128
- trim_len_cache = trim_len - (self._max_span - cache_size)
129
- if trim_len_cache > 0:
130
- key = key[:, trim_len_cache:, :]
131
- value = value[:, trim_len_cache:, :]
132
- elif trim_len_cache < 0:
133
- # cache is too short! this happens when validation resumes
134
- # after a lot of updates.
135
- key = F.pad(key, [0, 0, -trim_len_cache, 0])
136
- value = F.pad(value, [0, 0, -trim_len_cache, 0])
137
- if trim_len > 0:
138
- if key_pe is not None:
139
- key_pe = key_pe[:, :, trim_len:]
140
- return key, value, key_pe
141
-
142
- def get_cache_size(self):
143
- """determine how long the cache should be"""
144
- trim_len = self.get_trim_len()
145
- # give a buffer of 64 steps since a span might increase
146
- # in future updates
147
- return min(self._max_span, self._max_span - trim_len + 64)
148
-
149
- def get_loss(self):
150
- """a loss term for regularizing the span length"""
151
- return self._max_span * self._mask.current_val.float().mean()
152
-
153
- def get_current_max_span(self):
154
- return self._mask.get_current_max_size()
155
-
156
- def get_current_avg_span(self):
157
- return self._mask.get_current_avg_size()
158
-
159
- def clamp_param(self):
160
- self._mask.clamp_param()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/adaptive_span/adaptive_span_loss.py DELETED
@@ -1,106 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import math
7
- from dataclasses import dataclass
8
-
9
- import torch.nn.functional as F
10
- from fairseq import metrics, utils
11
- from fairseq.criterions import register_criterion
12
- from fairseq.criterions.cross_entropy import CrossEntropyCriterion
13
- from fairseq.dataclass import FairseqDataclass
14
- from omegaconf import II
15
-
16
-
17
- @dataclass
18
- class AdaptiveSpanCriterionConfig(FairseqDataclass):
19
- sentence_avg: bool = II("optimization.sentence_avg")
20
-
21
-
22
- @register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
23
- class AdaptiveSpanCriterion(CrossEntropyCriterion):
24
- def __init__(self, task, sentence_avg):
25
- super().__init__(task, sentence_avg)
26
-
27
- def forward(self, model, sample, reduce=True):
28
- """Compute the loss for the given sample.
29
-
30
- Returns a tuple with three elements:
31
- 1) the loss here is summed, different from the adaptive span code
32
- 2) the sample size, which is used as the denominator for the gradient
33
- 3) logging outputs to display while training
34
- """
35
- net_output = model(**sample["net_input"])
36
- loss, aux_loss, avg_span, max_span = self.compute_loss(
37
- model, net_output, sample, reduce=reduce
38
- )
39
- sample_size = (
40
- sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
41
- )
42
- loss /= sample_size
43
- total_loss = loss + aux_loss
44
- sample_size = 1
45
-
46
- logging_output = {
47
- "loss": loss.data,
48
- "ntokens": sample["ntokens"],
49
- "nsentences": sample["target"].size(0),
50
- "sample_size": sample_size,
51
- "total_loss": total_loss.data,
52
- "avg_span": avg_span * sample_size,
53
- "max_span": max_span * sample_size,
54
- }
55
- return total_loss, sample_size, logging_output
56
-
57
- def compute_loss(self, model, net_output, sample, reduce=True):
58
- loss, _ = super().compute_loss(model, net_output, sample, reduce)
59
- aux_loss = model.get_aux_loss()
60
- avg_span = model.get_current_avg_span()
61
- max_span = model.get_current_max_span()
62
- return loss, aux_loss, avg_span, max_span
63
-
64
- @staticmethod
65
- def reduce_metrics(logging_outputs) -> None:
66
- """Aggregate logging outputs from data parallel training."""
67
- loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
68
- ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
69
- sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
70
- total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
71
- avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
72
- max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
73
-
74
- # we divide by log(2) to convert the loss from base e to base 2
75
- metrics.log_scalar(
76
- "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
77
- )
78
- metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
79
- metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
80
- # total loss contains the L1 norm on adaptive-span
81
- metrics.log_scalar(
82
- "total_loss",
83
- total_loss_sum / sample_size / math.log(2),
84
- sample_size,
85
- round=3,
86
- )
87
- if sample_size != ntokens:
88
- metrics.log_scalar(
89
- "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
90
- )
91
- metrics.log_derived(
92
- "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
93
- )
94
- else:
95
- metrics.log_derived(
96
- "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
97
- )
98
-
99
- @staticmethod
100
- def logging_outputs_can_be_summed() -> bool:
101
- """
102
- Whether the logging outputs returned by `forward` can be summed
103
- across workers prior to calling `reduce_metrics`. Setting this
104
- to True will improves distributed training speed.
105
- """
106
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/adaptive_span/adaptive_span_model.py DELETED
@@ -1,263 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import math
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
-
13
- from fairseq.modules.layer_norm import LayerNorm
14
-
15
- from .adaptive_span_attention import AdaptiveSpan
16
-
17
- # Size notations:
18
- # B = batch_size, H = d_model, M = block_size, L = attn_span
19
-
20
-
21
- def _skew(X, pad_value):
22
- """shift every row 1 step to right"""
23
- # X = B x M x L
24
- B, M, L = X.size()
25
- X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
26
- X = X.view(B, -1) # B x ML+MM+M
27
- X = X[:, :-M] # B x ML+MM
28
- X = X.view(B, M, M + L) # B x M x L+M
29
- return X
30
-
31
-
32
- def _unskew(X):
33
- """reverse _skew operation"""
34
- # X = B x M x L+M
35
- B, M, L = X.size()
36
- L -= M
37
- X = X.view(B, -1) # B x ML+MM
38
- X = F.pad(X, (0, M)) # B x ML+MM+M
39
- X = X.view(B, M, M + L + 1) # B x M x L+M+1
40
- X = X[:, :, :L] # B x M x L
41
- return X
42
-
43
-
44
- class SeqAttention(nn.Module):
45
- """Sequential self-attention layer.
46
- Each token will attend to its previous fixed number of steps.
47
- Note that attention doesn't include the current step itself.
48
- """
49
-
50
- def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
51
- nn.Module.__init__(self)
52
- self.dropout = nn.Dropout(dropout)
53
- self.d_model = d_model # size of a single head
54
- self.attn_span = attn_span
55
- self.adaptive_span = AdaptiveSpan(
56
- attn_span=attn_span,
57
- n_head=n_head,
58
- adapt_span_layer=adapt_span_layer,
59
- **kargs
60
- )
61
-
62
- def forward(self, query, key, value, key_pe):
63
- # query size = B x M x H
64
- # key, value sizes = B x (M+L) x H
65
-
66
- key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
67
-
68
- # compute attention from context
69
- # B x M (dest) x (M+L) (src)
70
- attn_cont = torch.matmul(query, key.transpose(-1, -2))
71
- attn_cont = _unskew(attn_cont) # B x M x L
72
-
73
- # compute the effect of position embedding
74
- attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
75
- attn = attn_cont + attn_pos
76
-
77
- attn = attn / math.sqrt(self.d_model) # B x M X L_pos
78
-
79
- attn = F.softmax(attn.float(), dim=-1).type_as(attn)
80
-
81
- # trim attention lengths according to the learned span
82
- attn = self.adaptive_span(attn)
83
-
84
- attn = self.dropout(attn) # B x M X L_pos
85
-
86
- attn_cont = _skew(attn, 0) # B x M X (L+M)
87
- out = torch.matmul(attn_cont, value) # B x M x H
88
- return out
89
-
90
- def get_cache_size(self):
91
- return self.adaptive_span.get_cache_size()
92
-
93
-
94
- class MultiHeadSeqAttention(nn.Module):
95
- def __init__(self, d_model, n_head, **kargs):
96
- nn.Module.__init__(self)
97
- assert d_model % n_head == 0
98
- self.n_head = n_head
99
- self.head_dim = d_model // n_head
100
- self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
101
- self.proj_query = nn.Linear(d_model, d_model, bias=False)
102
- nn.init.xavier_normal_(self.proj_query.weight)
103
- self.proj_out = nn.Linear(d_model, d_model, bias=False)
104
- nn.init.xavier_normal_(self.proj_out.weight)
105
- self.proj_val = nn.Linear(d_model, d_model, bias=False)
106
- nn.init.xavier_normal_(self.proj_val.weight)
107
- self.proj_key = nn.Linear(d_model, d_model, bias=False)
108
- nn.init.xavier_normal_(self.proj_key.weight)
109
-
110
- def head_reshape(self, x):
111
- K = self.n_head
112
- D = self.head_dim
113
- x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
114
- x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
115
- x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
116
- return x
117
-
118
- def forward(self, query, key, value, key_pe):
119
- B = query.size(0)
120
- K = self.n_head
121
- D = self.head_dim
122
- M = query.size(1)
123
-
124
- query = self.proj_query(query)
125
- query = self.head_reshape(query)
126
- value = self.proj_val(value)
127
- value = self.head_reshape(value)
128
- key = self.proj_key(key)
129
- key = self.head_reshape(key)
130
-
131
- out = self.attn(query, key, value, key_pe) # B_K x M x D
132
- out = out.view(B, K, M, D) # B x K x M x D
133
- out = out.transpose(1, 2).contiguous() # B x M x K x D
134
- out = out.view(B, M, -1) # B x M x K_D
135
- out = self.proj_out(out)
136
- return out
137
-
138
-
139
- class FeedForwardLayer(nn.Module):
140
- def __init__(self, d_model, d_inner, dropout, **kargs):
141
- nn.Module.__init__(self)
142
- self.fc1 = nn.Linear(d_model, d_inner)
143
- self.fc2 = nn.Linear(d_inner, d_model)
144
- nn.init.xavier_uniform_(self.fc1.weight)
145
- nn.init.xavier_uniform_(self.fc2.weight)
146
- self.dropout = nn.Dropout(dropout)
147
-
148
- def forward(self, h):
149
- h1 = F.relu(self.fc1(h))
150
- h1 = self.dropout(h1)
151
- h2 = self.fc2(h1)
152
- return h2
153
-
154
-
155
- class TransformerSeqLayer(nn.Module):
156
- def __init__(self, d_model, **kargs):
157
- nn.Module.__init__(self)
158
- self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
159
- self.norm1 = LayerNorm(d_model)
160
- self.ff = FeedForwardLayer(d_model=d_model, **kargs)
161
- self.norm2 = LayerNorm(d_model)
162
-
163
- def forward(self, h, h_cache, key_pe):
164
- # h = B x M x H
165
- # h_cache = B x L x H
166
- h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
167
- attn_out = self.attn(h, h_all, h_all, key_pe)
168
- h = self.norm1(h + attn_out) # B x M x H
169
- if self.ff is not None:
170
- ff_out = self.ff(h)
171
- out = self.norm2(h + ff_out) # B x M x H
172
- else:
173
- out = h
174
- return out
175
-
176
- def get_cache_size(self):
177
- return self.attn.attn.get_cache_size()
178
-
179
-
180
- class TransformerSeq(nn.Module):
181
- def __init__(
182
- self,
183
- vocab_size,
184
- d_model,
185
- n_head,
186
- n_layer,
187
- attn_span,
188
- emb_dropout,
189
- aux_loss_scaler,
190
- adapt_span_layer,
191
- **kargs
192
- ):
193
- nn.Module.__init__(self)
194
- # token embeddings
195
- self.in_emb = nn.Embedding(vocab_size, d_model)
196
- nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
197
- self.out_emb = nn.Linear(d_model, vocab_size)
198
- self.aux_loss_scaler = aux_loss_scaler
199
- if emb_dropout > 0:
200
- self.emb_dropout = nn.Dropout(emb_dropout)
201
- else:
202
- self.emb_dropout = None
203
- # position embeddings
204
- self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
205
-
206
- self.layers = nn.ModuleList()
207
- self.layers.extend(
208
- TransformerSeqLayer(
209
- d_model=d_model,
210
- n_head=n_head,
211
- attn_span=attn_span,
212
- adapt_span_layer=adapt_span_layer,
213
- **kargs
214
- )
215
- for _ in range(n_layer)
216
- )
217
-
218
- def forward(self, x, h_cache, target=None):
219
- # x size = B x M
220
- block_size = x.size(1)
221
- h = self.in_emb(x) # B x M x H
222
- if self.emb_dropout is not None:
223
- h = self.emb_dropout(h)
224
-
225
- h_cache_next = []
226
- for l, layer in enumerate(self.layers):
227
- cache_size = layer.attn.attn.get_cache_size()
228
- if cache_size > block_size:
229
- h_cache_next_l = torch.cat(
230
- [h_cache[l][:, -cache_size + block_size :, :], h], dim=1
231
- ).detach()
232
- else:
233
- h_cache_next_l = h[:, -cache_size:, :].detach()
234
- h_cache_next.append(h_cache_next_l)
235
- h = layer(h, h_cache[l], self.key_pe) # B x M x H
236
-
237
- if self.emb_dropout is not None:
238
- h = self.emb_dropout(h)
239
-
240
- out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
241
- dummy_loss = None
242
-
243
- return out, h_cache_next, dummy_loss
244
-
245
- def get_aux_loss(self):
246
- loss = 0.0
247
- for layer in self.layers:
248
- loss += layer.attn.attn.adaptive_span.get_loss()
249
- return self.aux_loss_scaler * loss
250
-
251
- def get_current_max_span(self):
252
- max_span = 0.0
253
- for layer in self.layers:
254
- max_span = max(
255
- max_span, layer.attn.attn.adaptive_span.get_current_max_span()
256
- )
257
- return max_span
258
-
259
- def get_current_avg_span(self):
260
- avg_span = 0.0
261
- for layer in self.layers:
262
- avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
263
- return avg_span / len(self.layers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/adaptive_span/adaptive_span_model_wrapper.py DELETED
@@ -1,145 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import logging
7
- from dataclasses import dataclass
8
- from typing import Dict, List, Optional
9
-
10
- import torch
11
- from fairseq.dataclass import FairseqDataclass
12
- from fairseq.models import (
13
- FairseqIncrementalDecoder,
14
- FairseqLanguageModel,
15
- register_model,
16
- )
17
- from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
18
-
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- @dataclass
24
- class AdaptiveSpanSmallConfig(FairseqDataclass):
25
- # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
26
- vocab_size: int = 50
27
- d_model: int = 256
28
- n_head: int = 4
29
- d_inner: int = 1024
30
- n_layer: int = 8
31
- attn_span: int = 1024
32
- dropout: float = 0.0
33
- emb_dropout: float = 0.0
34
- adapt_span_ramp: int = 32
35
- adapt_span_init: float = 0.0
36
- aux_loss_scaler: float = 0.000002
37
- adapt_span_layer: bool = False
38
-
39
-
40
- @register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
41
- class AdaptiveSpanTransformer(FairseqLanguageModel):
42
- @classmethod
43
- def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
44
- return cls(AdaptiveSpanDecoder(cfg, task))
45
-
46
- def get_aux_loss(self):
47
- return self.decoder.get_aux_loss()
48
-
49
- def get_current_max_span(self):
50
- return self.decoder.get_current_max_span()
51
-
52
- def get_current_avg_span(self):
53
- return self.decoder.get_current_avg_span()
54
-
55
-
56
- class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
57
- def __init__(self, cfg, task):
58
-
59
- super().__init__(task.target_dictionary)
60
-
61
- self.config = cfg
62
- config = AdaptiveSpanSmallConfig(
63
- vocab_size=len(task.target_dictionary),
64
- d_model=cfg.d_model,
65
- n_head=cfg.n_head,
66
- d_inner=cfg.d_inner,
67
- n_layer=cfg.n_layer,
68
- attn_span=cfg.attn_span,
69
- dropout=cfg.dropout,
70
- emb_dropout=cfg.emb_dropout,
71
- adapt_span_ramp=cfg.adapt_span_ramp,
72
- adapt_span_init=cfg.adapt_span_init,
73
- aux_loss_scaler=cfg.aux_loss_scaler,
74
- adapt_span_layer=cfg.adapt_span_layer,
75
- )
76
- logger.info(config)
77
- self.model = AdaptiveSpanTransformerModel(**config.__dict__)
78
-
79
- self._mems = None
80
-
81
- def forward(
82
- self,
83
- src_tokens,
84
- incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
85
- encoder_out=None,
86
- ):
87
- bsz = src_tokens.size(0)
88
- if incremental_state is not None: # used during inference
89
- mems = self.get_incremental_state("mems")
90
- src_tokens = src_tokens[:, -1:] # only keep the most recent token
91
- else:
92
- mems = self._mems
93
-
94
- if mems is None:
95
- # first time init
96
- mems = self.init_hid_cache(bsz)
97
- output = self.model(x=src_tokens, h_cache=mems,)
98
- if incremental_state is not None:
99
- self.set_incremental_state(incremental_state, "mems", output[1])
100
- else:
101
- self._mems = output[1]
102
- return (output[0],)
103
-
104
- def max_positions(self):
105
- return self.config.attn_span
106
-
107
- def init_hid_cache(self, batch_sz):
108
- hid = []
109
- for layer in self.model.layers:
110
- param = next(self.model.parameters())
111
- h = torch.zeros(
112
- batch_sz,
113
- layer.get_cache_size(),
114
- self.config.d_model,
115
- dtype=param.dtype,
116
- device=param.device,
117
- )
118
- hid.append(h)
119
- return hid
120
-
121
- def get_aux_loss(self):
122
- return self.model.get_aux_loss()
123
-
124
- def get_current_max_span(self):
125
- return self.model.get_current_max_span()
126
-
127
- def get_current_avg_span(self):
128
- return self.model.get_current_avg_span()
129
-
130
- def reorder_incremental_state(
131
- self,
132
- incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
133
- new_order: torch.Tensor,
134
- ):
135
- """Reorder incremental state.
136
-
137
- This will be called when the order of the input has changed from the
138
- previous time step. A typical use case is beam search, where the input
139
- order changes between time steps based on the selection of beams.
140
- """
141
- raise NotImplementedError("This is required for generation/beam search")
142
- # mems = self.get_incremental_state(incremental_state, "mems")
143
- # if mems is not None:
144
- # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
145
- # self.set_incremental_state(incremental_state, "mems", new_mems)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/adaptive_span/truncated_bptt_lm_task.py DELETED
@@ -1 +0,0 @@
1
- ../truncated_bptt/truncated_bptt_lm_task.py
 
 
examples/backtranslation/README.md DELETED
@@ -1,297 +0,0 @@
1
- # Understanding Back-Translation at Scale (Edunov et al., 2018)
2
-
3
- This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
4
-
5
- ## Pre-trained models
6
-
7
- Model | Description | Dataset | Download
8
- ---|---|---|---
9
- `transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
10
-
11
- ## Example usage (torch.hub)
12
-
13
- We require a few additional Python dependencies for preprocessing:
14
- ```bash
15
- pip install subword_nmt sacremoses
16
- ```
17
-
18
- Then to generate translations from the full model ensemble:
19
- ```python
20
- import torch
21
-
22
- # List available models
23
- torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
24
-
25
- # Load the WMT'18 En-De ensemble
26
- en2de_ensemble = torch.hub.load(
27
- 'pytorch/fairseq', 'transformer.wmt18.en-de',
28
- checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
29
- tokenizer='moses', bpe='subword_nmt')
30
-
31
- # The ensemble contains 5 models
32
- len(en2de_ensemble.models)
33
- # 5
34
-
35
- # Translate
36
- en2de_ensemble.translate('Hello world!')
37
- # 'Hallo Welt!'
38
- ```
39
-
40
- ## Training your own model (WMT'18 English-German)
41
-
42
- The following instructions can be adapted to reproduce the models from the paper.
43
-
44
-
45
- #### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
46
-
47
- First download and preprocess the data:
48
- ```bash
49
- # Download and prepare the data
50
- cd examples/backtranslation/
51
- bash prepare-wmt18en2de.sh
52
- cd ../..
53
-
54
- # Binarize the data
55
- TEXT=examples/backtranslation/wmt18_en_de
56
- fairseq-preprocess \
57
- --joined-dictionary \
58
- --source-lang en --target-lang de \
59
- --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
60
- --destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
61
- --workers 20
62
-
63
- # Copy the BPE code into the data-bin directory for future use
64
- cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
65
- ```
66
-
67
- (Optionally) Train a baseline model (English-German) using just the parallel data:
68
- ```bash
69
- CHECKPOINT_DIR=checkpoints_en_de_parallel
70
- fairseq-train --fp16 \
71
- data-bin/wmt18_en_de \
72
- --source-lang en --target-lang de \
73
- --arch transformer_wmt_en_de_big --share-all-embeddings \
74
- --dropout 0.3 --weight-decay 0.0 \
75
- --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
76
- --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
77
- --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
78
- --max-tokens 3584 --update-freq 16 \
79
- --max-update 30000 \
80
- --save-dir $CHECKPOINT_DIR
81
- # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
82
- # different number of GPUs.
83
- ```
84
-
85
- Average the last 10 checkpoints:
86
- ```bash
87
- python scripts/average_checkpoints.py \
88
- --inputs $CHECKPOINT_DIR \
89
- --num-epoch-checkpoints 10 \
90
- --output $CHECKPOINT_DIR/checkpoint.avg10.pt
91
- ```
92
-
93
- Evaluate BLEU:
94
- ```bash
95
- # tokenized BLEU on newstest2017:
96
- bash examples/backtranslation/tokenized_bleu.sh \
97
- wmt17 \
98
- en-de \
99
- data-bin/wmt18_en_de \
100
- data-bin/wmt18_en_de/code \
101
- $CHECKPOINT_DIR/checkpoint.avg10.pt
102
- # BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
103
- # compare to 29.46 in Table 1, which is also for tokenized BLEU
104
-
105
- # generally it's better to report (detokenized) sacrebleu though:
106
- bash examples/backtranslation/sacrebleu.sh \
107
- wmt17 \
108
- en-de \
109
- data-bin/wmt18_en_de \
110
- data-bin/wmt18_en_de/code \
111
- $CHECKPOINT_DIR/checkpoint.avg10.pt
112
- # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
113
- ```
114
-
115
-
116
- #### Step 2. Back-translate monolingual German data
117
-
118
- Train a reverse model (German-English) to do the back-translation:
119
- ```bash
120
- CHECKPOINT_DIR=checkpoints_de_en_parallel
121
- fairseq-train --fp16 \
122
- data-bin/wmt18_en_de \
123
- --source-lang de --target-lang en \
124
- --arch transformer_wmt_en_de_big --share-all-embeddings \
125
- --dropout 0.3 --weight-decay 0.0 \
126
- --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
127
- --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
128
- --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
129
- --max-tokens 3584 --update-freq 16 \
130
- --max-update 30000 \
131
- --save-dir $CHECKPOINT_DIR
132
- # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
133
- # different number of GPUs.
134
- ```
135
-
136
- Let's evaluate the back-translation (BT) model to make sure it is well trained:
137
- ```bash
138
- bash examples/backtranslation/sacrebleu.sh \
139
- wmt17 \
140
- de-en \
141
- data-bin/wmt18_en_de \
142
- data-bin/wmt18_en_de/code \
143
- $CHECKPOINT_DIR/checkpoint_best.py
144
- # BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
145
- # compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
146
- ```
147
-
148
- Next prepare the monolingual data:
149
- ```bash
150
- # Download and prepare the monolingual data
151
- # By default the script samples 25M monolingual sentences, which after
152
- # deduplication should be just over 24M sentences. These are split into 25
153
- # shards, each with 1M sentences (except for the last shard).
154
- cd examples/backtranslation/
155
- bash prepare-de-monolingual.sh
156
- cd ../..
157
-
158
- # Binarize each shard of the monolingual data
159
- TEXT=examples/backtranslation/wmt18_de_mono
160
- for SHARD in $(seq -f "%02g" 0 24); do \
161
- fairseq-preprocess \
162
- --only-source \
163
- --source-lang de --target-lang en \
164
- --joined-dictionary \
165
- --srcdict data-bin/wmt18_en_de/dict.de.txt \
166
- --testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
167
- --destdir data-bin/wmt18_de_mono/shard${SHARD} \
168
- --workers 20; \
169
- cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
170
- done
171
- ```
172
-
173
- Now we're ready to perform back-translation over the monolingual data. The
174
- following command generates via sampling, but it's possible to use greedy
175
- decoding (`--beam 1`), beam search (`--beam 5`),
176
- top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
177
- ```bash
178
- mkdir backtranslation_output
179
- for SHARD in $(seq -f "%02g" 0 24); do \
180
- fairseq-generate --fp16 \
181
- data-bin/wmt18_de_mono/shard${SHARD} \
182
- --path $CHECKPOINT_DIR/checkpoint_best.pt \
183
- --skip-invalid-size-inputs-valid-test \
184
- --max-tokens 4096 \
185
- --sampling --beam 1 \
186
- > backtranslation_output/sampling.shard${SHARD}.out; \
187
- done
188
- ```
189
-
190
- After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
191
- the back-translations and apply length ratio filters:
192
- ```bash
193
- python examples/backtranslation/extract_bt_data.py \
194
- --minlen 1 --maxlen 250 --ratio 1.5 \
195
- --output backtranslation_output/bt_data --srclang en --tgtlang de \
196
- backtranslation_output/sampling.shard*.out
197
-
198
- # Ensure lengths are the same:
199
- # wc -l backtranslation_output/bt_data.{en,de}
200
- # 21795614 backtranslation_output/bt_data.en
201
- # 21795614 backtranslation_output/bt_data.de
202
- # 43591228 total
203
- ```
204
-
205
- Binarize the filtered BT data and combine it with the parallel data:
206
- ```bash
207
- TEXT=backtranslation_output
208
- fairseq-preprocess \
209
- --source-lang en --target-lang de \
210
- --joined-dictionary \
211
- --srcdict data-bin/wmt18_en_de/dict.en.txt \
212
- --trainpref $TEXT/bt_data \
213
- --destdir data-bin/wmt18_en_de_bt \
214
- --workers 20
215
-
216
- # We want to train on the combined data, so we'll symlink the parallel + BT data
217
- # in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
218
- # and the BT data as "train1", so that fairseq will combine them automatically
219
- # and so that we can use the `--upsample-primary` option to upsample the
220
- # parallel data (if desired).
221
- PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
222
- BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
223
- COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
224
- mkdir -p $COMB_DATA
225
- for LANG in en de; do \
226
- ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
227
- for EXT in bin idx; do \
228
- ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
229
- ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
230
- ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
231
- ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
232
- done; \
233
- done
234
- ```
235
-
236
-
237
- #### 3. Train an English-German model over the combined parallel + BT data
238
-
239
- Finally we can train a model over the parallel + BT data:
240
- ```bash
241
- CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
242
- fairseq-train --fp16 \
243
- data-bin/wmt18_en_de_para_plus_bt \
244
- --upsample-primary 16 \
245
- --source-lang en --target-lang de \
246
- --arch transformer_wmt_en_de_big --share-all-embeddings \
247
- --dropout 0.3 --weight-decay 0.0 \
248
- --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
249
- --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
250
- --lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
251
- --max-tokens 3584 --update-freq 16 \
252
- --max-update 100000 \
253
- --save-dir $CHECKPOINT_DIR
254
- # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
255
- # different number of GPUs.
256
- ```
257
-
258
- Average the last 10 checkpoints:
259
- ```bash
260
- python scripts/average_checkpoints.py \
261
- --inputs $CHECKPOINT_DIR \
262
- --num-epoch-checkpoints 10 \
263
- --output $CHECKPOINT_DIR/checkpoint.avg10.pt
264
- ```
265
-
266
- Evaluate BLEU:
267
- ```bash
268
- # tokenized BLEU on newstest2017:
269
- bash examples/backtranslation/tokenized_bleu.sh \
270
- wmt17 \
271
- en-de \
272
- data-bin/wmt18_en_de \
273
- data-bin/wmt18_en_de/code \
274
- $CHECKPOINT_DIR/checkpoint.avg10.pt
275
- # BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
276
- # compare to 32.35 in Table 1, which is also for tokenized BLEU
277
-
278
- # generally it's better to report (detokenized) sacrebleu:
279
- bash examples/backtranslation/sacrebleu.sh \
280
- wmt17 \
281
- en-de \
282
- data-bin/wmt18_en_de \
283
- data-bin/wmt18_en_de/code \
284
- $CHECKPOINT_DIR/checkpoint.avg10.pt
285
- # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
286
- ```
287
-
288
-
289
- ## Citation
290
- ```bibtex
291
- @inproceedings{edunov2018backtranslation,
292
- title = {Understanding Back-Translation at Scale},
293
- author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
294
- booktitle = {Conference of the Association for Computational Linguistics (ACL)},
295
- year = 2018,
296
- }
297
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/backtranslation/deduplicate_lines.py DELETED
@@ -1,41 +0,0 @@
1
- #!/usr/bin/python3
2
- # Copyright (c) Facebook, Inc. and its affiliates.
3
- #
4
- # This source code is licensed under the MIT license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import argparse
8
- import fileinput
9
- import hashlib
10
- import sys
11
- from multiprocessing import Pool
12
-
13
-
14
- def get_hashes_and_lines(raw_line):
15
- hash = hashlib.md5(raw_line).hexdigest()
16
- return hash, raw_line
17
-
18
-
19
- def main():
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument("--workers", type=int, default=10)
22
- parser.add_argument("files", nargs="*", help="input files")
23
- args = parser.parse_args()
24
-
25
- seen = set()
26
- with fileinput.input(args.files, mode="rb") as h:
27
- pool = Pool(args.workers)
28
- results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
29
- for i, (hash, raw_line) in enumerate(results):
30
- if hash not in seen:
31
- seen.add(hash)
32
- sys.stdout.buffer.write(raw_line)
33
- if i % 1000000 == 0:
34
- print(i, file=sys.stderr, end="", flush=True)
35
- elif i % 100000 == 0:
36
- print(".", file=sys.stderr, end="", flush=True)
37
- print(file=sys.stderr, flush=True)
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/backtranslation/extract_bt_data.py DELETED
@@ -1,72 +0,0 @@
1
- #!/usr/bin/env python
2
- # Copyright (c) Facebook, Inc. and its affiliates.
3
- #
4
- # This source code is licensed under the MIT license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import argparse
8
- import fileinput
9
-
10
- from tqdm import tqdm
11
-
12
-
13
- def main():
14
- parser = argparse.ArgumentParser(
15
- description=(
16
- "Extract back-translations from the stdout of fairseq-generate. "
17
- "If there are multiply hypotheses for a source, we only keep the first one. "
18
- )
19
- )
20
- parser.add_argument("--output", required=True, help="output prefix")
21
- parser.add_argument(
22
- "--srclang", required=True, help="source language (extracted from H-* lines)"
23
- )
24
- parser.add_argument(
25
- "--tgtlang", required=True, help="target language (extracted from S-* lines)"
26
- )
27
- parser.add_argument("--minlen", type=int, help="min length filter")
28
- parser.add_argument("--maxlen", type=int, help="max length filter")
29
- parser.add_argument("--ratio", type=float, help="ratio filter")
30
- parser.add_argument("files", nargs="*", help="input files")
31
- args = parser.parse_args()
32
-
33
- def validate(src, tgt):
34
- srclen = len(src.split(" ")) if src != "" else 0
35
- tgtlen = len(tgt.split(" ")) if tgt != "" else 0
36
- if (
37
- (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
38
- or (
39
- args.maxlen is not None
40
- and (srclen > args.maxlen or tgtlen > args.maxlen)
41
- )
42
- or (
43
- args.ratio is not None
44
- and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
45
- )
46
- ):
47
- return False
48
- return True
49
-
50
- def safe_index(toks, index, default):
51
- try:
52
- return toks[index]
53
- except IndexError:
54
- return default
55
-
56
- with open(args.output + "." + args.srclang, "w") as src_h, open(
57
- args.output + "." + args.tgtlang, "w"
58
- ) as tgt_h:
59
- for line in tqdm(fileinput.input(args.files)):
60
- if line.startswith("S-"):
61
- tgt = safe_index(line.rstrip().split("\t"), 1, "")
62
- elif line.startswith("H-"):
63
- if tgt is not None:
64
- src = safe_index(line.rstrip().split("\t"), 2, "")
65
- if validate(src, tgt):
66
- print(src, file=src_h)
67
- print(tgt, file=tgt_h)
68
- tgt = None
69
-
70
-
71
- if __name__ == "__main__":
72
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/backtranslation/prepare-de-monolingual.sh DELETED
@@ -1,98 +0,0 @@
1
- #!/bin/bash
2
-
3
- SCRIPTS=mosesdecoder/scripts
4
- TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
5
- NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
6
- REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
7
- BPEROOT=subword-nmt/subword_nmt
8
-
9
-
10
- BPE_CODE=wmt18_en_de/code
11
- SUBSAMPLE_SIZE=25000000
12
- LANG=de
13
-
14
-
15
- OUTDIR=wmt18_${LANG}_mono
16
- orig=orig
17
- tmp=$OUTDIR/tmp
18
- mkdir -p $OUTDIR $tmp
19
-
20
-
21
- URLS=(
22
- "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz"
23
- "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz"
24
- "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz"
25
- "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz"
26
- "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz"
27
- "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz"
28
- "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz"
29
- "http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz"
30
- "http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz"
31
- "http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz"
32
- "http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz"
33
- )
34
- FILES=(
35
- "news.2007.de.shuffled.gz"
36
- "news.2008.de.shuffled.gz"
37
- "news.2009.de.shuffled.gz"
38
- "news.2010.de.shuffled.gz"
39
- "news.2011.de.shuffled.gz"
40
- "news.2012.de.shuffled.gz"
41
- "news.2013.de.shuffled.gz"
42
- "news.2014.de.shuffled.v2.gz"
43
- "news.2015.de.shuffled.gz"
44
- "news.2016.de.shuffled.gz"
45
- "news.2017.de.shuffled.deduped.gz"
46
- )
47
-
48
-
49
- cd $orig
50
- for ((i=0;i<${#URLS[@]};++i)); do
51
- file=${FILES[i]}
52
- if [ -f $file ]; then
53
- echo "$file already exists, skipping download"
54
- else
55
- url=${URLS[i]}
56
- wget "$url"
57
- fi
58
- done
59
- cd ..
60
-
61
-
62
- if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
63
- echo "found monolingual sample, skipping shuffle/sample/tokenize"
64
- else
65
- gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
66
- | shuf -n $SUBSAMPLE_SIZE \
67
- | perl $NORM_PUNC $LANG \
68
- | perl $REM_NON_PRINT_CHAR \
69
- | perl $TOKENIZER -threads 8 -a -l $LANG \
70
- > $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
71
- fi
72
-
73
-
74
- if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
75
- echo "found BPE monolingual sample, skipping BPE step"
76
- else
77
- python $BPEROOT/apply_bpe.py -c $BPE_CODE \
78
- < $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
79
- > $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
80
- fi
81
-
82
-
83
- if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
84
- echo "found deduplicated monolingual sample, skipping deduplication step"
85
- else
86
- python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
87
- > $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
88
- fi
89
-
90
-
91
- if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then
92
- echo "found sharded data, skipping sharding step"
93
- else
94
- split --lines 1000000 --numeric-suffixes \
95
- --additional-suffix .${LANG} \
96
- $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
97
- $OUTDIR/bpe.monolingual.dedup.
98
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/backtranslation/prepare-wmt18en2de.sh DELETED
@@ -1,135 +0,0 @@
1
- #!/bin/bash
2
- # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
3
-
4
- echo 'Cloning Moses github repository (for tokenization scripts)...'
5
- git clone https://github.com/moses-smt/mosesdecoder.git
6
-
7
- echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
8
- git clone https://github.com/rsennrich/subword-nmt.git
9
-
10
- SCRIPTS=mosesdecoder/scripts
11
- TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
12
- CLEAN=$SCRIPTS/training/clean-corpus-n.perl
13
- NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
14
- REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
15
- BPEROOT=subword-nmt/subword_nmt
16
- BPE_TOKENS=32000
17
-
18
- URLS=(
19
- "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
20
- "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
21
- "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
22
- "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
23
- "http://data.statmt.org/wmt17/translation-task/dev.tgz"
24
- "http://statmt.org/wmt14/test-full.tgz"
25
- )
26
- FILES=(
27
- "training-parallel-europarl-v7.tgz"
28
- "training-parallel-commoncrawl.tgz"
29
- "training-parallel-nc-v13.tgz"
30
- "rapid2016.tgz"
31
- "dev.tgz"
32
- "test-full.tgz"
33
- )
34
- CORPORA=(
35
- "training/europarl-v7.de-en"
36
- "commoncrawl.de-en"
37
- "training-parallel-nc-v13/news-commentary-v13.de-en"
38
- "rapid2016.de-en"
39
- )
40
-
41
- if [ ! -d "$SCRIPTS" ]; then
42
- echo "Please set SCRIPTS variable correctly to point to Moses scripts."
43
- exit 1
44
- fi
45
-
46
- OUTDIR=wmt18_en_de
47
-
48
- src=en
49
- tgt=de
50
- lang=en-de
51
- prep=$OUTDIR
52
- tmp=$prep/tmp
53
- orig=orig
54
-
55
- mkdir -p $orig $tmp $prep
56
-
57
- cd $orig
58
-
59
- for ((i=0;i<${#URLS[@]};++i)); do
60
- file=${FILES[i]}
61
- if [ -f $file ]; then
62
- echo "$file already exists, skipping download"
63
- else
64
- url=${URLS[i]}
65
- wget "$url"
66
- if [ -f $file ]; then
67
- echo "$url successfully downloaded."
68
- else
69
- echo "$url not successfully downloaded."
70
- exit 1
71
- fi
72
- if [ ${file: -4} == ".tgz" ]; then
73
- tar zxvf $file
74
- elif [ ${file: -4} == ".tar" ]; then
75
- tar xvf $file
76
- fi
77
- fi
78
- done
79
- cd ..
80
-
81
- echo "pre-processing train data..."
82
- for l in $src $tgt; do
83
- rm $tmp/train.tags.$lang.tok.$l
84
- for f in "${CORPORA[@]}"; do
85
- cat $orig/$f.$l | \
86
- perl $NORM_PUNC $l | \
87
- perl $REM_NON_PRINT_CHAR | \
88
- perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
89
- done
90
- done
91
-
92
- echo "pre-processing test data..."
93
- for l in $src $tgt; do
94
- if [ "$l" == "$src" ]; then
95
- t="src"
96
- else
97
- t="ref"
98
- fi
99
- grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
100
- sed -e 's/<seg id="[0-9]*">\s*//g' | \
101
- sed -e 's/\s*<\/seg>\s*//g' | \
102
- sed -e "s/\’/\'/g" | \
103
- perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
104
- echo ""
105
- done
106
-
107
- echo "splitting train and valid..."
108
- for l in $src $tgt; do
109
- awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
110
- awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
111
- done
112
-
113
- TRAIN=$tmp/train.de-en
114
- BPE_CODE=$prep/code
115
- rm -f $TRAIN
116
- for l in $src $tgt; do
117
- cat $tmp/train.$l >> $TRAIN
118
- done
119
-
120
- echo "learn_bpe.py on ${TRAIN}..."
121
- python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
122
-
123
- for L in $src $tgt; do
124
- for f in train.$L valid.$L test.$L; do
125
- echo "apply_bpe.py to ${f}..."
126
- python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
127
- done
128
- done
129
-
130
- perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
131
- perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
132
-
133
- for L in $src $tgt; do
134
- cp $tmp/bpe.test.$L $prep/test.$L
135
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/backtranslation/sacrebleu.sh DELETED
@@ -1,37 +0,0 @@
1
- #!/bin/bash
2
-
3
- if [ $# -ne 5 ]; then
4
- echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
- exit
6
- fi
7
-
8
-
9
- DATASET=$1
10
- LANGPAIR=$2
11
- DATABIN=$3
12
- BPECODE=$4
13
- MODEL=$5
14
-
15
- SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
- TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
-
18
-
19
- BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
- if [ ! -e $BPEROOT ]; then
21
- BPEROOT=subword-nmt/subword_nmt
22
- if [ ! -e $BPEROOT ]; then
23
- echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
- git clone https://github.com/rsennrich/subword-nmt.git
25
- fi
26
- fi
27
-
28
-
29
- sacrebleu -t $DATASET -l $LANGPAIR --echo src \
30
- | sacremoses tokenize -a -l $SRCLANG -q \
31
- | python $BPEROOT/apply_bpe.py -c $BPECODE \
32
- | fairseq-interactive $DATABIN --path $MODEL \
33
- -s $SRCLANG -t $TGTLANG \
34
- --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
35
- | grep ^H- | cut -f 3- \
36
- | sacremoses detokenize -l $TGTLANG -q \
37
- | sacrebleu -t $DATASET -l $LANGPAIR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/backtranslation/tokenized_bleu.sh DELETED
@@ -1,46 +0,0 @@
1
- #!/bin/bash
2
-
3
- if [ $# -ne 5 ]; then
4
- echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
- exit
6
- fi
7
-
8
-
9
- DATASET=$1
10
- LANGPAIR=$2
11
- DATABIN=$3
12
- BPECODE=$4
13
- MODEL=$5
14
-
15
- SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
- TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
-
18
-
19
- BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
- if [ ! -e $BPEROOT ]; then
21
- BPEROOT=subword-nmt/subword_nmt
22
- if [ ! -e $BPEROOT ]; then
23
- echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
- git clone https://github.com/rsennrich/subword-nmt.git
25
- fi
26
- fi
27
-
28
-
29
- TMP_REF=$(mktemp)
30
-
31
- sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
32
- | sacremoses normalize -l $TGTLANG -q \
33
- | sacremoses tokenize -a -l $TGTLANG -q \
34
- > $TMP_REF
35
-
36
- sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
37
- | sacremoses normalize -l $SRCLANG -q \
38
- | sacremoses tokenize -a -l $SRCLANG -q \
39
- | python $BPEROOT/apply_bpe.py -c $BPECODE \
40
- | fairseq-interactive $DATABIN --path $MODEL \
41
- -s $SRCLANG -t $TGTLANG \
42
- --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
43
- | grep ^H- | cut -f 3- \
44
- | fairseq-score --ref $TMP_REF
45
-
46
- rm -f $TMP_REF
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/bart/README.glue.md DELETED
@@ -1,99 +0,0 @@
1
- # Fine-tuning BART on GLUE tasks
2
-
3
- ### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
4
- ```bash
5
- wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
6
- python download_glue_data.py --data_dir glue_data --tasks all
7
- ```
8
-
9
- ### 2) Preprocess GLUE task data (same as RoBERTa):
10
- ```bash
11
- ./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
12
- ```
13
- `glue_task_name` is one of the following:
14
- `{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
15
- Use `ALL` for preprocessing all the glue tasks.
16
-
17
- ### 3) Fine-tuning on GLUE task:
18
- Example fine-tuning cmd for `RTE` task
19
- ```bash
20
- TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
21
- WARMUP_UPDATES=61 # 6 percent of the number of updates
22
- LR=1e-05 # Peak LR for polynomial LR scheduler.
23
- NUM_CLASSES=2
24
- MAX_SENTENCES=16 # Batch size.
25
- BART_PATH=/path/to/bart/model.pt
26
-
27
- CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
28
- --restore-file $BART_PATH \
29
- --batch-size $MAX_SENTENCES \
30
- --max-tokens 4400 \
31
- --task sentence_prediction \
32
- --add-prev-output-tokens \
33
- --layernorm-embedding \
34
- --share-all-embeddings \
35
- --share-decoder-input-output-embed \
36
- --reset-optimizer --reset-dataloader --reset-meters \
37
- --required-batch-size-multiple 1 \
38
- --init-token 0 \
39
- --arch bart_large \
40
- --criterion sentence_prediction \
41
- --num-classes $NUM_CLASSES \
42
- --dropout 0.1 --attention-dropout 0.1 \
43
- --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
44
- --clip-norm 0.0 \
45
- --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
46
- --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
47
- --max-epoch 10 \
48
- --find-unused-parameters \
49
- --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
50
- ```
51
-
52
- For each of the GLUE task, you will need to use following cmd-line arguments:
53
-
54
- Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
55
- ---|---|---|---|---|---|---|---|---
56
- `--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
57
- `--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
58
- `bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
59
- `--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
60
- `--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
61
-
62
- For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
63
-
64
- **Note:**
65
-
66
- a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
67
-
68
- b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
69
-
70
- ### Inference on GLUE task
71
- After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
72
-
73
- ```python
74
- from fairseq.models.bart import BARTModel
75
-
76
- bart = BARTModel.from_pretrained(
77
- 'checkpoints/',
78
- checkpoint_file='checkpoint_best.pt',
79
- data_name_or_path='RTE-bin'
80
- )
81
-
82
- label_fn = lambda label: bart.task.label_dictionary.string(
83
- [label + bart.task.label_dictionary.nspecial]
84
- )
85
- ncorrect, nsamples = 0, 0
86
- bart.cuda()
87
- bart.eval()
88
- with open('glue_data/RTE/dev.tsv') as fin:
89
- fin.readline()
90
- for index, line in enumerate(fin):
91
- tokens = line.strip().split('\t')
92
- sent1, sent2, target = tokens[1], tokens[2], tokens[3]
93
- tokens = bart.encode(sent1, sent2)
94
- prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
95
- prediction_label = label_fn(prediction)
96
- ncorrect += int(prediction_label == target)
97
- nsamples += 1
98
- print('| Accuracy: ', float(ncorrect)/float(nsamples))
99
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/bart/README.md DELETED
@@ -1,228 +0,0 @@
1
- # BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
2
-
3
- [https://arxiv.org/abs/1910.13461](https://arxiv.org/abs/1910.13461)
4
-
5
- ## Introduction
6
-
7
- BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
8
-
9
- ## Pre-trained models
10
-
11
- Model | Description | # params | Download
12
- ---|---|---|---
13
- `bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz)
14
- `bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
15
- `bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
16
- `bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
17
- `bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz)
18
-
19
- ## Results
20
-
21
- **[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
22
- _(dev set, single model, single-task finetuning)_
23
-
24
- Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
25
- ---|---|---|---|---|---|---|---|---
26
- `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
27
- `bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
28
-
29
- **[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
30
- _(dev set, no additional data used)_
31
-
32
- Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
33
- ---|---|---
34
- `roberta.large` | 88.9/94.6 | 86.5/89.4
35
- `bart.large` | 88.8/94.6 | 86.1/89.2
36
-
37
- **[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
38
- _(test set, no additional data used)_
39
-
40
- Model | R1 | R2 | RL
41
- ---|---|---|---
42
- `BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
43
- `bart.large` | 44.16 | 21.28 | 40.90
44
-
45
- ## Example usage
46
-
47
- ##### Load BART from torch.hub (PyTorch >= 1.1):
48
- ```python
49
- import torch
50
- bart = torch.hub.load('pytorch/fairseq', 'bart.large')
51
- bart.eval() # disable dropout (or leave in train mode to finetune)
52
- ```
53
-
54
- ##### Load BART (for PyTorch 1.0 or custom models):
55
- ```python
56
- # Download bart.large model
57
- wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
58
- tar -xzvf bart.large.tar.gz
59
-
60
- # Load the model in fairseq
61
- from fairseq.models.bart import BARTModel
62
- bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
63
- bart.eval() # disable dropout (or leave in train mode to finetune)
64
- ```
65
-
66
- ##### Apply Byte-Pair Encoding (BPE) to input text:
67
- ```python
68
- tokens = bart.encode('Hello world!')
69
- assert tokens.tolist() == [0, 31414, 232, 328, 2]
70
- bart.decode(tokens) # 'Hello world!'
71
- ```
72
-
73
- ##### Extract features from BART:
74
- ```python
75
- # Extract the last layer's features
76
- last_layer_features = bart.extract_features(tokens)
77
- assert last_layer_features.size() == torch.Size([1, 5, 1024])
78
-
79
- # Extract all layer's features from decoder (layer 0 is the embedding layer)
80
- all_layers = bart.extract_features(tokens, return_all_hiddens=True)
81
- assert len(all_layers) == 13
82
- assert torch.all(all_layers[-1] == last_layer_features)
83
- ```
84
-
85
- ##### Use BART for sentence-pair classification tasks:
86
- ```python
87
- # Download BART already finetuned for MNLI
88
- bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
89
- bart.eval() # disable dropout for evaluation
90
-
91
- # Encode a pair of sentences and make a prediction
92
- tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
93
- bart.predict('mnli', tokens).argmax() # 0: contradiction
94
-
95
- # Encode another pair of sentences
96
- tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
97
- bart.predict('mnli', tokens).argmax() # 2: entailment
98
- ```
99
-
100
- ##### Register a new (randomly initialized) classification head:
101
- ```python
102
- bart.register_classification_head('new_task', num_classes=3)
103
- logprobs = bart.predict('new_task', tokens)
104
- ```
105
-
106
- ##### Batched prediction:
107
- ```python
108
- import torch
109
- from fairseq.data.data_utils import collate_tokens
110
-
111
- bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
112
- bart.eval()
113
-
114
- batch_of_pairs = [
115
- ['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
116
- ['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
117
- ]
118
-
119
- batch = collate_tokens(
120
- [bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
121
- )
122
-
123
- logprobs = bart.predict('mnli', batch)
124
- print(logprobs.argmax(dim=1))
125
- # tensor([0, 2])
126
- ```
127
-
128
- ##### Using the GPU:
129
- ```python
130
- bart.cuda()
131
- bart.predict('new_task', tokens)
132
- ```
133
-
134
- #### Filling masks:
135
-
136
- BART can be used to fill multiple `<mask>` tokens in the input.
137
- ```python
138
- bart = torch.hub.load('pytorch/fairseq', 'bart.base')
139
- bart.eval()
140
- bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10)
141
- # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]
142
- ```
143
-
144
- Note that by default we enforce the output length to match the input length.
145
- This can be disabled by setting ``match_source_len=False``:
146
- ```
147
- bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10, match_source_len=False)
148
- # [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]]
149
- ```
150
-
151
- Example code to fill masks for a batch of sentences using GPU
152
- ```
153
- bart.cuda()
154
- bart.fill_mask(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.'], topk=3, beam=10)
155
- # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)),
156
- ('The dog was asleep on the couch', tensor(-0.6796))]]
157
- ```
158
-
159
- #### Evaluating the `bart.large.mnli` model:
160
-
161
- Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
162
- ```python
163
- label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
164
- ncorrect, nsamples = 0, 0
165
- bart.cuda()
166
- bart.eval()
167
- with open('glue_data/MNLI/dev_matched.tsv') as fin:
168
- fin.readline()
169
- for index, line in enumerate(fin):
170
- tokens = line.strip().split('\t')
171
- sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
172
- tokens = bart.encode(sent1, sent2)
173
- prediction = bart.predict('mnli', tokens).argmax().item()
174
- prediction_label = label_map[prediction]
175
- ncorrect += int(prediction_label == target)
176
- nsamples += 1
177
- print('| Accuracy: ', float(ncorrect)/float(nsamples))
178
- # Expected output: 0.9010
179
- ```
180
-
181
- #### Evaluating the `bart.large.cnn` model:
182
- - Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
183
- - For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores
184
- - `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search.
185
- In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`.
186
-
187
- In `fairseq`, summaries can be generated using:
188
-
189
- ```bash
190
- cp data-bin/cnn_dm/dict.source.txt checkpoints/
191
- python examples/bart/summarize.py \
192
- --model-dir pytorch/fairseq \
193
- --model-file bart.large.cnn \
194
- --src cnn_dm/test.source \
195
- --out cnn_dm/test.hypo
196
- ```
197
-
198
- For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
199
-
200
- ```bash
201
- export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
202
-
203
- # Tokenize hypothesis and target files.
204
- cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
205
- cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
206
- files2rouge test.hypo.tokenized test.hypo.target
207
- # Expected output: (ROUGE-2 Average_F: 0.21238)
208
- ```
209
-
210
-
211
- ## Finetuning
212
-
213
- - [Finetuning on GLUE](README.glue.md)
214
- - [Finetuning on CNN-DM](README.summarization.md)
215
-
216
- ## Citation
217
-
218
- ```bibtex
219
- @article{lewis2019bart,
220
- title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
221
- Language Generation, Translation, and Comprehension},
222
- author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
223
- Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
224
- and Luke Zettlemoyer },
225
- journal={arXiv preprint arXiv:1910.13461},
226
- year = {2019},
227
- }
228
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/bart/README.summarization.md DELETED
@@ -1,102 +0,0 @@
1
- # Fine-tuning BART on CNN-Dailymail summarization task
2
-
3
- ### 1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
4
-
5
- Follow the instructions [here](https://github.com/abisee/cnn-dailymail) to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in [this issue](https://github.com/pytorch/fairseq/issues/1391) or check out the code [here](https://github.com/artmatsak/cnn-dailymail).
6
-
7
- Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to download the original Extreme Summarization datasets, or check out the code [here](https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset), Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
8
-
9
- ### 2) BPE preprocess:
10
-
11
- ```bash
12
- wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
13
- wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
14
- wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
15
-
16
- TASK=cnn_dm
17
- for SPLIT in train val
18
- do
19
- for LANG in source target
20
- do
21
- python -m examples.roberta.multiprocessing_bpe_encoder \
22
- --encoder-json encoder.json \
23
- --vocab-bpe vocab.bpe \
24
- --inputs "$TASK/$SPLIT.$LANG" \
25
- --outputs "$TASK/$SPLIT.bpe.$LANG" \
26
- --workers 60 \
27
- --keep-empty;
28
- done
29
- done
30
- ```
31
-
32
- ### 3) Binarize dataset:
33
- ```bash
34
- fairseq-preprocess \
35
- --source-lang "source" \
36
- --target-lang "target" \
37
- --trainpref "${TASK}/train.bpe" \
38
- --validpref "${TASK}/val.bpe" \
39
- --destdir "${TASK}-bin/" \
40
- --workers 60 \
41
- --srcdict dict.txt \
42
- --tgtdict dict.txt;
43
- ```
44
-
45
- ### 4) Fine-tuning on CNN-DM summarization task:
46
- Example fine-tuning CNN-DM
47
- ```bash
48
- TOTAL_NUM_UPDATES=20000
49
- WARMUP_UPDATES=500
50
- LR=3e-05
51
- MAX_TOKENS=2048
52
- UPDATE_FREQ=4
53
- BART_PATH=/path/to/bart/model.pt
54
-
55
- CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \
56
- --restore-file $BART_PATH \
57
- --max-tokens $MAX_TOKENS \
58
- --task translation \
59
- --source-lang source --target-lang target \
60
- --truncate-source \
61
- --layernorm-embedding \
62
- --share-all-embeddings \
63
- --share-decoder-input-output-embed \
64
- --reset-optimizer --reset-dataloader --reset-meters \
65
- --required-batch-size-multiple 1 \
66
- --arch bart_large \
67
- --criterion label_smoothed_cross_entropy \
68
- --label-smoothing 0.1 \
69
- --dropout 0.1 --attention-dropout 0.1 \
70
- --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
71
- --clip-norm 0.1 \
72
- --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
73
- --fp16 --update-freq $UPDATE_FREQ \
74
- --skip-invalid-size-inputs-valid-test \
75
- --find-unused-parameters;
76
- ```
77
- Above is expected to run on `1` node with `8 32gb-V100`.
78
- Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
79
-
80
- Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
81
-
82
- ### Inference for CNN-DM test data using above trained checkpoint.
83
- After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example
84
-
85
- ```bash
86
- cp data-bin/cnn_dm/dict.source.txt checkpoints/
87
- python examples/bart/summarize.py \
88
- --model-dir checkpoints \
89
- --model-file checkpoint_best.pt \
90
- --src cnn_dm/test.source \
91
- --out cnn_dm/test.hypo
92
- ```
93
- For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10:
94
- ```bash
95
- cp data-bin/cnn_dm/dict.source.txt checkpoints/
96
- python examples/bart/summarize.py \
97
- --model-dir checkpoints \
98
- --model-file checkpoint_best.pt \
99
- --src cnn_dm/test.source \
100
- --out cnn_dm/test.hypo \
101
- --xsum-kwargs
102
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/bart/summarize.py DELETED
@@ -1,100 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import torch
7
- from fairseq.models.bart import BARTModel
8
- import argparse
9
-
10
- XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
11
- CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
12
-
13
-
14
- @torch.no_grad()
15
- def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs):
16
- count = 1
17
-
18
- # if n_obs is not None: bsz = min(bsz, n_obs)
19
-
20
- with open(infile) as source, open(outfile, "w") as fout:
21
- sline = source.readline().strip()
22
- slines = [sline]
23
- for sline in source:
24
- if n_obs is not None and count > n_obs:
25
- break
26
- if count % bsz == 0:
27
- hypotheses_batch = bart.sample(slines, **eval_kwargs)
28
- for hypothesis in hypotheses_batch:
29
- fout.write(hypothesis + "\n")
30
- fout.flush()
31
- slines = []
32
-
33
- slines.append(sline.strip())
34
- count += 1
35
-
36
- if slines != []:
37
- hypotheses_batch = bart.sample(slines, **eval_kwargs)
38
- for hypothesis in hypotheses_batch:
39
- fout.write(hypothesis + "\n")
40
- fout.flush()
41
-
42
-
43
- def main():
44
- """
45
- Usage::
46
-
47
- python examples/bart/summarize.py \
48
- --model-dir $HOME/bart.large.cnn \
49
- --model-file model.pt \
50
- --src $HOME/data-bin/cnn_dm/test.source
51
- """
52
- parser = argparse.ArgumentParser()
53
- parser.add_argument(
54
- "--model-dir",
55
- required=True,
56
- type=str,
57
- default="bart.large.cnn/",
58
- help="path containing model file and src_dict.txt",
59
- )
60
- parser.add_argument(
61
- "--model-file",
62
- default="checkpoint_best.pt",
63
- help="where in model_dir are weights saved",
64
- )
65
- parser.add_argument(
66
- "--src", default="test.source", help="text to summarize", type=str
67
- )
68
- parser.add_argument(
69
- "--out", default="test.hypo", help="where to save summaries", type=str
70
- )
71
- parser.add_argument("--bsz", default=32, help="where to save summaries", type=int)
72
- parser.add_argument(
73
- "--n", default=None, help="how many examples to summarize", type=int
74
- )
75
- parser.add_argument(
76
- "--xsum-kwargs",
77
- action="store_true",
78
- default=False,
79
- help="if true use XSUM_KWARGS else CNN_KWARGS",
80
- )
81
- args = parser.parse_args()
82
- eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS
83
- if args.model_dir == "pytorch/fairseq":
84
- bart = torch.hub.load("pytorch/fairseq", args.model_file)
85
- else:
86
- bart = BARTModel.from_pretrained(
87
- args.model_dir,
88
- checkpoint_file=args.model_file,
89
- data_name_or_path=args.model_dir,
90
- )
91
- bart = bart.eval()
92
- if torch.cuda.is_available():
93
- bart = bart.cuda().half()
94
- generate(
95
- bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs
96
- )
97
-
98
-
99
- if __name__ == "__main__":
100
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/byte_level_bpe/README.md DELETED
@@ -1,88 +0,0 @@
1
- # Neural Machine Translation with Byte-Level Subwords
2
-
3
- https://arxiv.org/abs/1909.03341
4
-
5
- We provide an implementation of byte-level byte-pair encoding (BBPE), taking IWSLT 2017 Fr-En translation as
6
- example.
7
-
8
- ## Data
9
- Get data and generate fairseq binary dataset:
10
- ```bash
11
- bash ./get_data.sh
12
- ```
13
-
14
- ## Model Training
15
- Train Transformer model with Bi-GRU embedding contextualization (implemented in `gru_transformer.py`):
16
- ```bash
17
- # VOCAB=bytes
18
- # VOCAB=chars
19
- VOCAB=bbpe2048
20
- # VOCAB=bpe2048
21
- # VOCAB=bbpe4096
22
- # VOCAB=bpe4096
23
- # VOCAB=bpe16384
24
- ```
25
- ```bash
26
- fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
27
- --arch gru_transformer --encoder-layers 2 --decoder-layers 2 --dropout 0.3 --share-all-embeddings \
28
- --optimizer adam --adam-betas '(0.9, 0.98)' \
29
- --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
30
- --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
31
- --log-format 'simple' --log-interval 100 --save-dir "checkpoints/${VOCAB}" \
32
- --batch-size 100 --max-update 100000 --update-freq 2
33
- ```
34
-
35
- ## Generation
36
- `fairseq-generate` requires bytes (BBPE) decoder to convert byte-level representation back to characters:
37
- ```bash
38
- # BPE=--bpe bytes
39
- # BPE=--bpe characters
40
- BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe2048.model
41
- # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe2048.model
42
- # BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe4096.model
43
- # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe4096.model
44
- # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe16384.model
45
- ```
46
-
47
- ```bash
48
- fairseq-generate "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
49
- --source-lang fr --gen-subset test --sacrebleu --path "checkpoints/${VOCAB}/checkpoint_last.pt" \
50
- --tokenizer moses --moses-target-lang en ${BPE}
51
- ```
52
- When using `fairseq-interactive`, bytes (BBPE) encoder/decoder is required to tokenize input data and detokenize model predictions:
53
- ```bash
54
- fairseq-interactive "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
55
- --path "checkpoints/${VOCAB}/checkpoint_last.pt" --input data/test.fr --tokenizer moses --moses-source-lang fr \
56
- --moses-target-lang en ${BPE} --buffer-size 1000 --max-tokens 10000
57
- ```
58
-
59
- ## Results
60
- | Vocabulary | Model | BLEU |
61
- |:-------------:|:-------------:|:-------------:|
62
- | Joint BPE 16k ([Kudo, 2018](https://arxiv.org/abs/1804.10959)) | 512d LSTM 2+2 | 33.81 |
63
- | Joint BPE 16k | Transformer base 2+2 (w/ GRU) | 36.64 (36.72) |
64
- | Joint BPE 4k | Transformer base 2+2 (w/ GRU) | 35.49 (36.10) |
65
- | Joint BBPE 4k | Transformer base 2+2 (w/ GRU) | 35.61 (35.82) |
66
- | Joint BPE 2k | Transformer base 2+2 (w/ GRU) | 34.87 (36.13) |
67
- | Joint BBPE 2k | Transformer base 2+2 (w/ GRU) | 34.98 (35.43) |
68
- | Characters | Transformer base 2+2 (w/ GRU) | 31.78 (33.30) |
69
- | Bytes | Transformer base 2+2 (w/ GRU) | 31.57 (33.62) |
70
-
71
-
72
- ## Citation
73
- ```
74
- @misc{wang2019neural,
75
- title={Neural Machine Translation with Byte-Level Subwords},
76
- author={Changhan Wang and Kyunghyun Cho and Jiatao Gu},
77
- year={2019},
78
- eprint={1909.03341},
79
- archivePrefix={arXiv},
80
- primaryClass={cs.CL}
81
- }
82
- ```
83
-
84
-
85
- ## Contact
86
- Changhan Wang ([changhan@fb.com](mailto:changhan@fb.com)),
87
- Kyunghyun Cho ([kyunghyuncho@fb.com](mailto:kyunghyuncho@fb.com)),
88
- Jiatao Gu ([jgu@fb.com](mailto:jgu@fb.com))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/byte_level_bpe/get_bitext.py DELETED
@@ -1,254 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
-
7
- import argparse
8
- import os
9
- import os.path as op
10
- from collections import namedtuple
11
- from multiprocessing import cpu_count
12
- from typing import List, Optional
13
-
14
- import sentencepiece as sp
15
- from fairseq.data.encoders.byte_bpe import ByteBPE
16
- from fairseq.data.encoders.byte_utils import byte_encode
17
- from fairseq.data.encoders.bytes import Bytes
18
- from fairseq.data.encoders.characters import Characters
19
- from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
20
- from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
21
-
22
-
23
- SPLITS = ["train", "valid", "test"]
24
-
25
-
26
- def _convert_xml(in_path: str, out_path: str):
27
- with open(in_path) as f, open(out_path, "w") as f_o:
28
- for s in f:
29
- ss = s.strip()
30
- if not ss.startswith("<seg"):
31
- continue
32
- ss = ss.replace("</seg>", "").split('">')
33
- assert len(ss) == 2
34
- f_o.write(ss[1].strip() + "\n")
35
-
36
-
37
- def _convert_train(in_path: str, out_path: str):
38
- with open(in_path) as f, open(out_path, "w") as f_o:
39
- for s in f:
40
- ss = s.strip()
41
- if ss.startswith("<"):
42
- continue
43
- f_o.write(ss.strip() + "\n")
44
-
45
-
46
- def _get_bytes(in_path: str, out_path: str):
47
- with open(in_path) as f, open(out_path, "w") as f_o:
48
- for s in f:
49
- f_o.write(Bytes.encode(s.strip()) + "\n")
50
-
51
-
52
- def _get_chars(in_path: str, out_path: str):
53
- with open(in_path) as f, open(out_path, "w") as f_o:
54
- for s in f:
55
- f_o.write(Characters.encode(s.strip()) + "\n")
56
-
57
-
58
- def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
59
- Args = namedtuple(
60
- "Args",
61
- [
62
- "moses_source_lang",
63
- "moses_target_lang",
64
- "moses_no_dash_splits",
65
- "moses_no_escape",
66
- ],
67
- )
68
- args = Args(
69
- moses_source_lang=src,
70
- moses_target_lang=tgt,
71
- moses_no_dash_splits=False,
72
- moses_no_escape=False,
73
- )
74
- pretokenizer = MosesTokenizer(args)
75
- with open(in_path) as f, open(out_path, "w") as f_o:
76
- for s in f:
77
- f_o.write(pretokenizer.encode(s.strip()) + "\n")
78
-
79
-
80
- def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
81
- with open(out_path, "w") as f_o:
82
- for lang in [src, tgt]:
83
- with open(f"{in_path_prefix}.{lang}") as f:
84
- for s in f:
85
- f_o.write(byte_encode(s.strip()) + "\n")
86
-
87
-
88
- def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
89
- arguments = [
90
- f"--input={in_path}",
91
- f"--model_prefix={model_prefix}",
92
- f"--model_type=bpe",
93
- f"--vocab_size={vocab_size}",
94
- "--character_coverage=1.0",
95
- "--normalization_rule_name=identity",
96
- f"--num_threads={cpu_count()}",
97
- ]
98
- sp.SentencePieceTrainer.Train(" ".join(arguments))
99
-
100
-
101
- def _apply_bbpe(model_path: str, in_path: str, out_path: str):
102
- Args = namedtuple("Args", ["sentencepiece_model_path"])
103
- args = Args(sentencepiece_model_path=model_path)
104
- tokenizer = ByteBPE(args)
105
- with open(in_path) as f, open(out_path, "w") as f_o:
106
- for s in f:
107
- f_o.write(tokenizer.encode(s.strip()) + "\n")
108
-
109
-
110
- def _apply_bpe(model_path: str, in_path: str, out_path: str):
111
- Args = namedtuple("Args", ["sentencepiece_model"])
112
- args = Args(sentencepiece_model=model_path)
113
- tokenizer = SentencepieceBPE(args)
114
- with open(in_path) as f, open(out_path, "w") as f_o:
115
- for s in f:
116
- f_o.write(tokenizer.encode(s.strip()) + "\n")
117
-
118
-
119
- def _concat_files(in_paths: List[str], out_path: str):
120
- with open(out_path, "w") as f_o:
121
- for p in in_paths:
122
- with open(p) as f:
123
- for r in f:
124
- f_o.write(r)
125
-
126
-
127
- def preprocess_iwslt17(
128
- root: str,
129
- src: str,
130
- tgt: str,
131
- bpe_size: Optional[int],
132
- need_chars: bool,
133
- bbpe_size: Optional[int],
134
- need_bytes: bool,
135
- ):
136
- # extract bitext
137
- in_root = op.join(root, f"{src}-{tgt}")
138
- for lang in [src, tgt]:
139
- _convert_train(
140
- op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
141
- op.join(root, f"train.{lang}"),
142
- )
143
- _convert_xml(
144
- op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
145
- op.join(root, f"valid.{lang}"),
146
- )
147
- _convert_xml(
148
- op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
149
- op.join(root, f"test.{lang}"),
150
- )
151
- # pre-tokenize
152
- for lang in [src, tgt]:
153
- for split in SPLITS:
154
- pretokenize(
155
- op.join(root, f"{split}.{lang}"),
156
- op.join(root, f"{split}.moses.{lang}"),
157
- src,
158
- tgt,
159
- )
160
- # tokenize with BPE vocabulary
161
- if bpe_size is not None:
162
- # learn vocabulary
163
- concated_train_path = op.join(root, "train.all")
164
- _concat_files(
165
- [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
166
- concated_train_path,
167
- )
168
- bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
169
- _get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
170
- os.remove(concated_train_path)
171
- # apply
172
- for lang in [src, tgt]:
173
- for split in SPLITS:
174
- _apply_bpe(
175
- bpe_model_prefix + ".model",
176
- op.join(root, f"{split}.moses.{lang}"),
177
- op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
178
- )
179
- # tokenize with bytes vocabulary
180
- if need_bytes:
181
- for lang in [src, tgt]:
182
- for split in SPLITS:
183
- _get_bytes(
184
- op.join(root, f"{split}.moses.{lang}"),
185
- op.join(root, f"{split}.moses.bytes.{lang}"),
186
- )
187
- # tokenize with characters vocabulary
188
- if need_chars:
189
- for lang in [src, tgt]:
190
- for split in SPLITS:
191
- _get_chars(
192
- op.join(root, f"{split}.moses.{lang}"),
193
- op.join(root, f"{split}.moses.chars.{lang}"),
194
- )
195
- # tokenize with byte-level BPE vocabulary
196
- if bbpe_size is not None:
197
- # learn vocabulary
198
- bchar_path = op.join(root, "train.bchar")
199
- _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
200
- bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
201
- _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
202
- os.remove(bchar_path)
203
- # apply
204
- for lang in [src, tgt]:
205
- for split in SPLITS:
206
- _apply_bbpe(
207
- bbpe_model_prefix + ".model",
208
- op.join(root, f"{split}.moses.{lang}"),
209
- op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
210
- )
211
-
212
-
213
- def main():
214
- parser = argparse.ArgumentParser()
215
- parser.add_argument("--root", type=str, default="data")
216
- parser.add_argument(
217
- "--bpe-vocab",
218
- default=None,
219
- type=int,
220
- help="Generate tokenized bitext with BPE of size K."
221
- "Default to None (disabled).",
222
- )
223
- parser.add_argument(
224
- "--bbpe-vocab",
225
- default=None,
226
- type=int,
227
- help="Generate tokenized bitext with BBPE of size K."
228
- "Default to None (disabled).",
229
- )
230
- parser.add_argument(
231
- "--byte-vocab",
232
- action="store_true",
233
- help="Generate tokenized bitext with bytes vocabulary",
234
- )
235
- parser.add_argument(
236
- "--char-vocab",
237
- action="store_true",
238
- help="Generate tokenized bitext with chars vocabulary",
239
- )
240
- args = parser.parse_args()
241
-
242
- preprocess_iwslt17(
243
- args.root,
244
- "fr",
245
- "en",
246
- args.bpe_vocab,
247
- args.char_vocab,
248
- args.bbpe_vocab,
249
- args.byte_vocab,
250
- )
251
-
252
-
253
- if __name__ == "__main__":
254
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/byte_level_bpe/get_data.sh DELETED
@@ -1,47 +0,0 @@
1
- #!/bin/bash
2
-
3
- # Copyright (c) Facebook, Inc. and its affiliates.
4
- #
5
- # This source code is licensed under the MIT license found in the
6
- # LICENSE file in the root directory of this source tree.
7
-
8
- PY_BIN_ROOT=
9
-
10
- # PyPI dependency
11
- ${PY_BIN_ROOT}pip install sentencepiece sacremoses
12
-
13
- # Get data
14
- if [ ! -d "data" ]; then
15
- mkdir data
16
- fi
17
-
18
- if [ ! -f "data/fr-en.tgz" ]; then
19
- wget https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz -P data
20
- tar xvf data/fr-en.tgz -C data
21
- fi
22
- ${PY_BIN_ROOT}python get_bitext.py --bpe-vocab 16384 --byte-vocab --char-vocab
23
- for VOCAB_SIZE in 2048 4096; do
24
- ${PY_BIN_ROOT}python get_bitext.py --bpe-vocab ${VOCAB_SIZE} --bbpe-vocab ${VOCAB_SIZE}
25
- done
26
- rm -r data/fr-en data/fr-en.tgz
27
-
28
- # Generate binary dataset
29
- ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bpe16384 --joined-dictionary \
30
- --workers "$(nproc)" --trainpref data/train.moses.bpe16384 --validpref data/valid.moses.bpe16384 \
31
- --testpref data/test.moses.bpe16384
32
-
33
- ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bytes --joined-dictionary \
34
- --workers "$(nproc)" --trainpref data/train.moses.bytes --validpref data/valid.moses.bytes \
35
- --testpref data/test.moses.bytes
36
-
37
- ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_chars --joined-dictionary \
38
- --workers "$(nproc)" --trainpref data/train.moses.chars --validpref data/valid.moses.chars \
39
- --testpref data/test.moses.chars
40
-
41
- for VOCAB_SIZE in 2048 4096; do
42
- for TYPE in bbpe bpe; do
43
- ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir "data/bin_${TYPE}${VOCAB_SIZE}" \
44
- --joined-dictionary --workers "$(nproc)" --trainpref "data/train.moses.${TYPE}${VOCAB_SIZE}" \
45
- --validpref "data/valid.moses.${TYPE}${VOCAB_SIZE}" --testpref "data/test.moses.${TYPE}${VOCAB_SIZE}"
46
- done
47
- done