JustinLin610 commited on
Commit
8437114
1 Parent(s): 6c83714
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +1 -1
  2. fairseq/.github/ISSUE_TEMPLATE.md +3 -0
  3. fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
  4. fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
  5. fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
  6. fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
  7. fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
  8. fairseq/.github/stale.yml +30 -0
  9. fairseq/.github/workflows/build.yml +55 -0
  10. fairseq/.github/workflows/build_wheels.yml +41 -0
  11. fairseq/.gitmodules +4 -0
  12. fairseq/CODE_OF_CONDUCT.md +77 -0
  13. fairseq/CONTRIBUTING.md +28 -0
  14. fairseq/LICENSE +21 -0
  15. fairseq/README.md +229 -0
  16. fairseq/examples/__init__.py +9 -0
  17. fairseq/examples/adaptive_span/README.md +90 -0
  18. fairseq/examples/adaptive_span/__init__.py +19 -0
  19. fairseq/examples/adaptive_span/adagrad_with_grad_clip.py +128 -0
  20. fairseq/examples/adaptive_span/adaptive_span_attention.py +160 -0
  21. fairseq/examples/adaptive_span/adaptive_span_loss.py +106 -0
  22. fairseq/examples/adaptive_span/adaptive_span_model.py +263 -0
  23. fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py +145 -0
  24. fairseq/examples/adaptive_span/truncated_bptt_lm_task.py +281 -0
  25. fairseq/examples/backtranslation/README.md +297 -0
  26. fairseq/examples/backtranslation/deduplicate_lines.py +41 -0
  27. fairseq/examples/backtranslation/extract_bt_data.py +72 -0
  28. fairseq/examples/backtranslation/prepare-de-monolingual.sh +98 -0
  29. fairseq/examples/backtranslation/prepare-wmt18en2de.sh +135 -0
  30. fairseq/examples/backtranslation/sacrebleu.sh +37 -0
  31. fairseq/examples/backtranslation/tokenized_bleu.sh +46 -0
  32. fairseq/examples/bart/README.glue.md +99 -0
  33. fairseq/examples/bart/README.md +228 -0
  34. fairseq/examples/bart/README.summarization.md +102 -0
  35. fairseq/examples/bart/summarize.py +100 -0
  36. fairseq/examples/byte_level_bpe/README.md +88 -0
  37. fairseq/examples/byte_level_bpe/get_bitext.py +254 -0
  38. fairseq/examples/byte_level_bpe/get_data.sh +47 -0
  39. fairseq/examples/byte_level_bpe/gru_transformer.py +107 -0
  40. fairseq/examples/camembert/README.md +75 -0
  41. fairseq/examples/constrained_decoding/README.md +123 -0
  42. fairseq/examples/constrained_decoding/normalize.py +27 -0
  43. fairseq/examples/constrained_decoding/tok.py +34 -0
  44. fairseq/examples/conv_seq2seq/README.md +25 -0
  45. fairseq/examples/criss/README.md +61 -0
  46. fairseq/examples/criss/download_and_preprocess_flores_test.sh +64 -0
  47. fairseq/examples/criss/download_and_preprocess_tatoeba.sh +46 -0
  48. fairseq/examples/criss/mining/mine.py +240 -0
  49. fairseq/examples/criss/mining/mine_example.sh +103 -0
  50. fairseq/examples/criss/save_encoder.py +214 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
 
3
- os.system('git clone https://github.com/pytorch/fairseq.git; cd fairseq;'
4
  'pip install --use-feature=in-tree-build ./; cd ..')
5
  os.system('ls -l')
6
 
1
  import os
2
 
3
+ os.system('cd fairseq;'
4
  'pip install --use-feature=in-tree-build ./; cd ..')
5
  os.system('ls -l')
6
 
fairseq/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ ## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
2
+
3
+ Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
fairseq/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🐛 Bug Report
3
+ about: Submit a bug report to help us improve
4
+ labels: 'bug, needs triage'
5
+ ---
6
+
7
+ ## 🐛 Bug
8
+
9
+ <!-- A clear and concise description of what the bug is. -->
10
+
11
+ ### To Reproduce
12
+
13
+ Steps to reproduce the behavior (**always include the command you ran**):
14
+
15
+ 1. Run cmd '....'
16
+ 2. See error
17
+
18
+ <!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
19
+
20
+
21
+ #### Code sample
22
+ <!-- Ideally attach a minimal code sample to reproduce the decried issue.
23
+ Minimal means having the shortest code but still preserving the bug. -->
24
+
25
+ ### Expected behavior
26
+
27
+ <!-- A clear and concise description of what you expected to happen. -->
28
+
29
+ ### Environment
30
+
31
+ - fairseq Version (e.g., 1.0 or main):
32
+ - PyTorch Version (e.g., 1.0)
33
+ - OS (e.g., Linux):
34
+ - How you installed fairseq (`pip`, source):
35
+ - Build command you used (if compiling from source):
36
+ - Python version:
37
+ - CUDA/cuDNN version:
38
+ - GPU models and configuration:
39
+ - Any other relevant information:
40
+
41
+ ### Additional context
42
+
43
+ <!-- Add any other context about the problem here. -->
fairseq/.github/ISSUE_TEMPLATE/documentation.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 📚 Documentation/Typos
3
+ about: Report an issue related to documentation or a typo
4
+ labels: 'documentation, needs triage'
5
+ ---
6
+
7
+ ## 📚 Documentation
8
+
9
+ For typos and doc fixes, please go ahead and:
10
+
11
+ 1. Create an issue.
12
+ 2. Fix the typo.
13
+ 3. Submit a PR.
14
+
15
+ Thanks!
fairseq/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🚀 Feature Request
3
+ about: Submit a proposal/request for a new feature
4
+ labels: 'enhancement, help wanted, needs triage'
5
+ ---
6
+
7
+ ## 🚀 Feature Request
8
+ <!-- A clear and concise description of the feature proposal -->
9
+
10
+ ### Motivation
11
+
12
+ <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
13
+
14
+ ### Pitch
15
+
16
+ <!-- A clear and concise description of what you want to happen. -->
17
+
18
+ ### Alternatives
19
+
20
+ <!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
21
+
22
+ ### Additional context
23
+
24
+ <!-- Add any other context or screenshots about the feature request here. -->
fairseq/.github/ISSUE_TEMPLATE/how-to-question.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: ❓ Questions/Help
3
+ about: If you have questions, please first search existing issues and docs
4
+ labels: 'question, needs triage'
5
+ ---
6
+
7
+ ## ❓ Questions and Help
8
+
9
+ ### Before asking:
10
+ 1. search the issues.
11
+ 2. search the docs.
12
+
13
+ <!-- If you still can't find what you need: -->
14
+
15
+ #### What is your question?
16
+
17
+ #### Code
18
+
19
+ <!-- Please paste a code snippet if your question requires it! -->
20
+
21
+ #### What have you tried?
22
+
23
+ #### What's your environment?
24
+
25
+ - fairseq Version (e.g., 1.0 or main):
26
+ - PyTorch Version (e.g., 1.0)
27
+ - OS (e.g., Linux):
28
+ - How you installed fairseq (`pip`, source):
29
+ - Build command you used (if compiling from source):
30
+ - Python version:
31
+ - CUDA/cuDNN version:
32
+ - GPU models and configuration:
33
+ - Any other relevant information:
fairseq/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Before submitting
2
+
3
+ - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
4
+ - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
5
+ - [ ] Did you make sure to update the docs?
6
+ - [ ] Did you write any new necessary tests?
7
+
8
+ ## What does this PR do?
9
+ Fixes # (issue).
10
+
11
+ ## PR review
12
+ Anyone in the community is free to review the PR once the tests have passed.
13
+ If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
14
+
15
+ ## Did you have fun?
16
+ Make sure you had fun coding 🙃
fairseq/.github/stale.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for probot-stale - https://github.com/probot/stale
2
+ # Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
3
+ # Number of days of inactivity before an issue becomes stale
4
+ daysUntilStale: 90
5
+ # Number of days of inactivity before a stale issue is closed
6
+ daysUntilClose: 7
7
+ # Issues with these labels will never be considered stale
8
+ exemptLabels:
9
+ - bug
10
+ # Label to use when marking an issue as stale
11
+ staleLabel: stale
12
+ issues:
13
+ # Comment to post when marking an issue as stale.
14
+ markComment: >
15
+ This issue has been automatically marked as stale.
16
+ **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
17
+ We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
18
+ # Comment to post when closing a stale issue.
19
+ closeComment: >
20
+ Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
21
+ pulls:
22
+ # Comment to post when marking a pull request as stale.
23
+ markComment: >
24
+ This pull request has been automatically marked as stale.
25
+ **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
26
+ We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
27
+ # Comment to post when closing a stale pull request.
28
+ closeComment: >
29
+ Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
30
+
fairseq/.github/workflows/build.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build
2
+
3
+ on:
4
+ # Trigger the workflow on push to main or any pull request
5
+ push:
6
+ branches:
7
+ - main
8
+ pull_request:
9
+
10
+ jobs:
11
+ build:
12
+
13
+ strategy:
14
+ max-parallel: 4
15
+ matrix:
16
+ platform: [ubuntu-latest, macos-latest]
17
+ python-version: [3.6, 3.7]
18
+
19
+ runs-on: ${{ matrix.platform }}
20
+
21
+ steps:
22
+ - uses: actions/checkout@v2
23
+
24
+ - name: Set up Python ${{ matrix.python-version }}
25
+ uses: actions/setup-python@v2
26
+ with:
27
+ python-version: ${{ matrix.python-version }}
28
+
29
+ - name: Conditionally install pytorch
30
+ if: matrix.platform == 'windows-latest'
31
+ run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
32
+
33
+ - name: Install locally
34
+ run: |
35
+ python -m pip install --upgrade pip
36
+ git submodule update --init --recursive
37
+ python setup.py build_ext --inplace
38
+ python -m pip install --editable .
39
+
40
+ - name: Install optional test requirements
41
+ run: |
42
+ python -m pip install iopath transformers pyarrow
43
+ python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
44
+
45
+ - name: Lint with flake8
46
+ run: |
47
+ pip install flake8
48
+ # stop the build if there are Python syntax errors or undefined names
49
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
50
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
51
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
52
+
53
+ - name: Run tests
54
+ run: |
55
+ python setup.py test
fairseq/.github/workflows/build_wheels.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build_wheels
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - v[0-9]+.[0-9]+.[x0-9]+
7
+ tags:
8
+ - v*
9
+
10
+ jobs:
11
+ build_wheels:
12
+ name: Build wheels on ${{ matrix.os }}
13
+ runs-on: ${{ matrix.os }}
14
+ strategy:
15
+ matrix:
16
+ os: [ubuntu-latest, macos-latest]
17
+
18
+ steps:
19
+ - uses: actions/checkout@v2
20
+
21
+ - name: Install Python
22
+ uses: actions/setup-python@v2
23
+ with:
24
+ python-version: '3.7'
25
+
26
+ - name: Install cibuildwheel
27
+ run: |
28
+ python -m pip install cibuildwheel
29
+
30
+ - name: Build wheels for CPython
31
+ run: |
32
+ python -m cibuildwheel --output-dir dist
33
+ env:
34
+ CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
35
+ CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
36
+ CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
37
+
38
+ - uses: actions/upload-artifact@v2
39
+ with:
40
+ name: wheels
41
+ path: ./dist/*.whl
fairseq/.gitmodules ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ [submodule "fairseq/model_parallel/megatron"]
2
+ path = fairseq/model_parallel/megatron
3
+ url = https://github.com/ngoyal2707/Megatron-LM
4
+ branch = fairseq
fairseq/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <conduct@pytorch.org>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
77
+
fairseq/CONTRIBUTING.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ ## License
26
+ By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27
+ you agree that your contributions will be licensed under the LICENSE file in
28
+ the root directory of this source tree.
fairseq/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
fairseq/README.md ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="docs/fairseq_logo.png" width="150">
3
+ <br />
4
+ <br />
5
+ <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
6
+ <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
7
+ <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
8
+ <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
9
+ </p>
10
+
11
+ --------------------------------------------------------------------------------
12
+
13
+ Fairseq(-py) is a sequence modeling toolkit that allows researchers and
14
+ developers to train custom models for translation, summarization, language
15
+ modeling and other text generation tasks.
16
+
17
+ We provide reference implementations of various sequence modeling papers:
18
+
19
+ <details><summary>List of implemented papers</summary><p>
20
+
21
+ * **Convolutional Neural Networks (CNN)**
22
+ + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
23
+ + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
24
+ + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
25
+ + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
26
+ + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
27
+ * **LightConv and DynamicConv models**
28
+ + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
29
+ * **Long Short-Term Memory (LSTM) networks**
30
+ + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
31
+ * **Transformer (self-attention) networks**
32
+ + Attention Is All You Need (Vaswani et al., 2017)
33
+ + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
34
+ + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
35
+ + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
36
+ + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
37
+ + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
38
+ + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
39
+ + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
40
+ + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
41
+ + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
42
+ + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
43
+ + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
44
+ + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
45
+ + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
46
+ + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
47
+ + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
48
+ + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
49
+ + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
50
+ + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
51
+ + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
52
+ + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
53
+ + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
54
+ * **Non-autoregressive Transformers**
55
+ + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
56
+ + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
57
+ + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
58
+ + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
59
+ + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
60
+ * **Finetuning**
61
+ + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
62
+
63
+ </p></details>
64
+
65
+ ### What's New:
66
+
67
+ * September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
68
+ * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
69
+ * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
70
+ * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
71
+ * May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
72
+ * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
73
+ * February 2021 [Added LASER training code](examples/laser/README.md)
74
+ * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
75
+ * December 2020: [GottBERT model and code released](examples/gottbert/README.md)
76
+ * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
77
+ * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
78
+ * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
79
+ * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
80
+ * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
81
+ * October 2020: [Added CRISS models and code](examples/criss/README.md)
82
+
83
+ <details><summary>Previous updates</summary><p>
84
+
85
+ * September 2020: [Added Linformer code](examples/linformer/README.md)
86
+ * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
87
+ * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
88
+ * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
89
+ * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
90
+ * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
91
+ * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
92
+ * April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
93
+ * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
94
+ * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
95
+ * February 2020: [mBART model and code released](examples/mbart/README.md)
96
+ * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
97
+ * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
98
+ * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
99
+ * November 2019: [CamemBERT model and code released](examples/camembert/README.md)
100
+ * November 2019: [BART model and code released](examples/bart/README.md)
101
+ * November 2019: [XLM-R models and code released](examples/xlmr/README.md)
102
+ * September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
103
+ * August 2019: [WMT'19 models released](examples/wmt19/README.md)
104
+ * July 2019: fairseq relicensed under MIT license
105
+ * July 2019: [RoBERTa models and code released](examples/roberta/README.md)
106
+ * June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
107
+
108
+ </p></details>
109
+
110
+ ### Features:
111
+
112
+ * multi-GPU training on one machine or across multiple machines (data and model parallel)
113
+ * fast generation on both CPU and GPU with multiple search algorithms implemented:
114
+ + beam search
115
+ + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
116
+ + sampling (unconstrained, top-k and top-p/nucleus)
117
+ + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
118
+ * [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
119
+ * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
120
+ * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
121
+ * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
122
+ * [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
123
+ * [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
124
+
125
+ We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
126
+ with a convenient `torch.hub` interface:
127
+
128
+ ``` python
129
+ en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
130
+ en2de.translate('Hello world', beam=5)
131
+ # 'Hallo Welt'
132
+ ```
133
+
134
+ See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
135
+ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
136
+
137
+ # Requirements and Installation
138
+
139
+ * [PyTorch](http://pytorch.org/) version >= 1.5.0
140
+ * Python version >= 3.6
141
+ * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
142
+ * **To install fairseq** and develop locally:
143
+
144
+ ``` bash
145
+ git clone https://github.com/pytorch/fairseq
146
+ cd fairseq
147
+ pip install --editable ./
148
+
149
+ # on MacOS:
150
+ # CFLAGS="-stdlib=libc++" pip install --editable ./
151
+
152
+ # to install the latest stable release (0.10.x)
153
+ # pip install fairseq
154
+ ```
155
+
156
+ * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
157
+
158
+ ``` bash
159
+ git clone https://github.com/NVIDIA/apex
160
+ cd apex
161
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
162
+ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
163
+ --global-option="--fast_multihead_attn" ./
164
+ ```
165
+
166
+ * **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
167
+ * If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
168
+ as command line options to `nvidia-docker run` .
169
+
170
+ # Getting Started
171
+
172
+ The [full documentation](https://fairseq.readthedocs.io/) contains instructions
173
+ for getting started, training new models and extending fairseq with new model
174
+ types and tasks.
175
+
176
+ # Pre-trained models and examples
177
+
178
+ We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
179
+ as well as example training and evaluation commands.
180
+
181
+ * [Translation](examples/translation/README.md): convolutional and transformer models are available
182
+ * [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
183
+
184
+ We also have more detailed READMEs to reproduce results from specific papers:
185
+
186
+ * [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
187
+ * [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
188
+ * [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
189
+ * [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
190
+ * [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
191
+ * [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
192
+ * [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
193
+ * [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
194
+ * [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
195
+ * [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
196
+ * [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
197
+ * [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
198
+ * [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
199
+ * [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
200
+ * [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
201
+ * [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
202
+ * [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
203
+ * [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
204
+ * [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
205
+ * [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
206
+
207
+ # Join the fairseq community
208
+
209
+ * Twitter: https://twitter.com/fairseq
210
+ * Facebook page: https://www.facebook.com/groups/fairseq.users
211
+ * Google group: https://groups.google.com/forum/#!forum/fairseq-users
212
+
213
+ # License
214
+
215
+ fairseq(-py) is MIT-licensed.
216
+ The license applies to the pre-trained models as well.
217
+
218
+ # Citation
219
+
220
+ Please cite as:
221
+
222
+ ``` bibtex
223
+ @inproceedings{ott2019fairseq,
224
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
225
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
226
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
227
+ year = {2019},
228
+ }
229
+ ```
fairseq/examples/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from fairseq.version import __version__ # noqa
8
+ except ImportError:
9
+ pass
fairseq/examples/adaptive_span/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adaptive Span
2
+
3
+ Adaptive Span is a novel self-attention mechanism that can learn its optimal
4
+ attention span. This allows us to extend significantly the maximum context size
5
+ used in Transformer, while maintaining control over their memory footprint
6
+ and computational time. It uses the Truncated BPTT technique for training,
7
+ as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
8
+
9
+ Adaptive Span was introduced by paper:
10
+ [Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
11
+ which achieved state-of-the-art language modeling results at the time of publication.
12
+
13
+ We manage to reproduce their result in fairseq and keep most of the
14
+ [original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
15
+ You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
16
+
17
+ ##### 0. Setup
18
+
19
+ First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
20
+ from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
21
+ You can download the dataset, and then run:
22
+ ```bash
23
+ fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
24
+ --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
25
+ --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
26
+ ```
27
+
28
+ ##### 1. Train a Adaptive Span model on Enwik8
29
+
30
+ We will train a 12-layer Adaptive Span model following the [hyperparameters
31
+ used in the original
32
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
33
+
34
+ The following command assumes 4 GPUs, so that the total batch size is 64
35
+ sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
36
+ ```bash
37
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
38
+ --user-dir examples/adaptive_span \
39
+ --data ~/data/enwik8/data-bin/ \
40
+ --fp16 --fp16-no-flatten-grads --max-update 600000 \
41
+ --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
42
+ --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
43
+ --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
44
+ --validate-interval-updates 1000 \
45
+ --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
46
+ --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
47
+ --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
48
+ ```
49
+ This should land around 1.05 on validation, 1.03 on test. You can lower the
50
+ --aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
51
+ improvement to the transformerXL baseline here.
52
+ If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
53
+ and simulate training on 4 GPUs.
54
+ You can also reproduce the transformerXL result on enwik8 using this code base.
55
+ It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
56
+ You can try by
57
+ ```bash
58
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
59
+ --user-dir examples/truncated_bptt \
60
+ ~/data/enwik8/data-bin/ \
61
+ --task truncated_bptt_lm --fp16 --max-update 400000 \
62
+ --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
63
+ --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
64
+ --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
65
+ --lr-scheduler cosine --warmup-updates 0 \
66
+ --lr 0.0 --lr 0.00025 --batch-size 15 \
67
+ --update-freq 1 --seed 2 --log-format json --log-interval 25 \
68
+ --fp16
69
+ ```
70
+
71
+ ##### 2. Evaluate
72
+ For Adaptive Span:
73
+ ```bash
74
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
75
+ --user-dir examples/adaptive_span \
76
+ --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
77
+ ```
78
+ For Transformer-XL evaluation:
79
+ ```bash
80
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
81
+ --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
82
+ --tokens-per-sample 80 \
83
+ --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
84
+ --gen-subset valid
85
+ ```
86
+
87
+ *Note:* During training the model saw 512 tokens of context
88
+ (``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
89
+ settings from [the original
90
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
fairseq/examples/adaptive_span/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+ # automatically import any Python files in the current directory
10
+ cur_dir = os.path.dirname(__file__)
11
+ for file in os.listdir(cur_dir):
12
+ path = os.path.join(cur_dir, file)
13
+ if (
14
+ not file.startswith("_")
15
+ and not file.startswith(".")
16
+ and (file.endswith(".py") or os.path.isdir(path))
17
+ ):
18
+ mod_name = file[: file.find(".py")] if file.endswith(".py") else file
19
+ module = importlib.import_module(__name__ + "." + mod_name)
fairseq/examples/adaptive_span/adagrad_with_grad_clip.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.optim import Adagrad
7
+
8
+ from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
9
+
10
+
11
+ @register_optimizer("adagrad_with_grad_clip")
12
+ class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
13
+ def __init__(self, args, params):
14
+ super().__init__(args)
15
+ self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
16
+
17
+ @staticmethod
18
+ def add_args(parser):
19
+ """Add optimizer-specific arguments to the parser."""
20
+ # fmt: off
21
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22
+ help='weight decay')
23
+ parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
24
+ help='internal grad clip')
25
+ # fmt: on
26
+
27
+ @property
28
+ def optimizer_config(self):
29
+ """
30
+ Return a kwarg dictionary that will be used to override optimizer
31
+ args stored in checkpoints. This allows us to load a checkpoint and
32
+ resume training using a different set of optimizer args, e.g., with a
33
+ different learning rate.
34
+ """
35
+ return {
36
+ "lr": self.args.lr[0],
37
+ "weight_decay": self.args.weight_decay,
38
+ "grad_clip": self.args.adagrad_clip,
39
+ }
40
+
41
+ @property
42
+ def supports_flat_params(self):
43
+ return False
44
+
45
+
46
+ def _clip_grad(clr, grad, group_grad_clip):
47
+ if group_grad_clip > 0:
48
+ norm = grad.norm(2).item()
49
+ if norm > group_grad_clip:
50
+ clr *= group_grad_clip / (norm + 1e-10)
51
+ return clr
52
+
53
+
54
+ class AdagradWithGradClip(Adagrad):
55
+ """Adagrad algorithm with custom gradient clipping"""
56
+
57
+ def __init__(
58
+ self,
59
+ params,
60
+ lr=1e-2,
61
+ lr_decay=0,
62
+ weight_decay=0,
63
+ initial_accumulator_value=0,
64
+ grad_clip=0,
65
+ ):
66
+ Adagrad.__init__(
67
+ self,
68
+ params,
69
+ lr=lr,
70
+ lr_decay=lr_decay,
71
+ weight_decay=weight_decay,
72
+ initial_accumulator_value=initial_accumulator_value,
73
+ )
74
+ self.defaults["grad_clip"] = grad_clip
75
+ self.param_groups[0].setdefault("grad_clip", grad_clip)
76
+
77
+ def step(self, closure=None):
78
+ loss = None
79
+ if closure is not None:
80
+ loss = closure()
81
+
82
+ for group in self.param_groups:
83
+ for p in group["params"]:
84
+ if p.grad is None:
85
+ continue
86
+
87
+ grad = p.grad.data
88
+ state = self.state[p]
89
+
90
+ state["step"] += 1
91
+
92
+ if group["weight_decay"] != 0:
93
+ if p.grad.data.is_sparse:
94
+ raise RuntimeError(
95
+ "weight_decay option is "
96
+ "not compatible with sparse "
97
+ "gradients"
98
+ )
99
+ grad = grad.add(group["weight_decay"], p.data)
100
+
101
+ clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
102
+
103
+ # clip
104
+ clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
105
+
106
+ if grad.is_sparse:
107
+ # the update is non-linear so indices must be unique
108
+ grad = grad.coalesce()
109
+ grad_indices = grad._indices()
110
+ grad_values = grad._values()
111
+ size = grad.size()
112
+
113
+ def make_sparse(values):
114
+ constructor = grad.new
115
+ if grad_indices.dim() == 0 or values.dim() == 0:
116
+ return constructor().resize_as_(grad)
117
+ return constructor(grad_indices, values, size)
118
+
119
+ state["sum"].add_(make_sparse(grad_values.pow(2)))
120
+ std = state["sum"]._sparse_mask(grad)
121
+ std_values = std._values().sqrt_().add_(1e-10)
122
+ p.data.add_(-clr, make_sparse(grad_values / std_values))
123
+ else:
124
+ state["sum"].addcmul_(1, grad, grad)
125
+ std = state["sum"].sqrt().add_(1e-10)
126
+ p.data.addcdiv_(-clr, grad, std)
127
+
128
+ return loss
fairseq/examples/adaptive_span/adaptive_span_attention.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class AdaptiveMask(nn.Module):
13
+ """Soft masking function for adaptive size.
14
+ It masks out the last K values of an input. The masking value
15
+ goes from 1 to 0 gradually, so K can be learned with
16
+ back-propagation.
17
+ Args:
18
+ max_size: maximum size (i.e. input dimension)
19
+ ramp_size: size of the ramp going from 0 to 1
20
+ init_val: initial size proportion not to be masked out
21
+ shape: learn multiple sizes independent of each other
22
+ """
23
+
24
+ def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
25
+ nn.Module.__init__(self)
26
+ self._max_size = max_size
27
+ self._ramp_size = ramp_size
28
+ self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
29
+ mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
30
+ self.register_buffer("mask_template", mask_template)
31
+
32
+ def forward(self, x):
33
+ mask = self.mask_template.float() + self.current_val.float() * self._max_size
34
+ mask = mask / self._ramp_size + 1
35
+ mask = mask.clamp(0, 1)
36
+ if x.size(-1) < self._max_size:
37
+ # the input could have been trimmed beforehand to save computation
38
+ mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
39
+ x = (x * mask).type_as(x)
40
+ return x
41
+
42
+ def get_current_max_size(self, include_ramp=True):
43
+ current_size = math.ceil(self.current_val.max().item() * self._max_size)
44
+ if include_ramp:
45
+ current_size += self._ramp_size
46
+ current_size = max(0, min(self._max_size, current_size))
47
+ return current_size
48
+
49
+ def get_current_avg_size(self, include_ramp=True):
50
+ current_size = math.ceil(
51
+ self.current_val.float().mean().item() * self._max_size
52
+ )
53
+ if include_ramp:
54
+ current_size += self._ramp_size
55
+ current_size = max(0, min(self._max_size, current_size))
56
+ return current_size
57
+
58
+ def clamp_param(self):
59
+ """this need to be called after each update"""
60
+ self.current_val.data.clamp_(0, 1)
61
+
62
+
63
+ class AdaptiveSpan(nn.Module):
64
+ """Adaptive attention span for Transformerself.
65
+ This module learns an attention span length from data for each
66
+ self-attention head.
67
+ Args:
68
+ attn_span: maximum attention span
69
+ adapt_span_loss: loss coefficient for the span length
70
+ adapt_span_ramp: length of the masking ramp
71
+ adapt_span_init: initial size ratio
72
+ adapt_span_cache: adapt cache size to reduce memory usage
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ attn_span,
78
+ adapt_span_ramp,
79
+ adapt_span_init,
80
+ n_head,
81
+ adapt_span_layer,
82
+ **kargs
83
+ ):
84
+ nn.Module.__init__(self)
85
+ self._max_span = attn_span
86
+ self._n_head = n_head
87
+ self._adapt_span_layer = adapt_span_layer
88
+ if self._adapt_span_layer:
89
+ self._mask = AdaptiveMask(
90
+ max_size=self._max_span,
91
+ ramp_size=adapt_span_ramp,
92
+ init_val=adapt_span_init,
93
+ )
94
+ else:
95
+ self._mask = AdaptiveMask(
96
+ max_size=self._max_span,
97
+ ramp_size=adapt_span_ramp,
98
+ init_val=adapt_span_init,
99
+ shape=(n_head, 1, 1),
100
+ )
101
+
102
+ def forward(self, attn, normalize=True):
103
+ """mask attention with the right span"""
104
+ # batch and head dimensions are merged together, so separate them first
105
+ self.clamp_param()
106
+ if self._adapt_span_layer:
107
+ attn = self._mask(attn)
108
+ else:
109
+ B = attn.size(0) # batch size
110
+ M = attn.size(1) # block size
111
+ attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
112
+ attn = self._mask(attn)
113
+ attn = attn.view(B, M, -1)
114
+ return attn
115
+
116
+ def get_trim_len(self):
117
+ """how much of memory can be trimmed to reduce computation"""
118
+ L = self._max_span
119
+ trim_len = min(L - 1, L - self._mask.get_current_max_size())
120
+ # too fine granularity might be bad for the memory management
121
+ trim_len = math.floor(trim_len / 64) * 64
122
+ return trim_len
123
+
124
+ def trim_memory(self, query, key, value, key_pe):
125
+ """trim out unnecessary memory beforehand to reduce computation"""
126
+ trim_len = self.get_trim_len()
127
+ cache_size = key.size(1) - query.size(1)
128
+ trim_len_cache = trim_len - (self._max_span - cache_size)
129
+ if trim_len_cache > 0:
130
+ key = key[:, trim_len_cache:, :]
131
+ value = value[:, trim_len_cache:, :]
132
+ elif trim_len_cache < 0:
133
+ # cache is too short! this happens when validation resumes
134
+ # after a lot of updates.
135
+ key = F.pad(key, [0, 0, -trim_len_cache, 0])
136
+ value = F.pad(value, [0, 0, -trim_len_cache, 0])
137
+ if trim_len > 0:
138
+ if key_pe is not None:
139
+ key_pe = key_pe[:, :, trim_len:]
140
+ return key, value, key_pe
141
+
142
+ def get_cache_size(self):
143
+ """determine how long the cache should be"""
144
+ trim_len = self.get_trim_len()
145
+ # give a buffer of 64 steps since a span might increase
146
+ # in future updates
147
+ return min(self._max_span, self._max_span - trim_len + 64)
148
+
149
+ def get_loss(self):
150
+ """a loss term for regularizing the span length"""
151
+ return self._max_span * self._mask.current_val.float().mean()
152
+
153
+ def get_current_max_span(self):
154
+ return self._mask.get_current_max_size()
155
+
156
+ def get_current_avg_span(self):
157
+ return self._mask.get_current_avg_size()
158
+
159
+ def clamp_param(self):
160
+ self._mask.clamp_param()
fairseq/examples/adaptive_span/adaptive_span_loss.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch.nn.functional as F
10
+ from fairseq import metrics, utils
11
+ from fairseq.criterions import register_criterion
12
+ from fairseq.criterions.cross_entropy import CrossEntropyCriterion
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from omegaconf import II
15
+
16
+
17
+ @dataclass
18
+ class AdaptiveSpanCriterionConfig(FairseqDataclass):
19
+ sentence_avg: bool = II("optimization.sentence_avg")
20
+
21
+
22
+ @register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
23
+ class AdaptiveSpanCriterion(CrossEntropyCriterion):
24
+ def __init__(self, task, sentence_avg):
25
+ super().__init__(task, sentence_avg)
26
+
27
+ def forward(self, model, sample, reduce=True):
28
+ """Compute the loss for the given sample.
29
+
30
+ Returns a tuple with three elements:
31
+ 1) the loss here is summed, different from the adaptive span code
32
+ 2) the sample size, which is used as the denominator for the gradient
33
+ 3) logging outputs to display while training
34
+ """
35
+ net_output = model(**sample["net_input"])
36
+ loss, aux_loss, avg_span, max_span = self.compute_loss(
37
+ model, net_output, sample, reduce=reduce
38
+ )
39
+ sample_size = (
40
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
41
+ )
42
+ loss /= sample_size
43
+ total_loss = loss + aux_loss
44
+ sample_size = 1
45
+
46
+ logging_output = {
47
+ "loss": loss.data,
48
+ "ntokens": sample["ntokens"],
49
+ "nsentences": sample["target"].size(0),
50
+ "sample_size": sample_size,
51
+ "total_loss": total_loss.data,
52
+ "avg_span": avg_span * sample_size,
53
+ "max_span": max_span * sample_size,
54
+ }
55
+ return total_loss, sample_size, logging_output
56
+
57
+ def compute_loss(self, model, net_output, sample, reduce=True):
58
+ loss, _ = super().compute_loss(model, net_output, sample, reduce)
59
+ aux_loss = model.get_aux_loss()
60
+ avg_span = model.get_current_avg_span()
61
+ max_span = model.get_current_max_span()
62
+ return loss, aux_loss, avg_span, max_span
63
+
64
+ @staticmethod
65
+ def reduce_metrics(logging_outputs) -> None:
66
+ """Aggregate logging outputs from data parallel training."""
67
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
68
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
69
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
70
+ total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
71
+ avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
72
+ max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
73
+
74
+ # we divide by log(2) to convert the loss from base e to base 2
75
+ metrics.log_scalar(
76
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
77
+ )
78
+ metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
79
+ metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
80
+ # total loss contains the L1 norm on adaptive-span
81
+ metrics.log_scalar(
82
+ "total_loss",
83
+ total_loss_sum / sample_size / math.log(2),
84
+ sample_size,
85
+ round=3,
86
+ )
87
+ if sample_size != ntokens:
88
+ metrics.log_scalar(
89
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
90
+ )
91
+ metrics.log_derived(
92
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
93
+ )
94
+ else:
95
+ metrics.log_derived(
96
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
97
+ )
98
+
99
+ @staticmethod
100
+ def logging_outputs_can_be_summed() -> bool:
101
+ """
102
+ Whether the logging outputs returned by `forward` can be summed
103
+ across workers prior to calling `reduce_metrics`. Setting this
104
+ to True will improves distributed training speed.
105
+ """
106
+ return True
fairseq/examples/adaptive_span/adaptive_span_model.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from fairseq.modules.layer_norm import LayerNorm
14
+
15
+ from .adaptive_span_attention import AdaptiveSpan
16
+
17
+ # Size notations:
18
+ # B = batch_size, H = d_model, M = block_size, L = attn_span
19
+
20
+
21
+ def _skew(X, pad_value):
22
+ """shift every row 1 step to right"""
23
+ # X = B x M x L
24
+ B, M, L = X.size()
25
+ X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
26
+ X = X.view(B, -1) # B x ML+MM+M
27
+ X = X[:, :-M] # B x ML+MM
28
+ X = X.view(B, M, M + L) # B x M x L+M
29
+ return X
30
+
31
+
32
+ def _unskew(X):
33
+ """reverse _skew operation"""
34
+ # X = B x M x L+M
35
+ B, M, L = X.size()
36
+ L -= M
37
+ X = X.view(B, -1) # B x ML+MM
38
+ X = F.pad(X, (0, M)) # B x ML+MM+M
39
+ X = X.view(B, M, M + L + 1) # B x M x L+M+1
40
+ X = X[:, :, :L] # B x M x L
41
+ return X
42
+
43
+
44
+ class SeqAttention(nn.Module):
45
+ """Sequential self-attention layer.
46
+ Each token will attend to its previous fixed number of steps.
47
+ Note that attention doesn't include the current step itself.
48
+ """
49
+
50
+ def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
51
+ nn.Module.__init__(self)
52
+ self.dropout = nn.Dropout(dropout)
53
+ self.d_model = d_model # size of a single head
54
+ self.attn_span = attn_span
55
+ self.adaptive_span = AdaptiveSpan(
56
+ attn_span=attn_span,
57
+ n_head=n_head,
58
+ adapt_span_layer=adapt_span_layer,
59
+ **kargs
60
+ )
61
+
62
+ def forward(self, query, key, value, key_pe):
63
+ # query size = B x M x H
64
+ # key, value sizes = B x (M+L) x H
65
+
66
+ key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
67
+
68
+ # compute attention from context
69
+ # B x M (dest) x (M+L) (src)
70
+ attn_cont = torch.matmul(query, key.transpose(-1, -2))
71
+ attn_cont = _unskew(attn_cont) # B x M x L
72
+
73
+ # compute the effect of position embedding
74
+ attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
75
+ attn = attn_cont + attn_pos
76
+
77
+ attn = attn / math.sqrt(self.d_model) # B x M X L_pos
78
+
79
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
80
+
81
+ # trim attention lengths according to the learned span
82
+ attn = self.adaptive_span(attn)
83
+
84
+ attn = self.dropout(attn) # B x M X L_pos
85
+
86
+ attn_cont = _skew(attn, 0) # B x M X (L+M)
87
+ out = torch.matmul(attn_cont, value) # B x M x H
88
+ return out
89
+
90
+ def get_cache_size(self):
91
+ return self.adaptive_span.get_cache_size()
92
+
93
+
94
+ class MultiHeadSeqAttention(nn.Module):
95
+ def __init__(self, d_model, n_head, **kargs):
96
+ nn.Module.__init__(self)
97
+ assert d_model % n_head == 0
98
+ self.n_head = n_head
99
+ self.head_dim = d_model // n_head
100
+ self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
101
+ self.proj_query = nn.Linear(d_model, d_model, bias=False)
102
+ nn.init.xavier_normal_(self.proj_query.weight)
103
+ self.proj_out = nn.Linear(d_model, d_model, bias=False)
104
+ nn.init.xavier_normal_(self.proj_out.weight)
105
+ self.proj_val = nn.Linear(d_model, d_model, bias=False)
106
+ nn.init.xavier_normal_(self.proj_val.weight)
107
+ self.proj_key = nn.Linear(d_model, d_model, bias=False)
108
+ nn.init.xavier_normal_(self.proj_key.weight)
109
+
110
+ def head_reshape(self, x):
111
+ K = self.n_head
112
+ D = self.head_dim
113
+ x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
114
+ x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
115
+ x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
116
+ return x
117
+
118
+ def forward(self, query, key, value, key_pe):
119
+ B = query.size(0)
120
+ K = self.n_head
121
+ D = self.head_dim
122
+ M = query.size(1)
123
+
124
+ query = self.proj_query(query)
125
+ query = self.head_reshape(query)
126
+ value = self.proj_val(value)
127
+ value = self.head_reshape(value)
128
+ key = self.proj_key(key)
129
+ key = self.head_reshape(key)
130
+
131
+ out = self.attn(query, key, value, key_pe) # B_K x M x D
132
+ out = out.view(B, K, M, D) # B x K x M x D
133
+ out = out.transpose(1, 2).contiguous() # B x M x K x D
134
+ out = out.view(B, M, -1) # B x M x K_D
135
+ out = self.proj_out(out)
136
+ return out
137
+
138
+
139
+ class FeedForwardLayer(nn.Module):
140
+ def __init__(self, d_model, d_inner, dropout, **kargs):
141
+ nn.Module.__init__(self)
142
+ self.fc1 = nn.Linear(d_model, d_inner)
143
+ self.fc2 = nn.Linear(d_inner, d_model)
144
+ nn.init.xavier_uniform_(self.fc1.weight)
145
+ nn.init.xavier_uniform_(self.fc2.weight)
146
+ self.dropout = nn.Dropout(dropout)
147
+
148
+ def forward(self, h):
149
+ h1 = F.relu(self.fc1(h))
150
+ h1 = self.dropout(h1)
151
+ h2 = self.fc2(h1)
152
+ return h2
153
+
154
+
155
+ class TransformerSeqLayer(nn.Module):
156
+ def __init__(self, d_model, **kargs):
157
+ nn.Module.__init__(self)
158
+ self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
159
+ self.norm1 = LayerNorm(d_model)
160
+ self.ff = FeedForwardLayer(d_model=d_model, **kargs)
161
+ self.norm2 = LayerNorm(d_model)
162
+
163
+ def forward(self, h, h_cache, key_pe):
164
+ # h = B x M x H
165
+ # h_cache = B x L x H
166
+ h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
167
+ attn_out = self.attn(h, h_all, h_all, key_pe)
168
+ h = self.norm1(h + attn_out) # B x M x H
169
+ if self.ff is not None:
170
+ ff_out = self.ff(h)
171
+ out = self.norm2(h + ff_out) # B x M x H
172
+ else:
173
+ out = h
174
+ return out
175
+
176
+ def get_cache_size(self):
177
+ return self.attn.attn.get_cache_size()
178
+
179
+
180
+ class TransformerSeq(nn.Module):
181
+ def __init__(
182
+ self,
183
+ vocab_size,
184
+ d_model,
185
+ n_head,
186
+ n_layer,
187
+ attn_span,
188
+ emb_dropout,
189
+ aux_loss_scaler,
190
+ adapt_span_layer,
191
+ **kargs
192
+ ):
193
+ nn.Module.__init__(self)
194
+ # token embeddings
195
+ self.in_emb = nn.Embedding(vocab_size, d_model)
196
+ nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
197
+ self.out_emb = nn.Linear(d_model, vocab_size)
198
+ self.aux_loss_scaler = aux_loss_scaler
199
+ if emb_dropout > 0:
200
+ self.emb_dropout = nn.Dropout(emb_dropout)
201
+ else:
202
+ self.emb_dropout = None
203
+ # position embeddings
204
+ self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
205
+
206
+ self.layers = nn.ModuleList()
207
+ self.layers.extend(
208
+ TransformerSeqLayer(
209
+ d_model=d_model,
210
+ n_head=n_head,
211
+ attn_span=attn_span,
212
+ adapt_span_layer=adapt_span_layer,
213
+ **kargs
214
+ )
215
+ for _ in range(n_layer)
216
+ )
217
+
218
+ def forward(self, x, h_cache, target=None):
219
+ # x size = B x M
220
+ block_size = x.size(1)
221
+ h = self.in_emb(x) # B x M x H
222
+ if self.emb_dropout is not None:
223
+ h = self.emb_dropout(h)
224
+
225
+ h_cache_next = []
226
+ for l, layer in enumerate(self.layers):
227
+ cache_size = layer.attn.attn.get_cache_size()
228
+ if cache_size > block_size:
229
+ h_cache_next_l = torch.cat(
230
+ [h_cache[l][:, -cache_size + block_size :, :], h], dim=1
231
+ ).detach()
232
+ else:
233
+ h_cache_next_l = h[:, -cache_size:, :].detach()
234
+ h_cache_next.append(h_cache_next_l)
235
+ h = layer(h, h_cache[l], self.key_pe) # B x M x H
236
+
237
+ if self.emb_dropout is not None:
238
+ h = self.emb_dropout(h)
239
+
240
+ out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
241
+ dummy_loss = None
242
+
243
+ return out, h_cache_next, dummy_loss
244
+
245
+ def get_aux_loss(self):
246
+ loss = 0.0
247
+ for layer in self.layers:
248
+ loss += layer.attn.attn.adaptive_span.get_loss()
249
+ return self.aux_loss_scaler * loss
250
+
251
+ def get_current_max_span(self):
252
+ max_span = 0.0
253
+ for layer in self.layers:
254
+ max_span = max(
255
+ max_span, layer.attn.attn.adaptive_span.get_current_max_span()
256
+ )
257
+ return max_span
258
+
259
+ def get_current_avg_span(self):
260
+ avg_span = 0.0
261
+ for layer in self.layers:
262
+ avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
263
+ return avg_span / len(self.layers)
fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ from fairseq.dataclass import FairseqDataclass
12
+ from fairseq.models import (
13
+ FairseqIncrementalDecoder,
14
+ FairseqLanguageModel,
15
+ register_model,
16
+ )
17
+ from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class AdaptiveSpanSmallConfig(FairseqDataclass):
25
+ # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
26
+ vocab_size: int = 50
27
+ d_model: int = 256
28
+ n_head: int = 4
29
+ d_inner: int = 1024
30
+ n_layer: int = 8
31
+ attn_span: int = 1024
32
+ dropout: float = 0.0
33
+ emb_dropout: float = 0.0
34
+ adapt_span_ramp: int = 32
35
+ adapt_span_init: float = 0.0
36
+ aux_loss_scaler: float = 0.000002
37
+ adapt_span_layer: bool = False
38
+
39
+
40
+ @register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
41
+ class AdaptiveSpanTransformer(FairseqLanguageModel):
42
+ @classmethod
43
+ def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
44
+ return cls(AdaptiveSpanDecoder(cfg, task))
45
+
46
+ def get_aux_loss(self):
47
+ return self.decoder.get_aux_loss()
48
+
49
+ def get_current_max_span(self):
50
+ return self.decoder.get_current_max_span()
51
+
52
+ def get_current_avg_span(self):
53
+ return self.decoder.get_current_avg_span()
54
+
55
+
56
+ class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
57
+ def __init__(self, cfg, task):
58
+
59
+ super().__init__(task.target_dictionary)
60
+
61
+ self.config = cfg
62
+ config = AdaptiveSpanSmallConfig(
63
+ vocab_size=len(task.target_dictionary),
64
+ d_model=cfg.d_model,
65
+ n_head=cfg.n_head,
66
+ d_inner=cfg.d_inner,
67
+ n_layer=cfg.n_layer,
68
+ attn_span=cfg.attn_span,
69
+ dropout=cfg.dropout,
70
+ emb_dropout=cfg.emb_dropout,
71
+ adapt_span_ramp=cfg.adapt_span_ramp,
72
+ adapt_span_init=cfg.adapt_span_init,
73
+ aux_loss_scaler=cfg.aux_loss_scaler,
74
+ adapt_span_layer=cfg.adapt_span_layer,
75
+ )
76
+ logger.info(config)
77
+ self.model = AdaptiveSpanTransformerModel(**config.__dict__)
78
+
79
+ self._mems = None
80
+
81
+ def forward(
82
+ self,
83
+ src_tokens,
84
+ incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
85
+ encoder_out=None,
86
+ ):
87
+ bsz = src_tokens.size(0)
88
+ if incremental_state is not None: # used during inference
89
+ mems = self.get_incremental_state("mems")
90
+ src_tokens = src_tokens[:, -1:] # only keep the most recent token
91
+ else:
92
+ mems = self._mems
93
+
94
+ if mems is None:
95
+ # first time init
96
+ mems = self.init_hid_cache(bsz)
97
+ output = self.model(x=src_tokens, h_cache=mems,)
98
+ if incremental_state is not None:
99
+ self.set_incremental_state(incremental_state, "mems", output[1])
100
+ else:
101
+ self._mems = output[1]
102
+ return (output[0],)
103
+
104
+ def max_positions(self):
105
+ return self.config.attn_span
106
+
107
+ def init_hid_cache(self, batch_sz):
108
+ hid = []
109
+ for layer in self.model.layers:
110
+ param = next(self.model.parameters())
111
+ h = torch.zeros(
112
+ batch_sz,
113
+ layer.get_cache_size(),
114
+ self.config.d_model,
115
+ dtype=param.dtype,
116
+ device=param.device,
117
+ )
118
+ hid.append(h)
119
+ return hid
120
+
121
+ def get_aux_loss(self):
122
+ return self.model.get_aux_loss()
123
+
124
+ def get_current_max_span(self):
125
+ return self.model.get_current_max_span()
126
+
127
+ def get_current_avg_span(self):
128
+ return self.model.get_current_avg_span()
129
+
130
+ def reorder_incremental_state(
131
+ self,
132
+ incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
133
+ new_order: torch.Tensor,
134
+ ):
135
+ """Reorder incremental state.
136
+
137
+ This will be called when the order of the input has changed from the
138
+ previous time step. A typical use case is beam search, where the input
139
+ order changes between time steps based on the selection of beams.
140
+ """
141
+ raise NotImplementedError("This is required for generation/beam search")
142
+ # mems = self.get_incremental_state(incremental_state, "mems")
143
+ # if mems is not None:
144
+ # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
145
+ # self.set_incremental_state(incremental_state, "mems", new_mems)
fairseq/examples/adaptive_span/truncated_bptt_lm_task.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from typing import List, Optional, Tuple
10
+
11
+ import torch
12
+ from fairseq import utils
13
+ from fairseq.data import (
14
+ Dictionary,
15
+ TokenBlockDataset,
16
+ data_utils,
17
+ iterators,
18
+ )
19
+ from fairseq.dataclass import FairseqDataclass
20
+ from fairseq.distributed import utils as dist_utils
21
+ from fairseq.tasks import FairseqTask, register_task
22
+ from omegaconf import II
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class TruncatedBPTTLMConfig(FairseqDataclass):
30
+ data: str = field(default="???", metadata={"help": "path to data directory"})
31
+ tokens_per_sample: int = field(
32
+ default=1024,
33
+ metadata={"help": "max number of tokens per sequence"},
34
+ )
35
+ batch_size: int = II("dataset.batch_size")
36
+ # Some models use *max_target_positions* to know how many positional
37
+ # embeddings to learn. We use II(...) to make it default to
38
+ # *tokens_per_sample*, but in principle there could be more positional
39
+ # embeddings than tokens in a single batch. This may also be irrelevant for
40
+ # custom model implementations.
41
+ max_target_positions: int = II("task.tokens_per_sample")
42
+ # these will be populated automatically if not provided
43
+ data_parallel_rank: Optional[int] = None
44
+ data_parallel_size: Optional[int] = None
45
+
46
+
47
+ @register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
48
+ class TruncatedBPTTLMTask(FairseqTask):
49
+ def __init__(self, cfg: TruncatedBPTTLMConfig):
50
+ super().__init__(cfg)
51
+
52
+ if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
53
+ if torch.distributed.is_initialized():
54
+ cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
55
+ cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
56
+ else:
57
+ cfg.data_parallel_rank = 0
58
+ cfg.data_parallel_size = 1
59
+
60
+ # load the dictionary
61
+ paths = utils.split_paths(cfg.data)
62
+ assert len(paths) > 0
63
+ self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
64
+ logger.info("dictionary: {} types".format(len(self.dictionary)))
65
+
66
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
67
+ """Load a given dataset split (e.g., train, valid, test)"""
68
+
69
+ # support sharded datasets
70
+ paths = utils.split_paths(self.cfg.data)
71
+ assert len(paths) > 0
72
+ data_path = paths[(epoch - 1) % len(paths)]
73
+ split_path = os.path.join(data_path, split)
74
+
75
+ # each element of *data* will be a tensorized line from the original
76
+ # text dataset, similar to ``open(split_path).readlines()``
77
+ data = data_utils.load_indexed_dataset(
78
+ split_path, self.dictionary, combine=combine
79
+ )
80
+ if data is None:
81
+ raise FileNotFoundError(
82
+ "Dataset not found: {} ({})".format(split, split_path)
83
+ )
84
+
85
+ # this is similar to ``data.view(-1).split(tokens_per_sample)``
86
+ data = TokenBlockDataset(
87
+ data,
88
+ data.sizes,
89
+ block_size=self.cfg.tokens_per_sample,
90
+ pad=None, # unused
91
+ eos=None, # unused
92
+ break_mode="none",
93
+ )
94
+
95
+ self.datasets[split] = TruncatedBPTTDataset(
96
+ data=data,
97
+ bsz_per_shard=self.cfg.batch_size,
98
+ shard_id=self.cfg.data_parallel_rank,
99
+ num_shards=self.cfg.data_parallel_size,
100
+ )
101
+
102
+ def dataset(self, split):
103
+ return self.datasets[split]
104
+
105
+ def get_batch_iterator(
106
+ self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs
107
+ ):
108
+ return iterators.EpochBatchIterator(
109
+ dataset=dataset,
110
+ collate_fn=self._collate_fn,
111
+ num_workers=num_workers,
112
+ epoch=epoch,
113
+ buffer_size=data_buffer_size,
114
+ # we don't use the batching functionality from EpochBatchIterator;
115
+ # instead every item in *dataset* is a whole batch
116
+ batch_sampler=[[i] for i in range(len(dataset))],
117
+ disable_shuffling=True,
118
+ )
119
+
120
+ def _collate_fn(self, items: List[List[torch.Tensor]]):
121
+ # we don't use fairseq's batching functionality, so we expect a single
122
+ # Tensor of type List[torch.Tensor]
123
+ assert len(items) == 1
124
+
125
+ # item will have shape B x T (the last batch may have length < T)
126
+ id, item = items[0]
127
+ item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
128
+ B, T = item.size()
129
+
130
+ # shift item one position over and append a padding token for the target
131
+ target = torch.nn.functional.pad(
132
+ item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
133
+ )
134
+
135
+ # fairseq expects batches to have the following structure
136
+ return {
137
+ "id": torch.tensor([id]*item.size(0)),
138
+ "net_input": {
139
+ "src_tokens": item,
140
+ },
141
+ "target": target,
142
+ "nsentences": item.size(0),
143
+ "ntokens": item.numel(),
144
+ }
145
+
146
+ def build_dataset_for_inference(
147
+ self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
148
+ ) -> torch.utils.data.Dataset:
149
+ eos = self.source_dictionary.eos()
150
+ dataset = TokenBlockDataset(
151
+ src_tokens,
152
+ src_lengths,
153
+ block_size=None, # ignored for "eos" break mode
154
+ pad=self.source_dictionary.pad(),
155
+ eos=eos,
156
+ break_mode="eos",
157
+ )
158
+
159
+ class Dataset(torch.utils.data.Dataset):
160
+ def __getitem__(self, i):
161
+ item = dataset[i]
162
+ if item[-1] == eos:
163
+ # remove eos to support generating with a prefix
164
+ item = item[:-1]
165
+ return (i, [item])
166
+
167
+ def __len__(self):
168
+ return len(dataset)
169
+
170
+ return Dataset()
171
+
172
+ def inference_step(
173
+ self, generator, models, sample, prefix_tokens=None, constraints=None
174
+ ):
175
+ with torch.no_grad():
176
+ if constraints is not None:
177
+ raise NotImplementedError
178
+
179
+ # SequenceGenerator doesn't use *src_tokens* directly, we need to
180
+ # pass the *prefix_tokens* argument instead.
181
+ if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
182
+ prefix_tokens = sample["net_input"]["src_tokens"]
183
+
184
+ # begin generation with the end-of-sentence token
185
+ bos_token = self.source_dictionary.eos()
186
+
187
+ return generator.generate(
188
+ models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
189
+ )
190
+
191
+ def eval_lm_dataloader(
192
+ self,
193
+ dataset,
194
+ max_tokens: Optional[int] = 36000,
195
+ batch_size: Optional[int] = None,
196
+ max_positions: Optional[int] = None,
197
+ num_shards: int = 1,
198
+ shard_id: int = 0,
199
+ num_workers: int = 1,
200
+ data_buffer_size: int = 10,
201
+ context_window: int = 0,
202
+ ):
203
+ if context_window > 0:
204
+ raise NotImplementedError(
205
+ "Transformer-XL doesn't need --context-window, try "
206
+ "--model-overrides '{\"mem_len\":42}' instead "
207
+ )
208
+ return self.get_batch_iterator(
209
+ dataset=dataset,
210
+ max_tokens=max_tokens,
211
+ max_sentences=batch_size,
212
+ max_positions=max_positions,
213
+ ignore_invalid_inputs=True,
214
+ num_shards=num_shards,
215
+ shard_id=shard_id,
216
+ num_workers=num_workers,
217
+ data_buffer_size=data_buffer_size,
218
+ ).next_epoch_itr(shuffle=False)
219
+
220
+ @property
221
+ def source_dictionary(self):
222
+ return self.dictionary
223
+
224
+ @property
225
+ def target_dictionary(self):
226
+ return self.dictionary
227
+
228
+
229
+ class TruncatedBPTTDataset(torch.utils.data.Dataset):
230
+ def __init__(
231
+ self,
232
+ data: List[torch.Tensor], # ordered list of items
233
+ bsz_per_shard, # number of items processed per GPUs per forward
234
+ shard_id, # current GPU ID
235
+ num_shards, # number of GPUs
236
+ ):
237
+ super().__init__()
238
+ self.data = data
239
+
240
+ def batchify(data, bsz):
241
+ # Work out how cleanly we can divide the dataset into bsz parts.
242
+ nbatch = data.size(0) // bsz
243
+ # Trim off any extra elements that wouldn't cleanly fit (remainders).
244
+ data = data.narrow(0, 0, nbatch * bsz)
245
+ # Evenly divide the data across the bsz batches.
246
+ data = data.view(bsz, -1).contiguous()
247
+ return data
248
+
249
+ # total number of sequences processed by all GPUs in each forward pass
250
+ global_batch_size = bsz_per_shard * num_shards
251
+
252
+ """
253
+ With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
254
+ *indices* might look like:
255
+
256
+ indices = [[0, 1],
257
+ [2, 3],
258
+ [4, 5],
259
+ [6, 7],
260
+ [8, 9],
261
+ [10, 11]]
262
+
263
+ The size of the TruncatedBPTTDataset instance will be 2,
264
+ and shard 1 will see items:
265
+
266
+ [(0, [data[4], data[6]]),
267
+ (1, [data[5], data[7]])]
268
+ """
269
+ indices = batchify(torch.arange(len(data)), global_batch_size)
270
+ assert indices.size(0) == global_batch_size
271
+
272
+ self.my_indices = indices[
273
+ shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
274
+ ]
275
+ assert self.my_indices.size(0) == bsz_per_shard
276
+
277
+ def __len__(self):
278
+ return self.my_indices.size(1)
279
+
280
+ def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
281
+ return (i, [self.data[idx] for idx in self.my_indices[:, i]])
fairseq/examples/backtranslation/README.md ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Understanding Back-Translation at Scale (Edunov et al., 2018)
2
+
3
+ This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
4
+
5
+ ## Pre-trained models
6
+
7
+ Model | Description | Dataset | Download
8
+ ---|---|---|---
9
+ `transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
10
+
11
+ ## Example usage (torch.hub)
12
+
13
+ We require a few additional Python dependencies for preprocessing:
14
+ ```bash
15
+ pip install subword_nmt sacremoses
16
+ ```
17
+
18
+ Then to generate translations from the full model ensemble:
19
+ ```python
20
+ import torch
21
+
22
+ # List available models
23
+ torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
24
+
25
+ # Load the WMT'18 En-De ensemble
26
+ en2de_ensemble = torch.hub.load(
27
+ 'pytorch/fairseq', 'transformer.wmt18.en-de',
28
+ checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
29
+ tokenizer='moses', bpe='subword_nmt')
30
+
31
+ # The ensemble contains 5 models
32
+ len(en2de_ensemble.models)
33
+ # 5
34
+
35
+ # Translate
36
+ en2de_ensemble.translate('Hello world!')
37
+ # 'Hallo Welt!'
38
+ ```
39
+
40
+ ## Training your own model (WMT'18 English-German)
41
+
42
+ The following instructions can be adapted to reproduce the models from the paper.
43
+
44
+
45
+ #### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
46
+
47
+ First download and preprocess the data:
48
+ ```bash
49
+ # Download and prepare the data
50
+ cd examples/backtranslation/
51
+ bash prepare-wmt18en2de.sh
52
+ cd ../..
53
+
54
+ # Binarize the data
55
+ TEXT=examples/backtranslation/wmt18_en_de
56
+ fairseq-preprocess \
57
+ --joined-dictionary \
58
+ --source-lang en --target-lang de \
59
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
60
+ --destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
61
+ --workers 20
62
+
63
+ # Copy the BPE code into the data-bin directory for future use
64
+ cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
65
+ ```
66
+
67
+ (Optionally) Train a baseline model (English-German) using just the parallel data:
68
+ ```bash
69
+ CHECKPOINT_DIR=checkpoints_en_de_parallel
70
+ fairseq-train --fp16 \
71
+ data-bin/wmt18_en_de \
72
+ --source-lang en --target-lang de \
73
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
74
+ --dropout 0.3 --weight-decay 0.0 \
75
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
76
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
77
+ --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
78
+ --max-tokens 3584 --update-freq 16 \
79
+ --max-update 30000 \
80
+ --save-dir $CHECKPOINT_DIR
81
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
82
+ # different number of GPUs.
83
+ ```
84
+
85
+ Average the last 10 checkpoints:
86
+ ```bash
87
+ python scripts/average_checkpoints.py \
88
+ --inputs $CHECKPOINT_DIR \
89
+ --num-epoch-checkpoints 10 \
90
+ --output $CHECKPOINT_DIR/checkpoint.avg10.pt
91
+ ```
92
+
93
+ Evaluate BLEU:
94
+ ```bash
95
+ # tokenized BLEU on newstest2017:
96
+ bash examples/backtranslation/tokenized_bleu.sh \
97
+ wmt17 \
98
+ en-de \
99
+ data-bin/wmt18_en_de \
100
+ data-bin/wmt18_en_de/code \
101
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
102
+ # BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
103
+ # compare to 29.46 in Table 1, which is also for tokenized BLEU
104
+
105
+ # generally it's better to report (detokenized) sacrebleu though:
106
+ bash examples/backtranslation/sacrebleu.sh \
107
+ wmt17 \
108
+ en-de \
109
+ data-bin/wmt18_en_de \
110
+ data-bin/wmt18_en_de/code \
111
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
112
+ # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
113
+ ```
114
+
115
+
116
+ #### Step 2. Back-translate monolingual German data
117
+
118
+ Train a reverse model (German-English) to do the back-translation:
119
+ ```bash
120
+ CHECKPOINT_DIR=checkpoints_de_en_parallel
121
+ fairseq-train --fp16 \
122
+ data-bin/wmt18_en_de \
123
+ --source-lang de --target-lang en \
124
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
125
+ --dropout 0.3 --weight-decay 0.0 \
126
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
127
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
128
+ --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
129
+ --max-tokens 3584 --update-freq 16 \
130
+ --max-update 30000 \
131
+ --save-dir $CHECKPOINT_DIR
132
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
133
+ # different number of GPUs.
134
+ ```
135
+
136
+ Let's evaluate the back-translation (BT) model to make sure it is well trained:
137
+ ```bash
138
+ bash examples/backtranslation/sacrebleu.sh \
139
+ wmt17 \
140
+ de-en \
141
+ data-bin/wmt18_en_de \
142
+ data-bin/wmt18_en_de/code \
143
+ $CHECKPOINT_DIR/checkpoint_best.py
144
+ # BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
145
+ # compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
146
+ ```
147
+
148
+ Next prepare the monolingual data:
149
+ ```bash
150
+ # Download and prepare the monolingual data
151
+ # By default the script samples 25M monolingual sentences, which after
152
+ # deduplication should be just over 24M sentences. These are split into 25
153
+ # shards, each with 1M sentences (except for the last shard).
154
+ cd examples/backtranslation/
155
+ bash prepare-de-monolingual.sh
156
+ cd ../..
157
+
158
+ # Binarize each shard of the monolingual data
159
+ TEXT=examples/backtranslation/wmt18_de_mono
160
+ for SHARD in $(seq -f "%02g" 0 24); do \
161
+ fairseq-preprocess \
162
+ --only-source \
163
+ --source-lang de --target-lang en \
164
+ --joined-dictionary \
165
+ --srcdict data-bin/wmt18_en_de/dict.de.txt \
166
+ --testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
167
+ --destdir data-bin/wmt18_de_mono/shard${SHARD} \
168
+ --workers 20; \
169
+ cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
170
+ done
171
+ ```
172
+
173
+ Now we're ready to perform back-translation over the monolingual data. The
174
+ following command generates via sampling, but it's possible to use greedy
175
+ decoding (`--beam 1`), beam search (`--beam 5`),
176
+ top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
177
+ ```bash
178
+ mkdir backtranslation_output
179
+ for SHARD in $(seq -f "%02g" 0 24); do \
180
+ fairseq-generate --fp16 \
181
+ data-bin/wmt18_de_mono/shard${SHARD} \
182
+ --path $CHECKPOINT_DIR/checkpoint_best.pt \
183
+ --skip-invalid-size-inputs-valid-test \
184
+ --max-tokens 4096 \
185
+ --sampling --beam 1 \
186
+ > backtranslation_output/sampling.shard${SHARD}.out; \
187
+ done
188
+ ```
189
+
190
+ After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
191
+ the back-translations and apply length ratio filters:
192
+ ```bash
193
+ python examples/backtranslation/extract_bt_data.py \
194
+ --minlen 1 --maxlen 250 --ratio 1.5 \
195
+ --output backtranslation_output/bt_data --srclang en --tgtlang de \
196
+ backtranslation_output/sampling.shard*.out
197
+
198
+ # Ensure lengths are the same:
199
+ # wc -l backtranslation_output/bt_data.{en,de}
200
+ # 21795614 backtranslation_output/bt_data.en
201
+ # 21795614 backtranslation_output/bt_data.de
202
+ # 43591228 total
203
+ ```
204
+
205
+ Binarize the filtered BT data and combine it with the parallel data:
206
+ ```bash
207
+ TEXT=backtranslation_output
208
+ fairseq-preprocess \
209
+ --source-lang en --target-lang de \
210
+ --joined-dictionary \
211
+ --srcdict data-bin/wmt18_en_de/dict.en.txt \
212
+ --trainpref $TEXT/bt_data \
213
+ --destdir data-bin/wmt18_en_de_bt \
214
+ --workers 20
215
+
216
+ # We want to train on the combined data, so we'll symlink the parallel + BT data
217
+ # in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
218
+ # and the BT data as "train1", so that fairseq will combine them automatically
219
+ # and so that we can use the `--upsample-primary` option to upsample the
220
+ # parallel data (if desired).
221
+ PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
222
+ BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
223
+ COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
224
+ mkdir -p $COMB_DATA
225
+ for LANG in en de; do \
226
+ ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
227
+ for EXT in bin idx; do \
228
+ ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
229
+ ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
230
+ ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
231
+ ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
232
+ done; \
233
+ done
234
+ ```
235
+
236
+
237
+ #### 3. Train an English-German model over the combined parallel + BT data
238
+
239
+ Finally we can train a model over the parallel + BT data:
240
+ ```bash
241
+ CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
242
+ fairseq-train --fp16 \
243
+ data-bin/wmt18_en_de_para_plus_bt \
244
+ --upsample-primary 16 \
245
+ --source-lang en --target-lang de \
246
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
247
+ --dropout 0.3 --weight-decay 0.0 \
248
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
249
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
250
+ --lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
251
+ --max-tokens 3584 --update-freq 16 \
252
+ --max-update 100000 \
253
+ --save-dir $CHECKPOINT_DIR
254
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
255
+ # different number of GPUs.
256
+ ```
257
+
258
+ Average the last 10 checkpoints:
259
+ ```bash
260
+ python scripts/average_checkpoints.py \
261
+ --inputs $CHECKPOINT_DIR \
262
+ --num-epoch-checkpoints 10 \
263
+ --output $CHECKPOINT_DIR/checkpoint.avg10.pt
264
+ ```
265
+
266
+ Evaluate BLEU:
267
+ ```bash
268
+ # tokenized BLEU on newstest2017:
269
+ bash examples/backtranslation/tokenized_bleu.sh \
270
+ wmt17 \
271
+ en-de \
272
+ data-bin/wmt18_en_de \
273
+ data-bin/wmt18_en_de/code \
274
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
275
+ # BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
276
+ # compare to 32.35 in Table 1, which is also for tokenized BLEU
277
+
278
+ # generally it's better to report (detokenized) sacrebleu:
279
+ bash examples/backtranslation/sacrebleu.sh \
280
+ wmt17 \
281
+ en-de \
282
+ data-bin/wmt18_en_de \
283
+ data-bin/wmt18_en_de/code \
284
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
285
+ # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
286
+ ```
287
+
288
+
289
+ ## Citation
290
+ ```bibtex
291
+ @inproceedings{edunov2018backtranslation,
292
+ title = {Understanding Back-Translation at Scale},
293
+ author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
294
+ booktitle = {Conference of the Association for Computational Linguistics (ACL)},
295
+ year = 2018,
296
+ }
297
+ ```
fairseq/examples/backtranslation/deduplicate_lines.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import fileinput
9
+ import hashlib
10
+ import sys
11
+ from multiprocessing import Pool
12
+
13
+
14
+ def get_hashes_and_lines(raw_line):
15
+ hash = hashlib.md5(raw_line).hexdigest()
16
+ return hash, raw_line
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--workers", type=int, default=10)
22
+ parser.add_argument("files", nargs="*", help="input files")
23
+ args = parser.parse_args()
24
+
25
+ seen = set()
26
+ with fileinput.input(args.files, mode="rb") as h:
27
+ pool = Pool(args.workers)
28
+ results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
29
+ for i, (hash, raw_line) in enumerate(results):
30
+ if hash not in seen:
31
+ seen.add(hash)
32
+ sys.stdout.buffer.write(raw_line)
33
+ if i % 1000000 == 0:
34
+ print(i, file=sys.stderr, end="", flush=True)
35
+ elif i % 100000 == 0:
36
+ print(".", file=sys.stderr, end="", flush=True)
37
+ print(file=sys.stderr, flush=True)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
fairseq/examples/backtranslation/extract_bt_data.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import fileinput
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(
15
+ description=(
16
+ "Extract back-translations from the stdout of fairseq-generate. "
17
+ "If there are multiply hypotheses for a source, we only keep the first one. "
18
+ )
19
+ )
20
+ parser.add_argument("--output", required=True, help="output prefix")
21
+ parser.add_argument(
22
+ "--srclang", required=True, help="source language (extracted from H-* lines)"
23
+ )
24
+ parser.add_argument(
25
+ "--tgtlang", required=True, help="target language (extracted from S-* lines)"
26
+ )
27
+ parser.add_argument("--minlen", type=int, help="min length filter")
28
+ parser.add_argument("--maxlen", type=int, help="max length filter")
29
+ parser.add_argument("--ratio", type=float, help="ratio filter")
30
+ parser.add_argument("files", nargs="*", help="input files")
31
+ args = parser.parse_args()
32
+
33
+ def validate(src, tgt):
34
+ srclen = len(src.split(" ")) if src != "" else 0
35
+ tgtlen = len(tgt.split(" ")) if tgt != "" else 0
36
+ if (
37
+ (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
38
+ or (
39
+ args.maxlen is not None
40
+ and (srclen > args.maxlen or tgtlen > args.maxlen)
41
+ )
42
+ or (
43
+ args.ratio is not None
44
+ and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
45
+ )
46
+ ):
47
+ return False
48
+ return True
49
+
50
+ def safe_index(toks, index, default):
51
+ try:
52
+ return toks[index]
53
+ except IndexError:
54
+ return default
55
+
56
+ with open(args.output + "." + args.srclang, "w") as src_h, open(
57
+ args.output + "." + args.tgtlang, "w"
58
+ ) as tgt_h:
59
+ for line in tqdm(fileinput.input(args.files)):
60
+ if line.startswith("S-"):
61
+ tgt = safe_index(line.rstrip().split("\t"), 1, "")
62
+ elif line.startswith("H-"):
63
+ if tgt is not None:
64
+ src = safe_index(line.rstrip().split("\t"), 2, "")
65
+ if validate(src, tgt):
66
+ print(src, file=src_h)
67
+ print(tgt, file=tgt_h)
68
+ tgt = None
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
fairseq/examples/backtranslation/prepare-de-monolingual.sh ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ SCRIPTS=mosesdecoder/scripts
4
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
5
+ NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
6
+ REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
7
+ BPEROOT=subword-nmt/subword_nmt
8
+
9
+
10
+ BPE_CODE=wmt18_en_de/code
11
+ SUBSAMPLE_SIZE=25000000
12
+ LANG=de
13
+
14
+
15
+ OUTDIR=wmt18_${LANG}_mono
16
+ orig=orig
17
+ tmp=$OUTDIR/tmp
18
+ mkdir -p $OUTDIR $tmp
19
+
20
+
21
+ URLS=(
22
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz"
23
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz"
24
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz"
25
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz"
26
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz"
27
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz"
28
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz"
29
+ "http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz"
30
+ "http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz"
31
+ "http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz"
32
+ "http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz"
33
+ )
34
+ FILES=(
35
+ "news.2007.de.shuffled.gz"
36
+ "news.2008.de.shuffled.gz"
37
+ "news.2009.de.shuffled.gz"
38
+ "news.2010.de.shuffled.gz"
39
+ "news.2011.de.shuffled.gz"
40
+ "news.2012.de.shuffled.gz"
41
+ "news.2013.de.shuffled.gz"
42
+ "news.2014.de.shuffled.v2.gz"
43
+ "news.2015.de.shuffled.gz"
44
+ "news.2016.de.shuffled.gz"
45
+ "news.2017.de.shuffled.deduped.gz"
46
+ )
47
+
48
+
49
+ cd $orig
50
+ for ((i=0;i<${#URLS[@]};++i)); do
51
+ file=${FILES[i]}
52
+ if [ -f $file ]; then
53
+ echo "$file already exists, skipping download"
54
+ else
55
+ url=${URLS[i]}
56
+ wget "$url"
57
+ fi
58
+ done
59
+ cd ..
60
+
61
+
62
+ if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
63
+ echo "found monolingual sample, skipping shuffle/sample/tokenize"
64
+ else
65
+ gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
66
+ | shuf -n $SUBSAMPLE_SIZE \
67
+ | perl $NORM_PUNC $LANG \
68
+ | perl $REM_NON_PRINT_CHAR \
69
+ | perl $TOKENIZER -threads 8 -a -l $LANG \
70
+ > $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
71
+ fi
72
+
73
+
74
+ if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
75
+ echo "found BPE monolingual sample, skipping BPE step"
76
+ else
77
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE \
78
+ < $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
79
+ > $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
80
+ fi
81
+
82
+
83
+ if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
84
+ echo "found deduplicated monolingual sample, skipping deduplication step"
85
+ else
86
+ python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
87
+ > $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
88
+ fi
89
+
90
+
91
+ if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then
92
+ echo "found sharded data, skipping sharding step"
93
+ else
94
+ split --lines 1000000 --numeric-suffixes \
95
+ --additional-suffix .${LANG} \
96
+ $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
97
+ $OUTDIR/bpe.monolingual.dedup.
98
+ fi
fairseq/examples/backtranslation/prepare-wmt18en2de.sh ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
3
+
4
+ echo 'Cloning Moses github repository (for tokenization scripts)...'
5
+ git clone https://github.com/moses-smt/mosesdecoder.git
6
+
7
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
8
+ git clone https://github.com/rsennrich/subword-nmt.git
9
+
10
+ SCRIPTS=mosesdecoder/scripts
11
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
12
+ CLEAN=$SCRIPTS/training/clean-corpus-n.perl
13
+ NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
14
+ REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
15
+ BPEROOT=subword-nmt/subword_nmt
16
+ BPE_TOKENS=32000
17
+
18
+ URLS=(
19
+ "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
20
+ "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
21
+ "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
22
+ "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
23
+ "http://data.statmt.org/wmt17/translation-task/dev.tgz"
24
+ "http://statmt.org/wmt14/test-full.tgz"
25
+ )
26
+ FILES=(
27
+ "training-parallel-europarl-v7.tgz"
28
+ "training-parallel-commoncrawl.tgz"
29
+ "training-parallel-nc-v13.tgz"
30
+ "rapid2016.tgz"
31
+ "dev.tgz"
32
+ "test-full.tgz"
33
+ )
34
+ CORPORA=(
35
+ "training/europarl-v7.de-en"
36
+ "commoncrawl.de-en"
37
+ "training-parallel-nc-v13/news-commentary-v13.de-en"
38
+ "rapid2016.de-en"
39
+ )
40
+
41
+ if [ ! -d "$SCRIPTS" ]; then
42
+ echo "Please set SCRIPTS variable correctly to point to Moses scripts."
43
+ exit 1
44
+ fi
45
+
46
+ OUTDIR=wmt18_en_de
47
+
48
+ src=en
49
+ tgt=de
50
+ lang=en-de
51
+ prep=$OUTDIR
52
+ tmp=$prep/tmp
53
+ orig=orig
54
+
55
+ mkdir -p $orig $tmp $prep
56
+
57
+ cd $orig
58
+
59
+ for ((i=0;i<${#URLS[@]};++i)); do
60
+ file=${FILES[i]}
61
+ if [ -f $file ]; then
62
+ echo "$file already exists, skipping download"
63
+ else
64
+ url=${URLS[i]}
65
+ wget "$url"
66
+ if [ -f $file ]; then
67
+ echo "$url successfully downloaded."
68
+ else
69
+ echo "$url not successfully downloaded."
70
+ exit 1
71
+ fi
72
+ if [ ${file: -4} == ".tgz" ]; then
73
+ tar zxvf $file
74
+ elif [ ${file: -4} == ".tar" ]; then
75
+ tar xvf $file
76
+ fi
77
+ fi
78
+ done
79
+ cd ..
80
+
81
+ echo "pre-processing train data..."
82
+ for l in $src $tgt; do
83
+ rm $tmp/train.tags.$lang.tok.$l
84
+ for f in "${CORPORA[@]}"; do
85
+ cat $orig/$f.$l | \
86
+ perl $NORM_PUNC $l | \
87
+ perl $REM_NON_PRINT_CHAR | \
88
+ perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
89
+ done
90
+ done
91
+
92
+ echo "pre-processing test data..."
93
+ for l in $src $tgt; do
94
+ if [ "$l" == "$src" ]; then
95
+ t="src"
96
+ else
97
+ t="ref"
98
+ fi
99
+ grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
100
+ sed -e 's/<seg id="[0-9]*">\s*//g' | \
101
+ sed -e 's/\s*<\/seg>\s*//g' | \
102
+ sed -e "s/\’/\'/g" | \
103
+ perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
104
+ echo ""
105
+ done
106
+
107
+ echo "splitting train and valid..."
108
+ for l in $src $tgt; do
109
+ awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
110
+ awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
111
+ done
112
+
113
+ TRAIN=$tmp/train.de-en
114
+ BPE_CODE=$prep/code
115
+ rm -f $TRAIN
116
+ for l in $src $tgt; do
117
+ cat $tmp/train.$l >> $TRAIN
118
+ done
119
+
120
+ echo "learn_bpe.py on ${TRAIN}..."
121
+ python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
122
+
123
+ for L in $src $tgt; do
124
+ for f in train.$L valid.$L test.$L; do
125
+ echo "apply_bpe.py to ${f}..."
126
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
127
+ done
128
+ done
129
+
130
+ perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
131
+ perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
132
+
133
+ for L in $src $tgt; do
134
+ cp $tmp/bpe.test.$L $prep/test.$L
135
+ done
fairseq/examples/backtranslation/sacrebleu.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 5 ]; then
4
+ echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
+ exit
6
+ fi
7
+
8
+
9
+ DATASET=$1
10
+ LANGPAIR=$2
11
+ DATABIN=$3
12
+ BPECODE=$4
13
+ MODEL=$5
14
+
15
+ SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
+ TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
+
18
+
19
+ BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
+ if [ ! -e $BPEROOT ]; then
21
+ BPEROOT=subword-nmt/subword_nmt
22
+ if [ ! -e $BPEROOT ]; then
23
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
+ git clone https://github.com/rsennrich/subword-nmt.git
25
+ fi
26
+ fi
27
+
28
+
29
+ sacrebleu -t $DATASET -l $LANGPAIR --echo src \
30
+ | sacremoses tokenize -a -l $SRCLANG -q \
31
+ | python $BPEROOT/apply_bpe.py -c $BPECODE \
32
+ | fairseq-interactive $DATABIN --path $MODEL \
33
+ -s $SRCLANG -t $TGTLANG \
34
+ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
35
+ | grep ^H- | cut -f 3- \
36
+ | sacremoses detokenize -l $TGTLANG -q \
37
+ | sacrebleu -t $DATASET -l $LANGPAIR
fairseq/examples/backtranslation/tokenized_bleu.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 5 ]; then
4
+ echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
+ exit
6
+ fi
7
+
8
+
9
+ DATASET=$1
10
+ LANGPAIR=$2
11
+ DATABIN=$3
12
+ BPECODE=$4
13
+ MODEL=$5
14
+
15
+ SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
+ TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
+
18
+
19
+ BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
+ if [ ! -e $BPEROOT ]; then
21
+ BPEROOT=subword-nmt/subword_nmt
22
+ if [ ! -e $BPEROOT ]; then
23
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
+ git clone https://github.com/rsennrich/subword-nmt.git
25
+ fi
26
+ fi
27
+
28
+
29
+ TMP_REF=$(mktemp)
30
+
31
+ sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
32
+ | sacremoses normalize -l $TGTLANG -q \
33
+ | sacremoses tokenize -a -l $TGTLANG -q \
34
+ > $TMP_REF
35
+
36
+ sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
37
+ | sacremoses normalize -l $SRCLANG -q \
38
+ | sacremoses tokenize -a -l $SRCLANG -q \
39
+ | python $BPEROOT/apply_bpe.py -c $BPECODE \
40
+ | fairseq-interactive $DATABIN --path $MODEL \
41
+ -s $SRCLANG -t $TGTLANG \
42
+ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
43
+ | grep ^H- | cut -f 3- \
44
+ | fairseq-score --ref $TMP_REF
45
+
46
+ rm -f $TMP_REF
fairseq/examples/bart/README.glue.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BART on GLUE tasks
2
+
3
+ ### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
4
+ ```bash
5
+ wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
6
+ python download_glue_data.py --data_dir glue_data --tasks all
7
+ ```
8
+
9
+ ### 2) Preprocess GLUE task data (same as RoBERTa):
10
+ ```bash
11
+ ./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
12
+ ```
13
+ `glue_task_name` is one of the following:
14
+ `{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
15
+ Use `ALL` for preprocessing all the glue tasks.
16
+
17
+ ### 3) Fine-tuning on GLUE task:
18
+ Example fine-tuning cmd for `RTE` task
19
+ ```bash
20
+ TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
21
+ WARMUP_UPDATES=61 # 6 percent of the number of updates
22
+ LR=1e-05 # Peak LR for polynomial LR scheduler.
23
+ NUM_CLASSES=2
24
+ MAX_SENTENCES=16 # Batch size.
25
+ BART_PATH=/path/to/bart/model.pt
26
+
27
+ CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
28
+ --restore-file $BART_PATH \
29
+ --batch-size $MAX_SENTENCES \
30
+ --max-tokens 4400 \
31
+ --task sentence_prediction \
32
+ --add-prev-output-tokens \
33
+ --layernorm-embedding \
34
+ --share-all-embeddings \
35
+ --share-decoder-input-output-embed \
36
+ --reset-optimizer --reset-dataloader --reset-meters \
37
+ --required-batch-size-multiple 1 \
38
+ --init-token 0 \
39
+ --arch bart_large \
40
+ --criterion sentence_prediction \
41
+ --num-classes $NUM_CLASSES \
42
+ --dropout 0.1 --attention-dropout 0.1 \
43
+ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
44
+ --clip-norm 0.0 \
45
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
46
+ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
47
+ --max-epoch 10 \
48
+ --find-unused-parameters \
49
+ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
50
+ ```
51
+
52
+ For each of the GLUE task, you will need to use following cmd-line arguments:
53
+
54
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
55
+ ---|---|---|---|---|---|---|---|---
56
+ `--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
57
+ `--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
58
+ `bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
59
+ `--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
60
+ `--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
61
+
62
+ For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
63
+
64
+ **Note:**
65
+
66
+ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
67
+
68
+ b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
69
+
70
+ ### Inference on GLUE task
71
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
72
+
73
+ ```python
74
+ from fairseq.models.bart import BARTModel
75
+
76
+ bart = BARTModel.from_pretrained(
77
+ 'checkpoints/',
78
+ checkpoint_file='checkpoint_best.pt',
79
+ data_name_or_path='RTE-bin'
80
+ )
81
+
82
+ label_fn = lambda label: bart.task.label_dictionary.string(
83
+ [label + bart.task.label_dictionary.nspecial]
84
+ )
85
+ ncorrect, nsamples = 0, 0
86
+ bart.cuda()
87
+ bart.eval()
88
+ with open('glue_data/RTE/dev.tsv') as fin:
89
+ fin.readline()
90
+ for index, line in enumerate(fin):
91
+ tokens = line.strip().split('\t')
92
+ sent1, sent2, target = tokens[1], tokens[2], tokens[3]
93
+ tokens = bart.encode(sent1, sent2)
94
+ prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
95
+ prediction_label = label_fn(prediction)
96
+ ncorrect += int(prediction_label == target)
97
+ nsamples += 1
98
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
99
+ ```
fairseq/examples/bart/README.md ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
2
+
3
+ [https://arxiv.org/abs/1910.13461](https://arxiv.org/abs/1910.13461)
4
+
5
+ ## Introduction
6
+
7
+ BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
8
+
9
+ ## Pre-trained models
10
+
11
+ Model | Description | # params | Download
12
+ ---|---|---|---
13
+ `bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz)
14
+ `bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
15
+ `bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
16
+ `bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
17
+ `bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz)
18
+
19
+ ## Results
20
+
21
+ **[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
22
+ _(dev set, single model, single-task finetuning)_
23
+
24
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
25
+ ---|---|---|---|---|---|---|---|---
26
+ `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
27
+ `bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
28
+
29
+ **[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
30
+ _(dev set, no additional data used)_
31
+
32
+ Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
33
+ ---|---|---
34
+ `roberta.large` | 88.9/94.6 | 86.5/89.4
35
+ `bart.large` | 88.8/94.6 | 86.1/89.2
36
+
37
+ **[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
38
+ _(test set, no additional data used)_
39
+
40
+ Model | R1 | R2 | RL
41
+ ---|---|---|---
42
+ `BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
43
+ `bart.large` | 44.16 | 21.28 | 40.90
44
+
45
+ ## Example usage
46
+
47
+ ##### Load BART from torch.hub (PyTorch >= 1.1):
48
+ ```python
49
+ import torch
50
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large')
51
+ bart.eval() # disable dropout (or leave in train mode to finetune)
52
+ ```
53
+
54
+ ##### Load BART (for PyTorch 1.0 or custom models):
55
+ ```python
56
+ # Download bart.large model
57
+ wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
58
+ tar -xzvf bart.large.tar.gz
59
+
60
+ # Load the model in fairseq
61
+ from fairseq.models.bart import BARTModel
62
+ bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
63
+ bart.eval() # disable dropout (or leave in train mode to finetune)
64
+ ```
65
+
66
+ ##### Apply Byte-Pair Encoding (BPE) to input text:
67
+ ```python
68
+ tokens = bart.encode('Hello world!')
69
+ assert tokens.tolist() == [0, 31414, 232, 328, 2]
70
+ bart.decode(tokens) # 'Hello world!'
71
+ ```
72
+
73
+ ##### Extract features from BART:
74
+ ```python
75
+ # Extract the last layer's features
76
+ last_layer_features = bart.extract_features(tokens)
77
+ assert last_layer_features.size() == torch.Size([1, 5, 1024])
78
+
79
+ # Extract all layer's features from decoder (layer 0 is the embedding layer)
80
+ all_layers = bart.extract_features(tokens, return_all_hiddens=True)
81
+ assert len(all_layers) == 13
82
+ assert torch.all(all_layers[-1] == last_layer_features)
83
+ ```
84
+
85
+ ##### Use BART for sentence-pair classification tasks:
86
+ ```python
87
+ # Download BART already finetuned for MNLI
88
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
89
+ bart.eval() # disable dropout for evaluation
90
+
91
+ # Encode a pair of sentences and make a prediction
92
+ tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
93
+ bart.predict('mnli', tokens).argmax() # 0: contradiction
94
+
95
+ # Encode another pair of sentences
96
+ tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
97
+ bart.predict('mnli', tokens).argmax() # 2: entailment
98
+ ```
99
+
100
+ ##### Register a new (randomly initialized) classification head:
101
+ ```python
102
+ bart.register_classification_head('new_task', num_classes=3)
103
+ logprobs = bart.predict('new_task', tokens)
104
+ ```
105
+
106
+ ##### Batched prediction:
107
+ ```python
108
+ import torch
109
+ from fairseq.data.data_utils import collate_tokens
110
+
111
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
112
+ bart.eval()
113
+
114
+ batch_of_pairs = [
115
+ ['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
116
+ ['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
117
+ ]
118
+
119
+ batch = collate_tokens(
120
+ [bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
121
+ )
122
+
123
+ logprobs = bart.predict('mnli', batch)
124
+ print(logprobs.argmax(dim=1))
125
+ # tensor([0, 2])
126
+ ```
127
+
128
+ ##### Using the GPU:
129
+ ```python
130
+ bart.cuda()
131
+ bart.predict('new_task', tokens)
132
+ ```
133
+
134
+ #### Filling masks:
135
+
136
+ BART can be used to fill multiple `<mask>` tokens in the input.
137
+ ```python
138
+ bart = torch.hub.load('pytorch/fairseq', 'bart.base')
139
+ bart.eval()
140
+ bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10)
141
+ # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]
142
+ ```
143
+
144
+ Note that by default we enforce the output length to match the input length.
145
+ This can be disabled by setting ``match_source_len=False``:
146
+ ```
147
+ bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10, match_source_len=False)
148
+ # [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]]
149
+ ```
150
+
151
+ Example code to fill masks for a batch of sentences using GPU
152
+ ```
153
+ bart.cuda()
154
+ bart.fill_mask(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.'], topk=3, beam=10)
155
+ # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)),
156
+ ('The dog was asleep on the couch', tensor(-0.6796))]]
157
+ ```
158
+
159
+ #### Evaluating the `bart.large.mnli` model:
160
+
161
+ Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
162
+ ```python
163
+ label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
164
+ ncorrect, nsamples = 0, 0
165
+ bart.cuda()
166
+ bart.eval()
167
+ with open('glue_data/MNLI/dev_matched.tsv') as fin:
168
+ fin.readline()
169
+ for index, line in enumerate(fin):
170
+ tokens = line.strip().split('\t')
171
+ sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
172
+ tokens = bart.encode(sent1, sent2)
173
+ prediction = bart.predict('mnli', tokens).argmax().item()
174
+ prediction_label = label_map[prediction]
175
+ ncorrect += int(prediction_label == target)
176
+ nsamples += 1
177
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
178
+ # Expected output: 0.9010
179
+ ```
180
+
181
+ #### Evaluating the `bart.large.cnn` model:
182
+ - Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
183
+ - For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores
184
+ - `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search.
185
+ In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`.
186
+
187
+ In `fairseq`, summaries can be generated using:
188
+
189
+ ```bash
190
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
191
+ python examples/bart/summarize.py \
192
+ --model-dir pytorch/fairseq \
193
+ --model-file bart.large.cnn \
194
+ --src cnn_dm/test.source \
195
+ --out cnn_dm/test.hypo
196
+ ```
197
+
198
+ For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
199
+
200
+ ```bash
201
+ export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
202
+
203
+ # Tokenize hypothesis and target files.
204
+ cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
205
+ cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
206
+ files2rouge test.hypo.tokenized test.hypo.target
207
+ # Expected output: (ROUGE-2 Average_F: 0.21238)
208
+ ```
209
+
210
+
211
+ ## Finetuning
212
+
213
+ - [Finetuning on GLUE](README.glue.md)
214
+ - [Finetuning on CNN-DM](README.summarization.md)
215
+
216
+ ## Citation
217
+
218
+ ```bibtex
219
+ @article{lewis2019bart,
220
+ title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
221
+ Language Generation, Translation, and Comprehension},
222
+ author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
223
+ Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
224
+ and Luke Zettlemoyer },
225
+ journal={arXiv preprint arXiv:1910.13461},
226
+ year = {2019},
227
+ }
228
+ ```
fairseq/examples/bart/README.summarization.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BART on CNN-Dailymail summarization task
2
+
3
+ ### 1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
4
+
5
+ Follow the instructions [here](https://github.com/abisee/cnn-dailymail) to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in [this issue](https://github.com/pytorch/fairseq/issues/1391) or check out the code [here](https://github.com/artmatsak/cnn-dailymail).
6
+
7
+ Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to download the original Extreme Summarization datasets, or check out the code [here](https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset), Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
8
+
9
+ ### 2) BPE preprocess:
10
+
11
+ ```bash
12
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
13
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
14
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
15
+
16
+ TASK=cnn_dm
17
+ for SPLIT in train val
18
+ do
19
+ for LANG in source target
20
+ do
21
+ python -m examples.roberta.multiprocessing_bpe_encoder \
22
+ --encoder-json encoder.json \
23
+ --vocab-bpe vocab.bpe \
24
+ --inputs "$TASK/$SPLIT.$LANG" \
25
+ --outputs "$TASK/$SPLIT.bpe.$LANG" \
26
+ --workers 60 \
27
+ --keep-empty;
28
+ done
29
+ done
30
+ ```
31
+
32
+ ### 3) Binarize dataset:
33
+ ```bash
34
+ fairseq-preprocess \
35
+ --source-lang "source" \
36
+ --target-lang "target" \
37
+ --trainpref "${TASK}/train.bpe" \
38
+ --validpref "${TASK}/val.bpe" \
39
+ --destdir "${TASK}-bin/" \
40
+ --workers 60 \
41
+ --srcdict dict.txt \
42
+ --tgtdict dict.txt;
43
+ ```
44
+
45
+ ### 4) Fine-tuning on CNN-DM summarization task:
46
+ Example fine-tuning CNN-DM
47
+ ```bash
48
+ TOTAL_NUM_UPDATES=20000
49
+ WARMUP_UPDATES=500
50
+ LR=3e-05
51
+ MAX_TOKENS=2048
52
+ UPDATE_FREQ=4
53
+ BART_PATH=/path/to/bart/model.pt
54
+
55
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \
56
+ --restore-file $BART_PATH \
57
+ --max-tokens $MAX_TOKENS \
58
+ --task translation \
59
+ --source-lang source --target-lang target \
60
+ --truncate-source \
61
+ --layernorm-embedding \
62
+ --share-all-embeddings \
63
+ --share-decoder-input-output-embed \
64
+ --reset-optimizer --reset-dataloader --reset-meters \
65
+ --required-batch-size-multiple 1 \
66
+ --arch bart_large \
67
+ --criterion label_smoothed_cross_entropy \
68
+ --label-smoothing 0.1 \
69
+ --dropout 0.1 --attention-dropout 0.1 \
70
+ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
71
+ --clip-norm 0.1 \
72
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
73
+ --fp16 --update-freq $UPDATE_FREQ \
74
+ --skip-invalid-size-inputs-valid-test \
75
+ --find-unused-parameters;
76
+ ```
77
+ Above is expected to run on `1` node with `8 32gb-V100`.
78
+ Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
79
+
80
+ Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
81
+
82
+ ### Inference for CNN-DM test data using above trained checkpoint.
83
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example
84
+
85
+ ```bash
86
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
87
+ python examples/bart/summarize.py \
88
+ --model-dir checkpoints \
89
+ --model-file checkpoint_best.pt \
90
+ --src cnn_dm/test.source \
91
+ --out cnn_dm/test.hypo
92
+ ```
93
+ For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10:
94
+ ```bash
95
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
96
+ python examples/bart/summarize.py \
97
+ --model-dir checkpoints \
98
+ --model-file checkpoint_best.pt \
99
+ --src cnn_dm/test.source \
100
+ --out cnn_dm/test.hypo \
101
+ --xsum-kwargs
102
+ ```
fairseq/examples/bart/summarize.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq.models.bart import BARTModel
8
+ import argparse
9
+
10
+ XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
11
+ CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
12
+
13
+
14
+ @torch.no_grad()
15
+ def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs):
16
+ count = 1
17
+
18
+ # if n_obs is not None: bsz = min(bsz, n_obs)
19
+
20
+ with open(infile) as source, open(outfile, "w") as fout:
21
+ sline = source.readline().strip()
22
+ slines = [sline]
23
+ for sline in source:
24
+ if n_obs is not None and count > n_obs:
25
+ break
26
+ if count % bsz == 0:
27
+ hypotheses_batch = bart.sample(slines, **eval_kwargs)
28
+ for hypothesis in hypotheses_batch:
29
+ fout.write(hypothesis + "\n")
30
+ fout.flush()
31
+ slines = []
32
+
33
+ slines.append(sline.strip())
34
+ count += 1
35
+
36
+ if slines != []:
37
+ hypotheses_batch = bart.sample(slines, **eval_kwargs)
38
+ for hypothesis in hypotheses_batch:
39
+ fout.write(hypothesis + "\n")
40
+ fout.flush()
41
+
42
+
43
+ def main():
44
+ """
45
+ Usage::
46
+
47
+ python examples/bart/summarize.py \
48
+ --model-dir $HOME/bart.large.cnn \
49
+ --model-file model.pt \
50
+ --src $HOME/data-bin/cnn_dm/test.source
51
+ """
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument(
54
+ "--model-dir",
55
+ required=True,
56
+ type=str,
57
+ default="bart.large.cnn/",
58
+ help="path containing model file and src_dict.txt",
59
+ )
60
+ parser.add_argument(
61
+ "--model-file",
62
+ default="checkpoint_best.pt",
63
+ help="where in model_dir are weights saved",
64
+ )
65
+ parser.add_argument(
66
+ "--src", default="test.source", help="text to summarize", type=str
67
+ )
68
+ parser.add_argument(
69
+ "--out", default="test.hypo", help="where to save summaries", type=str
70
+ )
71
+ parser.add_argument("--bsz", default=32, help="where to save summaries", type=int)
72
+ parser.add_argument(
73
+ "--n", default=None, help="how many examples to summarize", type=int
74
+ )
75
+ parser.add_argument(
76
+ "--xsum-kwargs",
77
+ action="store_true",
78
+ default=False,
79
+ help="if true use XSUM_KWARGS else CNN_KWARGS",
80
+ )
81
+ args = parser.parse_args()
82
+ eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS
83
+ if args.model_dir == "pytorch/fairseq":
84
+ bart = torch.hub.load("pytorch/fairseq", args.model_file)
85
+ else:
86
+ bart = BARTModel.from_pretrained(
87
+ args.model_dir,
88
+ checkpoint_file=args.model_file,
89
+ data_name_or_path=args.model_dir,
90
+ )
91
+ bart = bart.eval()
92
+ if torch.cuda.is_available():
93
+ bart = bart.cuda().half()
94
+ generate(
95
+ bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs
96
+ )
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
fairseq/examples/byte_level_bpe/README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Neural Machine Translation with Byte-Level Subwords
2
+
3
+ https://arxiv.org/abs/1909.03341
4
+
5
+ We provide an implementation of byte-level byte-pair encoding (BBPE), taking IWSLT 2017 Fr-En translation as
6
+ example.
7
+
8
+ ## Data
9
+ Get data and generate fairseq binary dataset:
10
+ ```bash
11
+ bash ./get_data.sh
12
+ ```
13
+
14
+ ## Model Training
15
+ Train Transformer model with Bi-GRU embedding contextualization (implemented in `gru_transformer.py`):
16
+ ```bash
17
+ # VOCAB=bytes
18
+ # VOCAB=chars
19
+ VOCAB=bbpe2048
20
+ # VOCAB=bpe2048
21
+ # VOCAB=bbpe4096
22
+ # VOCAB=bpe4096
23
+ # VOCAB=bpe16384
24
+ ```
25
+ ```bash
26
+ fairseq-train "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
27
+ --arch gru_transformer --encoder-layers 2 --decoder-layers 2 --dropout 0.3 --share-all-embeddings \
28
+ --optimizer adam --adam-betas '(0.9, 0.98)' \
29
+ --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
30
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
31
+ --log-format 'simple' --log-interval 100 --save-dir "checkpoints/${VOCAB}" \
32
+ --batch-size 100 --max-update 100000 --update-freq 2
33
+ ```
34
+
35
+ ## Generation
36
+ `fairseq-generate` requires bytes (BBPE) decoder to convert byte-level representation back to characters:
37
+ ```bash
38
+ # BPE=--bpe bytes
39
+ # BPE=--bpe characters
40
+ BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe2048.model
41
+ # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe2048.model
42
+ # BPE=--bpe byte_bpe --sentencepiece-model-path data/spm_bbpe4096.model
43
+ # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe4096.model
44
+ # BPE=--bpe sentencepiece --sentencepiece-model data/spm_bpe16384.model
45
+ ```
46
+
47
+ ```bash
48
+ fairseq-generate "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
49
+ --source-lang fr --gen-subset test --sacrebleu --path "checkpoints/${VOCAB}/checkpoint_last.pt" \
50
+ --tokenizer moses --moses-target-lang en ${BPE}
51
+ ```
52
+ When using `fairseq-interactive`, bytes (BBPE) encoder/decoder is required to tokenize input data and detokenize model predictions:
53
+ ```bash
54
+ fairseq-interactive "data/bin_${VOCAB}" --task translation --user-dir examples/byte_level_bpe/gru_transformer \
55
+ --path "checkpoints/${VOCAB}/checkpoint_last.pt" --input data/test.fr --tokenizer moses --moses-source-lang fr \
56
+ --moses-target-lang en ${BPE} --buffer-size 1000 --max-tokens 10000
57
+ ```
58
+
59
+ ## Results
60
+ | Vocabulary | Model | BLEU |
61
+ |:-------------:|:-------------:|:-------------:|
62
+ | Joint BPE 16k ([Kudo, 2018](https://arxiv.org/abs/1804.10959)) | 512d LSTM 2+2 | 33.81 |
63
+ | Joint BPE 16k | Transformer base 2+2 (w/ GRU) | 36.64 (36.72) |
64
+ | Joint BPE 4k | Transformer base 2+2 (w/ GRU) | 35.49 (36.10) |
65
+ | Joint BBPE 4k | Transformer base 2+2 (w/ GRU) | 35.61 (35.82) |
66
+ | Joint BPE 2k | Transformer base 2+2 (w/ GRU) | 34.87 (36.13) |
67
+ | Joint BBPE 2k | Transformer base 2+2 (w/ GRU) | 34.98 (35.43) |
68
+ | Characters | Transformer base 2+2 (w/ GRU) | 31.78 (33.30) |
69
+ | Bytes | Transformer base 2+2 (w/ GRU) | 31.57 (33.62) |
70
+
71
+
72
+ ## Citation
73
+ ```
74
+ @misc{wang2019neural,
75
+ title={Neural Machine Translation with Byte-Level Subwords},
76
+ author={Changhan Wang and Kyunghyun Cho and Jiatao Gu},
77
+ year={2019},
78
+ eprint={1909.03341},
79
+ archivePrefix={arXiv},
80
+ primaryClass={cs.CL}
81
+ }
82
+ ```
83
+
84
+
85
+ ## Contact
86
+ Changhan Wang ([changhan@fb.com](mailto:changhan@fb.com)),
87
+ Kyunghyun Cho ([kyunghyuncho@fb.com](mailto:kyunghyuncho@fb.com)),
88
+ Jiatao Gu ([jgu@fb.com](mailto:jgu@fb.com))
fairseq/examples/byte_level_bpe/get_bitext.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import argparse
8
+ import os
9
+ import os.path as op
10
+ from collections import namedtuple
11
+ from multiprocessing import cpu_count
12
+ from typing import List, Optional
13
+
14
+ import sentencepiece as sp
15
+ from fairseq.data.encoders.byte_bpe import ByteBPE
16
+ from fairseq.data.encoders.byte_utils import byte_encode
17
+ from fairseq.data.encoders.bytes import Bytes
18
+ from fairseq.data.encoders.characters import Characters
19
+ from fairseq.data.encoders.moses_tokenizer import MosesTokenizer
20
+ from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
21
+
22
+
23
+ SPLITS = ["train", "valid", "test"]
24
+
25
+
26
+ def _convert_xml(in_path: str, out_path: str):
27
+ with open(in_path) as f, open(out_path, "w") as f_o:
28
+ for s in f:
29
+ ss = s.strip()
30
+ if not ss.startswith("<seg"):
31
+ continue
32
+ ss = ss.replace("</seg>", "").split('">')
33
+ assert len(ss) == 2
34
+ f_o.write(ss[1].strip() + "\n")
35
+
36
+
37
+ def _convert_train(in_path: str, out_path: str):
38
+ with open(in_path) as f, open(out_path, "w") as f_o:
39
+ for s in f:
40
+ ss = s.strip()
41
+ if ss.startswith("<"):
42
+ continue
43
+ f_o.write(ss.strip() + "\n")
44
+
45
+
46
+ def _get_bytes(in_path: str, out_path: str):
47
+ with open(in_path) as f, open(out_path, "w") as f_o:
48
+ for s in f:
49
+ f_o.write(Bytes.encode(s.strip()) + "\n")
50
+
51
+
52
+ def _get_chars(in_path: str, out_path: str):
53
+ with open(in_path) as f, open(out_path, "w") as f_o:
54
+ for s in f:
55
+ f_o.write(Characters.encode(s.strip()) + "\n")
56
+
57
+
58
+ def pretokenize(in_path: str, out_path: str, src: str, tgt: str):
59
+ Args = namedtuple(
60
+ "Args",
61
+ [
62
+ "moses_source_lang",
63
+ "moses_target_lang",
64
+ "moses_no_dash_splits",
65
+ "moses_no_escape",
66
+ ],
67
+ )
68
+ args = Args(
69
+ moses_source_lang=src,
70
+ moses_target_lang=tgt,
71
+ moses_no_dash_splits=False,
72
+ moses_no_escape=False,
73
+ )
74
+ pretokenizer = MosesTokenizer(args)
75
+ with open(in_path) as f, open(out_path, "w") as f_o:
76
+ for s in f:
77
+ f_o.write(pretokenizer.encode(s.strip()) + "\n")
78
+
79
+
80
+ def _convert_to_bchar(in_path_prefix: str, src: str, tgt: str, out_path: str):
81
+ with open(out_path, "w") as f_o:
82
+ for lang in [src, tgt]:
83
+ with open(f"{in_path_prefix}.{lang}") as f:
84
+ for s in f:
85
+ f_o.write(byte_encode(s.strip()) + "\n")
86
+
87
+
88
+ def _get_bpe(in_path: str, model_prefix: str, vocab_size: int):
89
+ arguments = [
90
+ f"--input={in_path}",
91
+ f"--model_prefix={model_prefix}",
92
+ f"--model_type=bpe",
93
+ f"--vocab_size={vocab_size}",
94
+ "--character_coverage=1.0",
95
+ "--normalization_rule_name=identity",
96
+ f"--num_threads={cpu_count()}",
97
+ ]
98
+ sp.SentencePieceTrainer.Train(" ".join(arguments))
99
+
100
+
101
+ def _apply_bbpe(model_path: str, in_path: str, out_path: str):
102
+ Args = namedtuple("Args", ["sentencepiece_model_path"])
103
+ args = Args(sentencepiece_model_path=model_path)
104
+ tokenizer = ByteBPE(args)
105
+ with open(in_path) as f, open(out_path, "w") as f_o:
106
+ for s in f:
107
+ f_o.write(tokenizer.encode(s.strip()) + "\n")
108
+
109
+
110
+ def _apply_bpe(model_path: str, in_path: str, out_path: str):
111
+ Args = namedtuple("Args", ["sentencepiece_model"])
112
+ args = Args(sentencepiece_model=model_path)
113
+ tokenizer = SentencepieceBPE(args)
114
+ with open(in_path) as f, open(out_path, "w") as f_o:
115
+ for s in f:
116
+ f_o.write(tokenizer.encode(s.strip()) + "\n")
117
+
118
+
119
+ def _concat_files(in_paths: List[str], out_path: str):
120
+ with open(out_path, "w") as f_o:
121
+ for p in in_paths:
122
+ with open(p) as f:
123
+ for r in f:
124
+ f_o.write(r)
125
+
126
+
127
+ def preprocess_iwslt17(
128
+ root: str,
129
+ src: str,
130
+ tgt: str,
131
+ bpe_size: Optional[int],
132
+ need_chars: bool,
133
+ bbpe_size: Optional[int],
134
+ need_bytes: bool,
135
+ ):
136
+ # extract bitext
137
+ in_root = op.join(root, f"{src}-{tgt}")
138
+ for lang in [src, tgt]:
139
+ _convert_train(
140
+ op.join(in_root, f"train.tags.{src}-{tgt}.{lang}"),
141
+ op.join(root, f"train.{lang}"),
142
+ )
143
+ _convert_xml(
144
+ op.join(in_root, f"IWSLT17.TED.dev2010.{src}-{tgt}.{lang}.xml"),
145
+ op.join(root, f"valid.{lang}"),
146
+ )
147
+ _convert_xml(
148
+ op.join(in_root, f"IWSLT17.TED.tst2015.{src}-{tgt}.{lang}.xml"),
149
+ op.join(root, f"test.{lang}"),
150
+ )
151
+ # pre-tokenize
152
+ for lang in [src, tgt]:
153
+ for split in SPLITS:
154
+ pretokenize(
155
+ op.join(root, f"{split}.{lang}"),
156
+ op.join(root, f"{split}.moses.{lang}"),
157
+ src,
158
+ tgt,
159
+ )
160
+ # tokenize with BPE vocabulary
161
+ if bpe_size is not None:
162
+ # learn vocabulary
163
+ concated_train_path = op.join(root, "train.all")
164
+ _concat_files(
165
+ [op.join(root, "train.moses.fr"), op.join(root, "train.moses.en")],
166
+ concated_train_path,
167
+ )
168
+ bpe_model_prefix = op.join(root, f"spm_bpe{bpe_size}")
169
+ _get_bpe(concated_train_path, bpe_model_prefix, bpe_size)
170
+ os.remove(concated_train_path)
171
+ # apply
172
+ for lang in [src, tgt]:
173
+ for split in SPLITS:
174
+ _apply_bpe(
175
+ bpe_model_prefix + ".model",
176
+ op.join(root, f"{split}.moses.{lang}"),
177
+ op.join(root, f"{split}.moses.bpe{bpe_size}.{lang}"),
178
+ )
179
+ # tokenize with bytes vocabulary
180
+ if need_bytes:
181
+ for lang in [src, tgt]:
182
+ for split in SPLITS:
183
+ _get_bytes(
184
+ op.join(root, f"{split}.moses.{lang}"),
185
+ op.join(root, f"{split}.moses.bytes.{lang}"),
186
+ )
187
+ # tokenize with characters vocabulary
188
+ if need_chars:
189
+ for lang in [src, tgt]:
190
+ for split in SPLITS:
191
+ _get_chars(
192
+ op.join(root, f"{split}.moses.{lang}"),
193
+ op.join(root, f"{split}.moses.chars.{lang}"),
194
+ )
195
+ # tokenize with byte-level BPE vocabulary
196
+ if bbpe_size is not None:
197
+ # learn vocabulary
198
+ bchar_path = op.join(root, "train.bchar")
199
+ _convert_to_bchar(op.join(root, "train.moses"), src, tgt, bchar_path)
200
+ bbpe_model_prefix = op.join(root, f"spm_bbpe{bbpe_size}")
201
+ _get_bpe(bchar_path, bbpe_model_prefix, bbpe_size)
202
+ os.remove(bchar_path)
203
+ # apply
204
+ for lang in [src, tgt]:
205
+ for split in SPLITS:
206
+ _apply_bbpe(
207
+ bbpe_model_prefix + ".model",
208
+ op.join(root, f"{split}.moses.{lang}"),
209
+ op.join(root, f"{split}.moses.bbpe{bbpe_size}.{lang}"),
210
+ )
211
+
212
+
213
+ def main():
214
+ parser = argparse.ArgumentParser()
215
+ parser.add_argument("--root", type=str, default="data")
216
+ parser.add_argument(
217
+ "--bpe-vocab",
218
+ default=None,
219
+ type=int,
220
+ help="Generate tokenized bitext with BPE of size K."
221
+ "Default to None (disabled).",
222
+ )
223
+ parser.add_argument(
224
+ "--bbpe-vocab",
225
+ default=None,
226
+ type=int,
227
+ help="Generate tokenized bitext with BBPE of size K."
228
+ "Default to None (disabled).",
229
+ )
230
+ parser.add_argument(
231
+ "--byte-vocab",
232
+ action="store_true",
233
+ help="Generate tokenized bitext with bytes vocabulary",
234
+ )
235
+ parser.add_argument(
236
+ "--char-vocab",
237
+ action="store_true",
238
+ help="Generate tokenized bitext with chars vocabulary",
239
+ )
240
+ args = parser.parse_args()
241
+
242
+ preprocess_iwslt17(
243
+ args.root,
244
+ "fr",
245
+ "en",
246
+ args.bpe_vocab,
247
+ args.char_vocab,
248
+ args.bbpe_vocab,
249
+ args.byte_vocab,
250
+ )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
fairseq/examples/byte_level_bpe/get_data.sh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ PY_BIN_ROOT=
9
+
10
+ # PyPI dependency
11
+ ${PY_BIN_ROOT}pip install sentencepiece sacremoses
12
+
13
+ # Get data
14
+ if [ ! -d "data" ]; then
15
+ mkdir data
16
+ fi
17
+
18
+ if [ ! -f "data/fr-en.tgz" ]; then
19
+ wget https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz -P data
20
+ tar xvf data/fr-en.tgz -C data
21
+ fi
22
+ ${PY_BIN_ROOT}python get_bitext.py --bpe-vocab 16384 --byte-vocab --char-vocab
23
+ for VOCAB_SIZE in 2048 4096; do
24
+ ${PY_BIN_ROOT}python get_bitext.py --bpe-vocab ${VOCAB_SIZE} --bbpe-vocab ${VOCAB_SIZE}
25
+ done
26
+ rm -r data/fr-en data/fr-en.tgz
27
+
28
+ # Generate binary dataset
29
+ ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bpe16384 --joined-dictionary \
30
+ --workers "$(nproc)" --trainpref data/train.moses.bpe16384 --validpref data/valid.moses.bpe16384 \
31
+ --testpref data/test.moses.bpe16384
32
+
33
+ ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_bytes --joined-dictionary \
34
+ --workers "$(nproc)" --trainpref data/train.moses.bytes --validpref data/valid.moses.bytes \
35
+ --testpref data/test.moses.bytes
36
+
37
+ ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir data/bin_chars --joined-dictionary \
38
+ --workers "$(nproc)" --trainpref data/train.moses.chars --validpref data/valid.moses.chars \
39
+ --testpref data/test.moses.chars
40
+
41
+ for VOCAB_SIZE in 2048 4096; do
42
+ for TYPE in bbpe bpe; do
43
+ ${PY_BIN_ROOT}/fairseq-preprocess --source-lang fr --target-lang en --destdir "data/bin_${TYPE}${VOCAB_SIZE}" \
44
+ --joined-dictionary --workers "$(nproc)" --trainpref "data/train.moses.${TYPE}${VOCAB_SIZE}" \
45
+ --validpref "data/valid.moses.${TYPE}${VOCAB_SIZE}" --testpref "data/test.moses.${TYPE}${VOCAB_SIZE}"
46
+ done
47
+ done
fairseq/examples/byte_level_bpe/gru_transformer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # Copyright (c) Facebook, Inc. and its affiliates.
7
+ #
8
+ # This source code is licensed under the MIT license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from fairseq.models import register_model, register_model_architecture
14
+ from fairseq.models.transformer import TransformerEncoder, TransformerModel
15
+
16
+
17
+ @register_model("gru_transformer")
18
+ class GRUTransformerModel(TransformerModel):
19
+ @classmethod
20
+ def build_encoder(cls, args, src_dict, embed_tokens):
21
+ return GRUTransformerEncoder(args, src_dict, embed_tokens)
22
+
23
+
24
+ class GRUTransformerEncoder(TransformerEncoder):
25
+ def __init__(self, args, dictionary, embed_tokens):
26
+ super().__init__(args, dictionary, embed_tokens)
27
+ self.emb_ctx = nn.GRU(
28
+ input_size=embed_tokens.embedding_dim,
29
+ hidden_size=embed_tokens.embedding_dim // 2,
30
+ num_layers=1,
31
+ bidirectional=True,
32
+ )
33
+
34
+ def forward_embedding(self, src_tokens):
35
+ # embed tokens and positions
36
+ x = embed = self.embed_scale * self.embed_tokens(src_tokens)
37
+ if self.embed_positions is not None:
38
+ x = embed + self.embed_positions(src_tokens)
39
+
40
+ # contextualize embeddings
41
+ x = x.transpose(0, 1)
42
+ x = self.dropout_module(x)
43
+ x, _ = self.emb_ctx.forward(x)
44
+ x = x.transpose(0, 1)
45
+
46
+ if self.layernorm_embedding is not None:
47
+ x = self.layernorm_embedding(x)
48
+ x = self.dropout_module(x)
49
+ return x, embed
50
+
51
+
52
+ @register_model_architecture("gru_transformer", "gru_transformer")
53
+ def gru_transformer_base_architecture(args):
54
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
55
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
56
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
57
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
58
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
59
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
60
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
61
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
62
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
63
+ args.decoder_ffn_embed_dim = getattr(
64
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
65
+ )
66
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
67
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
68
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
69
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
70
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
71
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
72
+ args.activation_fn = getattr(args, "activation_fn", "relu")
73
+ args.dropout = getattr(args, "dropout", 0.1)
74
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
75
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
76
+ args.share_decoder_input_output_embed = getattr(
77
+ args, "share_decoder_input_output_embed", False
78
+ )
79
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
80
+ args.no_token_positional_embeddings = getattr(
81
+ args, "no_token_positional_embeddings", False
82
+ )
83
+ args.adaptive_input = getattr(args, "adaptive_input", False)
84
+ args.no_cross_attention = getattr(args, "no_cross_attention", False)
85
+ args.cross_self_attention = getattr(args, "cross_self_attention", False)
86
+ args.layer_wise_attention = getattr(args, "layer_wise_attention", False)
87
+
88
+ args.decoder_output_dim = getattr(
89
+ args, "decoder_output_dim", args.decoder_embed_dim
90
+ )
91
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
92
+
93
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
94
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
95
+
96
+
97
+ @register_model_architecture("gru_transformer", "gru_transformer_big")
98
+ def gru_transformer_big(args):
99
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
100
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
101
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
102
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
103
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
104
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
105
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
106
+ args.dropout = getattr(args, "dropout", 0.3)
107
+ gru_transformer_base_architecture(args)
fairseq/examples/camembert/README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CamemBERT: a Tasty French Language Model
2
+
3
+ ## Introduction
4
+
5
+ [CamemBERT](https://arxiv.org/abs/1911.03894) is a pretrained language model trained on 138GB of French text based on RoBERTa.
6
+
7
+ Also available in [github.com/huggingface/transformers](https://github.com/huggingface/transformers/).
8
+
9
+ ## Pre-trained models
10
+
11
+ | Model | #params | Download | Arch. | Training data |
12
+ |--------------------------------|---------|--------------------------------------------------------------------------------------------------------------------------|-------|-----------------------------------|
13
+ | `camembert` / `camembert-base` | 110M | [camembert-base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz) | Base | OSCAR (138 GB of text) |
14
+ | `camembert-large` | 335M | [camembert-large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz) | Large | CCNet (135 GB of text) |
15
+ | `camembert-base-ccnet` | 110M | [camembert-base-ccnet.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz) | Base | CCNet (135 GB of text) |
16
+ | `camembert-base-wikipedia-4gb` | 110M | [camembert-base-wikipedia-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz) | Base | Wikipedia (4 GB of text) |
17
+ | `camembert-base-oscar-4gb` | 110M | [camembert-base-oscar-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz) | Base | Subsample of OSCAR (4 GB of text) |
18
+ | `camembert-base-ccnet-4gb` | 110M | [camembert-base-ccnet-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz) | Base | Subsample of CCNet (4 GB of text) |
19
+
20
+ ## Example usage
21
+
22
+ ### fairseq
23
+ ##### Load CamemBERT from torch.hub (PyTorch >= 1.1):
24
+ ```python
25
+ import torch
26
+ camembert = torch.hub.load('pytorch/fairseq', 'camembert')
27
+ camembert.eval() # disable dropout (or leave in train mode to finetune)
28
+ ```
29
+
30
+ ##### Load CamemBERT (for PyTorch 1.0 or custom models):
31
+ ```python
32
+ # Download camembert model
33
+ wget https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz
34
+ tar -xzvf camembert.tar.gz
35
+
36
+ # Load the model in fairseq
37
+ from fairseq.models.roberta import CamembertModel
38
+ camembert = CamembertModel.from_pretrained('/path/to/camembert')
39
+ camembert.eval() # disable dropout (or leave in train mode to finetune)
40
+ ```
41
+
42
+ ##### Filling masks:
43
+ ```python
44
+ masked_line = 'Le camembert est <mask> :)'
45
+ camembert.fill_mask(masked_line, topk=3)
46
+ # [('Le camembert est délicieux :)', 0.4909118115901947, ' délicieux'),
47
+ # ('Le camembert est excellent :)', 0.10556942224502563, ' excellent'),
48
+ # ('Le camembert est succulent :)', 0.03453322499990463, ' succulent')]
49
+ ```
50
+
51
+ ##### Extract features from Camembert:
52
+ ```python
53
+ # Extract the last layer's features
54
+ line = "J'aime le camembert !"
55
+ tokens = camembert.encode(line)
56
+ last_layer_features = camembert.extract_features(tokens)
57
+ assert last_layer_features.size() == torch.Size([1, 10, 768])
58
+
59
+ # Extract all layer's features (layer 0 is the embedding layer)
60
+ all_layers = camembert.extract_features(tokens, return_all_hiddens=True)
61
+ assert len(all_layers) == 13
62
+ assert torch.all(all_layers[-1] == last_layer_features)
63
+ ```
64
+
65
+ ## Citation
66
+ If you use our work, please cite:
67
+
68
+ ```bibtex
69
+ @inproceedings{martin2020camembert,
70
+ title={CamemBERT: a Tasty French Language Model},
71
+ author={Martin, Louis and Muller, Benjamin and Su{\'a}rez, Pedro Javier Ortiz and Dupont, Yoann and Romary, Laurent and de la Clergerie, {\'E}ric Villemonte and Seddah, Djam{\'e} and Sagot, Beno{\^\i}t},
72
+ booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
73
+ year={2020}
74
+ }
75
+ ```
fairseq/examples/constrained_decoding/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # (Vectorized) Lexically constrained decoding with dynamic beam allocation
2
+
3
+ This page provides instructions for how to use lexically constrained decoding in Fairseq.
4
+ Fairseq implements the code described in the following papers:
5
+
6
+ * [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/) (Post & Vilar, 2018)
7
+ * [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/) (Hu et al., 2019)
8
+
9
+ ## Quick start
10
+
11
+ Constrained search is enabled by adding the command-line argument `--constraints` to `fairseq-interactive`.
12
+ Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens)
13
+ is a separate field.
14
+
15
+ The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/main/examples/wmt19/README.md),
16
+ translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints
17
+ "hard" and "to influence".
18
+
19
+ echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\ttoinfluence" \
20
+ | normalize.py | tok.py \
21
+ | fairseq-interactive /path/to/model \
22
+ --path /path/to/model/model1.pt \
23
+ --bpe fastbpe \
24
+ --bpe-codes /path/to/model/bpecodes \
25
+ --constraints \
26
+ -s de -t en \
27
+ --beam 10
28
+
29
+ (tok.py and normalize.py can be found in the same directory as this README; they are just shortcuts around Fairseq's WMT19 preprocessing).
30
+ This will generate the following output:
31
+
32
+ [snip]
33
+ S-0 Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren .
34
+ W-0 1.844 seconds
35
+ C-0 hard
36
+ C-0 influence
37
+ H-0 -1.5333266258239746 Mach@@ ine trans@@ lation is hard to influence .
38
+ D-0 -1.5333266258239746 Machine translation is hard to influence .
39
+ P-0 -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511
40
+
41
+ By default, constraints are generated in the order supplied, with any number (zero or more) of tokens generated
42
+ between constraints. If you wish for the decoder to order the constraints, then use `--constraints unordered`.
43
+ Note that you may want to use a larger beam.
44
+
45
+ ## Implementation details
46
+
47
+ The heart of the implementation is in `fairseq/search.py`, which adds a `LexicallyConstrainedBeamSearch` instance.
48
+ This instance of beam search tracks the progress of each hypothesis in the beam through the set of constraints
49
+ provided for each input sentence. It does this using one of two classes, both found in `fairseq/token_generation_contstraints.py`:
50
+
51
+ * OrderedConstraintState: assumes the `C` input constraints will be generated in the provided order
52
+ * UnorderedConstraintState: tries to apply `C` (phrasal) constraints in all `C!` orders
53
+
54
+ ## Differences from Sockeye
55
+
56
+ There are a number of [differences from Sockeye's implementation](https://awslabs.github.io/sockeye/inference.html#lexical-constraints).
57
+
58
+ * Generating constraints in the order supplied (the default option here) is not available in Sockeye.
59
+ * Due to an improved beam allocation method, there is no need to prune the beam.
60
+ * Again due to better allocation, beam sizes as low as 10 or even 5 are often sufficient.
61
+ * [The vector extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged
62
+ into the main Sockeye branch.
63
+
64
+ ## Citation
65
+
66
+ The paper first describing lexical constraints for seq2seq decoding is:
67
+
68
+ ```bibtex
69
+ @inproceedings{hokamp-liu-2017-lexically,
70
+ title = "Lexically Constrained Decoding for Sequence Generation Using Grid Beam Search",
71
+ author = "Hokamp, Chris and
72
+ Liu, Qun",
73
+ booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
74
+ month = jul,
75
+ year = "2017",
76
+ address = "Vancouver, Canada",
77
+ publisher = "Association for Computational Linguistics",
78
+ url = "https://www.aclweb.org/anthology/P17-1141",
79
+ doi = "10.18653/v1/P17-1141",
80
+ pages = "1535--1546",
81
+ }
82
+ ```
83
+
84
+ The fairseq implementation uses the extensions described in
85
+
86
+ ```bibtex
87
+ @inproceedings{post-vilar-2018-fast,
88
+ title = "Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation",
89
+ author = "Post, Matt and
90
+ Vilar, David",
91
+ booktitle = "Proceedings of the 2018 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers)",
92
+ month = jun,
93
+ year = "2018",
94
+ address = "New Orleans, Louisiana",
95
+ publisher = "Association for Computational Linguistics",
96
+ url = "https://www.aclweb.org/anthology/N18-1119",
97
+ doi = "10.18653/v1/N18-1119",
98
+ pages = "1314--1324",
99
+ }
100
+ ```
101
+
102
+ and
103
+
104
+ ```bibtex
105
+ @inproceedings{hu-etal-2019-improved,
106
+ title = "Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting",
107
+ author = "Hu, J. Edward and
108
+ Khayrallah, Huda and
109
+ Culkin, Ryan and
110
+ Xia, Patrick and
111
+ Chen, Tongfei and
112
+ Post, Matt and
113
+ Van Durme, Benjamin",
114
+ booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)",
115
+ month = jun,
116
+ year = "2019",
117
+ address = "Minneapolis, Minnesota",
118
+ publisher = "Association for Computational Linguistics",
119
+ url = "https://www.aclweb.org/anthology/N19-1090",
120
+ doi = "10.18653/v1/N19-1090",
121
+ pages = "839--850",
122
+ }
123
+ ```
fairseq/examples/constrained_decoding/normalize.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import sys
9
+
10
+ from sacremoses.normalize import MosesPunctNormalizer
11
+
12
+
13
+ def main(args):
14
+ normalizer = MosesPunctNormalizer(lang=args.lang, penn=args.penn)
15
+ for line in sys.stdin:
16
+ print(normalizer.normalize(line.rstrip()), flush=True)
17
+
18
+
19
+ if __name__ == "__main__":
20
+ import argparse
21
+
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--lang", "-l", default="en")
24
+ parser.add_argument("--penn", "-p", action="store_true")
25
+ args = parser.parse_args()
26
+
27
+ main(args)
fairseq/examples/constrained_decoding/tok.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import sys
9
+
10
+ import sacremoses
11
+
12
+
13
+ def main(args):
14
+ """Tokenizes, preserving tabs"""
15
+ mt = sacremoses.MosesTokenizer(lang=args.lang)
16
+
17
+ def tok(s):
18
+ return mt.tokenize(s, return_str=True)
19
+
20
+ for line in sys.stdin:
21
+ parts = list(map(tok, line.split("\t")))
22
+ print(*parts, sep="\t", flush=True)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ import argparse
27
+
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--lang", "-l", default="en")
30
+ parser.add_argument("--penn", "-p", action="store_true")
31
+ parser.add_argument("--fields", "-f", help="fields to tokenize")
32
+ args = parser.parse_args()
33
+
34
+ main(args)
fairseq/examples/conv_seq2seq/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convolutional Sequence to Sequence Learning (Gehring et al., 2017)
2
+
3
+ ## Pre-trained models
4
+
5
+ Description | Dataset | Model | Test set(s)
6
+ ---|---|---|---
7
+ Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
8
+ Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
9
+ Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
10
+
11
+ ## Example usage
12
+
13
+ See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and
14
+ WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures.
15
+
16
+ ## Citation
17
+
18
+ ```bibtex
19
+ @inproceedings{gehring2017convs2s,
20
+ title = {Convolutional Sequence to Sequence Learning},
21
+ author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N},
22
+ booktitle = {Proc. of ICML},
23
+ year = 2017,
24
+ }
25
+ ```
fairseq/examples/criss/README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cross-lingual Retrieval for Iterative Self-Supervised Training
2
+
3
+ https://arxiv.org/pdf/2006.09526.pdf
4
+
5
+ ## Introduction
6
+
7
+ CRISS is a multilingual sequence-to-sequnce pretraining method where mining and training processes are applied iteratively, improving cross-lingual alignment and translation ability at the same time.
8
+
9
+ ## Requirements:
10
+
11
+ * faiss: https://github.com/facebookresearch/faiss
12
+ * mosesdecoder: https://github.com/moses-smt/mosesdecoder
13
+ * flores: https://github.com/facebookresearch/flores
14
+ * LASER: https://github.com/facebookresearch/LASER
15
+
16
+ ## Unsupervised Machine Translation
17
+ ##### 1. Download and decompress CRISS checkpoints
18
+ ```
19
+ cd examples/criss
20
+ wget https://dl.fbaipublicfiles.com/criss/criss_3rd_checkpoints.tar.gz
21
+ tar -xf criss_checkpoints.tar.gz
22
+ ```
23
+ ##### 2. Download and preprocess Flores test dataset
24
+ Make sure to run all scripts from examples/criss directory
25
+ ```
26
+ bash download_and_preprocess_flores_test.sh
27
+ ```
28
+
29
+ ##### 3. Run Evaluation on Sinhala-English
30
+ ```
31
+ bash unsupervised_mt/eval.sh
32
+ ```
33
+
34
+ ## Sentence Retrieval
35
+ ##### 1. Download and preprocess Tatoeba dataset
36
+ ```
37
+ bash download_and_preprocess_tatoeba.sh
38
+ ```
39
+
40
+ ##### 2. Run Sentence Retrieval on Tatoeba Kazakh-English
41
+ ```
42
+ bash sentence_retrieval/sentence_retrieval_tatoeba.sh
43
+ ```
44
+
45
+ ## Mining
46
+ ##### 1. Install faiss
47
+ Follow instructions on https://github.com/facebookresearch/faiss/blob/master/INSTALL.md
48
+ ##### 2. Mine pseudo-parallel data between Kazakh and English
49
+ ```
50
+ bash mining/mine_example.sh
51
+ ```
52
+
53
+ ## Citation
54
+ ```bibtex
55
+ @article{tran2020cross,
56
+ title={Cross-lingual retrieval for iterative self-supervised training},
57
+ author={Tran, Chau and Tang, Yuqing and Li, Xian and Gu, Jiatao},
58
+ journal={arXiv preprint arXiv:2006.09526},
59
+ year={2020}
60
+ }
61
+ ```
fairseq/examples/criss/download_and_preprocess_flores_test.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ SPM_ENCODE=flores/scripts/spm_encode.py
9
+ DATA=data_tmp
10
+ SPM_MODEL=criss_checkpoints/sentence.bpe.model
11
+ DICT=criss_checkpoints/dict.txt
12
+
13
+ download_data() {
14
+ CORPORA=$1
15
+ URL=$2
16
+
17
+ if [ -f $CORPORA ]; then
18
+ echo "$CORPORA already exists, skipping download"
19
+ else
20
+ echo "Downloading $URL"
21
+ wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA
22
+ if [ -f $CORPORA ]; then
23
+ echo "$URL successfully downloaded."
24
+ else
25
+ echo "$URL not successfully downloaded."
26
+ rm -f $CORPORA
27
+ fi
28
+ fi
29
+ }
30
+
31
+ if [[ -f flores ]]; then
32
+ echo "flores already cloned"
33
+ else
34
+ git clone https://github.com/facebookresearch/flores
35
+ fi
36
+
37
+ mkdir -p $DATA
38
+ download_data $DATA/wikipedia_en_ne_si_test_sets.tgz "https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz"
39
+ pushd $DATA
40
+ pwd
41
+ tar -vxf wikipedia_en_ne_si_test_sets.tgz
42
+ popd
43
+
44
+
45
+ for lang in ne_NP si_LK; do
46
+ datadir=$DATA/${lang}-en_XX-flores
47
+ rm -rf $datadir
48
+ mkdir -p $datadir
49
+ TEST_PREFIX=$DATA/wikipedia_en_ne_si_test_sets/wikipedia.test
50
+ python $SPM_ENCODE \
51
+ --model ${SPM_MODEL} \
52
+ --output_format=piece \
53
+ --inputs ${TEST_PREFIX}.${lang:0:2}-en.${lang:0:2} ${TEST_PREFIX}.${lang:0:2}-en.en \
54
+ --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
55
+
56
+ # binarize data
57
+ fairseq-preprocess \
58
+ --source-lang ${lang} --target-lang en_XX \
59
+ --testpref $datadir/test.bpe.${lang}-en_XX \
60
+ --destdir $datadir \
61
+ --srcdict ${DICT} \
62
+ --joined-dictionary \
63
+ --workers 4
64
+ done
fairseq/examples/criss/download_and_preprocess_tatoeba.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ SPM_ENCODE=flores/scripts/spm_encode.py
9
+ DATA=data_tmp
10
+ SPM_MODEL=criss_checkpoints/sentence.bpe.model
11
+ DICT=criss_checkpoints/dict.txt
12
+
13
+ if [[ -f flores ]]; then
14
+ echo "flores already cloned"
15
+ else
16
+ git clone https://github.com/facebookresearch/flores
17
+ fi
18
+ if [[ -f LASER ]]; then
19
+ echo "LASER already cloned"
20
+ else
21
+ git clone https://github.com/facebookresearch/LASER
22
+ fi
23
+ mkdir -p data_tmp
24
+ declare -A lang_tatoeba_map=( ["ar_AR"]="ara" ["de_DE"]="deu" ["es_XX"]="spa" ["et_EE"]="est" ["fi_FI"]="fin" ["fr_XX"]="fra" ["hi_IN"]="hin" ["it_IT"]="ita" ["ja_XX"]="jpn" ["ko_KR"]="kor" ["kk_KZ"]="kaz" ["nl_XX"]="nld" ["ru_RU"]="rus" ["tr_TR"]="tur" ["vi_VN"]="vie" ["zh_CN"]="cmn")
25
+ for lang in ar_AR de_DE es_XX et_EE fi_FI fr_XX hi_IN it_IT ja_XX kk_KZ ko_KR nl_XX ru_RU tr_TR vi_VN zh_CN; do
26
+ lang_tatoeba=${lang_tatoeba_map[$lang]}
27
+ echo $lang_tatoeba
28
+ datadir=$DATA/${lang}-en_XX-tatoeba
29
+ rm -rf $datadir
30
+ mkdir -p $datadir
31
+ TEST_PREFIX=LASER/data/tatoeba/v1/tatoeba
32
+ python $SPM_ENCODE \
33
+ --model ${SPM_MODEL} \
34
+ --output_format=piece \
35
+ --inputs ${TEST_PREFIX}.${lang_tatoeba}-eng.${lang_tatoeba} ${TEST_PREFIX}.${lang_tatoeba}-eng.eng \
36
+ --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
37
+
38
+ # binarize data
39
+ fairseq-preprocess \
40
+ --source-lang ${lang} --target-lang en_XX \
41
+ --testpref $datadir/test.bpe.${lang}-en_XX \
42
+ --destdir $datadir \
43
+ --srcdict ${DICT} \
44
+ --joined-dictionary \
45
+ --workers 4
46
+ done
fairseq/examples/criss/mining/mine.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import argparse
7
+ import glob
8
+ from subprocess import check_call
9
+
10
+ try:
11
+ import faiss
12
+
13
+ has_faiss = True
14
+ except ImportError:
15
+ has_faiss = False
16
+ import numpy as np
17
+
18
+
19
+ GB = 1024 * 1024 * 1024
20
+
21
+
22
+ def call(cmd):
23
+ print(cmd)
24
+ check_call(cmd, shell=True)
25
+
26
+
27
+ def get_batches(directory, lang, prefix="all_avg_pool"):
28
+ print(f"Finding in {directory}/{prefix}.{lang}*")
29
+ files = glob.glob(f"{directory}/{prefix}.{lang}*")
30
+ emb_files = []
31
+ txt_files = []
32
+ for emb_fi in files:
33
+ emb_files.append(emb_fi)
34
+ txt_fi = emb_fi.replace(prefix, "sentences")
35
+ txt_files.append(txt_fi)
36
+ return emb_files, txt_files
37
+
38
+
39
+ def load_batch(emb_file, dim):
40
+ embeddings = np.fromfile(emb_file, dtype=np.float32)
41
+ num_rows = int(embeddings.shape[0] / dim)
42
+ embeddings = embeddings.reshape((num_rows, dim))
43
+ faiss.normalize_L2(embeddings)
44
+ return embeddings
45
+
46
+
47
+ def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
48
+ if not has_faiss:
49
+ raise ImportError("Please install Faiss")
50
+ sims = []
51
+ inds = []
52
+ xfrom = 0
53
+ xto = 0
54
+ for x_batch_f in x_batches_f:
55
+ yfrom = 0
56
+ yto = 0
57
+ x_batch = load_batch(x_batch_f, dim)
58
+ xto = xfrom + x_batch.shape[0]
59
+ bsims, binds = [], []
60
+ for y_batch_f in y_batches_f:
61
+ y_batch = load_batch(y_batch_f, dim)
62
+ neighbor_size = min(k, y_batch.shape[0])
63
+ yto = yfrom + y_batch.shape[0]
64
+ print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
65
+ idx = faiss.IndexFlatIP(dim)
66
+ idx = faiss.index_cpu_to_all_gpus(idx)
67
+ idx.add(y_batch)
68
+ bsim, bind = idx.search(x_batch, neighbor_size)
69
+
70
+ bsims.append(bsim)
71
+ binds.append(bind + yfrom)
72
+ yfrom += y_batch.shape[0]
73
+ del idx
74
+ del y_batch
75
+ bsims = np.concatenate(bsims, axis=1)
76
+ binds = np.concatenate(binds, axis=1)
77
+ aux = np.argsort(-bsims, axis=1)
78
+ sim_batch = np.zeros((x_batch.shape[0], k), dtype=np.float32)
79
+ ind_batch = np.zeros((x_batch.shape[0], k), dtype=np.int64)
80
+ for i in range(x_batch.shape[0]):
81
+ for j in range(k):
82
+ sim_batch[i, j] = bsims[i, aux[i, j]]
83
+ ind_batch[i, j] = binds[i, aux[i, j]]
84
+ sims.append(sim_batch)
85
+ inds.append(ind_batch)
86
+ xfrom += x_batch.shape[0]
87
+ del x_batch
88
+ sim = np.concatenate(sims, axis=0)
89
+ ind = np.concatenate(inds, axis=0)
90
+ return sim, ind
91
+
92
+
93
+ def score(sim, fwd_mean, bwd_mean, margin):
94
+ return margin(sim, (fwd_mean + bwd_mean) / 2)
95
+
96
+
97
+ def score_candidates(
98
+ sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
99
+ ):
100
+ print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
101
+ scores = np.zeros(candidate_inds.shape)
102
+ for i in range(scores.shape[0]):
103
+ for j in range(scores.shape[1]):
104
+ k = int(candidate_inds[i, j])
105
+ scores[i, j] = score(sim_mat[i, j], fwd_mean[i], bwd_mean[k], margin)
106
+ return scores
107
+
108
+
109
+ def load_text(files):
110
+ all_sentences = []
111
+ for fi in files:
112
+ with open(fi) as sentence_fi:
113
+ for line in sentence_fi:
114
+ all_sentences.append(line.strip())
115
+ print(f"Read {len(all_sentences)} sentences")
116
+ return all_sentences
117
+
118
+
119
+ if __name__ == "__main__":
120
+ parser = argparse.ArgumentParser(description="Mine bitext")
121
+ parser.add_argument("--src-lang", help="Source language")
122
+ parser.add_argument("--tgt-lang", help="Target language")
123
+ parser.add_argument(
124
+ "--dict-path", help="Path to dictionary file", default="dict.txt"
125
+ )
126
+ parser.add_argument(
127
+ "--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
128
+ )
129
+ parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
130
+ parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
131
+ parser.add_argument("--src-dir", help="Source directory")
132
+ parser.add_argument("--tgt-dir", help="Target directory")
133
+ parser.add_argument("--output", help="Output path")
134
+ parser.add_argument(
135
+ "--neighborhood", type=int, default=4, help="Embedding dimension"
136
+ )
137
+ parser.add_argument(
138
+ "--threshold", type=float, default=1.06, help="Threshold on mined bitext"
139
+ )
140
+ parser.add_argument(
141
+ "--valid-size",
142
+ type=int,
143
+ default=2000,
144
+ help="Number of sentences used for validation set",
145
+ )
146
+ parser.add_argument(
147
+ "--min-count",
148
+ type=int,
149
+ default=50000,
150
+ help="Min num sentences used for each language",
151
+ )
152
+ args = parser.parse_args()
153
+
154
+ x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
155
+ y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
156
+ margin = lambda a, b: a / b
157
+ y2x_sim, y2x_ind = knnGPU_sharded(
158
+ y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
159
+ )
160
+ x2y_sim, x2y_ind = knnGPU_sharded(
161
+ x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
162
+ )
163
+
164
+ x2y_mean = x2y_sim.mean(axis=1)
165
+ y2x_mean = y2x_sim.mean(axis=1)
166
+ fwd_scores = score_candidates(x2y_sim, x2y_ind, x2y_mean, y2x_mean, margin)
167
+ bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
168
+ fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
169
+ bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
170
+ indices = np.stack(
171
+ (
172
+ np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
173
+ np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
174
+ ),
175
+ axis=1,
176
+ )
177
+ scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
178
+
179
+ x_sentences = load_text(x_sents_f)
180
+ y_sentences = load_text(y_sents_f)
181
+
182
+ threshold = args.threshold
183
+ min_count = args.min_count
184
+ seen_src, seen_trg = set(), set()
185
+ directory = args.output
186
+ call(f"mkdir -p {directory}")
187
+ src_out = open(
188
+ f"{directory}/all.{args.src_lang}",
189
+ mode="w",
190
+ encoding="utf-8",
191
+ errors="surrogateescape",
192
+ )
193
+ tgt_out = open(
194
+ f"{directory}/all.{args.tgt_lang}",
195
+ mode="w",
196
+ encoding="utf-8",
197
+ errors="surrogateescape",
198
+ )
199
+ scores_out = open(
200
+ f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
201
+ )
202
+ count = 0
203
+ for i in np.argsort(-scores):
204
+ src_ind, trg_ind = indices[i]
205
+ if src_ind not in seen_src and trg_ind not in seen_trg:
206
+ seen_src.add(src_ind)
207
+ seen_trg.add(trg_ind)
208
+ if scores[i] > threshold or count < min_count:
209
+ if x_sentences[src_ind]:
210
+ print(scores[i], file=scores_out)
211
+ print(x_sentences[src_ind], file=src_out)
212
+ print(y_sentences[trg_ind], file=tgt_out)
213
+ count += 1
214
+ else:
215
+ print(f"Ignoring sentence: {x_sentences[src_ind]}")
216
+ src_out.close()
217
+ tgt_out.close()
218
+ scores_out.close()
219
+
220
+ print(f"Found {count} pairs for threshold={threshold}")
221
+ with open(f"{directory}/all.{args.src_lang}") as all_s, open(
222
+ f"{directory}/all.{args.tgt_lang}"
223
+ ) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
224
+ f"{directory}/valid.{args.tgt_lang}", "w"
225
+ ) as valid_t, open(
226
+ f"{directory}/train.{args.src_lang}", "w"
227
+ ) as train_s, open(
228
+ f"{directory}/train.{args.tgt_lang}", "w"
229
+ ) as train_t:
230
+ count = 0
231
+ for s_line, t_line in zip(all_s, all_t):
232
+ s_line = s_line.split("\t")[1]
233
+ t_line = t_line.split("\t")[1]
234
+ if count >= args.valid_size:
235
+ train_s.write(s_line)
236
+ train_t.write(t_line)
237
+ else:
238
+ valid_s.write(s_line)
239
+ valid_t.write(t_line)
240
+ count += 1
fairseq/examples/criss/mining/mine_example.sh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ source_lang=kk_KZ
9
+ target_lang=en_XX
10
+ MODEL=criss_checkpoints/criss.3rd.pt
11
+ SPM=criss_checkpoints/sentence.bpe.model
12
+ SPLIT=test
13
+ LANG_DICT=criss_checkpoints/lang_dict.txt
14
+ SPM_ENCODE=flores/scripts/spm_encode.py
15
+ SAVE_ENCODER=save_encoder.py
16
+ ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
17
+ DICT=criss_checkpoints/dict.txt
18
+ THRESHOLD=1.02
19
+ MIN_COUNT=500
20
+
21
+ DATA_DIR=data_tmp
22
+ SAVE_DIR=mining/${source_lang}_${target_lang}_mined
23
+ ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
24
+ INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
25
+
26
+ mkdir -p $ENCODER_SAVE_DIR/${target_lang}
27
+ mkdir -p $ENCODER_SAVE_DIR/${source_lang}
28
+ mkdir -p $SAVE_DIR
29
+
30
+ ## Save encoder outputs
31
+
32
+ # Save encoder outputs for source sentences
33
+ python $SAVE_ENCODER \
34
+ ${INPUT_DIR} \
35
+ --path ${MODEL} \
36
+ --task translation_multi_simple_epoch \
37
+ --lang-pairs ${source_lang}-${target_lang} \
38
+ --lang-dict ${LANG_DICT} \
39
+ --gen-subset ${SPLIT} \
40
+ --bpe 'sentencepiece' \
41
+ -s ${source_lang} -t ${target_lang} \
42
+ --sentencepiece-model ${SPM} \
43
+ --remove-bpe 'sentencepiece' \
44
+ --beam 1 \
45
+ --lang-tok-style mbart \
46
+ --encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
47
+
48
+ ## Save encoder outputs for target sentences
49
+ python $SAVE_ENCODER \
50
+ ${INPUT_DIR} \
51
+ --path ${MODEL} \
52
+ --lang-pairs ${source_lang}-${target_lang} \
53
+ --lang-dict ${LANG_DICT} \
54
+ --task translation_multi_simple_epoch \
55
+ --gen-subset ${SPLIT} \
56
+ --bpe 'sentencepiece' \
57
+ -t ${source_lang} -s ${target_lang} \
58
+ --sentencepiece-model ${SPM} \
59
+ --remove-bpe 'sentencepiece' \
60
+ --beam 1 \
61
+ --lang-tok-style mbart \
62
+ --encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
63
+
64
+ ## Mining
65
+ python mining/mine.py \
66
+ --src-lang ${source_lang} \
67
+ --tgt-lang ${target_lang} \
68
+ --dim 1024 \
69
+ --mem 10 \
70
+ --neighborhood 4 \
71
+ --src-dir ${ENCODER_SAVE_DIR}/${source_lang} \
72
+ --tgt-dir ${ENCODER_SAVE_DIR}/${target_lang} \
73
+ --output $SAVE_DIR \
74
+ --threshold ${THRESHOLD} \
75
+ --min-count ${MIN_COUNT} \
76
+ --valid-size 100 \
77
+ --dict-path ${DICT} \
78
+ --spm-path ${SPM} \
79
+
80
+
81
+ ## Process and binarize mined data
82
+ python $SPM_ENCODE \
83
+ --model ${SPM} \
84
+ --output_format=piece \
85
+ --inputs mining/${source_lang}_${target_lang}_mined/train.${source_lang} mining/${source_lang}_${target_lang}_mined/train.${target_lang} \
86
+ --outputs mining/${source_lang}_${target_lang}_mined/train.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/train.bpe.${target_lang}
87
+
88
+ python $SPM_ENCODE \
89
+ --model ${SPM} \
90
+ --output_format=piece \
91
+ --inputs mining/${source_lang}_${target_lang}_mined/valid.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.${target_lang} \
92
+ --outputs mining/${source_lang}_${target_lang}_mined/valid.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.bpe.${target_lang}
93
+
94
+
95
+ fairseq-preprocess \
96
+ --source-lang ${source_lang} \
97
+ --target-lang ${target_lang} \
98
+ --trainpref mining/${source_lang}_${target_lang}_mined/train.bpe \
99
+ --validpref mining/${source_lang}_${target_lang}_mined/valid.bpe \
100
+ --destdir mining/${source_lang}_${target_lang}_mined \
101
+ --srcdict ${DICT} \
102
+ --joined-dictionary \
103
+ --workers 8
fairseq/examples/criss/save_encoder.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Translate pre-processed data with a trained model.
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
13
+ from fairseq.sequence_generator import EnsembleModel
14
+ from fairseq.utils import safe_hasattr
15
+
16
+
17
+ def get_avg_pool(
18
+ models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
19
+ ):
20
+ model = EnsembleModel(models)
21
+
22
+ # model.forward normally channels prev_output_tokens into the decoder
23
+ # separately, but SequenceGenerator directly calls model.encoder
24
+ encoder_input = {
25
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
26
+ }
27
+
28
+ # compute the encoder output for each beam
29
+ encoder_outs = model.forward_encoder(encoder_input)
30
+ np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
31
+ encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
32
+ np.float32
33
+ )
34
+ encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
35
+ if has_langtok:
36
+ encoder_mask = encoder_mask[1:, :, :]
37
+ np_encoder_outs = np_encoder_outs[1, :, :]
38
+ masked_encoder_outs = encoder_mask * np_encoder_outs
39
+ avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0)
40
+ return avg_pool
41
+
42
+
43
+ def main(args):
44
+ assert args.path is not None, "--path required for generation!"
45
+ assert (
46
+ not args.sampling or args.nbest == args.beam
47
+ ), "--sampling requires --nbest to be equal to --beam"
48
+ assert (
49
+ args.replace_unk is None or args.raw_text
50
+ ), "--replace-unk requires a raw text dataset (--raw-text)"
51
+
52
+ args.beam = 1
53
+ utils.import_user_module(args)
54
+
55
+ if args.max_tokens is None:
56
+ args.max_tokens = 12000
57
+ print(args)
58
+ use_cuda = torch.cuda.is_available() and not args.cpu
59
+
60
+ # Load dataset splits
61
+ task = tasks.setup_task(args)
62
+ task.load_dataset(args.gen_subset)
63
+
64
+ # Set dictionaries
65
+ try:
66
+ src_dict = getattr(task, "source_dictionary", None)
67
+ except NotImplementedError:
68
+ src_dict = None
69
+ tgt_dict = task.target_dictionary
70
+
71
+ # Load ensemble
72
+ print("| loading model(s) from {}".format(args.path))
73
+ models, _model_args = checkpoint_utils.load_model_ensemble(
74
+ args.path.split(":"),
75
+ arg_overrides=eval(args.model_overrides),
76
+ task=task,
77
+ )
78
+
79
+ # Optimize ensemble for generation
80
+ for model in models:
81
+ model.make_generation_fast_(
82
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
83
+ need_attn=args.print_alignment,
84
+ )
85
+ if args.fp16:
86
+ model.half()
87
+ if use_cuda:
88
+ model.cuda()
89
+
90
+ # Load alignment dictionary for unknown word replacement
91
+ # (None if no unknown word replacement, empty if no path to align dictionary)
92
+ align_dict = utils.load_align_dict(args.replace_unk)
93
+
94
+ # Load dataset (possibly sharded)
95
+ itr = task.get_batch_iterator(
96
+ dataset=task.dataset(args.gen_subset),
97
+ max_tokens=args.max_tokens,
98
+ max_positions=utils.resolve_max_positions(
99
+ task.max_positions(),
100
+ ),
101
+ ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
102
+ required_batch_size_multiple=args.required_batch_size_multiple,
103
+ num_shards=args.num_shards,
104
+ shard_id=args.shard_id,
105
+ num_workers=args.num_workers,
106
+ ).next_epoch_itr(shuffle=False)
107
+
108
+ num_sentences = 0
109
+ source_sentences = []
110
+ shard_id = 0
111
+ all_avg_pool = None
112
+ encoder_has_langtok = (
113
+ safe_hasattr(task.args, "encoder_langtok")
114
+ and task.args.encoder_langtok is not None
115
+ and safe_hasattr(task.args, "lang_tok_replacing_bos_eos")
116
+ and not task.args.lang_tok_replacing_bos_eos
117
+ )
118
+ with progress_bar.build_progress_bar(args, itr) as t:
119
+ for sample in t:
120
+ if sample is None:
121
+ print("Skipping None")
122
+ continue
123
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
124
+ if "net_input" not in sample:
125
+ continue
126
+
127
+ prefix_tokens = None
128
+ if args.prefix_size > 0:
129
+ prefix_tokens = sample["target"][:, : args.prefix_size]
130
+
131
+ with torch.no_grad():
132
+ avg_pool = get_avg_pool(
133
+ models,
134
+ sample,
135
+ prefix_tokens,
136
+ src_dict,
137
+ args.post_process,
138
+ has_langtok=encoder_has_langtok,
139
+ )
140
+ if all_avg_pool is not None:
141
+ all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
142
+ else:
143
+ all_avg_pool = avg_pool
144
+
145
+ if not isinstance(sample["id"], list):
146
+ sample_ids = sample["id"].tolist()
147
+ else:
148
+ sample_ids = sample["id"]
149
+ for i, sample_id in enumerate(sample_ids):
150
+ # Remove padding
151
+ src_tokens = utils.strip_pad(
152
+ sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
153
+ )
154
+
155
+ # Either retrieve the original sentences or regenerate them from tokens.
156
+ if align_dict is not None:
157
+ src_str = task.dataset(args.gen_subset).src.get_original_text(
158
+ sample_id
159
+ )
160
+ else:
161
+ if src_dict is not None:
162
+ src_str = src_dict.string(src_tokens, args.post_process)
163
+ else:
164
+ src_str = ""
165
+
166
+ if not args.quiet:
167
+ if src_dict is not None:
168
+ print("S-{}\t{}".format(sample_id, src_str))
169
+
170
+ source_sentences.append(f"{sample_id}\t{src_str}")
171
+
172
+ num_sentences += sample["nsentences"]
173
+ if all_avg_pool.shape[0] >= 1000000:
174
+ with open(
175
+ f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
176
+ "w",
177
+ ) as avg_pool_file:
178
+ all_avg_pool.tofile(avg_pool_file)
179
+ with open(
180
+ f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
181
+ "w",
182
+ ) as sentence_file:
183
+ sentence_file.writelines(f"{line}\n" for line in source_sentences)
184
+ all_avg_pool = None
185
+ source_sentences = []
186
+ shard_id += 1
187
+
188
+ if all_avg_pool is not None:
189
+ with open(
190
+ f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
191
+ ) as avg_pool_file:
192
+ all_avg_pool.tofile(avg_pool_file)
193
+ with open(
194
+ f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
195
+ ) as sentence_file:
196
+ sentence_file.writelines(f"{line}\n" for line in source_sentences)
197
+ return None
198
+
199
+
200
+ def cli_main():
201
+ parser = options.get_generation_parser()
202
+ parser.add_argument(
203
+ "--encoder-save-dir",
204
+ default="",
205
+ type=str,
206
+ metavar="N",
207
+ help="directory to save encoder outputs",
208
+ )
209
+ args = options.parse_args_and_arch(parser)
210
+ main(args)
211
+
212
+
213
+ if __name__ == "__main__":
214
+ cli_main()