Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/fairseq.egg-info/PKG-INFO +283 -0
- fairseq/fairseq.egg-info/SOURCES.txt +1546 -0
- fairseq/fairseq.egg-info/entry_points.txt +9 -0
- fairseq/fairseq.egg-info/requires.txt +22 -0
- fairseq/fairseq.egg-info/top_level.txt +4 -0
- fairseq/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc +0 -0
- fairseq/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc +0 -0
- fairseq/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc +0 -0
- fairseq/fairseq/__pycache__/pdb.cpython-310.pyc +0 -0
- fairseq/fairseq_cli/__init__.py +0 -0
- fairseq/fairseq_cli/eval_lm.py +347 -0
- fairseq/fairseq_cli/generate.py +417 -0
- fairseq/fairseq_cli/hydra_train.py +91 -0
- fairseq/fairseq_cli/hydra_validate.py +188 -0
- fairseq/fairseq_cli/interactive.py +317 -0
- fairseq/fairseq_cli/preprocess.py +393 -0
- fairseq/fairseq_cli/score.py +102 -0
- fairseq/fairseq_cli/train.py +581 -0
- fairseq/fairseq_cli/validate.py +153 -0
- fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py +3 -0
- fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py +23 -0
- fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py +121 -0
- fairseq/hydra_plugins/dependency_submitit_launcher/setup.py +29 -0
- fairseq/scripts/__init__.py +0 -0
- fairseq/scripts/average_checkpoints.py +176 -0
- fairseq/scripts/build_sym_alignment.py +97 -0
- fairseq/scripts/check_installation.py +36 -0
- fairseq/scripts/compare_namespaces.py +46 -0
- fairseq/scripts/compound_split_bleu.sh +20 -0
- fairseq/scripts/constraints/extract.py +90 -0
- fairseq/scripts/constraints/validate.py +34 -0
- fairseq/scripts/convert_dictionary.lua +34 -0
- fairseq/scripts/convert_model.lua +108 -0
- fairseq/scripts/count_docs.py +58 -0
- fairseq/scripts/read_binarized.py +48 -0
- fairseq/scripts/rm_pt.py +141 -0
- fairseq/scripts/sacrebleu.sh +27 -0
- fairseq/scripts/shard_docs.py +54 -0
- fairseq/scripts/split_train_valid_docs.py +86 -0
- fairseq/scripts/spm_decode.py +53 -0
- fairseq/scripts/spm_encode.py +119 -0
- fairseq/scripts/spm_train.py +16 -0
- fairseq/scripts/test_fsdp.sh +24 -0
- fairseq/tests/__init__.py +0 -0
- fairseq/tests/tasks/test_masked_lm.py +78 -0
- fairseq/tests/tasks/test_span_masked_lm.py +106 -0
- fairseq/tests/test_activation_checkpointing.py +79 -0
- fairseq/tests/test_amp_optimizer.py +75 -0
- fairseq/tests/test_average_checkpoints.py +134 -0
- fairseq/tests/test_backtranslation_dataset.py +123 -0
fairseq/fairseq.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.2
|
2 |
+
Name: fairseq
|
3 |
+
Version: 0.12.2
|
4 |
+
Summary: Facebook AI Research Sequence-to-Sequence Toolkit
|
5 |
+
Home-page: https://github.com/pytorch/fairseq
|
6 |
+
Classifier: Intended Audience :: Science/Research
|
7 |
+
Classifier: License :: OSI Approved :: MIT License
|
8 |
+
Classifier: Programming Language :: Python :: 3.6
|
9 |
+
Classifier: Programming Language :: Python :: 3.7
|
10 |
+
Classifier: Programming Language :: Python :: 3.8
|
11 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
12 |
+
Description-Content-Type: text/markdown
|
13 |
+
License-File: LICENSE
|
14 |
+
Requires-Dist: cffi
|
15 |
+
Requires-Dist: cython
|
16 |
+
Requires-Dist: hydra-core<1.1,>=1.0.7
|
17 |
+
Requires-Dist: omegaconf<2.1
|
18 |
+
Requires-Dist: numpy>=1.21.3
|
19 |
+
Requires-Dist: regex
|
20 |
+
Requires-Dist: sacrebleu>=1.4.12
|
21 |
+
Requires-Dist: torch>=1.13
|
22 |
+
Requires-Dist: tqdm
|
23 |
+
Requires-Dist: bitarray
|
24 |
+
Requires-Dist: torchaudio>=0.8.0
|
25 |
+
Requires-Dist: scikit-learn
|
26 |
+
Requires-Dist: packaging
|
27 |
+
Provides-Extra: dev
|
28 |
+
Requires-Dist: flake8; extra == "dev"
|
29 |
+
Requires-Dist: pytest; extra == "dev"
|
30 |
+
Requires-Dist: black==22.3.0; extra == "dev"
|
31 |
+
Provides-Extra: docs
|
32 |
+
Requires-Dist: sphinx; extra == "docs"
|
33 |
+
Requires-Dist: sphinx-argparse; extra == "docs"
|
34 |
+
Dynamic: classifier
|
35 |
+
Dynamic: description
|
36 |
+
Dynamic: description-content-type
|
37 |
+
Dynamic: home-page
|
38 |
+
Dynamic: provides-extra
|
39 |
+
Dynamic: requires-dist
|
40 |
+
Dynamic: summary
|
41 |
+
|
42 |
+
<p align="center">
|
43 |
+
<img src="docs/fairseq_logo.png" width="150">
|
44 |
+
<br />
|
45 |
+
<br />
|
46 |
+
<a href="https://opensource.fb.com/support-ukraine"><img alt="Support Ukraine" src="https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB" /></a>
|
47 |
+
<a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
48 |
+
<a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
|
49 |
+
<a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
|
50 |
+
<a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
|
51 |
+
<a href="https://app.circleci.com/pipelines/github/facebookresearch/fairseq/"><img alt="CicleCI Status" src="https://circleci.com/gh/facebookresearch/fairseq.svg?style=shield" /></a>
|
52 |
+
</p>
|
53 |
+
|
54 |
+
--------------------------------------------------------------------------------
|
55 |
+
|
56 |
+
Fairseq(-py) is a sequence modeling toolkit that allows researchers and
|
57 |
+
developers to train custom models for translation, summarization, language
|
58 |
+
modeling and other text generation tasks.
|
59 |
+
|
60 |
+
We provide reference implementations of various sequence modeling papers:
|
61 |
+
|
62 |
+
<details><summary>List of implemented papers</summary><p>
|
63 |
+
|
64 |
+
* **Convolutional Neural Networks (CNN)**
|
65 |
+
+ [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
|
66 |
+
+ [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
67 |
+
+ [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
68 |
+
+ [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
69 |
+
+ [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
70 |
+
* **LightConv and DynamicConv models**
|
71 |
+
+ [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
72 |
+
* **Long Short-Term Memory (LSTM) networks**
|
73 |
+
+ Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
|
74 |
+
* **Transformer (self-attention) networks**
|
75 |
+
+ Attention Is All You Need (Vaswani et al., 2017)
|
76 |
+
+ [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
77 |
+
+ [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
78 |
+
+ [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
|
79 |
+
+ [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
|
80 |
+
+ [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
|
81 |
+
+ [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
|
82 |
+
+ [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
83 |
+
+ [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
84 |
+
+ [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
85 |
+
+ [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
|
86 |
+
+ [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
87 |
+
+ [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
88 |
+
+ [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
89 |
+
+ [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
90 |
+
+ [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
|
91 |
+
+ [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
|
92 |
+
+ [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
93 |
+
+ [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
|
94 |
+
+ [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
|
95 |
+
+ [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430)
|
96 |
+
+ [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
|
97 |
+
+ [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
|
98 |
+
+ [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680)
|
99 |
+
+ [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf)
|
100 |
+
+ [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf)
|
101 |
+
+ [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md)
|
102 |
+
* **Non-autoregressive Transformers**
|
103 |
+
+ Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
|
104 |
+
+ Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
|
105 |
+
+ Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
|
106 |
+
+ Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
|
107 |
+
+ [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
108 |
+
* **Finetuning**
|
109 |
+
+ [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
|
110 |
+
|
111 |
+
</p></details>
|
112 |
+
|
113 |
+
### What's New:
|
114 |
+
* May 2023 [Released models for Scaling Speech Technology to 1,000+ Languages (Pratap, et al., 2023)](examples/mms/README.md)
|
115 |
+
* June 2022 [Released code for wav2vec-U 2.0 from Towards End-to-end Unsupervised Speech Recognition (Liu, et al., 2022)](examples/wav2vec/unsupervised/README.md)
|
116 |
+
* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers)
|
117 |
+
* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md)
|
118 |
+
* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
|
119 |
+
* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)
|
120 |
+
* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
|
121 |
+
* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
|
122 |
+
* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
|
123 |
+
* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
|
124 |
+
* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
|
125 |
+
* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
|
126 |
+
* February 2021 [Added LASER training code](examples/laser/README.md)
|
127 |
+
* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
|
128 |
+
* December 2020: [GottBERT model and code released](examples/gottbert/README.md)
|
129 |
+
* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
|
130 |
+
* [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
|
131 |
+
* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
|
132 |
+
* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
|
133 |
+
* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
|
134 |
+
* October 2020: [Added CRISS models and code](examples/criss/README.md)
|
135 |
+
|
136 |
+
<details><summary>Previous updates</summary><p>
|
137 |
+
|
138 |
+
* September 2020: [Added Linformer code](examples/linformer/README.md)
|
139 |
+
* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
|
140 |
+
* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
|
141 |
+
* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
|
142 |
+
* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
|
143 |
+
* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
|
144 |
+
* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
|
145 |
+
* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
|
146 |
+
* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
|
147 |
+
* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
|
148 |
+
* February 2020: [mBART model and code released](examples/mbart/README.md)
|
149 |
+
* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
|
150 |
+
* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
|
151 |
+
* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
|
152 |
+
* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
|
153 |
+
* November 2019: [BART model and code released](examples/bart/README.md)
|
154 |
+
* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
|
155 |
+
* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
|
156 |
+
* August 2019: [WMT'19 models released](examples/wmt19/README.md)
|
157 |
+
* July 2019: fairseq relicensed under MIT license
|
158 |
+
* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
|
159 |
+
* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
|
160 |
+
|
161 |
+
</p></details>
|
162 |
+
|
163 |
+
### Features:
|
164 |
+
|
165 |
+
* multi-GPU training on one machine or across multiple machines (data and model parallel)
|
166 |
+
* fast generation on both CPU and GPU with multiple search algorithms implemented:
|
167 |
+
+ beam search
|
168 |
+
+ Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
|
169 |
+
+ sampling (unconstrained, top-k and top-p/nucleus)
|
170 |
+
+ [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
|
171 |
+
* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
|
172 |
+
* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
|
173 |
+
* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
|
174 |
+
* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
|
175 |
+
* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
|
176 |
+
* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
|
177 |
+
|
178 |
+
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
|
179 |
+
with a convenient `torch.hub` interface:
|
180 |
+
|
181 |
+
``` python
|
182 |
+
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
|
183 |
+
en2de.translate('Hello world', beam=5)
|
184 |
+
# 'Hallo Welt'
|
185 |
+
```
|
186 |
+
|
187 |
+
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
|
188 |
+
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
|
189 |
+
|
190 |
+
# Requirements and Installation
|
191 |
+
|
192 |
+
* [PyTorch](http://pytorch.org/) version >= 1.10.0
|
193 |
+
* Python version >= 3.8
|
194 |
+
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
|
195 |
+
* **To install fairseq** and develop locally:
|
196 |
+
|
197 |
+
``` bash
|
198 |
+
git clone https://github.com/pytorch/fairseq
|
199 |
+
cd fairseq
|
200 |
+
pip install --editable ./
|
201 |
+
|
202 |
+
# on MacOS:
|
203 |
+
# CFLAGS="-stdlib=libc++" pip install --editable ./
|
204 |
+
|
205 |
+
# to install the latest stable release (0.10.x)
|
206 |
+
# pip install fairseq
|
207 |
+
```
|
208 |
+
|
209 |
+
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
|
210 |
+
|
211 |
+
``` bash
|
212 |
+
git clone https://github.com/NVIDIA/apex
|
213 |
+
cd apex
|
214 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
|
215 |
+
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
|
216 |
+
--global-option="--fast_multihead_attn" ./
|
217 |
+
```
|
218 |
+
|
219 |
+
* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
|
220 |
+
* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
|
221 |
+
as command line options to `nvidia-docker run` .
|
222 |
+
|
223 |
+
# Getting Started
|
224 |
+
|
225 |
+
The [full documentation](https://fairseq.readthedocs.io/) contains instructions
|
226 |
+
for getting started, training new models and extending fairseq with new model
|
227 |
+
types and tasks.
|
228 |
+
|
229 |
+
# Pre-trained models and examples
|
230 |
+
|
231 |
+
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
|
232 |
+
as well as example training and evaluation commands.
|
233 |
+
|
234 |
+
* [Translation](examples/translation/README.md): convolutional and transformer models are available
|
235 |
+
* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
|
236 |
+
|
237 |
+
We also have more detailed READMEs to reproduce results from specific papers:
|
238 |
+
|
239 |
+
* [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md)
|
240 |
+
* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
241 |
+
* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
242 |
+
* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
243 |
+
* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
|
244 |
+
* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
245 |
+
* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
246 |
+
* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
|
247 |
+
* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
|
248 |
+
* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
249 |
+
* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
250 |
+
* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
251 |
+
* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
252 |
+
* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
253 |
+
* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
254 |
+
* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
255 |
+
* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
256 |
+
* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
257 |
+
* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
258 |
+
* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
259 |
+
* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
|
260 |
+
|
261 |
+
# Join the fairseq community
|
262 |
+
|
263 |
+
* Twitter: https://twitter.com/fairseq
|
264 |
+
* Facebook page: https://www.facebook.com/groups/fairseq.users
|
265 |
+
* Google group: https://groups.google.com/forum/#!forum/fairseq-users
|
266 |
+
|
267 |
+
# License
|
268 |
+
|
269 |
+
fairseq(-py) is MIT-licensed.
|
270 |
+
The license applies to the pre-trained models as well.
|
271 |
+
|
272 |
+
# Citation
|
273 |
+
|
274 |
+
Please cite as:
|
275 |
+
|
276 |
+
``` bibtex
|
277 |
+
@inproceedings{ott2019fairseq,
|
278 |
+
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
|
279 |
+
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
|
280 |
+
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
|
281 |
+
year = {2019},
|
282 |
+
}
|
283 |
+
```
|
fairseq/fairseq.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,1546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE
|
2 |
+
MANIFEST.in
|
3 |
+
README.md
|
4 |
+
pyproject.toml
|
5 |
+
setup.cfg
|
6 |
+
setup.py
|
7 |
+
examples/operators/alignment_train_cpu.cpp
|
8 |
+
examples/operators/alignment_train_cuda.cpp
|
9 |
+
examples/operators/alignment_train_kernel.cu
|
10 |
+
fairseq/__init__.py
|
11 |
+
fairseq/binarizer.py
|
12 |
+
fairseq/checkpoint_utils.py
|
13 |
+
fairseq/file_chunker_utils.py
|
14 |
+
fairseq/file_io.py
|
15 |
+
fairseq/file_utils.py
|
16 |
+
fairseq/hub_utils.py
|
17 |
+
fairseq/incremental_decoding_utils.py
|
18 |
+
fairseq/iterative_refinement_generator.py
|
19 |
+
fairseq/nan_detector.py
|
20 |
+
fairseq/ngram_repeat_block.py
|
21 |
+
fairseq/options.py
|
22 |
+
fairseq/pdb.py
|
23 |
+
fairseq/quantization_utils.py
|
24 |
+
fairseq/registry.py
|
25 |
+
fairseq/search.py
|
26 |
+
fairseq/sequence_generator.py
|
27 |
+
fairseq/sequence_scorer.py
|
28 |
+
fairseq/speech_generator.py
|
29 |
+
fairseq/token_generation_constraints.py
|
30 |
+
fairseq/tokenizer.py
|
31 |
+
fairseq/trainer.py
|
32 |
+
fairseq/utils.py
|
33 |
+
fairseq/version.py
|
34 |
+
fairseq/version.txt
|
35 |
+
fairseq.egg-info/PKG-INFO
|
36 |
+
fairseq.egg-info/SOURCES.txt
|
37 |
+
fairseq.egg-info/dependency_links.txt
|
38 |
+
fairseq.egg-info/entry_points.txt
|
39 |
+
fairseq.egg-info/not-zip-safe
|
40 |
+
fairseq.egg-info/requires.txt
|
41 |
+
fairseq.egg-info/top_level.txt
|
42 |
+
fairseq/benchmark/__init__.py
|
43 |
+
fairseq/benchmark/benchmark_multihead_attention.py
|
44 |
+
fairseq/benchmark/dummy_dataset.py
|
45 |
+
fairseq/benchmark/dummy_lm.py
|
46 |
+
fairseq/benchmark/dummy_masked_lm.py
|
47 |
+
fairseq/benchmark/dummy_model.py
|
48 |
+
fairseq/benchmark/dummy_mt.py
|
49 |
+
fairseq/clib/cuda/ngram_repeat_block_cuda.cpp
|
50 |
+
fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu
|
51 |
+
fairseq/clib/libbase/balanced_assignment.cpp
|
52 |
+
fairseq/clib/libbleu/libbleu.cpp
|
53 |
+
fairseq/clib/libbleu/module.cpp
|
54 |
+
fairseq/clib/libnat/edit_dist.cpp
|
55 |
+
fairseq/clib/libnat_cuda/binding.cpp
|
56 |
+
fairseq/clib/libnat_cuda/edit_dist.cu
|
57 |
+
fairseq/config/__init__.py
|
58 |
+
fairseq/config/config.yaml
|
59 |
+
fairseq/config/fb_run_config/slurm.yaml
|
60 |
+
fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml
|
61 |
+
fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml
|
62 |
+
fairseq/config/model/transformer_lm/transformer_lm_big.yaml
|
63 |
+
fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml
|
64 |
+
fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml
|
65 |
+
fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml
|
66 |
+
fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml
|
67 |
+
fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml
|
68 |
+
fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml
|
69 |
+
fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml
|
70 |
+
fairseq/config/model/wav2vec2/wav2vec2_base.yaml
|
71 |
+
fairseq/config/model/wav2vec2/wav2vec2_large.yaml
|
72 |
+
fairseq/criterions/__init__.py
|
73 |
+
fairseq/criterions/adaptive_loss.py
|
74 |
+
fairseq/criterions/composite_loss.py
|
75 |
+
fairseq/criterions/cross_entropy.py
|
76 |
+
fairseq/criterions/ctc.py
|
77 |
+
fairseq/criterions/fairseq_criterion.py
|
78 |
+
fairseq/criterions/fastspeech2_loss.py
|
79 |
+
fairseq/criterions/hubert_criterion.py
|
80 |
+
fairseq/criterions/label_smoothed_cross_entropy.py
|
81 |
+
fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py
|
82 |
+
fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
|
83 |
+
fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
|
84 |
+
fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py
|
85 |
+
fairseq/criterions/legacy_masked_lm.py
|
86 |
+
fairseq/criterions/masked_lm.py
|
87 |
+
fairseq/criterions/model_criterion.py
|
88 |
+
fairseq/criterions/nat_loss.py
|
89 |
+
fairseq/criterions/sentence_prediction.py
|
90 |
+
fairseq/criterions/sentence_prediction_adapters.py
|
91 |
+
fairseq/criterions/sentence_ranking.py
|
92 |
+
fairseq/criterions/speech_dlm_criterion.py
|
93 |
+
fairseq/criterions/speech_to_speech_criterion.py
|
94 |
+
fairseq/criterions/speech_ulm_criterion.py
|
95 |
+
fairseq/criterions/tacotron2_loss.py
|
96 |
+
fairseq/criterions/wav2vec_criterion.py
|
97 |
+
fairseq/data/__init__.py
|
98 |
+
fairseq/data/add_class_target_dataset.py
|
99 |
+
fairseq/data/add_target_dataset.py
|
100 |
+
fairseq/data/append_token_dataset.py
|
101 |
+
fairseq/data/backtranslation_dataset.py
|
102 |
+
fairseq/data/base_wrapper_dataset.py
|
103 |
+
fairseq/data/bucket_pad_length_dataset.py
|
104 |
+
fairseq/data/codedataset.py
|
105 |
+
fairseq/data/colorize_dataset.py
|
106 |
+
fairseq/data/concat_dataset.py
|
107 |
+
fairseq/data/concat_sentences_dataset.py
|
108 |
+
fairseq/data/data_utils.py
|
109 |
+
fairseq/data/data_utils_fast.pyx
|
110 |
+
fairseq/data/denoising_dataset.py
|
111 |
+
fairseq/data/dictionary.py
|
112 |
+
fairseq/data/fairseq_dataset.py
|
113 |
+
fairseq/data/fasta_dataset.py
|
114 |
+
fairseq/data/id_dataset.py
|
115 |
+
fairseq/data/indexed_dataset.py
|
116 |
+
fairseq/data/iterators.py
|
117 |
+
fairseq/data/language_pair_dataset.py
|
118 |
+
fairseq/data/list_dataset.py
|
119 |
+
fairseq/data/lm_context_window_dataset.py
|
120 |
+
fairseq/data/lru_cache_dataset.py
|
121 |
+
fairseq/data/mask_tokens_dataset.py
|
122 |
+
fairseq/data/monolingual_dataset.py
|
123 |
+
fairseq/data/multi_corpus_dataset.py
|
124 |
+
fairseq/data/multi_corpus_sampled_dataset.py
|
125 |
+
fairseq/data/nested_dictionary_dataset.py
|
126 |
+
fairseq/data/noising.py
|
127 |
+
fairseq/data/num_samples_dataset.py
|
128 |
+
fairseq/data/numel_dataset.py
|
129 |
+
fairseq/data/offset_tokens_dataset.py
|
130 |
+
fairseq/data/pad_dataset.py
|
131 |
+
fairseq/data/padding_mask_dataset.py
|
132 |
+
fairseq/data/plasma_utils.py
|
133 |
+
fairseq/data/prepend_dataset.py
|
134 |
+
fairseq/data/prepend_token_dataset.py
|
135 |
+
fairseq/data/raw_label_dataset.py
|
136 |
+
fairseq/data/replace_dataset.py
|
137 |
+
fairseq/data/resampling_dataset.py
|
138 |
+
fairseq/data/roll_dataset.py
|
139 |
+
fairseq/data/round_robin_zip_datasets.py
|
140 |
+
fairseq/data/shorten_dataset.py
|
141 |
+
fairseq/data/sort_dataset.py
|
142 |
+
fairseq/data/span_mask_tokens_dataset.py
|
143 |
+
fairseq/data/speech_dlm_dataset.py
|
144 |
+
fairseq/data/strip_token_dataset.py
|
145 |
+
fairseq/data/subsample_dataset.py
|
146 |
+
fairseq/data/text_compressor.py
|
147 |
+
fairseq/data/token_block_dataset.py
|
148 |
+
fairseq/data/token_block_utils_fast.pyx
|
149 |
+
fairseq/data/transform_eos_concat_langpair_dataset.py
|
150 |
+
fairseq/data/transform_eos_dataset.py
|
151 |
+
fairseq/data/transform_eos_lang_pair_dataset.py
|
152 |
+
fairseq/data/audio/__init__.py
|
153 |
+
fairseq/data/audio/audio_utils.py
|
154 |
+
fairseq/data/audio/data_cfg.py
|
155 |
+
fairseq/data/audio/frm_text_to_speech_dataset.py
|
156 |
+
fairseq/data/audio/hubert_dataset.py
|
157 |
+
fairseq/data/audio/multi_modality_dataset.py
|
158 |
+
fairseq/data/audio/raw_audio_dataset.py
|
159 |
+
fairseq/data/audio/speech_to_speech_dataset.py
|
160 |
+
fairseq/data/audio/speech_to_text_dataset.py
|
161 |
+
fairseq/data/audio/speech_to_text_joint_dataset.py
|
162 |
+
fairseq/data/audio/text_to_speech_dataset.py
|
163 |
+
fairseq/data/audio/dataset_transforms/__init__.py
|
164 |
+
fairseq/data/audio/dataset_transforms/concataugment.py
|
165 |
+
fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py
|
166 |
+
fairseq/data/audio/feature_transforms/__init__.py
|
167 |
+
fairseq/data/audio/feature_transforms/delta_deltas.py
|
168 |
+
fairseq/data/audio/feature_transforms/global_cmvn.py
|
169 |
+
fairseq/data/audio/feature_transforms/specaugment.py
|
170 |
+
fairseq/data/audio/feature_transforms/utterance_cmvn.py
|
171 |
+
fairseq/data/audio/waveform_transforms/__init__.py
|
172 |
+
fairseq/data/audio/waveform_transforms/noiseaugment.py
|
173 |
+
fairseq/data/encoders/__init__.py
|
174 |
+
fairseq/data/encoders/byte_bpe.py
|
175 |
+
fairseq/data/encoders/byte_utils.py
|
176 |
+
fairseq/data/encoders/bytes.py
|
177 |
+
fairseq/data/encoders/characters.py
|
178 |
+
fairseq/data/encoders/fastbpe.py
|
179 |
+
fairseq/data/encoders/gpt2_bpe.py
|
180 |
+
fairseq/data/encoders/gpt2_bpe_utils.py
|
181 |
+
fairseq/data/encoders/hf_bert_bpe.py
|
182 |
+
fairseq/data/encoders/hf_byte_bpe.py
|
183 |
+
fairseq/data/encoders/moses_tokenizer.py
|
184 |
+
fairseq/data/encoders/nltk_tokenizer.py
|
185 |
+
fairseq/data/encoders/sentencepiece_bpe.py
|
186 |
+
fairseq/data/encoders/space_tokenizer.py
|
187 |
+
fairseq/data/encoders/subword_nmt_bpe.py
|
188 |
+
fairseq/data/encoders/utils.py
|
189 |
+
fairseq/data/huffman/__init__.py
|
190 |
+
fairseq/data/huffman/huffman_coder.py
|
191 |
+
fairseq/data/huffman/huffman_mmap_indexed_dataset.py
|
192 |
+
fairseq/data/legacy/__init__.py
|
193 |
+
fairseq/data/legacy/block_pair_dataset.py
|
194 |
+
fairseq/data/legacy/masked_lm_dataset.py
|
195 |
+
fairseq/data/legacy/masked_lm_dictionary.py
|
196 |
+
fairseq/data/multilingual/__init__.py
|
197 |
+
fairseq/data/multilingual/multilingual_data_manager.py
|
198 |
+
fairseq/data/multilingual/multilingual_utils.py
|
199 |
+
fairseq/data/multilingual/sampled_multi_dataset.py
|
200 |
+
fairseq/data/multilingual/sampled_multi_epoch_dataset.py
|
201 |
+
fairseq/data/multilingual/sampling_method.py
|
202 |
+
fairseq/dataclass/__init__.py
|
203 |
+
fairseq/dataclass/configs.py
|
204 |
+
fairseq/dataclass/constants.py
|
205 |
+
fairseq/dataclass/initialize.py
|
206 |
+
fairseq/dataclass/utils.py
|
207 |
+
fairseq/distributed/__init__.py
|
208 |
+
fairseq/distributed/distributed_timeout_wrapper.py
|
209 |
+
fairseq/distributed/fully_sharded_data_parallel.py
|
210 |
+
fairseq/distributed/legacy_distributed_data_parallel.py
|
211 |
+
fairseq/distributed/module_proxy_wrapper.py
|
212 |
+
fairseq/distributed/tpu_distributed_data_parallel.py
|
213 |
+
fairseq/distributed/utils.py
|
214 |
+
fairseq/examples/.gitignore
|
215 |
+
fairseq/examples/__init__.py
|
216 |
+
fairseq/examples/MMPT/.gitignore
|
217 |
+
fairseq/examples/MMPT/CONFIG.md
|
218 |
+
fairseq/examples/MMPT/DATASET.md
|
219 |
+
fairseq/examples/MMPT/README.md
|
220 |
+
fairseq/examples/MMPT/endtask.md
|
221 |
+
fairseq/examples/MMPT/locallaunch.py
|
222 |
+
fairseq/examples/MMPT/pretraining.md
|
223 |
+
fairseq/examples/MMPT/setup.py
|
224 |
+
fairseq/examples/MMPT/videoclip.png
|
225 |
+
fairseq/examples/MMPT/vlm.png
|
226 |
+
fairseq/examples/MMPT/mmpt/__init__.py
|
227 |
+
fairseq/examples/MMPT/mmpt/datasets/__init__.py
|
228 |
+
fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py
|
229 |
+
fairseq/examples/MMPT/mmpt/datasets/mmdataset.py
|
230 |
+
fairseq/examples/MMPT/mmpt/evaluators/__init__.py
|
231 |
+
fairseq/examples/MMPT/mmpt/evaluators/evaluator.py
|
232 |
+
fairseq/examples/MMPT/mmpt/evaluators/metric.py
|
233 |
+
fairseq/examples/MMPT/mmpt/evaluators/predictor.py
|
234 |
+
fairseq/examples/MMPT/mmpt/losses/__init__.py
|
235 |
+
fairseq/examples/MMPT/mmpt/losses/fairseqmmloss.py
|
236 |
+
fairseq/examples/MMPT/mmpt/losses/loss.py
|
237 |
+
fairseq/examples/MMPT/mmpt/losses/nce.py
|
238 |
+
fairseq/examples/MMPT/mmpt/models/__init__.py
|
239 |
+
fairseq/examples/MMPT/mmpt/models/fairseqmmmodel.py
|
240 |
+
fairseq/examples/MMPT/mmpt/models/mmfusion.py
|
241 |
+
fairseq/examples/MMPT/mmpt/models/mmfusionnlg.py
|
242 |
+
fairseq/examples/MMPT/mmpt/models/transformermodel.py
|
243 |
+
fairseq/examples/MMPT/mmpt/modules/__init__.py
|
244 |
+
fairseq/examples/MMPT/mmpt/modules/mm.py
|
245 |
+
fairseq/examples/MMPT/mmpt/modules/retri.py
|
246 |
+
fairseq/examples/MMPT/mmpt/modules/vectorpool.py
|
247 |
+
fairseq/examples/MMPT/mmpt/processors/__init__.py
|
248 |
+
fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py
|
249 |
+
fairseq/examples/MMPT/mmpt/processors/dsprocessor.py
|
250 |
+
fairseq/examples/MMPT/mmpt/processors/how2processor.py
|
251 |
+
fairseq/examples/MMPT/mmpt/processors/how2retriprocessor.py
|
252 |
+
fairseq/examples/MMPT/mmpt/processors/processor.py
|
253 |
+
fairseq/examples/MMPT/mmpt/processors/models/s3dg.py
|
254 |
+
fairseq/examples/MMPT/mmpt/tasks/__init__.py
|
255 |
+
fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py
|
256 |
+
fairseq/examples/MMPT/mmpt/tasks/milncetask.py
|
257 |
+
fairseq/examples/MMPT/mmpt/tasks/retritask.py
|
258 |
+
fairseq/examples/MMPT/mmpt/tasks/task.py
|
259 |
+
fairseq/examples/MMPT/mmpt/tasks/vlmtask.py
|
260 |
+
fairseq/examples/MMPT/mmpt/utils/__init__.py
|
261 |
+
fairseq/examples/MMPT/mmpt/utils/load_config.py
|
262 |
+
fairseq/examples/MMPT/mmpt/utils/shardedtensor.py
|
263 |
+
fairseq/examples/MMPT/mmpt_cli/localjob.py
|
264 |
+
fairseq/examples/MMPT/mmpt_cli/predict.py
|
265 |
+
fairseq/examples/MMPT/projects/mfmmlm.yaml
|
266 |
+
fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml
|
267 |
+
fairseq/examples/MMPT/projects/mtm/vlm.yaml
|
268 |
+
fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml
|
269 |
+
fairseq/examples/MMPT/projects/mtm/vlm/crosstask.yaml
|
270 |
+
fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml
|
271 |
+
fairseq/examples/MMPT/projects/mtm/vlm/test_coin.yaml
|
272 |
+
fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml
|
273 |
+
fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml
|
274 |
+
fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml
|
275 |
+
fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml
|
276 |
+
fairseq/examples/MMPT/projects/mtm/vlm/test_youcook.yaml
|
277 |
+
fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml
|
278 |
+
fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml
|
279 |
+
fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml
|
280 |
+
fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml
|
281 |
+
fairseq/examples/MMPT/projects/mtm/vlm/youcookcap.yaml
|
282 |
+
fairseq/examples/MMPT/projects/retri/videoclip.yaml
|
283 |
+
fairseq/examples/MMPT/projects/retri/videoretri.yaml
|
284 |
+
fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml
|
285 |
+
fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml
|
286 |
+
fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml
|
287 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml
|
288 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml
|
289 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml
|
290 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml
|
291 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml
|
292 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml
|
293 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml
|
294 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml
|
295 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml
|
296 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml
|
297 |
+
fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml
|
298 |
+
fairseq/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml
|
299 |
+
fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml
|
300 |
+
fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml
|
301 |
+
fairseq/examples/MMPT/projects/task/coin.yaml
|
302 |
+
fairseq/examples/MMPT/projects/task/coin_videoclip.yaml
|
303 |
+
fairseq/examples/MMPT/projects/task/crosstask.yaml
|
304 |
+
fairseq/examples/MMPT/projects/task/crosstask_videoclip.yaml
|
305 |
+
fairseq/examples/MMPT/projects/task/default.yaml
|
306 |
+
fairseq/examples/MMPT/projects/task/ft.yaml
|
307 |
+
fairseq/examples/MMPT/projects/task/how2.yaml
|
308 |
+
fairseq/examples/MMPT/projects/task/test.yaml
|
309 |
+
fairseq/examples/MMPT/projects/task/test_coin.yaml
|
310 |
+
fairseq/examples/MMPT/projects/task/test_coin_videoclip.yaml
|
311 |
+
fairseq/examples/MMPT/projects/task/test_coin_zs.yaml
|
312 |
+
fairseq/examples/MMPT/projects/task/test_crosstask.yaml
|
313 |
+
fairseq/examples/MMPT/projects/task/test_crosstask_videoclip.yaml
|
314 |
+
fairseq/examples/MMPT/projects/task/test_crosstask_zs.yaml
|
315 |
+
fairseq/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml
|
316 |
+
fairseq/examples/MMPT/projects/task/test_didemo_zs.yaml
|
317 |
+
fairseq/examples/MMPT/projects/task/test_vtt.yaml
|
318 |
+
fairseq/examples/MMPT/projects/task/test_vtt_videoclip.yaml
|
319 |
+
fairseq/examples/MMPT/projects/task/test_vtt_zs.yaml
|
320 |
+
fairseq/examples/MMPT/projects/task/test_vttqa.yaml
|
321 |
+
fairseq/examples/MMPT/projects/task/test_vttqa_videoclip.yaml
|
322 |
+
fairseq/examples/MMPT/projects/task/test_vttqa_zs.yaml
|
323 |
+
fairseq/examples/MMPT/projects/task/test_youcook.yaml
|
324 |
+
fairseq/examples/MMPT/projects/task/test_youcook_videoclip.yaml
|
325 |
+
fairseq/examples/MMPT/projects/task/test_youcook_zs.yaml
|
326 |
+
fairseq/examples/MMPT/projects/task/test_youcookcap.yaml
|
327 |
+
fairseq/examples/MMPT/projects/task/vtt.yaml
|
328 |
+
fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml
|
329 |
+
fairseq/examples/MMPT/projects/task/vttqa.yaml
|
330 |
+
fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml
|
331 |
+
fairseq/examples/MMPT/projects/task/youcook.yaml
|
332 |
+
fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml
|
333 |
+
fairseq/examples/MMPT/projects/task/youcookcap.yaml
|
334 |
+
fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py
|
335 |
+
fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml
|
336 |
+
fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py
|
337 |
+
fairseq/examples/MMPT/scripts/video_feature_extractor/model.py
|
338 |
+
fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py
|
339 |
+
fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py
|
340 |
+
fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py
|
341 |
+
fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py
|
342 |
+
fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py
|
343 |
+
fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh
|
344 |
+
fairseq/examples/adaptive_span/README.md
|
345 |
+
fairseq/examples/adaptive_span/__init__.py
|
346 |
+
fairseq/examples/adaptive_span/adagrad_with_grad_clip.py
|
347 |
+
fairseq/examples/adaptive_span/adaptive_span_attention.py
|
348 |
+
fairseq/examples/adaptive_span/adaptive_span_loss.py
|
349 |
+
fairseq/examples/adaptive_span/adaptive_span_model.py
|
350 |
+
fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py
|
351 |
+
fairseq/examples/adaptive_span/truncated_bptt_lm_task.py
|
352 |
+
fairseq/examples/attention_head_selection/README.md
|
353 |
+
fairseq/examples/attention_head_selection/src/__init__.py
|
354 |
+
fairseq/examples/attention_head_selection/src/speech_to_text_head_selection.py
|
355 |
+
fairseq/examples/attention_head_selection/src/data/__init__.py
|
356 |
+
fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py
|
357 |
+
fairseq/examples/attention_head_selection/src/loss/__init__.py
|
358 |
+
fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py
|
359 |
+
fairseq/examples/attention_head_selection/src/models/__init__.py
|
360 |
+
fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py
|
361 |
+
fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py
|
362 |
+
fairseq/examples/attention_head_selection/src/modules/__init__.py
|
363 |
+
fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py
|
364 |
+
fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py
|
365 |
+
fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py
|
366 |
+
fairseq/examples/attention_head_selection/src/modules/multihead_functional.py
|
367 |
+
fairseq/examples/audio_nlp/nlu/README.md
|
368 |
+
fairseq/examples/audio_nlp/nlu/create_dict_stop.sh
|
369 |
+
fairseq/examples/audio_nlp/nlu/generate_manifests.py
|
370 |
+
fairseq/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml
|
371 |
+
fairseq/examples/backtranslation/README.md
|
372 |
+
fairseq/examples/backtranslation/deduplicate_lines.py
|
373 |
+
fairseq/examples/backtranslation/extract_bt_data.py
|
374 |
+
fairseq/examples/backtranslation/prepare-de-monolingual.sh
|
375 |
+
fairseq/examples/backtranslation/prepare-wmt18en2de.sh
|
376 |
+
fairseq/examples/backtranslation/sacrebleu.sh
|
377 |
+
fairseq/examples/backtranslation/tokenized_bleu.sh
|
378 |
+
fairseq/examples/bart/README.glue.md
|
379 |
+
fairseq/examples/bart/README.md
|
380 |
+
fairseq/examples/bart/README.summarization.md
|
381 |
+
fairseq/examples/bart/summarize.py
|
382 |
+
fairseq/examples/byte_level_bpe/README.md
|
383 |
+
fairseq/examples/byte_level_bpe/get_bitext.py
|
384 |
+
fairseq/examples/byte_level_bpe/get_data.sh
|
385 |
+
fairseq/examples/byte_level_bpe/gru_transformer.py
|
386 |
+
fairseq/examples/camembert/README.md
|
387 |
+
fairseq/examples/constrained_decoding/README.md
|
388 |
+
fairseq/examples/constrained_decoding/normalize.py
|
389 |
+
fairseq/examples/constrained_decoding/tok.py
|
390 |
+
fairseq/examples/conv_seq2seq/README.md
|
391 |
+
fairseq/examples/criss/README.md
|
392 |
+
fairseq/examples/criss/download_and_preprocess_flores_test.sh
|
393 |
+
fairseq/examples/criss/download_and_preprocess_tatoeba.sh
|
394 |
+
fairseq/examples/criss/save_encoder.py
|
395 |
+
fairseq/examples/criss/mining/mine.py
|
396 |
+
fairseq/examples/criss/mining/mine_example.sh
|
397 |
+
fairseq/examples/criss/sentence_retrieval/encoder_analysis.py
|
398 |
+
fairseq/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh
|
399 |
+
fairseq/examples/criss/unsupervised_mt/eval.sh
|
400 |
+
fairseq/examples/cross_lingual_language_model/README.md
|
401 |
+
fairseq/examples/data2vec/README.md
|
402 |
+
fairseq/examples/data2vec/__init__.py
|
403 |
+
fairseq/examples/data2vec/fb_convert_beit_cp.py
|
404 |
+
fairseq/examples/data2vec/config/audio/classification/base_classification.yaml
|
405 |
+
fairseq/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml
|
406 |
+
fairseq/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml
|
407 |
+
fairseq/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml
|
408 |
+
fairseq/examples/data2vec/config/audio/pretraining/audioset.yaml
|
409 |
+
fairseq/examples/data2vec/config/audio/pretraining/base_librispeech.yaml
|
410 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/local.yaml
|
411 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml
|
412 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml
|
413 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml
|
414 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml
|
415 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml
|
416 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml
|
417 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml
|
418 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml
|
419 |
+
fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml
|
420 |
+
fairseq/examples/data2vec/config/text/pretraining/base.yaml
|
421 |
+
fairseq/examples/data2vec/config/text/pretraining/run_config/local.yaml
|
422 |
+
fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml
|
423 |
+
fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml
|
424 |
+
fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml
|
425 |
+
fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml
|
426 |
+
fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml
|
427 |
+
fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml
|
428 |
+
fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml
|
429 |
+
fairseq/examples/data2vec/config/v2/base_audio_only_task.yaml
|
430 |
+
fairseq/examples/data2vec/config/v2/base_images_only_task.yaml
|
431 |
+
fairseq/examples/data2vec/config/v2/base_text_only_task.yaml
|
432 |
+
fairseq/examples/data2vec/config/v2/huge_images14_only_task.yaml
|
433 |
+
fairseq/examples/data2vec/config/v2/huge_images_only_task.yaml
|
434 |
+
fairseq/examples/data2vec/config/v2/large_audio_only_task.yaml
|
435 |
+
fairseq/examples/data2vec/config/v2/large_images_only_task.yaml
|
436 |
+
fairseq/examples/data2vec/config/v2/large_text_only_task.yaml
|
437 |
+
fairseq/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml
|
438 |
+
fairseq/examples/data2vec/config/v2/run_config/local.yaml
|
439 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_1.yaml
|
440 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml
|
441 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_2.yaml
|
442 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml
|
443 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_3.yaml
|
444 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_4.yaml
|
445 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml
|
446 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml
|
447 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_8.yaml
|
448 |
+
fairseq/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml
|
449 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/cola.yaml
|
450 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/mnli.yaml
|
451 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/mrpc.yaml
|
452 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/qnli.yaml
|
453 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/qqp.yaml
|
454 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/rte.yaml
|
455 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/sst_2.yaml
|
456 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/sts_b.yaml
|
457 |
+
fairseq/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml
|
458 |
+
fairseq/examples/data2vec/config/vision/finetuning/imagenet.yaml
|
459 |
+
fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml
|
460 |
+
fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml
|
461 |
+
fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml
|
462 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/local.yaml
|
463 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml
|
464 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml
|
465 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml
|
466 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml
|
467 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml
|
468 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml
|
469 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml
|
470 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml
|
471 |
+
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml
|
472 |
+
fairseq/examples/data2vec/config/vision/pretraining/base_imagenet.yaml
|
473 |
+
fairseq/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml
|
474 |
+
fairseq/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml
|
475 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/local.yaml
|
476 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml
|
477 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml
|
478 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml
|
479 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml
|
480 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml
|
481 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml
|
482 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml
|
483 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml
|
484 |
+
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml
|
485 |
+
fairseq/examples/data2vec/data/__init__.py
|
486 |
+
fairseq/examples/data2vec/data/add_class_target_dataset.py
|
487 |
+
fairseq/examples/data2vec/data/image_dataset.py
|
488 |
+
fairseq/examples/data2vec/data/mae_finetuning_image_dataset.py
|
489 |
+
fairseq/examples/data2vec/data/mae_image_dataset.py
|
490 |
+
fairseq/examples/data2vec/data/modality.py
|
491 |
+
fairseq/examples/data2vec/data/path_dataset.py
|
492 |
+
fairseq/examples/data2vec/models/__init__.py
|
493 |
+
fairseq/examples/data2vec/models/audio_classification.py
|
494 |
+
fairseq/examples/data2vec/models/data2vec2.py
|
495 |
+
fairseq/examples/data2vec/models/data2vec_audio.py
|
496 |
+
fairseq/examples/data2vec/models/data2vec_image_classification.py
|
497 |
+
fairseq/examples/data2vec/models/data2vec_text.py
|
498 |
+
fairseq/examples/data2vec/models/data2vec_text_classification.py
|
499 |
+
fairseq/examples/data2vec/models/data2vec_vision.py
|
500 |
+
fairseq/examples/data2vec/models/mae.py
|
501 |
+
fairseq/examples/data2vec/models/mae_image_classification.py
|
502 |
+
fairseq/examples/data2vec/models/utils.py
|
503 |
+
fairseq/examples/data2vec/models/modalities/__init__.py
|
504 |
+
fairseq/examples/data2vec/models/modalities/audio.py
|
505 |
+
fairseq/examples/data2vec/models/modalities/base.py
|
506 |
+
fairseq/examples/data2vec/models/modalities/images.py
|
507 |
+
fairseq/examples/data2vec/models/modalities/modules.py
|
508 |
+
fairseq/examples/data2vec/models/modalities/text.py
|
509 |
+
fairseq/examples/data2vec/scripts/convert_audioset_labels.py
|
510 |
+
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh
|
511 |
+
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh
|
512 |
+
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh
|
513 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh
|
514 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair.sh
|
515 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws.sh
|
516 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh
|
517 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh
|
518 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh
|
519 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh
|
520 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh
|
521 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh
|
522 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh
|
523 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh
|
524 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh
|
525 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh
|
526 |
+
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh
|
527 |
+
fairseq/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh
|
528 |
+
fairseq/examples/data2vec/scripts/text/glue.py
|
529 |
+
fairseq/examples/data2vec/scripts/text/glue_lr.py
|
530 |
+
fairseq/examples/data2vec/scripts/text/unprocess_data.py
|
531 |
+
fairseq/examples/data2vec/scripts/text/valids.py
|
532 |
+
fairseq/examples/data2vec/tasks/__init__.py
|
533 |
+
fairseq/examples/data2vec/tasks/audio_classification.py
|
534 |
+
fairseq/examples/data2vec/tasks/image_classification.py
|
535 |
+
fairseq/examples/data2vec/tasks/image_pretraining.py
|
536 |
+
fairseq/examples/data2vec/tasks/mae_image_classification.py
|
537 |
+
fairseq/examples/data2vec/tasks/mae_image_pretraining.py
|
538 |
+
fairseq/examples/data2vec/tasks/multimodal.py
|
539 |
+
fairseq/examples/discriminative_reranking_nmt/README.md
|
540 |
+
fairseq/examples/discriminative_reranking_nmt/__init__.py
|
541 |
+
fairseq/examples/discriminative_reranking_nmt/drnmt_rerank.py
|
542 |
+
fairseq/examples/discriminative_reranking_nmt/config/deen.yaml
|
543 |
+
fairseq/examples/discriminative_reranking_nmt/criterions/__init__.py
|
544 |
+
fairseq/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py
|
545 |
+
fairseq/examples/discriminative_reranking_nmt/models/__init__.py
|
546 |
+
fairseq/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py
|
547 |
+
fairseq/examples/discriminative_reranking_nmt/scripts/prep_data.py
|
548 |
+
fairseq/examples/discriminative_reranking_nmt/tasks/__init__.py
|
549 |
+
fairseq/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py
|
550 |
+
fairseq/examples/emotion_conversion/README.md
|
551 |
+
fairseq/examples/emotion_conversion/requirements.txt
|
552 |
+
fairseq/examples/emotion_conversion/synthesize.py
|
553 |
+
fairseq/examples/emotion_conversion/emotion_models/__init__.py
|
554 |
+
fairseq/examples/emotion_conversion/emotion_models/duration_predictor.py
|
555 |
+
fairseq/examples/emotion_conversion/emotion_models/duration_predictor.yaml
|
556 |
+
fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.py
|
557 |
+
fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.yaml
|
558 |
+
fairseq/examples/emotion_conversion/emotion_models/utils.py
|
559 |
+
fairseq/examples/emotion_conversion/fairseq_models/__init__.py
|
560 |
+
fairseq/examples/emotion_conversion/preprocess/__init__.py
|
561 |
+
fairseq/examples/emotion_conversion/preprocess/build_hifigan_manifest.py
|
562 |
+
fairseq/examples/emotion_conversion/preprocess/build_translation_manifests.py
|
563 |
+
fairseq/examples/emotion_conversion/preprocess/create_core_manifest.py
|
564 |
+
fairseq/examples/emotion_conversion/preprocess/extract_f0.py
|
565 |
+
fairseq/examples/emotion_conversion/preprocess/process_km.py
|
566 |
+
fairseq/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py
|
567 |
+
fairseq/examples/emotion_conversion/preprocess/split_km.py
|
568 |
+
fairseq/examples/emotion_conversion/preprocess/split_km_tsv.py
|
569 |
+
fairseq/examples/fast_noisy_channel/README.md
|
570 |
+
fairseq/examples/fast_noisy_channel/__init__.py
|
571 |
+
fairseq/examples/fast_noisy_channel/noisy_channel_beam_search.py
|
572 |
+
fairseq/examples/fast_noisy_channel/noisy_channel_sequence_generator.py
|
573 |
+
fairseq/examples/fast_noisy_channel/noisy_channel_translation.py
|
574 |
+
fairseq/examples/flores101/README.md
|
575 |
+
fairseq/examples/flores101/flores_logo.png
|
576 |
+
fairseq/examples/fully_sharded_data_parallel/README.md
|
577 |
+
fairseq/examples/gottbert/README.md
|
578 |
+
fairseq/examples/hubert/README.md
|
579 |
+
fairseq/examples/hubert/measure_teacher_quality.py
|
580 |
+
fairseq/examples/hubert/update_ckpt.py
|
581 |
+
fairseq/examples/hubert/config/decode/infer_fsqlm.yaml
|
582 |
+
fairseq/examples/hubert/config/decode/infer_kenlm.yaml
|
583 |
+
fairseq/examples/hubert/config/decode/infer_viterbi.yaml
|
584 |
+
fairseq/examples/hubert/config/decode/ax_sweep/ngram.yaml
|
585 |
+
fairseq/examples/hubert/config/decode/ax_sweep/transformer.yaml
|
586 |
+
fairseq/examples/hubert/config/decode/run/submitit_slurm.yaml
|
587 |
+
fairseq/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml
|
588 |
+
fairseq/examples/hubert/config/finetune/base_10h.yaml
|
589 |
+
fairseq/examples/hubert/config/finetune/ckpt/it1.yaml
|
590 |
+
fairseq/examples/hubert/config/finetune/lm/ls_4gram.yaml
|
591 |
+
fairseq/examples/hubert/config/finetune/run/submitit_reg.yaml
|
592 |
+
fairseq/examples/hubert/config/pretrain/hubert_base_librispeech.yaml
|
593 |
+
fairseq/examples/hubert/config/pretrain/hubert_large_librivox.yaml
|
594 |
+
fairseq/examples/hubert/config/pretrain/hubert_xlarge_librivox.yaml
|
595 |
+
fairseq/examples/hubert/config/pretrain/data/iter1.yaml
|
596 |
+
fairseq/examples/hubert/config/pretrain/data/iter2.yaml
|
597 |
+
fairseq/examples/hubert/config/pretrain/run/submitit_reg.yaml
|
598 |
+
fairseq/examples/hubert/simple_kmeans/README.md
|
599 |
+
fairseq/examples/hubert/simple_kmeans/dump_hubert_feature.py
|
600 |
+
fairseq/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py
|
601 |
+
fairseq/examples/hubert/simple_kmeans/dump_km_label.py
|
602 |
+
fairseq/examples/hubert/simple_kmeans/dump_mfcc_feature.py
|
603 |
+
fairseq/examples/hubert/simple_kmeans/dump_w2v2_feature.py
|
604 |
+
fairseq/examples/hubert/simple_kmeans/feature_utils.py
|
605 |
+
fairseq/examples/hubert/simple_kmeans/learn_kmeans.py
|
606 |
+
fairseq/examples/hubert/tests/6313-76958-0021.flac
|
607 |
+
fairseq/examples/hubert/tests/sample.base.L9.km500.km
|
608 |
+
fairseq/examples/hubert/tests/sample.base.L9.len
|
609 |
+
fairseq/examples/hubert/tests/sample.base.L9.npy
|
610 |
+
fairseq/examples/hubert/tests/sample.large.L20.len
|
611 |
+
fairseq/examples/hubert/tests/sample.large.L20.npy
|
612 |
+
fairseq/examples/hubert/tests/sample.large.hypo.word
|
613 |
+
fairseq/examples/hubert/tests/sample.xlarge.L30.len
|
614 |
+
fairseq/examples/hubert/tests/sample.xlarge.L30.npy
|
615 |
+
fairseq/examples/hubert/tests/sample.xlarge.hypo.word
|
616 |
+
fairseq/examples/hubert/tests/test_feature_and_unit.sh
|
617 |
+
fairseq/examples/hubert/tests/test_finetuned_asr.sh
|
618 |
+
fairseq/examples/joint_alignment_translation/README.md
|
619 |
+
fairseq/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
|
620 |
+
fairseq/examples/language_model/README.adaptive_inputs.md
|
621 |
+
fairseq/examples/language_model/README.conv.md
|
622 |
+
fairseq/examples/language_model/README.md
|
623 |
+
fairseq/examples/language_model/prepare-wikitext-103.sh
|
624 |
+
fairseq/examples/laser/README.md
|
625 |
+
fairseq/examples/laser/laser_src/__init__.py
|
626 |
+
fairseq/examples/laser/laser_src/laser_lstm.py
|
627 |
+
fairseq/examples/laser/laser_src/laser_task.py
|
628 |
+
fairseq/examples/laser/laser_src/laser_transformer.py
|
629 |
+
fairseq/examples/laser/laser_src/multitask_data_utils.py
|
630 |
+
fairseq/examples/latent_depth/README.md
|
631 |
+
fairseq/examples/latent_depth/latent_depth_src/__init__.py
|
632 |
+
fairseq/examples/latent_depth/latent_depth_src/multilingual_translation_latent_depth.py
|
633 |
+
fairseq/examples/latent_depth/latent_depth_src/loss/__init__.py
|
634 |
+
fairseq/examples/latent_depth/latent_depth_src/loss/latent_depth.py
|
635 |
+
fairseq/examples/latent_depth/latent_depth_src/models/__init__.py
|
636 |
+
fairseq/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py
|
637 |
+
fairseq/examples/latent_depth/latent_depth_src/models/latent_transformer.py
|
638 |
+
fairseq/examples/latent_depth/latent_depth_src/modules/__init__.py
|
639 |
+
fairseq/examples/latent_depth/latent_depth_src/modules/latent_layers.py
|
640 |
+
fairseq/examples/layerdrop/README.md
|
641 |
+
fairseq/examples/linformer/README.md
|
642 |
+
fairseq/examples/linformer/linformer_src/__init__.py
|
643 |
+
fairseq/examples/linformer/linformer_src/models/__init__.py
|
644 |
+
fairseq/examples/linformer/linformer_src/models/linformer_roberta.py
|
645 |
+
fairseq/examples/linformer/linformer_src/modules/__init__.py
|
646 |
+
fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py
|
647 |
+
fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py
|
648 |
+
fairseq/examples/linformer/linformer_src/modules/multihead_linear_attention.py
|
649 |
+
fairseq/examples/m2m_100/README.md
|
650 |
+
fairseq/examples/m2m_100/install_dependecies.sh
|
651 |
+
fairseq/examples/m2m_100/tok.sh
|
652 |
+
fairseq/examples/m2m_100/process_data/clean_histogram.py
|
653 |
+
fairseq/examples/m2m_100/process_data/dedup_data.py
|
654 |
+
fairseq/examples/m2m_100/process_data/remove_too_much_punc.py
|
655 |
+
fairseq/examples/m2m_100/tokenizers/README.md
|
656 |
+
fairseq/examples/m2m_100/tokenizers/seg_ja.sh
|
657 |
+
fairseq/examples/m2m_100/tokenizers/seg_ko.sh
|
658 |
+
fairseq/examples/m2m_100/tokenizers/tokenize_indic.py
|
659 |
+
fairseq/examples/m2m_100/tokenizers/tokenize_thai.py
|
660 |
+
fairseq/examples/m2m_100/tokenizers/tokenize_zh.py
|
661 |
+
fairseq/examples/m2m_100/tokenizers/tokenizer_ar.sh
|
662 |
+
fairseq/examples/m2m_100/tokenizers/thirdparty/.gitignore
|
663 |
+
fairseq/examples/mbart/README.md
|
664 |
+
fairseq/examples/megatron_11b/README.md
|
665 |
+
fairseq/examples/megatron_11b/detok.py
|
666 |
+
fairseq/examples/mms/MODEL_CARD.md
|
667 |
+
fairseq/examples/mms/README.md
|
668 |
+
fairseq/examples/mms/asr/config/infer_common.yaml
|
669 |
+
fairseq/examples/mms/asr/infer/example_infer_adapter.sh
|
670 |
+
fairseq/examples/mms/asr/infer/mms_infer.py
|
671 |
+
fairseq/examples/mms/asr/tutorial/MMS_ASR_Inference_Colab.ipynb
|
672 |
+
fairseq/examples/mms/data_prep/README.md
|
673 |
+
fairseq/examples/mms/data_prep/align_and_segment.py
|
674 |
+
fairseq/examples/mms/data_prep/align_utils.py
|
675 |
+
fairseq/examples/mms/data_prep/norm_config.py
|
676 |
+
fairseq/examples/mms/data_prep/punctuations.lst
|
677 |
+
fairseq/examples/mms/data_prep/text_normalization.py
|
678 |
+
fairseq/examples/mms/lid/infer.py
|
679 |
+
fairseq/examples/mms/lid/tutorial/MMS_LID_Inference_Colab.ipynb
|
680 |
+
fairseq/examples/mms/lid_rerank/README.md
|
681 |
+
fairseq/examples/mms/lid_rerank/cer_langs.txt
|
682 |
+
fairseq/examples/mms/lid_rerank/requirements.txt
|
683 |
+
fairseq/examples/mms/lid_rerank/mala/infer.py
|
684 |
+
fairseq/examples/mms/lid_rerank/mms/make_parallel_single_runs.py
|
685 |
+
fairseq/examples/mms/lid_rerank/mms/merge_by_lang.py
|
686 |
+
fairseq/examples/mms/lid_rerank/mms/prep_wav_list.py
|
687 |
+
fairseq/examples/mms/lid_rerank/mms/run_single_lang.py
|
688 |
+
fairseq/examples/mms/lid_rerank/mms/split_by_lang.py
|
689 |
+
fairseq/examples/mms/lid_rerank/mms-zs/falign.py
|
690 |
+
fairseq/examples/mms/lid_rerank/mms-zs/lib.py
|
691 |
+
fairseq/examples/mms/lid_rerank/mms-zs/uromanize.py
|
692 |
+
fairseq/examples/mms/lid_rerank/nllb/infer.py
|
693 |
+
fairseq/examples/mms/lid_rerank/rerank/rerank.py
|
694 |
+
fairseq/examples/mms/lid_rerank/rerank/tune_coefficients.py
|
695 |
+
fairseq/examples/mms/lid_rerank/whisper/infer_asr.py
|
696 |
+
fairseq/examples/mms/lid_rerank/whisper/infer_lid.py
|
697 |
+
fairseq/examples/mms/lid_rerank/whisper/lid_mapping.txt
|
698 |
+
fairseq/examples/mms/misc/get_sample_size.py
|
699 |
+
fairseq/examples/mms/tts/infer.py
|
700 |
+
fairseq/examples/mms/tts/tutorial/MMS_TTS_Inference_Colab.ipynb
|
701 |
+
fairseq/examples/mms/zero_shot/README.md
|
702 |
+
fairseq/examples/moe_lm/README.md
|
703 |
+
fairseq/examples/moe_lm/data_card.md
|
704 |
+
fairseq/examples/moe_lm/model_card.md
|
705 |
+
fairseq/examples/mr_hubert/README.md
|
706 |
+
fairseq/examples/mr_hubert/decode.sh
|
707 |
+
fairseq/examples/mr_hubert/finetune.sh
|
708 |
+
fairseq/examples/mr_hubert/train.sh
|
709 |
+
fairseq/examples/mr_hubert/config/decode/infer.yaml
|
710 |
+
fairseq/examples/mr_hubert/config/decode/infer_lm.yaml
|
711 |
+
fairseq/examples/mr_hubert/config/decode/run/submitit_slurm.yaml
|
712 |
+
fairseq/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml
|
713 |
+
fairseq/examples/mr_hubert/config/finetune/base_100h.yaml
|
714 |
+
fairseq/examples/mr_hubert/config/finetune/base_100h_large.yaml
|
715 |
+
fairseq/examples/mr_hubert/config/finetune/base_10h.yaml
|
716 |
+
fairseq/examples/mr_hubert/config/finetune/base_10h_large.yaml
|
717 |
+
fairseq/examples/mr_hubert/config/finetune/base_1h.yaml
|
718 |
+
fairseq/examples/mr_hubert/config/finetune/base_1h_large.yaml
|
719 |
+
fairseq/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml
|
720 |
+
fairseq/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml
|
721 |
+
fairseq/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml
|
722 |
+
fairseq/examples/mr_hubert/simple_kmeans/README.md
|
723 |
+
fairseq/examples/mr_hubert/simple_kmeans/dump_hubert_feature.py
|
724 |
+
fairseq/examples/mr_hubert/simple_kmeans/dump_hubert_feature_s2t.py
|
725 |
+
fairseq/examples/mr_hubert/simple_kmeans/dump_km_label.py
|
726 |
+
fairseq/examples/mr_hubert/simple_kmeans/dump_mfcc_feature.py
|
727 |
+
fairseq/examples/mr_hubert/simple_kmeans/dump_w2v2_feature.py
|
728 |
+
fairseq/examples/mr_hubert/simple_kmeans/feature_utils.py
|
729 |
+
fairseq/examples/mr_hubert/simple_kmeans/learn_kmeans.py
|
730 |
+
fairseq/examples/multilingual/ML50_langs.txt
|
731 |
+
fairseq/examples/multilingual/README.md
|
732 |
+
fairseq/examples/multilingual/finetune_multilingual_model.sh
|
733 |
+
fairseq/examples/multilingual/multilingual_fairseq_gen.sh
|
734 |
+
fairseq/examples/multilingual/train_multilingual_model.sh
|
735 |
+
fairseq/examples/multilingual/data_scripts/README.md
|
736 |
+
fairseq/examples/multilingual/data_scripts/binarize.py
|
737 |
+
fairseq/examples/multilingual/data_scripts/check_iswlt_test_data.py
|
738 |
+
fairseq/examples/multilingual/data_scripts/check_self_overlaps.py
|
739 |
+
fairseq/examples/multilingual/data_scripts/check_valid_test_overlaps.py
|
740 |
+
fairseq/examples/multilingual/data_scripts/dedup_all.py
|
741 |
+
fairseq/examples/multilingual/data_scripts/download_ML50_v1.sh
|
742 |
+
fairseq/examples/multilingual/data_scripts/download_af_xh.sh
|
743 |
+
fairseq/examples/multilingual/data_scripts/download_flores_data.sh
|
744 |
+
fairseq/examples/multilingual/data_scripts/download_iitb.sh
|
745 |
+
fairseq/examples/multilingual/data_scripts/download_iwslt_and_extract.sh
|
746 |
+
fairseq/examples/multilingual/data_scripts/download_lotus.sh
|
747 |
+
fairseq/examples/multilingual/data_scripts/download_ted_and_extract.py
|
748 |
+
fairseq/examples/multilingual/data_scripts/download_wat19_my.sh
|
749 |
+
fairseq/examples/multilingual/data_scripts/download_wmt19_and_before.py
|
750 |
+
fairseq/examples/multilingual/data_scripts/download_wmt20.sh
|
751 |
+
fairseq/examples/multilingual/data_scripts/preprocess_ML50_v1.sh
|
752 |
+
fairseq/examples/multilingual/data_scripts/remove_valid_test_in_train.py
|
753 |
+
fairseq/examples/multilingual/data_scripts/requirement.txt
|
754 |
+
fairseq/examples/multilingual/data_scripts/utils/dedup.py
|
755 |
+
fairseq/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py
|
756 |
+
fairseq/examples/multilingual/data_scripts/utils/strip_sgm.sh
|
757 |
+
fairseq/examples/noisychannel/README.md
|
758 |
+
fairseq/examples/noisychannel/__init__.py
|
759 |
+
fairseq/examples/noisychannel/rerank.py
|
760 |
+
fairseq/examples/noisychannel/rerank_generate.py
|
761 |
+
fairseq/examples/noisychannel/rerank_options.py
|
762 |
+
fairseq/examples/noisychannel/rerank_score_bw.py
|
763 |
+
fairseq/examples/noisychannel/rerank_score_lm.py
|
764 |
+
fairseq/examples/noisychannel/rerank_tune.py
|
765 |
+
fairseq/examples/noisychannel/rerank_utils.py
|
766 |
+
fairseq/examples/nonautoregressive_translation/README.md
|
767 |
+
fairseq/examples/nonautoregressive_translation/scripts.md
|
768 |
+
fairseq/examples/normformer/README.md
|
769 |
+
fairseq/examples/normformer/train_lm.sh
|
770 |
+
fairseq/examples/operators/alignment_train_cpu.cpp
|
771 |
+
fairseq/examples/operators/alignment_train_cuda.cpp
|
772 |
+
fairseq/examples/operators/alignment_train_cuda.h
|
773 |
+
fairseq/examples/operators/alignment_train_kernel.cu
|
774 |
+
fairseq/examples/operators/utils.h
|
775 |
+
fairseq/examples/paraphraser/README.md
|
776 |
+
fairseq/examples/paraphraser/paraphrase.py
|
777 |
+
fairseq/examples/pay_less_attention_paper/README.md
|
778 |
+
fairseq/examples/pointer_generator/README.md
|
779 |
+
fairseq/examples/pointer_generator/README.xsum.md
|
780 |
+
fairseq/examples/pointer_generator/postprocess.py
|
781 |
+
fairseq/examples/pointer_generator/preprocess.py
|
782 |
+
fairseq/examples/pointer_generator/pointer_generator_src/__init__.py
|
783 |
+
fairseq/examples/pointer_generator/pointer_generator_src/transformer_pg.py
|
784 |
+
fairseq/examples/quant_noise/README.md
|
785 |
+
fairseq/examples/quant_noise/transformer_quantization_config.yaml
|
786 |
+
fairseq/examples/roberta/README.custom_classification.md
|
787 |
+
fairseq/examples/roberta/README.glue.md
|
788 |
+
fairseq/examples/roberta/README.md
|
789 |
+
fairseq/examples/roberta/README.pretraining.md
|
790 |
+
fairseq/examples/roberta/README.race.md
|
791 |
+
fairseq/examples/roberta/multiprocessing_bpe_encoder.py
|
792 |
+
fairseq/examples/roberta/preprocess_GLUE_tasks.sh
|
793 |
+
fairseq/examples/roberta/preprocess_RACE.py
|
794 |
+
fairseq/examples/roberta/preprocess_RACE.sh
|
795 |
+
fairseq/examples/roberta/commonsense_qa/README.md
|
796 |
+
fairseq/examples/roberta/commonsense_qa/__init__.py
|
797 |
+
fairseq/examples/roberta/commonsense_qa/commonsense_qa_task.py
|
798 |
+
fairseq/examples/roberta/commonsense_qa/download_cqa_data.sh
|
799 |
+
fairseq/examples/roberta/config/finetuning/cola.yaml
|
800 |
+
fairseq/examples/roberta/config/finetuning/mnli.yaml
|
801 |
+
fairseq/examples/roberta/config/finetuning/mrpc.yaml
|
802 |
+
fairseq/examples/roberta/config/finetuning/qnli.yaml
|
803 |
+
fairseq/examples/roberta/config/finetuning/qqp.yaml
|
804 |
+
fairseq/examples/roberta/config/finetuning/rte.yaml
|
805 |
+
fairseq/examples/roberta/config/finetuning/sst_2.yaml
|
806 |
+
fairseq/examples/roberta/config/finetuning/sts_b.yaml
|
807 |
+
fairseq/examples/roberta/config/finetuning/run_config/local.yaml
|
808 |
+
fairseq/examples/roberta/config/finetuning/run_config/slurm_1g.yaml
|
809 |
+
fairseq/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml
|
810 |
+
fairseq/examples/roberta/config/pretraining/base.yaml
|
811 |
+
fairseq/examples/roberta/config/pretraining/run_config/local.yaml
|
812 |
+
fairseq/examples/roberta/config/pretraining/run_config/slurm_2.yaml
|
813 |
+
fairseq/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml
|
814 |
+
fairseq/examples/roberta/config/pretraining/run_config/slurm_3.yaml
|
815 |
+
fairseq/examples/roberta/config/pretraining/run_config/slurm_4.yaml
|
816 |
+
fairseq/examples/roberta/fb_multilingual/README.multilingual.pretraining.md
|
817 |
+
fairseq/examples/roberta/wsc/README.md
|
818 |
+
fairseq/examples/roberta/wsc/__init__.py
|
819 |
+
fairseq/examples/roberta/wsc/wsc_criterion.py
|
820 |
+
fairseq/examples/roberta/wsc/wsc_task.py
|
821 |
+
fairseq/examples/roberta/wsc/wsc_utils.py
|
822 |
+
fairseq/examples/rxf/README.md
|
823 |
+
fairseq/examples/rxf/__init__.py
|
824 |
+
fairseq/examples/rxf/rxf_src/__init__.py
|
825 |
+
fairseq/examples/rxf/rxf_src/label_smoothed_cross_entropy_r3f.py
|
826 |
+
fairseq/examples/rxf/rxf_src/sentence_prediction_r3f.py
|
827 |
+
fairseq/examples/scaling_nmt/README.md
|
828 |
+
fairseq/examples/shuffled_word_order/README.finetuning.md
|
829 |
+
fairseq/examples/shuffled_word_order/README.md
|
830 |
+
fairseq/examples/simultaneous_translation/README.md
|
831 |
+
fairseq/examples/simultaneous_translation/__init__.py
|
832 |
+
fairseq/examples/simultaneous_translation/docs/ende-mma.md
|
833 |
+
fairseq/examples/simultaneous_translation/docs/enja-waitk.md
|
834 |
+
fairseq/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py
|
835 |
+
fairseq/examples/simultaneous_translation/models/__init__.py
|
836 |
+
fairseq/examples/simultaneous_translation/models/convtransformer_simul_trans.py
|
837 |
+
fairseq/examples/simultaneous_translation/models/transformer_monotonic_attention.py
|
838 |
+
fairseq/examples/simultaneous_translation/modules/__init__.py
|
839 |
+
fairseq/examples/simultaneous_translation/modules/fixed_pre_decision.py
|
840 |
+
fairseq/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
|
841 |
+
fairseq/examples/simultaneous_translation/modules/monotonic_transformer_layer.py
|
842 |
+
fairseq/examples/simultaneous_translation/tests/test_alignment_train.py
|
843 |
+
fairseq/examples/simultaneous_translation/tests/test_text_models.py
|
844 |
+
fairseq/examples/simultaneous_translation/utils/__init__.py
|
845 |
+
fairseq/examples/simultaneous_translation/utils/functions.py
|
846 |
+
fairseq/examples/simultaneous_translation/utils/monotonic_attention.py
|
847 |
+
fairseq/examples/simultaneous_translation/utils/p_choose_strategy.py
|
848 |
+
fairseq/examples/speech_recognition/README.md
|
849 |
+
fairseq/examples/speech_recognition/__init__.py
|
850 |
+
fairseq/examples/speech_recognition/infer.py
|
851 |
+
fairseq/examples/speech_recognition/w2l_decoder.py
|
852 |
+
fairseq/examples/speech_recognition/criterions/ASG_loss.py
|
853 |
+
fairseq/examples/speech_recognition/criterions/__init__.py
|
854 |
+
fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py
|
855 |
+
fairseq/examples/speech_recognition/data/__init__.py
|
856 |
+
fairseq/examples/speech_recognition/data/asr_dataset.py
|
857 |
+
fairseq/examples/speech_recognition/data/collaters.py
|
858 |
+
fairseq/examples/speech_recognition/data/data_utils.py
|
859 |
+
fairseq/examples/speech_recognition/data/replabels.py
|
860 |
+
fairseq/examples/speech_recognition/datasets/asr_prep_json.py
|
861 |
+
fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh
|
862 |
+
fairseq/examples/speech_recognition/kaldi/__init__.py
|
863 |
+
fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc
|
864 |
+
fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py
|
865 |
+
fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py
|
866 |
+
fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml
|
867 |
+
fairseq/examples/speech_recognition/models/__init__.py
|
868 |
+
fairseq/examples/speech_recognition/models/vggtransformer.py
|
869 |
+
fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py
|
870 |
+
fairseq/examples/speech_recognition/new/README.md
|
871 |
+
fairseq/examples/speech_recognition/new/__init__.py
|
872 |
+
fairseq/examples/speech_recognition/new/infer.py
|
873 |
+
fairseq/examples/speech_recognition/new/conf/infer.yaml
|
874 |
+
fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml
|
875 |
+
fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml
|
876 |
+
fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml
|
877 |
+
fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml
|
878 |
+
fairseq/examples/speech_recognition/new/decoders/__init__.py
|
879 |
+
fairseq/examples/speech_recognition/new/decoders/base_decoder.py
|
880 |
+
fairseq/examples/speech_recognition/new/decoders/decoder.py
|
881 |
+
fairseq/examples/speech_recognition/new/decoders/decoder_config.py
|
882 |
+
fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py
|
883 |
+
fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py
|
884 |
+
fairseq/examples/speech_recognition/tasks/__init__.py
|
885 |
+
fairseq/examples/speech_recognition/tasks/speech_recognition.py
|
886 |
+
fairseq/examples/speech_recognition/utils/wer_utils.py
|
887 |
+
fairseq/examples/speech_synthesis/README.md
|
888 |
+
fairseq/examples/speech_synthesis/__init__.py
|
889 |
+
fairseq/examples/speech_synthesis/data_utils.py
|
890 |
+
fairseq/examples/speech_synthesis/generate_waveform.py
|
891 |
+
fairseq/examples/speech_synthesis/utils.py
|
892 |
+
fairseq/examples/speech_synthesis/docs/common_voice_example.md
|
893 |
+
fairseq/examples/speech_synthesis/docs/ljspeech_example.md
|
894 |
+
fairseq/examples/speech_synthesis/docs/vctk_example.md
|
895 |
+
fairseq/examples/speech_synthesis/evaluation/__init__.py
|
896 |
+
fairseq/examples/speech_synthesis/evaluation/eval_asr.py
|
897 |
+
fairseq/examples/speech_synthesis/evaluation/eval_f0.py
|
898 |
+
fairseq/examples/speech_synthesis/evaluation/eval_sp.py
|
899 |
+
fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py
|
900 |
+
fairseq/examples/speech_synthesis/preprocessing/__init__.py
|
901 |
+
fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py
|
902 |
+
fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py
|
903 |
+
fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py
|
904 |
+
fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py
|
905 |
+
fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py
|
906 |
+
fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py
|
907 |
+
fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py
|
908 |
+
fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py
|
909 |
+
fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py
|
910 |
+
fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py
|
911 |
+
fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py
|
912 |
+
fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py
|
913 |
+
fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py
|
914 |
+
fairseq/examples/speech_text_joint_to_text/README.md
|
915 |
+
fairseq/examples/speech_text_joint_to_text/__init__.py
|
916 |
+
fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list
|
917 |
+
fairseq/examples/speech_text_joint_to_text/criterions/__init__.py
|
918 |
+
fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py
|
919 |
+
fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py
|
920 |
+
fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py
|
921 |
+
fairseq/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py
|
922 |
+
fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md
|
923 |
+
fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md
|
924 |
+
fairseq/examples/speech_text_joint_to_text/docs/pre-training.md
|
925 |
+
fairseq/examples/speech_text_joint_to_text/models/__init__.py
|
926 |
+
fairseq/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py
|
927 |
+
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py
|
928 |
+
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py
|
929 |
+
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py
|
930 |
+
fairseq/examples/speech_text_joint_to_text/scripts/convert_model.py
|
931 |
+
fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py
|
932 |
+
fairseq/examples/speech_text_joint_to_text/tasks/__init__.py
|
933 |
+
fairseq/examples/speech_text_joint_to_text/tasks/pair_denoising.py
|
934 |
+
fairseq/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py
|
935 |
+
fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py
|
936 |
+
fairseq/examples/speech_to_speech/README.md
|
937 |
+
fairseq/examples/speech_to_speech/__init__.py
|
938 |
+
fairseq/examples/speech_to_speech/generate_waveform_from_code.py
|
939 |
+
fairseq/examples/speech_to_speech/asr_bleu/README.md
|
940 |
+
fairseq/examples/speech_to_speech/asr_bleu/__init__.py
|
941 |
+
fairseq/examples/speech_to_speech/asr_bleu/asr_model_cfgs.json
|
942 |
+
fairseq/examples/speech_to_speech/asr_bleu/compute_asr_bleu.py
|
943 |
+
fairseq/examples/speech_to_speech/asr_bleu/requirements.txt
|
944 |
+
fairseq/examples/speech_to_speech/asr_bleu/utils.py
|
945 |
+
fairseq/examples/speech_to_speech/benchmarking/README.md
|
946 |
+
fairseq/examples/speech_to_speech/benchmarking/core.py
|
947 |
+
fairseq/examples/speech_to_speech/benchmarking/data_utils.py
|
948 |
+
fairseq/examples/speech_to_speech/benchmarking/get_metrics.py
|
949 |
+
fairseq/examples/speech_to_speech/benchmarking/configs/2StageS2ST.yaml
|
950 |
+
fairseq/examples/speech_to_speech/benchmarking/configs/3StageS2ST.yaml
|
951 |
+
fairseq/examples/speech_to_speech/benchmarking/configs/DirectS2U.yaml
|
952 |
+
fairseq/examples/speech_to_speech/benchmarking/configs/S2T.yaml
|
953 |
+
fairseq/examples/speech_to_speech/docs/data_augmentation.md
|
954 |
+
fairseq/examples/speech_to_speech/docs/direct_s2st_discrete_units.md
|
955 |
+
fairseq/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md
|
956 |
+
fairseq/examples/speech_to_speech/docs/textless_s2st_real_data.md
|
957 |
+
fairseq/examples/speech_to_speech/preprocessing/__init__.py
|
958 |
+
fairseq/examples/speech_to_speech/preprocessing/data_utils.py
|
959 |
+
fairseq/examples/speech_to_speech/preprocessing/prep_s2spect_data.py
|
960 |
+
fairseq/examples/speech_to_speech/preprocessing/prep_s2ut_data.py
|
961 |
+
fairseq/examples/speech_to_speech/preprocessing/prep_sn_data.py
|
962 |
+
fairseq/examples/speech_to_speech/preprocessing/prep_sn_output_data.py
|
963 |
+
fairseq/examples/speech_to_speech/unity/__init__.py
|
964 |
+
fairseq/examples/speech_to_speech/unity/sequence_generator.py
|
965 |
+
fairseq/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py
|
966 |
+
fairseq/examples/speech_to_text/README.md
|
967 |
+
fairseq/examples/speech_to_text/data_utils.py
|
968 |
+
fairseq/examples/speech_to_text/prep_covost_data.py
|
969 |
+
fairseq/examples/speech_to_text/prep_librispeech_data.py
|
970 |
+
fairseq/examples/speech_to_text/prep_mtedx_data.py
|
971 |
+
fairseq/examples/speech_to_text/prep_mustc_data.py
|
972 |
+
fairseq/examples/speech_to_text/seg_mustc_data.py
|
973 |
+
fairseq/examples/speech_to_text/docs/covost_example.md
|
974 |
+
fairseq/examples/speech_to_text/docs/librispeech_example.md
|
975 |
+
fairseq/examples/speech_to_text/docs/mtedx_example.md
|
976 |
+
fairseq/examples/speech_to_text/docs/mustc_example.md
|
977 |
+
fairseq/examples/speech_to_text/docs/simulst_mustc_example.md
|
978 |
+
fairseq/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py
|
979 |
+
fairseq/examples/stories/README.md
|
980 |
+
fairseq/examples/textless_nlp/dgslm/README.md
|
981 |
+
fairseq/examples/textless_nlp/dgslm/create_code_file.py
|
982 |
+
fairseq/examples/textless_nlp/dgslm/dgslm_utils.py
|
983 |
+
fairseq/examples/textless_nlp/dgslm/sample_speech_dlm.py
|
984 |
+
fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md
|
985 |
+
fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md
|
986 |
+
fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/generate_stereo_waveform.py
|
987 |
+
fairseq/examples/textless_nlp/gslm/README.md
|
988 |
+
fairseq/examples/textless_nlp/gslm/metrics/README.md
|
989 |
+
fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/README.md
|
990 |
+
fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py
|
991 |
+
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md
|
992 |
+
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py
|
993 |
+
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py
|
994 |
+
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py
|
995 |
+
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py
|
996 |
+
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py
|
997 |
+
fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt
|
998 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/README.md
|
999 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/__init__.py
|
1000 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/__init__.py
|
1001 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py
|
1002 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py
|
1003 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py
|
1004 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py
|
1005 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py
|
1006 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py
|
1007 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py
|
1008 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py
|
1009 |
+
fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py
|
1010 |
+
fairseq/examples/textless_nlp/gslm/tools/README.md
|
1011 |
+
fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py
|
1012 |
+
fairseq/examples/textless_nlp/gslm/ulm/README.md
|
1013 |
+
fairseq/examples/textless_nlp/gslm/ulm/sample.py
|
1014 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/README.md
|
1015 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py
|
1016 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/glow.py
|
1017 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py
|
1018 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py
|
1019 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py
|
1020 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/utils.py
|
1021 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py
|
1022 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py
|
1023 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py
|
1024 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py
|
1025 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py
|
1026 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py
|
1027 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py
|
1028 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py
|
1029 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py
|
1030 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py
|
1031 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py
|
1032 |
+
fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py
|
1033 |
+
fairseq/examples/textless_nlp/pgslm/README.md
|
1034 |
+
fairseq/examples/textless_nlp/pgslm/data_utils.py
|
1035 |
+
fairseq/examples/textless_nlp/pgslm/generate_waveform.py
|
1036 |
+
fairseq/examples/textless_nlp/pgslm/inference_dataset.py
|
1037 |
+
fairseq/examples/textless_nlp/pgslm/naive_decoder.py
|
1038 |
+
fairseq/examples/textless_nlp/pgslm/prepare_dataset.py
|
1039 |
+
fairseq/examples/textless_nlp/pgslm/preprocess_f0.py
|
1040 |
+
fairseq/examples/textless_nlp/pgslm/quantize_f0.py
|
1041 |
+
fairseq/examples/textless_nlp/pgslm/truncated_laplace.py
|
1042 |
+
fairseq/examples/textless_nlp/pgslm/eval/__init__.py
|
1043 |
+
fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py
|
1044 |
+
fairseq/examples/textless_nlp/pgslm/sample/__init__.py
|
1045 |
+
fairseq/examples/textless_nlp/pgslm/sample/sample.py
|
1046 |
+
fairseq/examples/textless_nlp/pgslm/scripts/join_units_manifest.py
|
1047 |
+
fairseq/examples/textless_nlp/pgslm/scripts/prepare_data.sh
|
1048 |
+
fairseq/examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh
|
1049 |
+
fairseq/examples/textless_nlp/speech-resynth/README.md
|
1050 |
+
fairseq/examples/textless_nlp/speech-resynth/img/fig.png
|
1051 |
+
fairseq/examples/translation/README.md
|
1052 |
+
fairseq/examples/translation/prepare-iwslt14.sh
|
1053 |
+
fairseq/examples/translation/prepare-iwslt17-multilingual.sh
|
1054 |
+
fairseq/examples/translation/prepare-wmt14en2de.sh
|
1055 |
+
fairseq/examples/translation/prepare-wmt14en2fr.sh
|
1056 |
+
fairseq/examples/translation_moe/README.md
|
1057 |
+
fairseq/examples/translation_moe/score.py
|
1058 |
+
fairseq/examples/translation_moe/translation_moe_src/__init__.py
|
1059 |
+
fairseq/examples/translation_moe/translation_moe_src/logsumexp_moe.py
|
1060 |
+
fairseq/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py
|
1061 |
+
fairseq/examples/translation_moe/translation_moe_src/translation_moe.py
|
1062 |
+
fairseq/examples/truncated_bptt/README.md
|
1063 |
+
fairseq/examples/truncated_bptt/__init__.py
|
1064 |
+
fairseq/examples/truncated_bptt/transformer_xl_model.py
|
1065 |
+
fairseq/examples/truncated_bptt/truncated_bptt_lm_task.py
|
1066 |
+
fairseq/examples/unsupervised_quality_estimation/README.md
|
1067 |
+
fairseq/examples/unsupervised_quality_estimation/aggregate_scores.py
|
1068 |
+
fairseq/examples/unsupervised_quality_estimation/meteor.py
|
1069 |
+
fairseq/examples/unsupervised_quality_estimation/repeat_lines.py
|
1070 |
+
fairseq/examples/wav2vec/README.md
|
1071 |
+
fairseq/examples/wav2vec/__init__.py
|
1072 |
+
fairseq/examples/wav2vec/libri_labels.py
|
1073 |
+
fairseq/examples/wav2vec/vq-wav2vec_featurize.py
|
1074 |
+
fairseq/examples/wav2vec/wav2vec_featurize.py
|
1075 |
+
fairseq/examples/wav2vec/wav2vec_manifest.py
|
1076 |
+
fairseq/examples/wav2vec/config/finetuning/base_100h.yaml
|
1077 |
+
fairseq/examples/wav2vec/config/finetuning/base_10h.yaml
|
1078 |
+
fairseq/examples/wav2vec/config/finetuning/base_10m.yaml
|
1079 |
+
fairseq/examples/wav2vec/config/finetuning/base_1h.yaml
|
1080 |
+
fairseq/examples/wav2vec/config/finetuning/base_960h.yaml
|
1081 |
+
fairseq/examples/wav2vec/config/finetuning/vox_100h.yaml
|
1082 |
+
fairseq/examples/wav2vec/config/finetuning/vox_100h_2.yaml
|
1083 |
+
fairseq/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml
|
1084 |
+
fairseq/examples/wav2vec/config/finetuning/vox_100h_3.yaml
|
1085 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10h.yaml
|
1086 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10h_2.yaml
|
1087 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml
|
1088 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10h_aws.yaml
|
1089 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml
|
1090 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10m.yaml
|
1091 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10m_2.yaml
|
1092 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml
|
1093 |
+
fairseq/examples/wav2vec/config/finetuning/vox_10m_3.yaml
|
1094 |
+
fairseq/examples/wav2vec/config/finetuning/vox_1h.yaml
|
1095 |
+
fairseq/examples/wav2vec/config/finetuning/vox_1h_2.yaml
|
1096 |
+
fairseq/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml
|
1097 |
+
fairseq/examples/wav2vec/config/finetuning/vox_1h_3.yaml
|
1098 |
+
fairseq/examples/wav2vec/config/finetuning/vox_1h_4.yaml
|
1099 |
+
fairseq/examples/wav2vec/config/finetuning/vox_1h_aws.yaml
|
1100 |
+
fairseq/examples/wav2vec/config/finetuning/vox_960h.yaml
|
1101 |
+
fairseq/examples/wav2vec/config/finetuning/vox_960h_2.yaml
|
1102 |
+
fairseq/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml
|
1103 |
+
fairseq/examples/wav2vec/config/finetuning/vox_960h_3.yaml
|
1104 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml
|
1105 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml
|
1106 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml
|
1107 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml
|
1108 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml
|
1109 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml
|
1110 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml
|
1111 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml
|
1112 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml
|
1113 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml
|
1114 |
+
fairseq/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml
|
1115 |
+
fairseq/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml
|
1116 |
+
fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_base_librispeech.yaml
|
1117 |
+
fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_large_librivox.yaml
|
1118 |
+
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml
|
1119 |
+
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml
|
1120 |
+
fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml
|
1121 |
+
fairseq/examples/wav2vec/scripts/binarize_manifest.sh
|
1122 |
+
fairseq/examples/wav2vec/unsupervised/README.md
|
1123 |
+
fairseq/examples/wav2vec/unsupervised/__init__.py
|
1124 |
+
fairseq/examples/wav2vec/unsupervised/w2vu_generate.py
|
1125 |
+
fairseq/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml
|
1126 |
+
fairseq/examples/wav2vec/unsupervised/config/gan/w2vu.yaml
|
1127 |
+
fairseq/examples/wav2vec/unsupervised/config/gan/w2vu2.yaml
|
1128 |
+
fairseq/examples/wav2vec/unsupervised/config/generate/viterbi.yaml
|
1129 |
+
fairseq/examples/wav2vec/unsupervised/config/timit_matched/test.uid
|
1130 |
+
fairseq/examples/wav2vec/unsupervised/config/timit_matched/train.uid
|
1131 |
+
fairseq/examples/wav2vec/unsupervised/config/timit_matched/train_text.uid
|
1132 |
+
fairseq/examples/wav2vec/unsupervised/config/timit_matched/valid.uid
|
1133 |
+
fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/test.uid
|
1134 |
+
fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train.uid
|
1135 |
+
fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid
|
1136 |
+
fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid
|
1137 |
+
fairseq/examples/wav2vec/unsupervised/data/__init__.py
|
1138 |
+
fairseq/examples/wav2vec/unsupervised/data/extracted_features_dataset.py
|
1139 |
+
fairseq/examples/wav2vec/unsupervised/data/random_input_dataset.py
|
1140 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/README.md
|
1141 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh
|
1142 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_phone.sh
|
1143 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step1.sh
|
1144 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step2.sh
|
1145 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/path.sh
|
1146 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/train.sh
|
1147 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py
|
1148 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/decode.sh
|
1149 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py
|
1150 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh
|
1151 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh
|
1152 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh
|
1153 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh
|
1154 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh
|
1155 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh
|
1156 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py
|
1157 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh
|
1158 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh
|
1159 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh
|
1160 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh
|
1161 |
+
fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh
|
1162 |
+
fairseq/examples/wav2vec/unsupervised/models/__init__.py
|
1163 |
+
fairseq/examples/wav2vec/unsupervised/models/wav2vec_u.py
|
1164 |
+
fairseq/examples/wav2vec/unsupervised/scripts/apply_pca.py
|
1165 |
+
fairseq/examples/wav2vec/unsupervised/scripts/copy_labels.py
|
1166 |
+
fairseq/examples/wav2vec/unsupervised/scripts/filter_lexicon.py
|
1167 |
+
fairseq/examples/wav2vec/unsupervised/scripts/filter_tsv.py
|
1168 |
+
fairseq/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py
|
1169 |
+
fairseq/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py
|
1170 |
+
fairseq/examples/wav2vec/unsupervised/scripts/mean_pool.py
|
1171 |
+
fairseq/examples/wav2vec/unsupervised/scripts/merge_clusters.py
|
1172 |
+
fairseq/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py
|
1173 |
+
fairseq/examples/wav2vec/unsupervised/scripts/normalize_text.py
|
1174 |
+
fairseq/examples/wav2vec/unsupervised/scripts/pca.py
|
1175 |
+
fairseq/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py
|
1176 |
+
fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio.sh
|
1177 |
+
fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio_v2.sh
|
1178 |
+
fairseq/examples/wav2vec/unsupervised/scripts/prepare_text.sh
|
1179 |
+
fairseq/examples/wav2vec/unsupervised/scripts/prepare_timit.sh
|
1180 |
+
fairseq/examples/wav2vec/unsupervised/scripts/remove_silence.py
|
1181 |
+
fairseq/examples/wav2vec/unsupervised/scripts/vads.py
|
1182 |
+
fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py
|
1183 |
+
fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py
|
1184 |
+
fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py
|
1185 |
+
fairseq/examples/wav2vec/unsupervised/scripts/wer.py
|
1186 |
+
fairseq/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py
|
1187 |
+
fairseq/examples/wav2vec/unsupervised/tasks/__init__.py
|
1188 |
+
fairseq/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py
|
1189 |
+
fairseq/examples/wav2vec/xlsr/README.md
|
1190 |
+
fairseq/examples/wav2vec/xlsr/config/finetune.yaml
|
1191 |
+
fairseq/examples/wav2vec/xlsr/scripts/eval_speaker_clf_task.py
|
1192 |
+
fairseq/examples/wav2vec/xlsr/scripts/gen_audio_embedding.py
|
1193 |
+
fairseq/examples/wmt19/README.md
|
1194 |
+
fairseq/examples/wmt20/README.md
|
1195 |
+
fairseq/examples/wmt21/README.md
|
1196 |
+
fairseq/examples/wmt21/eval.sh
|
1197 |
+
fairseq/examples/wmt21/scripts/normalize-punctuation.perl
|
1198 |
+
fairseq/examples/wmt21/scripts/replace-unicode-punctuation.perl
|
1199 |
+
fairseq/examples/womens_bios/README.md
|
1200 |
+
fairseq/examples/womens_bios/query_occupations_from_wikidata.py
|
1201 |
+
fairseq/examples/xformers/README.md
|
1202 |
+
fairseq/examples/xglm/README.md
|
1203 |
+
fairseq/examples/xglm/XStoryCloze.md
|
1204 |
+
fairseq/examples/xglm/model_card.md
|
1205 |
+
fairseq/examples/xlmr/README.md
|
1206 |
+
fairseq/examples/xmod/README.md
|
1207 |
+
fairseq/examples/xmod/preprocess_nli.py
|
1208 |
+
fairseq/logging/__init__.py
|
1209 |
+
fairseq/logging/meters.py
|
1210 |
+
fairseq/logging/metrics.py
|
1211 |
+
fairseq/logging/progress_bar.py
|
1212 |
+
fairseq/model_parallel/__init__.py
|
1213 |
+
fairseq/model_parallel/megatron_trainer.py
|
1214 |
+
fairseq/model_parallel/criterions/__init__.py
|
1215 |
+
fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py
|
1216 |
+
fairseq/model_parallel/models/__init__.py
|
1217 |
+
fairseq/model_parallel/models/transformer.py
|
1218 |
+
fairseq/model_parallel/models/transformer_lm.py
|
1219 |
+
fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py
|
1220 |
+
fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
|
1221 |
+
fairseq/model_parallel/models/pipeline_parallel_transformer/model.py
|
1222 |
+
fairseq/model_parallel/models/roberta/__init__.py
|
1223 |
+
fairseq/model_parallel/models/roberta/model.py
|
1224 |
+
fairseq/model_parallel/modules/__init__.py
|
1225 |
+
fairseq/model_parallel/modules/multihead_attention.py
|
1226 |
+
fairseq/model_parallel/modules/transformer_layer.py
|
1227 |
+
fairseq/models/__init__.py
|
1228 |
+
fairseq/models/composite_encoder.py
|
1229 |
+
fairseq/models/distributed_fairseq_model.py
|
1230 |
+
fairseq/models/fairseq_decoder.py
|
1231 |
+
fairseq/models/fairseq_encoder.py
|
1232 |
+
fairseq/models/fairseq_incremental_decoder.py
|
1233 |
+
fairseq/models/fairseq_model.py
|
1234 |
+
fairseq/models/fconv.py
|
1235 |
+
fairseq/models/fconv_lm.py
|
1236 |
+
fairseq/models/fconv_self_att.py
|
1237 |
+
fairseq/models/lightconv.py
|
1238 |
+
fairseq/models/lightconv_lm.py
|
1239 |
+
fairseq/models/lstm.py
|
1240 |
+
fairseq/models/lstm_lm.py
|
1241 |
+
fairseq/models/masked_lm.py
|
1242 |
+
fairseq/models/model_utils.py
|
1243 |
+
fairseq/models/multilingual_transformer.py
|
1244 |
+
fairseq/models/transformer_align.py
|
1245 |
+
fairseq/models/transformer_from_pretrained_xlm.py
|
1246 |
+
fairseq/models/transformer_lm.py
|
1247 |
+
fairseq/models/transformer_ulm.py
|
1248 |
+
fairseq/models/bart/__init__.py
|
1249 |
+
fairseq/models/bart/hub_interface.py
|
1250 |
+
fairseq/models/bart/model.py
|
1251 |
+
fairseq/models/ema/__init__.py
|
1252 |
+
fairseq/models/ema/ema.py
|
1253 |
+
fairseq/models/hubert/__init__.py
|
1254 |
+
fairseq/models/hubert/hubert.py
|
1255 |
+
fairseq/models/hubert/hubert_asr.py
|
1256 |
+
fairseq/models/huggingface/__init__.py
|
1257 |
+
fairseq/models/huggingface/hf_gpt2.py
|
1258 |
+
fairseq/models/multires_hubert/__init__.py
|
1259 |
+
fairseq/models/multires_hubert/multires_hubert.py
|
1260 |
+
fairseq/models/multires_hubert/multires_hubert_asr.py
|
1261 |
+
fairseq/models/nat/__init__.py
|
1262 |
+
fairseq/models/nat/cmlm_transformer.py
|
1263 |
+
fairseq/models/nat/fairseq_nat_model.py
|
1264 |
+
fairseq/models/nat/insertion_transformer.py
|
1265 |
+
fairseq/models/nat/iterative_nonautoregressive_transformer.py
|
1266 |
+
fairseq/models/nat/levenshtein_transformer.py
|
1267 |
+
fairseq/models/nat/levenshtein_utils.py
|
1268 |
+
fairseq/models/nat/nat_crf_transformer.py
|
1269 |
+
fairseq/models/nat/nonautoregressive_ensembles.py
|
1270 |
+
fairseq/models/nat/nonautoregressive_transformer.py
|
1271 |
+
fairseq/models/roberta/__init__.py
|
1272 |
+
fairseq/models/roberta/alignment_utils.py
|
1273 |
+
fairseq/models/roberta/enc_dec.py
|
1274 |
+
fairseq/models/roberta/hub_interface.py
|
1275 |
+
fairseq/models/roberta/model.py
|
1276 |
+
fairseq/models/roberta/model_camembert.py
|
1277 |
+
fairseq/models/roberta/model_gottbert.py
|
1278 |
+
fairseq/models/roberta/model_xlmr.py
|
1279 |
+
fairseq/models/speech_dlm/__init__.py
|
1280 |
+
fairseq/models/speech_dlm/hub_interface.py
|
1281 |
+
fairseq/models/speech_dlm/speech_dlm.py
|
1282 |
+
fairseq/models/speech_dlm/modules/__init__.py
|
1283 |
+
fairseq/models/speech_dlm/modules/speech_dlm_decoder.py
|
1284 |
+
fairseq/models/speech_dlm/modules/speech_dlm_decoder_layer.py
|
1285 |
+
fairseq/models/speech_dlm/sequence_generator/__init__.py
|
1286 |
+
fairseq/models/speech_dlm/sequence_generator/multichannel_search.py
|
1287 |
+
fairseq/models/speech_dlm/sequence_generator/multichannel_sequence_generator.py
|
1288 |
+
fairseq/models/speech_to_speech/__init__.py
|
1289 |
+
fairseq/models/speech_to_speech/s2s_conformer.py
|
1290 |
+
fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py
|
1291 |
+
fairseq/models/speech_to_speech/s2s_conformer_unity.py
|
1292 |
+
fairseq/models/speech_to_speech/s2s_transformer.py
|
1293 |
+
fairseq/models/speech_to_speech/modules/__init__.py
|
1294 |
+
fairseq/models/speech_to_speech/modules/ctc_decoder.py
|
1295 |
+
fairseq/models/speech_to_speech/modules/stacked_embedding.py
|
1296 |
+
fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py
|
1297 |
+
fairseq/models/speech_to_speech/modules/transformer_encoder.py
|
1298 |
+
fairseq/models/speech_to_text/__init__.py
|
1299 |
+
fairseq/models/speech_to_text/berard.py
|
1300 |
+
fairseq/models/speech_to_text/convtransformer.py
|
1301 |
+
fairseq/models/speech_to_text/hub_interface.py
|
1302 |
+
fairseq/models/speech_to_text/multi_modality_model.py
|
1303 |
+
fairseq/models/speech_to_text/s2t_conformer.py
|
1304 |
+
fairseq/models/speech_to_text/s2t_transformer.py
|
1305 |
+
fairseq/models/speech_to_text/s2t_wav_transformer.py
|
1306 |
+
fairseq/models/speech_to_text/utils.py
|
1307 |
+
fairseq/models/speech_to_text/xm_transformer.py
|
1308 |
+
fairseq/models/speech_to_text/xm_transformer_unity.py
|
1309 |
+
fairseq/models/speech_to_text/modules/__init__.py
|
1310 |
+
fairseq/models/speech_to_text/modules/augmented_memory_attention.py
|
1311 |
+
fairseq/models/speech_to_text/modules/convolution.py
|
1312 |
+
fairseq/models/speech_to_text/modules/emformer.py
|
1313 |
+
fairseq/models/text_to_speech/__init__.py
|
1314 |
+
fairseq/models/text_to_speech/codehifigan.py
|
1315 |
+
fairseq/models/text_to_speech/fastspeech2.py
|
1316 |
+
fairseq/models/text_to_speech/hifigan.py
|
1317 |
+
fairseq/models/text_to_speech/hub_interface.py
|
1318 |
+
fairseq/models/text_to_speech/tacotron2.py
|
1319 |
+
fairseq/models/text_to_speech/tts_transformer.py
|
1320 |
+
fairseq/models/text_to_speech/vocoder.py
|
1321 |
+
fairseq/models/transformer/__init__.py
|
1322 |
+
fairseq/models/transformer/transformer_base.py
|
1323 |
+
fairseq/models/transformer/transformer_config.py
|
1324 |
+
fairseq/models/transformer/transformer_decoder.py
|
1325 |
+
fairseq/models/transformer/transformer_decoder_aug.py
|
1326 |
+
fairseq/models/transformer/transformer_encoder.py
|
1327 |
+
fairseq/models/transformer/transformer_legacy.py
|
1328 |
+
fairseq/models/wav2vec/__init__.py
|
1329 |
+
fairseq/models/wav2vec/utils.py
|
1330 |
+
fairseq/models/wav2vec/wav2vec.py
|
1331 |
+
fairseq/models/wav2vec/wav2vec2.py
|
1332 |
+
fairseq/models/wav2vec/wav2vec2_asr.py
|
1333 |
+
fairseq/models/wav2vec/wav2vec2_classification.py
|
1334 |
+
fairseq/models/wav2vec/wav2vec2_laser.py
|
1335 |
+
fairseq/models/xmod/__init__.py
|
1336 |
+
fairseq/models/xmod/hub_interface.py
|
1337 |
+
fairseq/models/xmod/model.py
|
1338 |
+
fairseq/models/xmod/transformer_layer_xmod.py
|
1339 |
+
fairseq/modules/__init__.py
|
1340 |
+
fairseq/modules/adaptive_input.py
|
1341 |
+
fairseq/modules/adaptive_softmax.py
|
1342 |
+
fairseq/modules/base_layer.py
|
1343 |
+
fairseq/modules/beamable_mm.py
|
1344 |
+
fairseq/modules/character_token_embedder.py
|
1345 |
+
fairseq/modules/checkpoint_activations.py
|
1346 |
+
fairseq/modules/conformer_layer.py
|
1347 |
+
fairseq/modules/conv_tbc.py
|
1348 |
+
fairseq/modules/cross_entropy.py
|
1349 |
+
fairseq/modules/downsampled_multihead_attention.py
|
1350 |
+
fairseq/modules/dynamic_convolution.py
|
1351 |
+
fairseq/modules/dynamic_crf_layer.py
|
1352 |
+
fairseq/modules/ema_module.py
|
1353 |
+
fairseq/modules/espnet_multihead_attention.py
|
1354 |
+
fairseq/modules/fairseq_dropout.py
|
1355 |
+
fairseq/modules/fp32_batch_norm.py
|
1356 |
+
fairseq/modules/fp32_group_norm.py
|
1357 |
+
fairseq/modules/fp32_instance_norm.py
|
1358 |
+
fairseq/modules/gelu.py
|
1359 |
+
fairseq/modules/grad_multiply.py
|
1360 |
+
fairseq/modules/gumbel_vector_quantizer.py
|
1361 |
+
fairseq/modules/kmeans_attention.py
|
1362 |
+
fairseq/modules/kmeans_vector_quantizer.py
|
1363 |
+
fairseq/modules/layer_drop.py
|
1364 |
+
fairseq/modules/layer_norm.py
|
1365 |
+
fairseq/modules/learned_positional_embedding.py
|
1366 |
+
fairseq/modules/lightweight_convolution.py
|
1367 |
+
fairseq/modules/linearized_convolution.py
|
1368 |
+
fairseq/modules/location_attention.py
|
1369 |
+
fairseq/modules/lstm_cell_with_zoneout.py
|
1370 |
+
fairseq/modules/multihead_attention.py
|
1371 |
+
fairseq/modules/positional_embedding.py
|
1372 |
+
fairseq/modules/positional_encoding.py
|
1373 |
+
fairseq/modules/quant_noise.py
|
1374 |
+
fairseq/modules/rotary_positional_embedding.py
|
1375 |
+
fairseq/modules/same_pad.py
|
1376 |
+
fairseq/modules/scalar_bias.py
|
1377 |
+
fairseq/modules/sinusoidal_positional_embedding.py
|
1378 |
+
fairseq/modules/sparse_multihead_attention.py
|
1379 |
+
fairseq/modules/sparse_transformer_sentence_encoder.py
|
1380 |
+
fairseq/modules/sparse_transformer_sentence_encoder_layer.py
|
1381 |
+
fairseq/modules/transformer_layer.py
|
1382 |
+
fairseq/modules/transformer_layer_aug.py
|
1383 |
+
fairseq/modules/transformer_sentence_encoder.py
|
1384 |
+
fairseq/modules/transformer_sentence_encoder_layer.py
|
1385 |
+
fairseq/modules/transpose_last.py
|
1386 |
+
fairseq/modules/unfold.py
|
1387 |
+
fairseq/modules/vggblock.py
|
1388 |
+
fairseq/modules/dynamicconv_layer/__init__.py
|
1389 |
+
fairseq/modules/dynamicconv_layer/cuda_function_gen.py
|
1390 |
+
fairseq/modules/dynamicconv_layer/dynamicconv_layer.py
|
1391 |
+
fairseq/modules/dynamicconv_layer/setup.py
|
1392 |
+
fairseq/modules/lightconv_layer/__init__.py
|
1393 |
+
fairseq/modules/lightconv_layer/cuda_function_gen.py
|
1394 |
+
fairseq/modules/lightconv_layer/lightconv_layer.py
|
1395 |
+
fairseq/modules/lightconv_layer/setup.py
|
1396 |
+
fairseq/modules/quantization/__init__.py
|
1397 |
+
fairseq/modules/quantization/quantization_options.py
|
1398 |
+
fairseq/modules/quantization/pq/__init__.py
|
1399 |
+
fairseq/modules/quantization/pq/em.py
|
1400 |
+
fairseq/modules/quantization/pq/pq.py
|
1401 |
+
fairseq/modules/quantization/pq/utils.py
|
1402 |
+
fairseq/modules/quantization/pq/modules/__init__.py
|
1403 |
+
fairseq/modules/quantization/pq/modules/qconv.py
|
1404 |
+
fairseq/modules/quantization/pq/modules/qemb.py
|
1405 |
+
fairseq/modules/quantization/pq/modules/qlinear.py
|
1406 |
+
fairseq/modules/quantization/scalar/__init__.py
|
1407 |
+
fairseq/modules/quantization/scalar/ops.py
|
1408 |
+
fairseq/modules/quantization/scalar/utils.py
|
1409 |
+
fairseq/modules/quantization/scalar/modules/__init__.py
|
1410 |
+
fairseq/modules/quantization/scalar/modules/qact.py
|
1411 |
+
fairseq/modules/quantization/scalar/modules/qconv.py
|
1412 |
+
fairseq/modules/quantization/scalar/modules/qemb.py
|
1413 |
+
fairseq/modules/quantization/scalar/modules/qlinear.py
|
1414 |
+
fairseq/optim/__init__.py
|
1415 |
+
fairseq/optim/adadelta.py
|
1416 |
+
fairseq/optim/adafactor.py
|
1417 |
+
fairseq/optim/adagrad.py
|
1418 |
+
fairseq/optim/adam.py
|
1419 |
+
fairseq/optim/adamax.py
|
1420 |
+
fairseq/optim/amp_optimizer.py
|
1421 |
+
fairseq/optim/bmuf.py
|
1422 |
+
fairseq/optim/composite.py
|
1423 |
+
fairseq/optim/cpu_adam.py
|
1424 |
+
fairseq/optim/dynamic_loss_scaler.py
|
1425 |
+
fairseq/optim/fairseq_optimizer.py
|
1426 |
+
fairseq/optim/fp16_optimizer.py
|
1427 |
+
fairseq/optim/fused_adam.py
|
1428 |
+
fairseq/optim/fused_lamb.py
|
1429 |
+
fairseq/optim/nag.py
|
1430 |
+
fairseq/optim/sgd.py
|
1431 |
+
fairseq/optim/shard.py
|
1432 |
+
fairseq/optim/lr_scheduler/__init__.py
|
1433 |
+
fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
|
1434 |
+
fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py
|
1435 |
+
fairseq/optim/lr_scheduler/fixed_schedule.py
|
1436 |
+
fairseq/optim/lr_scheduler/inverse_square_root_schedule.py
|
1437 |
+
fairseq/optim/lr_scheduler/manual_lr_scheduler.py
|
1438 |
+
fairseq/optim/lr_scheduler/pass_through.py
|
1439 |
+
fairseq/optim/lr_scheduler/polynomial_decay_schedule.py
|
1440 |
+
fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
|
1441 |
+
fairseq/optim/lr_scheduler/step_lr_scheduler.py
|
1442 |
+
fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py
|
1443 |
+
fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
|
1444 |
+
fairseq/scoring/__init__.py
|
1445 |
+
fairseq/scoring/bertscore.py
|
1446 |
+
fairseq/scoring/bleu.py
|
1447 |
+
fairseq/scoring/chrf.py
|
1448 |
+
fairseq/scoring/meteor.py
|
1449 |
+
fairseq/scoring/tokenizer.py
|
1450 |
+
fairseq/scoring/wer.py
|
1451 |
+
fairseq/tasks/__init__.py
|
1452 |
+
fairseq/tasks/audio_classification.py
|
1453 |
+
fairseq/tasks/audio_finetuning.py
|
1454 |
+
fairseq/tasks/audio_pretraining.py
|
1455 |
+
fairseq/tasks/cross_lingual_lm.py
|
1456 |
+
fairseq/tasks/denoising.py
|
1457 |
+
fairseq/tasks/fairseq_task.py
|
1458 |
+
fairseq/tasks/frm_text_to_speech.py
|
1459 |
+
fairseq/tasks/hubert_pretraining.py
|
1460 |
+
fairseq/tasks/language_modeling.py
|
1461 |
+
fairseq/tasks/legacy_masked_lm.py
|
1462 |
+
fairseq/tasks/masked_lm.py
|
1463 |
+
fairseq/tasks/multilingual_denoising.py
|
1464 |
+
fairseq/tasks/multilingual_language_modeling.py
|
1465 |
+
fairseq/tasks/multilingual_masked_lm.py
|
1466 |
+
fairseq/tasks/multilingual_translation.py
|
1467 |
+
fairseq/tasks/multires_hubert_pretraining.py
|
1468 |
+
fairseq/tasks/nlu_finetuning.py
|
1469 |
+
fairseq/tasks/online_backtranslation.py
|
1470 |
+
fairseq/tasks/semisupervised_translation.py
|
1471 |
+
fairseq/tasks/sentence_prediction.py
|
1472 |
+
fairseq/tasks/sentence_prediction_adapters.py
|
1473 |
+
fairseq/tasks/sentence_ranking.py
|
1474 |
+
fairseq/tasks/simultaneous_translation.py
|
1475 |
+
fairseq/tasks/span_masked_lm.py
|
1476 |
+
fairseq/tasks/speech_dlm_task.py
|
1477 |
+
fairseq/tasks/speech_to_speech.py
|
1478 |
+
fairseq/tasks/speech_to_text.py
|
1479 |
+
fairseq/tasks/speech_ulm_task.py
|
1480 |
+
fairseq/tasks/text_to_speech.py
|
1481 |
+
fairseq/tasks/translation.py
|
1482 |
+
fairseq/tasks/translation_from_pretrained_bart.py
|
1483 |
+
fairseq/tasks/translation_from_pretrained_xlm.py
|
1484 |
+
fairseq/tasks/translation_lev.py
|
1485 |
+
fairseq/tasks/translation_multi_simple_epoch.py
|
1486 |
+
fairseq_cli/__init__.py
|
1487 |
+
fairseq_cli/eval_lm.py
|
1488 |
+
fairseq_cli/generate.py
|
1489 |
+
fairseq_cli/hydra_train.py
|
1490 |
+
fairseq_cli/hydra_validate.py
|
1491 |
+
fairseq_cli/interactive.py
|
1492 |
+
fairseq_cli/preprocess.py
|
1493 |
+
fairseq_cli/score.py
|
1494 |
+
fairseq_cli/train.py
|
1495 |
+
fairseq_cli/validate.py
|
1496 |
+
tests/test_activation_checkpointing.py
|
1497 |
+
tests/test_amp_optimizer.py
|
1498 |
+
tests/test_average_checkpoints.py
|
1499 |
+
tests/test_backtranslation_dataset.py
|
1500 |
+
tests/test_binaries.py
|
1501 |
+
tests/test_binarizer.py
|
1502 |
+
tests/test_character_token_embedder.py
|
1503 |
+
tests/test_checkpoint_utils.py
|
1504 |
+
tests/test_checkpoint_utils_for_task_level_attributes.py
|
1505 |
+
tests/test_concat_dataset.py
|
1506 |
+
tests/test_constraints.py
|
1507 |
+
tests/test_convtbc.py
|
1508 |
+
tests/test_data_utils.py
|
1509 |
+
tests/test_dataclass_utils.py
|
1510 |
+
tests/test_dataset.py
|
1511 |
+
tests/test_dictionary.py
|
1512 |
+
tests/test_ema.py
|
1513 |
+
tests/test_espnet_multihead_attention.py
|
1514 |
+
tests/test_export.py
|
1515 |
+
tests/test_file_chunker_utils.py
|
1516 |
+
tests/test_file_io.py
|
1517 |
+
tests/test_fp16_optimizer.py
|
1518 |
+
tests/test_hf_hub.py
|
1519 |
+
tests/test_huffman.py
|
1520 |
+
tests/test_inference_dropout.py
|
1521 |
+
tests/test_iopath.py
|
1522 |
+
tests/test_iterators.py
|
1523 |
+
tests/test_label_smoothing.py
|
1524 |
+
tests/test_lm_context_window.py
|
1525 |
+
tests/test_lstm_jitable.py
|
1526 |
+
tests/test_memory_efficient_fp16.py
|
1527 |
+
tests/test_metrics.py
|
1528 |
+
tests/test_multi_corpus_dataset.py
|
1529 |
+
tests/test_multi_corpus_sampled_dataset.py
|
1530 |
+
tests/test_multihead_attention.py
|
1531 |
+
tests/test_noising.py
|
1532 |
+
tests/test_online_backtranslation.py
|
1533 |
+
tests/test_plasma_utils.py
|
1534 |
+
tests/test_positional_encoding.py
|
1535 |
+
tests/test_reproducibility.py
|
1536 |
+
tests/test_resampling_dataset.py
|
1537 |
+
tests/test_roberta.py
|
1538 |
+
tests/test_rotary_positional_embedding.py
|
1539 |
+
tests/test_sequence_generator.py
|
1540 |
+
tests/test_sequence_scorer.py
|
1541 |
+
tests/test_sparse_multihead_attention.py
|
1542 |
+
tests/test_token_block_dataset.py
|
1543 |
+
tests/test_train.py
|
1544 |
+
tests/test_transformer.py
|
1545 |
+
tests/test_utils.py
|
1546 |
+
tests/test_valid_subset_checks.py
|
fairseq/fairseq.egg-info/entry_points.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[console_scripts]
|
2 |
+
fairseq-eval-lm = fairseq_cli.eval_lm:cli_main
|
3 |
+
fairseq-generate = fairseq_cli.generate:cli_main
|
4 |
+
fairseq-hydra-train = fairseq_cli.hydra_train:cli_main
|
5 |
+
fairseq-interactive = fairseq_cli.interactive:cli_main
|
6 |
+
fairseq-preprocess = fairseq_cli.preprocess:cli_main
|
7 |
+
fairseq-score = fairseq_cli.score:cli_main
|
8 |
+
fairseq-train = fairseq_cli.train:cli_main
|
9 |
+
fairseq-validate = fairseq_cli.validate:cli_main
|
fairseq/fairseq.egg-info/requires.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cffi
|
2 |
+
cython
|
3 |
+
hydra-core<1.1,>=1.0.7
|
4 |
+
omegaconf<2.1
|
5 |
+
numpy>=1.21.3
|
6 |
+
regex
|
7 |
+
sacrebleu>=1.4.12
|
8 |
+
torch>=1.13
|
9 |
+
tqdm
|
10 |
+
bitarray
|
11 |
+
torchaudio>=0.8.0
|
12 |
+
scikit-learn
|
13 |
+
packaging
|
14 |
+
|
15 |
+
[dev]
|
16 |
+
flake8
|
17 |
+
pytest
|
18 |
+
black==22.3.0
|
19 |
+
|
20 |
+
[docs]
|
21 |
+
sphinx
|
22 |
+
sphinx-argparse
|
fairseq/fairseq.egg-info/top_level.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alignment_train_cpu_binding
|
2 |
+
alignment_train_cuda_binding
|
3 |
+
fairseq
|
4 |
+
fairseq_cli
|
fairseq/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc
ADDED
Binary file (2.27 kB). View file
|
|
fairseq/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc
ADDED
Binary file (8.77 kB). View file
|
|
fairseq/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc
ADDED
Binary file (3.84 kB). View file
|
|
fairseq/fairseq/__pycache__/pdb.cpython-310.pyc
ADDED
Binary file (1.37 kB). View file
|
|
fairseq/fairseq_cli/__init__.py
ADDED
File without changes
|
fairseq/fairseq_cli/eval_lm.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Evaluate the perplexity of a trained language model.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
from argparse import Namespace
|
16 |
+
from typing import Iterable, List, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from omegaconf import DictConfig
|
20 |
+
|
21 |
+
import fairseq
|
22 |
+
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
|
23 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
24 |
+
from fairseq.logging import progress_bar
|
25 |
+
from fairseq.logging.meters import StopwatchMeter
|
26 |
+
from fairseq.sequence_scorer import SequenceScorer
|
27 |
+
|
28 |
+
logging.basicConfig(
|
29 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
30 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
31 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
32 |
+
stream=sys.stdout,
|
33 |
+
)
|
34 |
+
logger = logging.getLogger("fairseq_cli.eval_lm")
|
35 |
+
|
36 |
+
|
37 |
+
def eval_lm(
|
38 |
+
models: List[fairseq.models.FairseqModel],
|
39 |
+
source_dictionary: fairseq.data.Dictionary,
|
40 |
+
batch_iterator: Iterable,
|
41 |
+
post_process: Optional[str] = None,
|
42 |
+
output_word_probs: bool = False,
|
43 |
+
output_word_stats: bool = False,
|
44 |
+
target_dictionary: Optional[fairseq.data.Dictionary] = None,
|
45 |
+
softmax_batch: int = 0,
|
46 |
+
remove_bos_token: bool = False,
|
47 |
+
device: Optional[torch.device] = None,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
models (List[~fairseq.models.FairseqModel]): list of models to
|
52 |
+
evaluate. Models are essentially `nn.Module` instances, but
|
53 |
+
must be compatible with fairseq's `SequenceScorer`.
|
54 |
+
source_dictionary (~fairseq.data.Dictionary): dictionary for
|
55 |
+
applying any relevant post processing or outputing word
|
56 |
+
probs/stats.
|
57 |
+
batch_iterator (Iterable): yield batches of data
|
58 |
+
post_process (Optional[str]): post-process text by removing BPE,
|
59 |
+
letter segmentation, etc. Valid options can be found in
|
60 |
+
fairseq.data.utils.post_process, although not all options
|
61 |
+
are implemented here.
|
62 |
+
output_word_probs (Optional[bool]): output words and their
|
63 |
+
predicted log probabilities
|
64 |
+
output_word_stats (Optional[bool]): output word statistics such
|
65 |
+
as word count and average probability
|
66 |
+
target_dictionary (Optional[~fairseq.data.Dictionary]): output
|
67 |
+
dictionary (defaults to *source_dictionary*)
|
68 |
+
softmax_batch (Optional[bool]): if BxT is more than this, will
|
69 |
+
batch the softmax over vocab to this amount of tokens, in
|
70 |
+
order to fit into GPU memory
|
71 |
+
remove_bos_token (Optional[bool]): if True, confirm that the
|
72 |
+
first token is the beginning-of-sentence symbol (according
|
73 |
+
to the relevant dictionary) and remove it from the output
|
74 |
+
device (Optional[torch.device]): device to use for evaluation
|
75 |
+
(defaults to device of first model parameter)
|
76 |
+
"""
|
77 |
+
if target_dictionary is None:
|
78 |
+
target_dictionary = source_dictionary
|
79 |
+
if device is None:
|
80 |
+
device = next(models[0].parameters()).device
|
81 |
+
|
82 |
+
gen_timer = StopwatchMeter()
|
83 |
+
scorer = SequenceScorer(target_dictionary, softmax_batch)
|
84 |
+
|
85 |
+
score_sum = 0.0
|
86 |
+
count = 0
|
87 |
+
|
88 |
+
if post_process is not None:
|
89 |
+
if post_process in {"subword_nmt", "@@ "}:
|
90 |
+
bpe_cont = post_process.rstrip()
|
91 |
+
bpe_toks = {
|
92 |
+
i
|
93 |
+
for i in range(len(source_dictionary))
|
94 |
+
if source_dictionary[i].endswith(bpe_cont)
|
95 |
+
}
|
96 |
+
else:
|
97 |
+
raise NotImplementedError(
|
98 |
+
f"--post-process={post_process} is not implemented"
|
99 |
+
)
|
100 |
+
bpe_len = len(bpe_cont)
|
101 |
+
else:
|
102 |
+
bpe_toks = None
|
103 |
+
bpe_len = 0
|
104 |
+
|
105 |
+
word_stats = dict()
|
106 |
+
|
107 |
+
for sample in batch_iterator:
|
108 |
+
if "net_input" not in sample:
|
109 |
+
continue
|
110 |
+
|
111 |
+
sample = utils.move_to_cuda(sample, device=device)
|
112 |
+
|
113 |
+
gen_timer.start()
|
114 |
+
hypos = scorer.generate(models, sample)
|
115 |
+
gen_timer.stop(sample["ntokens"])
|
116 |
+
|
117 |
+
for i, hypos_i in enumerate(hypos):
|
118 |
+
hypo = hypos_i[0]
|
119 |
+
sample_id = sample["id"][i]
|
120 |
+
|
121 |
+
tokens = hypo["tokens"]
|
122 |
+
tgt_len = tokens.numel()
|
123 |
+
pos_scores = hypo["positional_scores"].float()
|
124 |
+
|
125 |
+
if remove_bos_token:
|
126 |
+
assert hypo["tokens"][0].item() == target_dictionary.bos()
|
127 |
+
tokens = tokens[1:]
|
128 |
+
pos_scores = pos_scores[1:]
|
129 |
+
|
130 |
+
skipped_toks = 0
|
131 |
+
if bpe_toks is not None:
|
132 |
+
for i in range(tgt_len - 1):
|
133 |
+
if tokens[i].item() in bpe_toks:
|
134 |
+
skipped_toks += 1
|
135 |
+
pos_scores[i + 1] += pos_scores[i]
|
136 |
+
pos_scores[i] = 0
|
137 |
+
|
138 |
+
inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf"))
|
139 |
+
if inf_scores.any():
|
140 |
+
logger.info(
|
141 |
+
"skipping tokens with inf scores:",
|
142 |
+
target_dictionary.string(tokens[inf_scores.nonzero()]),
|
143 |
+
)
|
144 |
+
pos_scores = pos_scores[(~inf_scores).nonzero()]
|
145 |
+
score_sum += pos_scores.sum().cpu()
|
146 |
+
count += pos_scores.numel() - skipped_toks
|
147 |
+
|
148 |
+
if output_word_probs or output_word_stats:
|
149 |
+
w = ""
|
150 |
+
word_prob = []
|
151 |
+
is_bpe = False
|
152 |
+
for i in range(len(tokens)):
|
153 |
+
w_ind = tokens[i].item()
|
154 |
+
w += source_dictionary[w_ind]
|
155 |
+
if bpe_toks is not None and w_ind in bpe_toks:
|
156 |
+
w = w[:-bpe_len]
|
157 |
+
is_bpe = True
|
158 |
+
else:
|
159 |
+
word_prob.append((w, pos_scores[i].item()))
|
160 |
+
|
161 |
+
next_prob = None
|
162 |
+
ind = i + 1
|
163 |
+
while ind < len(tokens):
|
164 |
+
if pos_scores[ind].item() != 0:
|
165 |
+
next_prob = pos_scores[ind]
|
166 |
+
break
|
167 |
+
ind += 1
|
168 |
+
|
169 |
+
word_stats.setdefault(w, WordStat(w, is_bpe)).add(
|
170 |
+
pos_scores[i].item(), next_prob
|
171 |
+
)
|
172 |
+
is_bpe = False
|
173 |
+
w = ""
|
174 |
+
if output_word_probs:
|
175 |
+
logger.info(
|
176 |
+
str(int(sample_id))
|
177 |
+
+ " "
|
178 |
+
+ (
|
179 |
+
"\t".join(
|
180 |
+
"{} [{:2f}]".format(x[0], x[1]) for x in word_prob
|
181 |
+
)
|
182 |
+
)
|
183 |
+
)
|
184 |
+
|
185 |
+
avg_nll_loss = (
|
186 |
+
-score_sum / count / math.log(2) if count > 0 else 0
|
187 |
+
) # convert to base 2
|
188 |
+
logger.info(
|
189 |
+
"Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format(
|
190 |
+
gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0
|
191 |
+
)
|
192 |
+
)
|
193 |
+
|
194 |
+
if output_word_stats:
|
195 |
+
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
|
196 |
+
logger.info(ws)
|
197 |
+
|
198 |
+
return {
|
199 |
+
"loss": avg_nll_loss,
|
200 |
+
"perplexity": 2**avg_nll_loss,
|
201 |
+
}
|
202 |
+
|
203 |
+
|
204 |
+
class WordStat(object):
|
205 |
+
def __init__(self, word, is_bpe):
|
206 |
+
self.word = word
|
207 |
+
self.is_bpe = is_bpe
|
208 |
+
self.log_prob = 0
|
209 |
+
self.next_word_prob = 0
|
210 |
+
self.count = 0
|
211 |
+
self.missing_next_words = 0
|
212 |
+
|
213 |
+
def add(self, log_prob, next_word_prob):
|
214 |
+
"""increments counters for the sum of log probs of current word and next
|
215 |
+
word (given context ending at current word). Since the next word might be at the end of the example,
|
216 |
+
or it might be not counted because it is not an ending subword unit,
|
217 |
+
also keeps track of how many of those we have seen"""
|
218 |
+
if next_word_prob is not None:
|
219 |
+
self.next_word_prob += next_word_prob
|
220 |
+
else:
|
221 |
+
self.missing_next_words += 1
|
222 |
+
self.log_prob += log_prob
|
223 |
+
self.count += 1
|
224 |
+
|
225 |
+
def __str__(self):
|
226 |
+
return "{}\t{}\t{}\t{}\t{}\t{}".format(
|
227 |
+
self.word,
|
228 |
+
self.count,
|
229 |
+
self.log_prob,
|
230 |
+
self.is_bpe,
|
231 |
+
self.next_word_prob,
|
232 |
+
self.count - self.missing_next_words,
|
233 |
+
)
|
234 |
+
|
235 |
+
|
236 |
+
def main(cfg: DictConfig, **unused_kwargs):
|
237 |
+
if isinstance(cfg, Namespace):
|
238 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
239 |
+
|
240 |
+
utils.import_user_module(cfg.common)
|
241 |
+
|
242 |
+
logger.info(cfg)
|
243 |
+
|
244 |
+
if cfg.eval_lm.context_window > 0:
|
245 |
+
# reduce tokens per sample by the required context window size
|
246 |
+
cfg.task.tokens_per_sample -= cfg.eval_lm.context_window
|
247 |
+
|
248 |
+
# Initialize the task using the current *cfg*
|
249 |
+
task = tasks.setup_task(cfg.task)
|
250 |
+
|
251 |
+
# Load ensemble
|
252 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
253 |
+
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
|
254 |
+
[cfg.common_eval.path],
|
255 |
+
arg_overrides=eval(cfg.common_eval.model_overrides),
|
256 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
257 |
+
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
258 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
259 |
+
task=task,
|
260 |
+
)
|
261 |
+
|
262 |
+
use_fp16 = cfg.common.fp16
|
263 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
264 |
+
if use_cuda:
|
265 |
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
266 |
+
|
267 |
+
# Optimize ensemble for generation and set the source and dest dicts on the model
|
268 |
+
# (required by scorer)
|
269 |
+
for model in models:
|
270 |
+
if use_fp16:
|
271 |
+
model.half()
|
272 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
273 |
+
model.cuda()
|
274 |
+
model.prepare_for_inference_(cfg)
|
275 |
+
|
276 |
+
assert len(models) > 0
|
277 |
+
|
278 |
+
logger.info(
|
279 |
+
"num. model params: {:,}".format(sum(p.numel() for p in models[0].parameters()))
|
280 |
+
)
|
281 |
+
|
282 |
+
# Load dataset splits
|
283 |
+
task.load_dataset(cfg.dataset.gen_subset)
|
284 |
+
dataset = task.dataset(cfg.dataset.gen_subset)
|
285 |
+
logger.info(
|
286 |
+
"{} {} {:,} examples".format(
|
287 |
+
cfg.task.data, cfg.dataset.gen_subset, len(dataset)
|
288 |
+
)
|
289 |
+
)
|
290 |
+
|
291 |
+
itr = task.eval_lm_dataloader(
|
292 |
+
dataset=dataset,
|
293 |
+
max_tokens=cfg.dataset.max_tokens or 36000,
|
294 |
+
batch_size=cfg.dataset.batch_size,
|
295 |
+
max_positions=utils.resolve_max_positions(
|
296 |
+
*[model.max_positions() for model in models]
|
297 |
+
),
|
298 |
+
num_shards=max(
|
299 |
+
cfg.dataset.num_shards,
|
300 |
+
cfg.distributed_training.distributed_world_size,
|
301 |
+
),
|
302 |
+
shard_id=max(
|
303 |
+
cfg.dataset.shard_id,
|
304 |
+
cfg.distributed_training.distributed_rank,
|
305 |
+
),
|
306 |
+
num_workers=cfg.dataset.num_workers,
|
307 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
308 |
+
context_window=cfg.eval_lm.context_window,
|
309 |
+
)
|
310 |
+
|
311 |
+
itr = progress_bar.progress_bar(
|
312 |
+
itr,
|
313 |
+
log_format=cfg.common.log_format,
|
314 |
+
log_interval=cfg.common.log_interval,
|
315 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
316 |
+
)
|
317 |
+
|
318 |
+
results = eval_lm(
|
319 |
+
models=models,
|
320 |
+
source_dictionary=task.source_dictionary,
|
321 |
+
batch_iterator=itr,
|
322 |
+
post_process=cfg.common_eval.post_process,
|
323 |
+
output_word_probs=cfg.eval_lm.output_word_probs,
|
324 |
+
output_word_stats=cfg.eval_lm.output_word_stats,
|
325 |
+
target_dictionary=task.target_dictionary,
|
326 |
+
softmax_batch=cfg.eval_lm.softmax_batch,
|
327 |
+
remove_bos_token=getattr(cfg.task, "add_bos_token", False),
|
328 |
+
)
|
329 |
+
|
330 |
+
logger.info(
|
331 |
+
"Loss (base 2): {:.4f}, Perplexity: {:.2f}".format(
|
332 |
+
results["loss"], results["perplexity"]
|
333 |
+
)
|
334 |
+
)
|
335 |
+
|
336 |
+
return results
|
337 |
+
|
338 |
+
|
339 |
+
def cli_main():
|
340 |
+
parser = options.get_eval_lm_parser()
|
341 |
+
args = options.parse_args_and_arch(parser)
|
342 |
+
|
343 |
+
distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
|
344 |
+
|
345 |
+
|
346 |
+
if __name__ == "__main__":
|
347 |
+
cli_main()
|
fairseq/fairseq_cli/generate.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Translate pre-processed data with a trained model.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import ast
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
from argparse import Namespace
|
16 |
+
from itertools import chain
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
|
22 |
+
from fairseq import checkpoint_utils, options, scoring, tasks, utils
|
23 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
24 |
+
from fairseq.logging import progress_bar
|
25 |
+
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
26 |
+
|
27 |
+
|
28 |
+
def main(cfg: DictConfig):
|
29 |
+
|
30 |
+
if isinstance(cfg, Namespace):
|
31 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
32 |
+
|
33 |
+
assert cfg.common_eval.path is not None, "--path required for generation!"
|
34 |
+
assert (
|
35 |
+
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
|
36 |
+
), "--sampling requires --nbest to be equal to --beam"
|
37 |
+
assert (
|
38 |
+
cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
|
39 |
+
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
|
40 |
+
|
41 |
+
if cfg.common_eval.results_path is not None:
|
42 |
+
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
|
43 |
+
output_path = os.path.join(
|
44 |
+
cfg.common_eval.results_path,
|
45 |
+
"generate-{}.txt".format(cfg.dataset.gen_subset),
|
46 |
+
)
|
47 |
+
with open(output_path, "w", buffering=1, encoding="utf-8") as h:
|
48 |
+
return _main(cfg, h)
|
49 |
+
else:
|
50 |
+
return _main(cfg, sys.stdout)
|
51 |
+
|
52 |
+
|
53 |
+
def get_symbols_to_strip_from_output(generator):
|
54 |
+
if hasattr(generator, "symbols_to_strip_from_output"):
|
55 |
+
return generator.symbols_to_strip_from_output
|
56 |
+
else:
|
57 |
+
return {generator.eos}
|
58 |
+
|
59 |
+
|
60 |
+
def _main(cfg: DictConfig, output_file):
|
61 |
+
logging.basicConfig(
|
62 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
63 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
64 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
65 |
+
stream=output_file,
|
66 |
+
)
|
67 |
+
logger = logging.getLogger("fairseq_cli.generate")
|
68 |
+
|
69 |
+
utils.import_user_module(cfg.common)
|
70 |
+
|
71 |
+
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
|
72 |
+
cfg.dataset.max_tokens = 12000
|
73 |
+
logger.info(cfg)
|
74 |
+
|
75 |
+
# Fix seed for stochastic decoding
|
76 |
+
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
77 |
+
np.random.seed(cfg.common.seed)
|
78 |
+
utils.set_torch_seed(cfg.common.seed)
|
79 |
+
|
80 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
81 |
+
|
82 |
+
# Load dataset splits
|
83 |
+
task = tasks.setup_task(cfg.task)
|
84 |
+
|
85 |
+
# Set dictionaries
|
86 |
+
try:
|
87 |
+
src_dict = getattr(task, "source_dictionary", None)
|
88 |
+
except NotImplementedError:
|
89 |
+
src_dict = None
|
90 |
+
tgt_dict = task.target_dictionary
|
91 |
+
|
92 |
+
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
|
93 |
+
|
94 |
+
# Load ensemble
|
95 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
96 |
+
models, saved_cfg = checkpoint_utils.load_model_ensemble(
|
97 |
+
utils.split_paths(cfg.common_eval.path),
|
98 |
+
arg_overrides=overrides,
|
99 |
+
task=task,
|
100 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
101 |
+
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
102 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
103 |
+
)
|
104 |
+
|
105 |
+
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
106 |
+
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
107 |
+
|
108 |
+
if cfg.generation.lm_path is not None:
|
109 |
+
overrides["data"] = cfg.task.data
|
110 |
+
|
111 |
+
try:
|
112 |
+
lms, _ = checkpoint_utils.load_model_ensemble(
|
113 |
+
[cfg.generation.lm_path], arg_overrides=overrides, task=None
|
114 |
+
)
|
115 |
+
except:
|
116 |
+
logger.warning(
|
117 |
+
f"Failed to load language model! Please make sure that the language model dict is the same "
|
118 |
+
f"as target dict and is located in the data dir ({cfg.task.data})"
|
119 |
+
)
|
120 |
+
raise
|
121 |
+
|
122 |
+
assert len(lms) == 1
|
123 |
+
else:
|
124 |
+
lms = [None]
|
125 |
+
|
126 |
+
# Optimize ensemble for generation
|
127 |
+
for model in chain(models, lms):
|
128 |
+
if model is None:
|
129 |
+
continue
|
130 |
+
if cfg.common.fp16:
|
131 |
+
model.half()
|
132 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
133 |
+
model.cuda()
|
134 |
+
model.prepare_for_inference_(cfg)
|
135 |
+
|
136 |
+
# Load alignment dictionary for unknown word replacement
|
137 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
138 |
+
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
|
139 |
+
|
140 |
+
# Load dataset (possibly sharded)
|
141 |
+
itr = task.get_batch_iterator(
|
142 |
+
dataset=task.dataset(cfg.dataset.gen_subset),
|
143 |
+
max_tokens=cfg.dataset.max_tokens,
|
144 |
+
max_sentences=cfg.dataset.batch_size,
|
145 |
+
max_positions=utils.resolve_max_positions(
|
146 |
+
task.max_positions(), *[m.max_positions() for m in models]
|
147 |
+
),
|
148 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
149 |
+
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
150 |
+
seed=cfg.common.seed,
|
151 |
+
num_shards=cfg.distributed_training.distributed_world_size,
|
152 |
+
shard_id=cfg.distributed_training.distributed_rank,
|
153 |
+
num_workers=cfg.dataset.num_workers,
|
154 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
155 |
+
).next_epoch_itr(shuffle=False)
|
156 |
+
progress = progress_bar.progress_bar(
|
157 |
+
itr,
|
158 |
+
log_format=cfg.common.log_format,
|
159 |
+
log_interval=cfg.common.log_interval,
|
160 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
161 |
+
)
|
162 |
+
|
163 |
+
# Initialize generator
|
164 |
+
gen_timer = StopwatchMeter()
|
165 |
+
|
166 |
+
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
|
167 |
+
generator = task.build_generator(
|
168 |
+
models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
169 |
+
)
|
170 |
+
|
171 |
+
# Handle tokenization and BPE
|
172 |
+
tokenizer = task.build_tokenizer(cfg.tokenizer)
|
173 |
+
bpe = task.build_bpe(cfg.bpe)
|
174 |
+
|
175 |
+
def decode_fn(x):
|
176 |
+
if bpe is not None:
|
177 |
+
x = bpe.decode(x)
|
178 |
+
if tokenizer is not None:
|
179 |
+
x = tokenizer.decode(x)
|
180 |
+
return x
|
181 |
+
|
182 |
+
scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
|
183 |
+
|
184 |
+
num_sentences = 0
|
185 |
+
has_target = True
|
186 |
+
wps_meter = TimeMeter()
|
187 |
+
for sample in progress:
|
188 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
189 |
+
if "net_input" not in sample:
|
190 |
+
continue
|
191 |
+
|
192 |
+
prefix_tokens = None
|
193 |
+
if cfg.generation.prefix_size > 0:
|
194 |
+
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
|
195 |
+
|
196 |
+
constraints = None
|
197 |
+
if "constraints" in sample:
|
198 |
+
constraints = sample["constraints"]
|
199 |
+
|
200 |
+
gen_timer.start()
|
201 |
+
hypos = task.inference_step(
|
202 |
+
generator,
|
203 |
+
models,
|
204 |
+
sample,
|
205 |
+
prefix_tokens=prefix_tokens,
|
206 |
+
constraints=constraints,
|
207 |
+
)
|
208 |
+
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
|
209 |
+
gen_timer.stop(num_generated_tokens)
|
210 |
+
|
211 |
+
for i, sample_id in enumerate(sample["id"].tolist()):
|
212 |
+
has_target = sample["target"] is not None
|
213 |
+
|
214 |
+
# Remove padding
|
215 |
+
if "src_tokens" in sample["net_input"]:
|
216 |
+
src_tokens = utils.strip_pad(
|
217 |
+
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
src_tokens = None
|
221 |
+
|
222 |
+
target_tokens = None
|
223 |
+
if has_target:
|
224 |
+
target_tokens = (
|
225 |
+
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
|
226 |
+
)
|
227 |
+
|
228 |
+
# Either retrieve the original sentences or regenerate them from tokens.
|
229 |
+
if align_dict is not None:
|
230 |
+
src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
|
231 |
+
sample_id
|
232 |
+
)
|
233 |
+
target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
|
234 |
+
sample_id
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
if src_dict is not None:
|
238 |
+
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
|
239 |
+
else:
|
240 |
+
src_str = ""
|
241 |
+
if has_target:
|
242 |
+
target_str = tgt_dict.string(
|
243 |
+
target_tokens,
|
244 |
+
cfg.common_eval.post_process,
|
245 |
+
escape_unk=True,
|
246 |
+
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
|
247 |
+
generator
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
src_str = decode_fn(src_str)
|
252 |
+
if has_target:
|
253 |
+
target_str = decode_fn(target_str)
|
254 |
+
|
255 |
+
if not cfg.common_eval.quiet:
|
256 |
+
if src_dict is not None:
|
257 |
+
print("S-{}\t{}".format(sample_id, src_str), file=output_file)
|
258 |
+
if has_target:
|
259 |
+
print("T-{}\t{}".format(sample_id, target_str), file=output_file)
|
260 |
+
|
261 |
+
# Process top predictions
|
262 |
+
for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
|
263 |
+
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
264 |
+
hypo_tokens=hypo["tokens"].int().cpu(),
|
265 |
+
src_str=src_str,
|
266 |
+
alignment=hypo["alignment"],
|
267 |
+
align_dict=align_dict,
|
268 |
+
tgt_dict=tgt_dict,
|
269 |
+
remove_bpe=cfg.common_eval.post_process,
|
270 |
+
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
|
271 |
+
)
|
272 |
+
detok_hypo_str = decode_fn(hypo_str)
|
273 |
+
if not cfg.common_eval.quiet:
|
274 |
+
score = hypo["score"] / math.log(2) # convert to base 2
|
275 |
+
# original hypothesis (after tokenization and BPE)
|
276 |
+
print(
|
277 |
+
"H-{}\t{}\t{}".format(sample_id, score, hypo_str),
|
278 |
+
file=output_file,
|
279 |
+
)
|
280 |
+
# detokenized hypothesis
|
281 |
+
print(
|
282 |
+
"D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str),
|
283 |
+
file=output_file,
|
284 |
+
)
|
285 |
+
print(
|
286 |
+
"P-{}\t{}".format(
|
287 |
+
sample_id,
|
288 |
+
" ".join(
|
289 |
+
map(
|
290 |
+
lambda x: "{:.4f}".format(x),
|
291 |
+
# convert from base e to base 2
|
292 |
+
hypo["positional_scores"]
|
293 |
+
.div_(math.log(2))
|
294 |
+
.tolist(),
|
295 |
+
)
|
296 |
+
),
|
297 |
+
),
|
298 |
+
file=output_file,
|
299 |
+
)
|
300 |
+
|
301 |
+
if cfg.generation.print_alignment == "hard":
|
302 |
+
print(
|
303 |
+
"A-{}\t{}".format(
|
304 |
+
sample_id,
|
305 |
+
" ".join(
|
306 |
+
[
|
307 |
+
"{}-{}".format(src_idx, tgt_idx)
|
308 |
+
for src_idx, tgt_idx in alignment
|
309 |
+
]
|
310 |
+
),
|
311 |
+
),
|
312 |
+
file=output_file,
|
313 |
+
)
|
314 |
+
if cfg.generation.print_alignment == "soft":
|
315 |
+
print(
|
316 |
+
"A-{}\t{}".format(
|
317 |
+
sample_id,
|
318 |
+
" ".join(
|
319 |
+
[",".join(src_probs) for src_probs in alignment]
|
320 |
+
),
|
321 |
+
),
|
322 |
+
file=output_file,
|
323 |
+
)
|
324 |
+
|
325 |
+
if cfg.generation.print_step:
|
326 |
+
print(
|
327 |
+
"I-{}\t{}".format(sample_id, hypo["steps"]),
|
328 |
+
file=output_file,
|
329 |
+
)
|
330 |
+
|
331 |
+
if cfg.generation.retain_iter_history:
|
332 |
+
for step, h in enumerate(hypo["history"]):
|
333 |
+
_, h_str, _ = utils.post_process_prediction(
|
334 |
+
hypo_tokens=h["tokens"].int().cpu(),
|
335 |
+
src_str=src_str,
|
336 |
+
alignment=None,
|
337 |
+
align_dict=None,
|
338 |
+
tgt_dict=tgt_dict,
|
339 |
+
remove_bpe=None,
|
340 |
+
)
|
341 |
+
print(
|
342 |
+
"E-{}_{}\t{}".format(sample_id, step, h_str),
|
343 |
+
file=output_file,
|
344 |
+
)
|
345 |
+
|
346 |
+
# Score only the top hypothesis
|
347 |
+
if has_target and j == 0:
|
348 |
+
if (
|
349 |
+
align_dict is not None
|
350 |
+
or cfg.common_eval.post_process is not None
|
351 |
+
):
|
352 |
+
# Convert back to tokens for evaluation with unk replacement and/or without BPE
|
353 |
+
target_tokens = tgt_dict.encode_line(
|
354 |
+
target_str, add_if_not_exist=True
|
355 |
+
)
|
356 |
+
hypo_tokens = tgt_dict.encode_line(
|
357 |
+
detok_hypo_str, add_if_not_exist=True
|
358 |
+
)
|
359 |
+
if hasattr(scorer, "add_string"):
|
360 |
+
scorer.add_string(target_str, detok_hypo_str)
|
361 |
+
else:
|
362 |
+
scorer.add(target_tokens, hypo_tokens)
|
363 |
+
|
364 |
+
wps_meter.update(num_generated_tokens)
|
365 |
+
progress.log({"wps": round(wps_meter.avg)})
|
366 |
+
num_sentences += (
|
367 |
+
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
368 |
+
)
|
369 |
+
|
370 |
+
logger.info("NOTE: hypothesis and token scores are output in base 2")
|
371 |
+
logger.info(
|
372 |
+
"Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
|
373 |
+
num_sentences,
|
374 |
+
gen_timer.n,
|
375 |
+
gen_timer.sum,
|
376 |
+
num_sentences / gen_timer.sum,
|
377 |
+
1.0 / gen_timer.avg,
|
378 |
+
)
|
379 |
+
)
|
380 |
+
if has_target:
|
381 |
+
if cfg.bpe and not cfg.generation.sacrebleu:
|
382 |
+
if cfg.common_eval.post_process:
|
383 |
+
logger.warning(
|
384 |
+
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
logger.warning(
|
388 |
+
"If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization"
|
389 |
+
)
|
390 |
+
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on
|
391 |
+
print(
|
392 |
+
"Generate {} with beam={}: {}".format(
|
393 |
+
cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
|
394 |
+
),
|
395 |
+
file=output_file,
|
396 |
+
)
|
397 |
+
|
398 |
+
return scorer
|
399 |
+
|
400 |
+
|
401 |
+
def cli_main():
|
402 |
+
parser = options.get_generation_parser()
|
403 |
+
# TODO: replace this workaround with refactoring of `AudioPretraining`
|
404 |
+
parser.add_argument(
|
405 |
+
"--arch",
|
406 |
+
"-a",
|
407 |
+
metavar="ARCH",
|
408 |
+
default="wav2vec2",
|
409 |
+
help="Model architecture. For constructing tasks that rely on "
|
410 |
+
"model args (e.g. `AudioPretraining`)",
|
411 |
+
)
|
412 |
+
args = options.parse_args_and_arch(parser)
|
413 |
+
main(args)
|
414 |
+
|
415 |
+
|
416 |
+
if __name__ == "__main__":
|
417 |
+
cli_main()
|
fairseq/fairseq_cli/hydra_train.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import hydra
|
11 |
+
import torch
|
12 |
+
from hydra.core.hydra_config import HydraConfig
|
13 |
+
from omegaconf import OmegaConf, open_dict
|
14 |
+
|
15 |
+
from fairseq import distributed_utils, metrics
|
16 |
+
from fairseq.dataclass.configs import FairseqConfig
|
17 |
+
from fairseq.dataclass.initialize import add_defaults, hydra_init
|
18 |
+
from fairseq.dataclass.utils import omegaconf_no_object_check
|
19 |
+
from fairseq.utils import reset_logging
|
20 |
+
from fairseq_cli.train import main as pre_main
|
21 |
+
|
22 |
+
logger = logging.getLogger("fairseq_cli.hydra_train")
|
23 |
+
|
24 |
+
|
25 |
+
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
|
26 |
+
def hydra_main(cfg: FairseqConfig) -> float:
|
27 |
+
_hydra_main(cfg)
|
28 |
+
|
29 |
+
|
30 |
+
def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
|
31 |
+
add_defaults(cfg)
|
32 |
+
|
33 |
+
if cfg.common.reset_logging:
|
34 |
+
reset_logging() # Hydra hijacks logging, fix that
|
35 |
+
else:
|
36 |
+
# check if directly called or called through hydra_main
|
37 |
+
if HydraConfig.initialized():
|
38 |
+
with open_dict(cfg):
|
39 |
+
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
40 |
+
cfg.job_logging_cfg = OmegaConf.to_container(
|
41 |
+
HydraConfig.get().job_logging, resolve=True
|
42 |
+
)
|
43 |
+
|
44 |
+
with omegaconf_no_object_check():
|
45 |
+
cfg = OmegaConf.create(
|
46 |
+
OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
|
47 |
+
)
|
48 |
+
OmegaConf.set_struct(cfg, True)
|
49 |
+
|
50 |
+
try:
|
51 |
+
if cfg.common.profile:
|
52 |
+
with torch.cuda.profiler.profile():
|
53 |
+
with torch.autograd.profiler.emit_nvtx():
|
54 |
+
distributed_utils.call_main(cfg, pre_main, **kwargs)
|
55 |
+
else:
|
56 |
+
distributed_utils.call_main(cfg, pre_main, **kwargs)
|
57 |
+
except BaseException as e:
|
58 |
+
if not cfg.common.suppress_crashes:
|
59 |
+
raise
|
60 |
+
else:
|
61 |
+
logger.error("Crashed! " + str(e))
|
62 |
+
|
63 |
+
# get best val and return - useful for sweepers
|
64 |
+
try:
|
65 |
+
best_val = metrics.get_smoothed_value(
|
66 |
+
"valid", cfg.checkpoint.best_checkpoint_metric
|
67 |
+
)
|
68 |
+
except:
|
69 |
+
best_val = None
|
70 |
+
|
71 |
+
if best_val is None:
|
72 |
+
best_val = float("inf")
|
73 |
+
|
74 |
+
return best_val
|
75 |
+
|
76 |
+
|
77 |
+
def cli_main():
|
78 |
+
try:
|
79 |
+
from hydra._internal.utils import get_args
|
80 |
+
|
81 |
+
cfg_name = get_args().config_name or "config"
|
82 |
+
except:
|
83 |
+
logger.warning("Failed to get config name from hydra args")
|
84 |
+
cfg_name = "config"
|
85 |
+
|
86 |
+
hydra_init(cfg_name)
|
87 |
+
hydra_main()
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
cli_main()
|
fairseq/fairseq_cli/hydra_validate.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from itertools import chain
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from hydra.core.hydra_config import HydraConfig
|
14 |
+
from omegaconf import OmegaConf, open_dict
|
15 |
+
import hydra
|
16 |
+
|
17 |
+
from fairseq import checkpoint_utils, distributed_utils, utils
|
18 |
+
from fairseq.dataclass.configs import FairseqConfig
|
19 |
+
from fairseq.dataclass.initialize import add_defaults, hydra_init
|
20 |
+
from fairseq.dataclass.utils import omegaconf_no_object_check
|
21 |
+
from fairseq.distributed import utils as distributed_utils
|
22 |
+
from fairseq.logging import metrics, progress_bar
|
23 |
+
from fairseq.utils import reset_logging
|
24 |
+
|
25 |
+
logging.basicConfig(
|
26 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
27 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
28 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
29 |
+
stream=sys.stdout,
|
30 |
+
)
|
31 |
+
logger = logging.getLogger("fairseq_cli.validate")
|
32 |
+
|
33 |
+
|
34 |
+
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
|
35 |
+
def hydra_main(cfg: FairseqConfig) -> float:
|
36 |
+
return _hydra_main(cfg)
|
37 |
+
|
38 |
+
|
39 |
+
def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
|
40 |
+
add_defaults(cfg)
|
41 |
+
|
42 |
+
if cfg.common.reset_logging:
|
43 |
+
reset_logging() # Hydra hijacks logging, fix that
|
44 |
+
else:
|
45 |
+
# check if directly called or called through hydra_main
|
46 |
+
if HydraConfig.initialized():
|
47 |
+
with open_dict(cfg):
|
48 |
+
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
49 |
+
cfg.job_logging_cfg = OmegaConf.to_container(
|
50 |
+
HydraConfig.get().job_logging, resolve=True
|
51 |
+
)
|
52 |
+
|
53 |
+
with omegaconf_no_object_check():
|
54 |
+
cfg = OmegaConf.create(
|
55 |
+
OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
|
56 |
+
)
|
57 |
+
OmegaConf.set_struct(cfg, True)
|
58 |
+
|
59 |
+
assert (
|
60 |
+
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
61 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
62 |
+
|
63 |
+
distributed_utils.call_main(cfg, validate, **kwargs)
|
64 |
+
|
65 |
+
|
66 |
+
def validate(cfg):
|
67 |
+
utils.import_user_module(cfg.common)
|
68 |
+
|
69 |
+
use_fp16 = cfg.common.fp16
|
70 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
71 |
+
|
72 |
+
if use_cuda:
|
73 |
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
74 |
+
|
75 |
+
if cfg.distributed_training.distributed_world_size > 1:
|
76 |
+
data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
|
77 |
+
data_parallel_rank = distributed_utils.get_data_parallel_rank()
|
78 |
+
else:
|
79 |
+
data_parallel_world_size = 1
|
80 |
+
data_parallel_rank = 0
|
81 |
+
|
82 |
+
overrides = {"task": {"data": cfg.task.data}}
|
83 |
+
|
84 |
+
# Load ensemble
|
85 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
86 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
87 |
+
[cfg.common_eval.path],
|
88 |
+
arg_overrides=overrides,
|
89 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
90 |
+
)
|
91 |
+
model = models[0]
|
92 |
+
|
93 |
+
# Move models to GPU
|
94 |
+
for model in models:
|
95 |
+
model.eval()
|
96 |
+
if use_fp16:
|
97 |
+
model.half()
|
98 |
+
if use_cuda:
|
99 |
+
model.cuda()
|
100 |
+
|
101 |
+
# Print args
|
102 |
+
logger.info(saved_cfg)
|
103 |
+
|
104 |
+
# Build criterion
|
105 |
+
criterion = task.build_criterion(saved_cfg.criterion, from_checkpoint=True)
|
106 |
+
criterion.eval()
|
107 |
+
|
108 |
+
for subset in cfg.dataset.valid_subset.split(","):
|
109 |
+
try:
|
110 |
+
task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task)
|
111 |
+
dataset = task.dataset(subset)
|
112 |
+
except KeyError:
|
113 |
+
raise Exception("Cannot find dataset: " + subset)
|
114 |
+
|
115 |
+
# Initialize data iterator
|
116 |
+
itr = task.get_batch_iterator(
|
117 |
+
dataset=dataset,
|
118 |
+
max_tokens=cfg.dataset.max_tokens,
|
119 |
+
max_sentences=cfg.dataset.batch_size,
|
120 |
+
max_positions=utils.resolve_max_positions(
|
121 |
+
task.max_positions(),
|
122 |
+
*[m.max_positions() for m in models],
|
123 |
+
),
|
124 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
125 |
+
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
126 |
+
seed=cfg.common.seed,
|
127 |
+
num_shards=data_parallel_world_size,
|
128 |
+
shard_id=data_parallel_rank,
|
129 |
+
num_workers=cfg.dataset.num_workers,
|
130 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
131 |
+
).next_epoch_itr(shuffle=False)
|
132 |
+
progress = progress_bar.progress_bar(
|
133 |
+
itr,
|
134 |
+
log_format=cfg.common.log_format,
|
135 |
+
log_interval=cfg.common.log_interval,
|
136 |
+
prefix=f"valid on '{subset}' subset",
|
137 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
138 |
+
)
|
139 |
+
|
140 |
+
def apply_half(t):
|
141 |
+
if t.dtype is torch.float32:
|
142 |
+
return t.to(dtype=torch.half)
|
143 |
+
return t
|
144 |
+
|
145 |
+
log_outputs = []
|
146 |
+
for i, sample in enumerate(progress):
|
147 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
148 |
+
|
149 |
+
if use_fp16:
|
150 |
+
sample = utils.apply_to_sample(apply_half, sample)
|
151 |
+
|
152 |
+
_loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
|
153 |
+
with metrics.aggregate() as agg:
|
154 |
+
task.reduce_metrics([log_output], criterion)
|
155 |
+
progress.log(agg.get_smoothed_values(), step=i)
|
156 |
+
# progress.log(log_output, step=i) from vision
|
157 |
+
log_outputs.append(log_output)
|
158 |
+
|
159 |
+
if data_parallel_world_size > 1:
|
160 |
+
log_outputs = distributed_utils.all_gather_list(
|
161 |
+
log_outputs,
|
162 |
+
max_size=cfg.common.all_gather_list_size,
|
163 |
+
group=distributed_utils.get_data_parallel_group(),
|
164 |
+
)
|
165 |
+
log_outputs = list(chain.from_iterable(log_outputs))
|
166 |
+
|
167 |
+
with metrics.aggregate() as agg:
|
168 |
+
task.reduce_metrics(log_outputs, criterion)
|
169 |
+
log_output = agg.get_smoothed_values()
|
170 |
+
|
171 |
+
progress.print(log_output, tag=subset, step=i)
|
172 |
+
|
173 |
+
|
174 |
+
def cli_main():
|
175 |
+
try:
|
176 |
+
from hydra._internal.utils import get_args
|
177 |
+
|
178 |
+
cfg_name = get_args().config_name or "config"
|
179 |
+
except:
|
180 |
+
logger.warning("Failed to get config name from hydra args")
|
181 |
+
cfg_name = "config"
|
182 |
+
|
183 |
+
hydra_init(cfg_name)
|
184 |
+
hydra_main()
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
cli_main()
|
fairseq/fairseq_cli/interactive.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Translate raw text with a trained model. Batches data on-the-fly.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import ast
|
11 |
+
import fileinput
|
12 |
+
import logging
|
13 |
+
import math
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import time
|
17 |
+
from argparse import Namespace
|
18 |
+
from collections import namedtuple
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
|
24 |
+
from fairseq.dataclass.configs import FairseqConfig
|
25 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
26 |
+
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
|
27 |
+
from fairseq_cli.generate import get_symbols_to_strip_from_output
|
28 |
+
|
29 |
+
logging.basicConfig(
|
30 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
31 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
32 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
33 |
+
stream=sys.stdout,
|
34 |
+
)
|
35 |
+
logger = logging.getLogger("fairseq_cli.interactive")
|
36 |
+
|
37 |
+
|
38 |
+
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
|
39 |
+
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
|
40 |
+
|
41 |
+
|
42 |
+
def buffered_read(input, buffer_size):
|
43 |
+
buffer = []
|
44 |
+
with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
|
45 |
+
for src_str in h:
|
46 |
+
buffer.append(src_str.strip())
|
47 |
+
if len(buffer) >= buffer_size:
|
48 |
+
yield buffer
|
49 |
+
buffer = []
|
50 |
+
|
51 |
+
if len(buffer) > 0:
|
52 |
+
yield buffer
|
53 |
+
|
54 |
+
|
55 |
+
def make_batches(lines, cfg, task, max_positions, encode_fn):
|
56 |
+
def encode_fn_target(x):
|
57 |
+
return encode_fn(x)
|
58 |
+
|
59 |
+
if cfg.generation.constraints:
|
60 |
+
# Strip (tab-delimited) contraints, if present, from input lines,
|
61 |
+
# store them in batch_constraints
|
62 |
+
batch_constraints = [list() for _ in lines]
|
63 |
+
for i, line in enumerate(lines):
|
64 |
+
if "\t" in line:
|
65 |
+
lines[i], *batch_constraints[i] = line.split("\t")
|
66 |
+
|
67 |
+
# Convert each List[str] to List[Tensor]
|
68 |
+
for i, constraint_list in enumerate(batch_constraints):
|
69 |
+
batch_constraints[i] = [
|
70 |
+
task.target_dictionary.encode_line(
|
71 |
+
encode_fn_target(constraint),
|
72 |
+
append_eos=False,
|
73 |
+
add_if_not_exist=False,
|
74 |
+
)
|
75 |
+
for constraint in constraint_list
|
76 |
+
]
|
77 |
+
|
78 |
+
if cfg.generation.constraints:
|
79 |
+
constraints_tensor = pack_constraints(batch_constraints)
|
80 |
+
else:
|
81 |
+
constraints_tensor = None
|
82 |
+
|
83 |
+
tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
|
84 |
+
|
85 |
+
itr = task.get_batch_iterator(
|
86 |
+
dataset=task.build_dataset_for_inference(
|
87 |
+
tokens, lengths, constraints=constraints_tensor
|
88 |
+
),
|
89 |
+
max_tokens=cfg.dataset.max_tokens,
|
90 |
+
max_sentences=cfg.dataset.batch_size,
|
91 |
+
max_positions=max_positions,
|
92 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
93 |
+
).next_epoch_itr(shuffle=False)
|
94 |
+
for batch in itr:
|
95 |
+
ids = batch["id"]
|
96 |
+
src_tokens = batch["net_input"]["src_tokens"]
|
97 |
+
src_lengths = batch["net_input"]["src_lengths"]
|
98 |
+
constraints = batch.get("constraints", None)
|
99 |
+
|
100 |
+
yield Batch(
|
101 |
+
ids=ids,
|
102 |
+
src_tokens=src_tokens,
|
103 |
+
src_lengths=src_lengths,
|
104 |
+
constraints=constraints,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def main(cfg: FairseqConfig):
|
109 |
+
if isinstance(cfg, Namespace):
|
110 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
111 |
+
|
112 |
+
start_time = time.time()
|
113 |
+
total_translate_time = 0
|
114 |
+
|
115 |
+
utils.import_user_module(cfg.common)
|
116 |
+
|
117 |
+
if cfg.interactive.buffer_size < 1:
|
118 |
+
cfg.interactive.buffer_size = 1
|
119 |
+
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
|
120 |
+
cfg.dataset.batch_size = 1
|
121 |
+
|
122 |
+
assert (
|
123 |
+
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
|
124 |
+
), "--sampling requires --nbest to be equal to --beam"
|
125 |
+
assert (
|
126 |
+
not cfg.dataset.batch_size
|
127 |
+
or cfg.dataset.batch_size <= cfg.interactive.buffer_size
|
128 |
+
), "--batch-size cannot be larger than --buffer-size"
|
129 |
+
|
130 |
+
logger.info(cfg)
|
131 |
+
|
132 |
+
# Fix seed for stochastic decoding
|
133 |
+
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
134 |
+
np.random.seed(cfg.common.seed)
|
135 |
+
utils.set_torch_seed(cfg.common.seed)
|
136 |
+
|
137 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
138 |
+
|
139 |
+
# Setup task, e.g., translation
|
140 |
+
task = tasks.setup_task(cfg.task)
|
141 |
+
|
142 |
+
# Load ensemble
|
143 |
+
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
|
144 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
145 |
+
models, _model_args = checkpoint_utils.load_model_ensemble(
|
146 |
+
utils.split_paths(cfg.common_eval.path),
|
147 |
+
arg_overrides=overrides,
|
148 |
+
task=task,
|
149 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
150 |
+
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
151 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
152 |
+
)
|
153 |
+
|
154 |
+
# Set dictionaries
|
155 |
+
src_dict = task.source_dictionary
|
156 |
+
tgt_dict = task.target_dictionary
|
157 |
+
|
158 |
+
# Optimize ensemble for generation
|
159 |
+
for model in models:
|
160 |
+
if model is None:
|
161 |
+
continue
|
162 |
+
if cfg.common.fp16:
|
163 |
+
model.half()
|
164 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
165 |
+
model.cuda()
|
166 |
+
model.prepare_for_inference_(cfg)
|
167 |
+
|
168 |
+
# Initialize generator
|
169 |
+
generator = task.build_generator(models, cfg.generation)
|
170 |
+
|
171 |
+
# Handle tokenization and BPE
|
172 |
+
tokenizer = task.build_tokenizer(cfg.tokenizer)
|
173 |
+
bpe = task.build_bpe(cfg.bpe)
|
174 |
+
|
175 |
+
def encode_fn(x):
|
176 |
+
if tokenizer is not None:
|
177 |
+
x = tokenizer.encode(x)
|
178 |
+
if bpe is not None:
|
179 |
+
x = bpe.encode(x)
|
180 |
+
return x
|
181 |
+
|
182 |
+
def decode_fn(x):
|
183 |
+
if bpe is not None:
|
184 |
+
x = bpe.decode(x)
|
185 |
+
if tokenizer is not None:
|
186 |
+
x = tokenizer.decode(x)
|
187 |
+
return x
|
188 |
+
|
189 |
+
# Load alignment dictionary for unknown word replacement
|
190 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
191 |
+
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
|
192 |
+
|
193 |
+
max_positions = utils.resolve_max_positions(
|
194 |
+
task.max_positions(), *[model.max_positions() for model in models]
|
195 |
+
)
|
196 |
+
|
197 |
+
if cfg.generation.constraints:
|
198 |
+
logger.warning(
|
199 |
+
"NOTE: Constrained decoding currently assumes a shared subword vocabulary."
|
200 |
+
)
|
201 |
+
|
202 |
+
if cfg.interactive.buffer_size > 1:
|
203 |
+
logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
|
204 |
+
logger.info("NOTE: hypothesis and token scores are output in base 2")
|
205 |
+
logger.info("Type the input sentence and press return:")
|
206 |
+
start_id = 0
|
207 |
+
for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size):
|
208 |
+
results = []
|
209 |
+
for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
|
210 |
+
bsz = batch.src_tokens.size(0)
|
211 |
+
src_tokens = batch.src_tokens
|
212 |
+
src_lengths = batch.src_lengths
|
213 |
+
constraints = batch.constraints
|
214 |
+
if use_cuda:
|
215 |
+
src_tokens = src_tokens.cuda()
|
216 |
+
src_lengths = src_lengths.cuda()
|
217 |
+
if constraints is not None:
|
218 |
+
constraints = constraints.cuda()
|
219 |
+
|
220 |
+
sample = {
|
221 |
+
"net_input": {
|
222 |
+
"src_tokens": src_tokens,
|
223 |
+
"src_lengths": src_lengths,
|
224 |
+
},
|
225 |
+
}
|
226 |
+
translate_start_time = time.time()
|
227 |
+
translations = task.inference_step(
|
228 |
+
generator, models, sample, constraints=constraints
|
229 |
+
)
|
230 |
+
translate_time = time.time() - translate_start_time
|
231 |
+
total_translate_time += translate_time
|
232 |
+
list_constraints = [[] for _ in range(bsz)]
|
233 |
+
if cfg.generation.constraints:
|
234 |
+
list_constraints = [unpack_constraints(c) for c in constraints]
|
235 |
+
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
|
236 |
+
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
|
237 |
+
constraints = list_constraints[i]
|
238 |
+
results.append(
|
239 |
+
(
|
240 |
+
start_id + id,
|
241 |
+
src_tokens_i,
|
242 |
+
hypos,
|
243 |
+
{
|
244 |
+
"constraints": constraints,
|
245 |
+
"time": translate_time / len(translations),
|
246 |
+
},
|
247 |
+
)
|
248 |
+
)
|
249 |
+
|
250 |
+
# sort output to match input order
|
251 |
+
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
|
252 |
+
src_str = ""
|
253 |
+
if src_dict is not None:
|
254 |
+
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
|
255 |
+
print("S-{}\t{}".format(id_, src_str))
|
256 |
+
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
|
257 |
+
for constraint in info["constraints"]:
|
258 |
+
print(
|
259 |
+
"C-{}\t{}".format(
|
260 |
+
id_,
|
261 |
+
tgt_dict.string(constraint, cfg.common_eval.post_process),
|
262 |
+
)
|
263 |
+
)
|
264 |
+
|
265 |
+
# Process top predictions
|
266 |
+
for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]:
|
267 |
+
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
268 |
+
hypo_tokens=hypo["tokens"].int().cpu(),
|
269 |
+
src_str=src_str,
|
270 |
+
alignment=hypo["alignment"],
|
271 |
+
align_dict=align_dict,
|
272 |
+
tgt_dict=tgt_dict,
|
273 |
+
remove_bpe=cfg.common_eval.post_process,
|
274 |
+
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
|
275 |
+
)
|
276 |
+
detok_hypo_str = decode_fn(hypo_str)
|
277 |
+
score = hypo["score"] / math.log(2) # convert to base 2
|
278 |
+
# original hypothesis (after tokenization and BPE)
|
279 |
+
print("H-{}\t{}\t{}".format(id_, score, hypo_str))
|
280 |
+
# detokenized hypothesis
|
281 |
+
print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
|
282 |
+
print(
|
283 |
+
"P-{}\t{}".format(
|
284 |
+
id_,
|
285 |
+
" ".join(
|
286 |
+
map(
|
287 |
+
lambda x: "{:.4f}".format(x),
|
288 |
+
# convert from base e to base 2
|
289 |
+
hypo["positional_scores"].div_(math.log(2)).tolist(),
|
290 |
+
)
|
291 |
+
),
|
292 |
+
)
|
293 |
+
)
|
294 |
+
if cfg.generation.print_alignment:
|
295 |
+
alignment_str = " ".join(
|
296 |
+
["{}-{}".format(src, tgt) for src, tgt in alignment]
|
297 |
+
)
|
298 |
+
print("A-{}\t{}".format(id_, alignment_str))
|
299 |
+
|
300 |
+
# update running id_ counter
|
301 |
+
start_id += len(inputs)
|
302 |
+
|
303 |
+
logger.info(
|
304 |
+
"Total time: {:.3f} seconds; translation time: {:.3f}".format(
|
305 |
+
time.time() - start_time, total_translate_time
|
306 |
+
)
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def cli_main():
|
311 |
+
parser = options.get_interactive_generation_parser()
|
312 |
+
args = options.parse_args_and_arch(parser)
|
313 |
+
distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
|
314 |
+
|
315 |
+
|
316 |
+
if __name__ == "__main__":
|
317 |
+
cli_main()
|
fairseq/fairseq_cli/preprocess.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Data pre-processing: build vocabularies and binarize training data.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
import sys
|
14 |
+
import typing as tp
|
15 |
+
from argparse import Namespace
|
16 |
+
from itertools import zip_longest
|
17 |
+
|
18 |
+
from fairseq import options, tasks, utils
|
19 |
+
from fairseq.binarizer import (
|
20 |
+
AlignmentDatasetBinarizer,
|
21 |
+
FileBinarizer,
|
22 |
+
VocabularyDatasetBinarizer,
|
23 |
+
)
|
24 |
+
from fairseq.data import Dictionary
|
25 |
+
|
26 |
+
logging.basicConfig(
|
27 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
28 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
29 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
30 |
+
stream=sys.stdout,
|
31 |
+
)
|
32 |
+
logger = logging.getLogger("fairseq_cli.preprocess")
|
33 |
+
|
34 |
+
#####################################################################
|
35 |
+
# file name tools
|
36 |
+
#####################################################################
|
37 |
+
|
38 |
+
|
39 |
+
def _train_path(lang, trainpref):
|
40 |
+
return "{}{}".format(trainpref, ("." + lang) if lang else "")
|
41 |
+
|
42 |
+
|
43 |
+
def _file_name(prefix, lang):
|
44 |
+
fname = prefix
|
45 |
+
if lang is not None:
|
46 |
+
fname += ".{lang}".format(lang=lang)
|
47 |
+
return fname
|
48 |
+
|
49 |
+
|
50 |
+
def _dest_path(prefix, lang, destdir):
|
51 |
+
return os.path.join(destdir, _file_name(prefix, lang))
|
52 |
+
|
53 |
+
|
54 |
+
def _dict_path(lang, destdir):
|
55 |
+
return _dest_path("dict", lang, destdir) + ".txt"
|
56 |
+
|
57 |
+
|
58 |
+
def dataset_dest_prefix(args, output_prefix, lang):
|
59 |
+
base = os.path.join(args.destdir, output_prefix)
|
60 |
+
if lang is not None:
|
61 |
+
lang_part = f".{args.source_lang}-{args.target_lang}.{lang}"
|
62 |
+
elif args.only_source:
|
63 |
+
lang_part = ""
|
64 |
+
else:
|
65 |
+
lang_part = f".{args.source_lang}-{args.target_lang}"
|
66 |
+
|
67 |
+
return "{}{}".format(base, lang_part)
|
68 |
+
|
69 |
+
|
70 |
+
def dataset_dest_file(args, output_prefix, lang, extension):
|
71 |
+
return "{}.{}".format(dataset_dest_prefix(args, output_prefix, lang), extension)
|
72 |
+
|
73 |
+
|
74 |
+
#####################################################################
|
75 |
+
# dictionary tools
|
76 |
+
#####################################################################
|
77 |
+
|
78 |
+
|
79 |
+
def _build_dictionary(
|
80 |
+
filenames,
|
81 |
+
task,
|
82 |
+
args,
|
83 |
+
src=False,
|
84 |
+
tgt=False,
|
85 |
+
):
|
86 |
+
assert src ^ tgt
|
87 |
+
return task.build_dictionary(
|
88 |
+
filenames,
|
89 |
+
workers=args.workers,
|
90 |
+
threshold=args.thresholdsrc if src else args.thresholdtgt,
|
91 |
+
nwords=args.nwordssrc if src else args.nwordstgt,
|
92 |
+
padding_factor=args.padding_factor,
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
#####################################################################
|
97 |
+
# bin file creation logic
|
98 |
+
#####################################################################
|
99 |
+
|
100 |
+
|
101 |
+
def _make_binary_dataset(
|
102 |
+
vocab: Dictionary,
|
103 |
+
input_prefix: str,
|
104 |
+
output_prefix: str,
|
105 |
+
lang: tp.Optional[str],
|
106 |
+
num_workers: int,
|
107 |
+
args: Namespace,
|
108 |
+
):
|
109 |
+
logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
|
110 |
+
|
111 |
+
binarizer = VocabularyDatasetBinarizer(
|
112 |
+
vocab,
|
113 |
+
append_eos=True,
|
114 |
+
)
|
115 |
+
|
116 |
+
input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "")
|
117 |
+
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang)
|
118 |
+
|
119 |
+
final_summary = FileBinarizer.multiprocess_dataset(
|
120 |
+
input_file,
|
121 |
+
args.dataset_impl,
|
122 |
+
binarizer,
|
123 |
+
full_output_prefix,
|
124 |
+
vocab_size=len(vocab),
|
125 |
+
num_workers=num_workers,
|
126 |
+
)
|
127 |
+
|
128 |
+
logger.info(f"[{lang}] {input_file}: {final_summary} (by {vocab.unk_word})")
|
129 |
+
|
130 |
+
|
131 |
+
def _make_binary_alignment_dataset(
|
132 |
+
input_prefix: str, output_prefix: str, num_workers: int, args: Namespace
|
133 |
+
):
|
134 |
+
|
135 |
+
binarizer = AlignmentDatasetBinarizer(utils.parse_alignment)
|
136 |
+
|
137 |
+
input_file = input_prefix
|
138 |
+
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang=None)
|
139 |
+
|
140 |
+
final_summary = FileBinarizer.multiprocess_dataset(
|
141 |
+
input_file,
|
142 |
+
args.dataset_impl,
|
143 |
+
binarizer,
|
144 |
+
full_output_prefix,
|
145 |
+
vocab_size=None,
|
146 |
+
num_workers=num_workers,
|
147 |
+
)
|
148 |
+
|
149 |
+
logger.info(
|
150 |
+
"[alignments] {}: parsed {} alignments".format(
|
151 |
+
input_file, final_summary.num_seq
|
152 |
+
)
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
#####################################################################
|
157 |
+
# routing logic
|
158 |
+
#####################################################################
|
159 |
+
|
160 |
+
|
161 |
+
def _make_dataset(
|
162 |
+
vocab: Dictionary,
|
163 |
+
input_prefix: str,
|
164 |
+
output_prefix: str,
|
165 |
+
lang: tp.Optional[str],
|
166 |
+
args: Namespace,
|
167 |
+
num_workers: int,
|
168 |
+
):
|
169 |
+
if args.dataset_impl == "raw":
|
170 |
+
# Copy original text file to destination folder
|
171 |
+
output_text_file = _dest_path(
|
172 |
+
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
|
173 |
+
lang,
|
174 |
+
args.destdir,
|
175 |
+
)
|
176 |
+
shutil.copyfile(_file_name(input_prefix, lang), output_text_file)
|
177 |
+
else:
|
178 |
+
_make_binary_dataset(
|
179 |
+
vocab, input_prefix, output_prefix, lang, num_workers, args
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
def _make_all(lang, vocab, args):
|
184 |
+
if args.trainpref:
|
185 |
+
_make_dataset(
|
186 |
+
vocab, args.trainpref, "train", lang, args=args, num_workers=args.workers
|
187 |
+
)
|
188 |
+
if args.validpref:
|
189 |
+
for k, validpref in enumerate(args.validpref.split(",")):
|
190 |
+
outprefix = "valid{}".format(k) if k > 0 else "valid"
|
191 |
+
_make_dataset(
|
192 |
+
vocab, validpref, outprefix, lang, args=args, num_workers=args.workers
|
193 |
+
)
|
194 |
+
if args.testpref:
|
195 |
+
for k, testpref in enumerate(args.testpref.split(",")):
|
196 |
+
outprefix = "test{}".format(k) if k > 0 else "test"
|
197 |
+
_make_dataset(
|
198 |
+
vocab, testpref, outprefix, lang, args=args, num_workers=args.workers
|
199 |
+
)
|
200 |
+
|
201 |
+
|
202 |
+
def _make_all_alignments(args):
|
203 |
+
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
|
204 |
+
_make_binary_alignment_dataset(
|
205 |
+
args.trainpref + "." + args.align_suffix,
|
206 |
+
"train.align",
|
207 |
+
num_workers=args.workers,
|
208 |
+
args=args,
|
209 |
+
)
|
210 |
+
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
|
211 |
+
_make_binary_alignment_dataset(
|
212 |
+
args.validpref + "." + args.align_suffix,
|
213 |
+
"valid.align",
|
214 |
+
num_workers=args.workers,
|
215 |
+
args=args,
|
216 |
+
)
|
217 |
+
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
|
218 |
+
_make_binary_alignment_dataset(
|
219 |
+
args.testpref + "." + args.align_suffix,
|
220 |
+
"test.align",
|
221 |
+
num_workers=args.workers,
|
222 |
+
args=args,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
#####################################################################
|
227 |
+
# align
|
228 |
+
#####################################################################
|
229 |
+
|
230 |
+
|
231 |
+
def _align_files(args, src_dict, tgt_dict):
|
232 |
+
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
|
233 |
+
src_file_name = _train_path(args.source_lang, args.trainpref)
|
234 |
+
tgt_file_name = _train_path(args.target_lang, args.trainpref)
|
235 |
+
freq_map = {}
|
236 |
+
with open(args.alignfile, "r", encoding="utf-8") as align_file:
|
237 |
+
with open(src_file_name, "r", encoding="utf-8") as src_file:
|
238 |
+
with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
|
239 |
+
for a, s, t in zip_longest(align_file, src_file, tgt_file):
|
240 |
+
si = src_dict.encode_line(s, add_if_not_exist=False)
|
241 |
+
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
|
242 |
+
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
|
243 |
+
for sai, tai in ai:
|
244 |
+
srcidx = si[int(sai)]
|
245 |
+
tgtidx = ti[int(tai)]
|
246 |
+
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
|
247 |
+
assert srcidx != src_dict.pad()
|
248 |
+
assert srcidx != src_dict.eos()
|
249 |
+
assert tgtidx != tgt_dict.pad()
|
250 |
+
assert tgtidx != tgt_dict.eos()
|
251 |
+
if srcidx not in freq_map:
|
252 |
+
freq_map[srcidx] = {}
|
253 |
+
if tgtidx not in freq_map[srcidx]:
|
254 |
+
freq_map[srcidx][tgtidx] = 1
|
255 |
+
else:
|
256 |
+
freq_map[srcidx][tgtidx] += 1
|
257 |
+
align_dict = {}
|
258 |
+
for srcidx in freq_map.keys():
|
259 |
+
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
|
260 |
+
with open(
|
261 |
+
os.path.join(
|
262 |
+
args.destdir,
|
263 |
+
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
|
264 |
+
),
|
265 |
+
"w",
|
266 |
+
encoding="utf-8",
|
267 |
+
) as f:
|
268 |
+
for k, v in align_dict.items():
|
269 |
+
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
|
270 |
+
|
271 |
+
|
272 |
+
#####################################################################
|
273 |
+
# MAIN
|
274 |
+
#####################################################################
|
275 |
+
|
276 |
+
|
277 |
+
def main(args):
|
278 |
+
# setup some basic things
|
279 |
+
utils.import_user_module(args)
|
280 |
+
|
281 |
+
os.makedirs(args.destdir, exist_ok=True)
|
282 |
+
|
283 |
+
logger.addHandler(
|
284 |
+
logging.FileHandler(
|
285 |
+
filename=os.path.join(args.destdir, "preprocess.log"),
|
286 |
+
)
|
287 |
+
)
|
288 |
+
logger.info(args)
|
289 |
+
|
290 |
+
assert (
|
291 |
+
args.dataset_impl != "huffman"
|
292 |
+
), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly."
|
293 |
+
|
294 |
+
# build dictionaries
|
295 |
+
|
296 |
+
target = not args.only_source
|
297 |
+
|
298 |
+
if not args.srcdict and os.path.exists(_dict_path(args.source_lang, args.destdir)):
|
299 |
+
raise FileExistsError(_dict_path(args.source_lang, args.destdir))
|
300 |
+
|
301 |
+
if (
|
302 |
+
target
|
303 |
+
and not args.tgtdict
|
304 |
+
and os.path.exists(_dict_path(args.target_lang, args.destdir))
|
305 |
+
):
|
306 |
+
raise FileExistsError(_dict_path(args.target_lang, args.destdir))
|
307 |
+
|
308 |
+
task = tasks.get_task(args.task)
|
309 |
+
|
310 |
+
if args.joined_dictionary:
|
311 |
+
assert (
|
312 |
+
not args.srcdict or not args.tgtdict
|
313 |
+
), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
|
314 |
+
|
315 |
+
if args.srcdict:
|
316 |
+
src_dict = task.load_dictionary(args.srcdict)
|
317 |
+
elif args.tgtdict:
|
318 |
+
src_dict = task.load_dictionary(args.tgtdict)
|
319 |
+
else:
|
320 |
+
assert (
|
321 |
+
args.trainpref
|
322 |
+
), "--trainpref must be set if --srcdict is not specified"
|
323 |
+
src_dict = _build_dictionary(
|
324 |
+
{
|
325 |
+
_train_path(lang, args.trainpref)
|
326 |
+
for lang in [args.source_lang, args.target_lang]
|
327 |
+
},
|
328 |
+
task=task,
|
329 |
+
args=args,
|
330 |
+
src=True,
|
331 |
+
)
|
332 |
+
tgt_dict = src_dict
|
333 |
+
else:
|
334 |
+
if args.srcdict:
|
335 |
+
src_dict = task.load_dictionary(args.srcdict)
|
336 |
+
else:
|
337 |
+
assert (
|
338 |
+
args.trainpref
|
339 |
+
), "--trainpref must be set if --srcdict is not specified"
|
340 |
+
src_dict = _build_dictionary(
|
341 |
+
[_train_path(args.source_lang, args.trainpref)],
|
342 |
+
task=task,
|
343 |
+
args=args,
|
344 |
+
src=True,
|
345 |
+
)
|
346 |
+
|
347 |
+
if target:
|
348 |
+
if args.tgtdict:
|
349 |
+
tgt_dict = task.load_dictionary(args.tgtdict)
|
350 |
+
else:
|
351 |
+
assert (
|
352 |
+
args.trainpref
|
353 |
+
), "--trainpref must be set if --tgtdict is not specified"
|
354 |
+
tgt_dict = _build_dictionary(
|
355 |
+
[_train_path(args.target_lang, args.trainpref)],
|
356 |
+
task=task,
|
357 |
+
args=args,
|
358 |
+
tgt=True,
|
359 |
+
)
|
360 |
+
else:
|
361 |
+
tgt_dict = None
|
362 |
+
|
363 |
+
# save dictionaries
|
364 |
+
|
365 |
+
src_dict.save(_dict_path(args.source_lang, args.destdir))
|
366 |
+
if target and tgt_dict is not None:
|
367 |
+
tgt_dict.save(_dict_path(args.target_lang, args.destdir))
|
368 |
+
|
369 |
+
if args.dict_only:
|
370 |
+
return
|
371 |
+
|
372 |
+
_make_all(args.source_lang, src_dict, args)
|
373 |
+
if target:
|
374 |
+
_make_all(args.target_lang, tgt_dict, args)
|
375 |
+
|
376 |
+
# align the datasets if needed
|
377 |
+
if args.align_suffix:
|
378 |
+
_make_all_alignments(args)
|
379 |
+
|
380 |
+
logger.info("Wrote preprocessed data to {}".format(args.destdir))
|
381 |
+
|
382 |
+
if args.alignfile:
|
383 |
+
_align_files(args, src_dict=src_dict, tgt_dict=tgt_dict)
|
384 |
+
|
385 |
+
|
386 |
+
def cli_main():
|
387 |
+
parser = options.get_preprocessing_parser()
|
388 |
+
args = parser.parse_args()
|
389 |
+
main(args)
|
390 |
+
|
391 |
+
|
392 |
+
if __name__ == "__main__":
|
393 |
+
cli_main()
|
fairseq/fairseq_cli/score.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
BLEU scoring of generated translations against reference translations.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
from fairseq.data import dictionary
|
15 |
+
from fairseq.scoring import bleu
|
16 |
+
|
17 |
+
|
18 |
+
def get_parser():
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
description="Command-line script for BLEU scoring."
|
21 |
+
)
|
22 |
+
# fmt: off
|
23 |
+
parser.add_argument('-s', '--sys', default='-', help='system output')
|
24 |
+
parser.add_argument('-r', '--ref', required=True, help='references')
|
25 |
+
parser.add_argument('-o', '--order', default=4, metavar='N',
|
26 |
+
type=int, help='consider ngrams up to this order')
|
27 |
+
parser.add_argument('--ignore-case', action='store_true',
|
28 |
+
help='case-insensitive scoring')
|
29 |
+
parser.add_argument('--sacrebleu', action='store_true',
|
30 |
+
help='score with sacrebleu')
|
31 |
+
parser.add_argument('--sentence-bleu', action='store_true',
|
32 |
+
help='report sentence-level BLEUs (i.e., with +1 smoothing)')
|
33 |
+
# fmt: on
|
34 |
+
return parser
|
35 |
+
|
36 |
+
|
37 |
+
def cli_main():
|
38 |
+
parser = get_parser()
|
39 |
+
args = parser.parse_args()
|
40 |
+
print(args)
|
41 |
+
|
42 |
+
assert args.sys == "-" or os.path.exists(
|
43 |
+
args.sys
|
44 |
+
), "System output file {} does not exist".format(args.sys)
|
45 |
+
assert os.path.exists(args.ref), "Reference file {} does not exist".format(args.ref)
|
46 |
+
|
47 |
+
dict = dictionary.Dictionary()
|
48 |
+
|
49 |
+
def readlines(fd):
|
50 |
+
for line in fd.readlines():
|
51 |
+
if args.ignore_case:
|
52 |
+
yield line.lower()
|
53 |
+
else:
|
54 |
+
yield line
|
55 |
+
|
56 |
+
if args.sacrebleu:
|
57 |
+
import sacrebleu
|
58 |
+
|
59 |
+
def score(fdsys):
|
60 |
+
with open(args.ref) as fdref:
|
61 |
+
print(sacrebleu.corpus_bleu(fdsys, [fdref]).format())
|
62 |
+
|
63 |
+
elif args.sentence_bleu:
|
64 |
+
|
65 |
+
def score(fdsys):
|
66 |
+
with open(args.ref) as fdref:
|
67 |
+
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
|
68 |
+
for i, (sys_tok, ref_tok) in enumerate(
|
69 |
+
zip(readlines(fdsys), readlines(fdref))
|
70 |
+
):
|
71 |
+
scorer.reset(one_init=True)
|
72 |
+
sys_tok = dict.encode_line(sys_tok)
|
73 |
+
ref_tok = dict.encode_line(ref_tok)
|
74 |
+
scorer.add(ref_tok, sys_tok)
|
75 |
+
print(i, scorer.result_string(args.order))
|
76 |
+
|
77 |
+
else:
|
78 |
+
|
79 |
+
def score(fdsys):
|
80 |
+
with open(args.ref) as fdref:
|
81 |
+
scorer = bleu.Scorer(
|
82 |
+
bleu.BleuConfig(
|
83 |
+
pad=dict.pad(),
|
84 |
+
eos=dict.eos(),
|
85 |
+
unk=dict.unk(),
|
86 |
+
)
|
87 |
+
)
|
88 |
+
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
|
89 |
+
sys_tok = dict.encode_line(sys_tok)
|
90 |
+
ref_tok = dict.encode_line(ref_tok)
|
91 |
+
scorer.add(ref_tok, sys_tok)
|
92 |
+
print(scorer.result_string(args.order))
|
93 |
+
|
94 |
+
if args.sys == "-":
|
95 |
+
score(sys.stdin)
|
96 |
+
else:
|
97 |
+
with open(args.sys, "r") as f:
|
98 |
+
score(f)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
cli_main()
|
fairseq/fairseq_cli/train.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Train a new model on one or across multiple GPUs.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
16 |
+
|
17 |
+
# We need to setup root logger before importing any fairseq libraries.
|
18 |
+
logging.basicConfig(
|
19 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
20 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
21 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
22 |
+
stream=sys.stdout,
|
23 |
+
)
|
24 |
+
logger = logging.getLogger("fairseq_cli.train")
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
from omegaconf import DictConfig, OmegaConf
|
29 |
+
|
30 |
+
from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils
|
31 |
+
from fairseq.data import data_utils, iterators
|
32 |
+
from fairseq.data.plasma_utils import PlasmaStore
|
33 |
+
from fairseq.dataclass.configs import FairseqConfig
|
34 |
+
from fairseq.dataclass.initialize import add_defaults
|
35 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
36 |
+
from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap
|
37 |
+
from fairseq.distributed import utils as distributed_utils
|
38 |
+
from fairseq.file_io import PathManager
|
39 |
+
from fairseq.logging import meters, metrics, progress_bar
|
40 |
+
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
|
41 |
+
from fairseq.trainer import Trainer
|
42 |
+
|
43 |
+
|
44 |
+
def main(cfg: FairseqConfig) -> None:
|
45 |
+
if isinstance(cfg, argparse.Namespace):
|
46 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
47 |
+
|
48 |
+
utils.import_user_module(cfg.common)
|
49 |
+
add_defaults(cfg)
|
50 |
+
|
51 |
+
if (
|
52 |
+
distributed_utils.is_master(cfg.distributed_training)
|
53 |
+
and "job_logging_cfg" in cfg
|
54 |
+
):
|
55 |
+
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
56 |
+
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
|
57 |
+
|
58 |
+
assert (
|
59 |
+
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
60 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
61 |
+
metrics.reset()
|
62 |
+
|
63 |
+
if cfg.common.log_file is not None:
|
64 |
+
handler = logging.FileHandler(filename=cfg.common.log_file)
|
65 |
+
logger.addHandler(handler)
|
66 |
+
|
67 |
+
np.random.seed(cfg.common.seed)
|
68 |
+
utils.set_torch_seed(cfg.common.seed)
|
69 |
+
|
70 |
+
if distributed_utils.is_master(cfg.distributed_training):
|
71 |
+
checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
|
72 |
+
|
73 |
+
# Print args
|
74 |
+
logger.info(cfg)
|
75 |
+
|
76 |
+
if cfg.checkpoint.write_checkpoints_asynchronously:
|
77 |
+
try:
|
78 |
+
import iopath # noqa: F401
|
79 |
+
except ImportError:
|
80 |
+
logging.exception(
|
81 |
+
"Asynchronous checkpoint writing is specified but iopath is "
|
82 |
+
"not installed: `pip install iopath`"
|
83 |
+
)
|
84 |
+
return
|
85 |
+
|
86 |
+
# Setup task, e.g., translation, language modeling, etc.
|
87 |
+
task = tasks.setup_task(cfg.task)
|
88 |
+
|
89 |
+
assert cfg.criterion, "Please specify criterion to train a model"
|
90 |
+
|
91 |
+
# Build model and criterion
|
92 |
+
if cfg.distributed_training.ddp_backend == "fully_sharded":
|
93 |
+
with fsdp_enable_wrap(cfg.distributed_training):
|
94 |
+
model = fsdp_wrap(task.build_model(cfg.model))
|
95 |
+
else:
|
96 |
+
model = task.build_model(cfg.model)
|
97 |
+
criterion = task.build_criterion(cfg.criterion)
|
98 |
+
logger.info(model)
|
99 |
+
logger.info("task: {}".format(task.__class__.__name__))
|
100 |
+
logger.info("model: {}".format(model.__class__.__name__))
|
101 |
+
logger.info("criterion: {}".format(criterion.__class__.__name__))
|
102 |
+
logger.info(
|
103 |
+
"num. shared model params: {:,} (num. trained: {:,})".format(
|
104 |
+
sum(
|
105 |
+
p.numel() for p in model.parameters() if not getattr(p, "expert", False)
|
106 |
+
),
|
107 |
+
sum(
|
108 |
+
p.numel()
|
109 |
+
for p in model.parameters()
|
110 |
+
if not getattr(p, "expert", False) and p.requires_grad
|
111 |
+
),
|
112 |
+
)
|
113 |
+
)
|
114 |
+
|
115 |
+
logger.info(
|
116 |
+
"num. expert model params: {} (num. trained: {})".format(
|
117 |
+
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
|
118 |
+
sum(
|
119 |
+
p.numel()
|
120 |
+
for p in model.parameters()
|
121 |
+
if getattr(p, "expert", False) and p.requires_grad
|
122 |
+
),
|
123 |
+
)
|
124 |
+
)
|
125 |
+
|
126 |
+
# Load valid dataset (we load training data below, based on the latest checkpoint)
|
127 |
+
# We load the valid dataset AFTER building the model
|
128 |
+
if not cfg.dataset.disable_validation:
|
129 |
+
data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
|
130 |
+
if cfg.dataset.combine_valid_subsets:
|
131 |
+
task.load_dataset("valid", combine=True, epoch=1)
|
132 |
+
else:
|
133 |
+
for valid_sub_split in cfg.dataset.valid_subset.split(","):
|
134 |
+
task.load_dataset(valid_sub_split, combine=False, epoch=1)
|
135 |
+
|
136 |
+
# (optionally) Configure quantization
|
137 |
+
if cfg.common.quantization_config_path is not None:
|
138 |
+
quantizer = quantization_utils.Quantizer(
|
139 |
+
config_path=cfg.common.quantization_config_path,
|
140 |
+
max_epoch=cfg.optimization.max_epoch,
|
141 |
+
max_update=cfg.optimization.max_update,
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
quantizer = None
|
145 |
+
|
146 |
+
# Build trainer
|
147 |
+
if cfg.common.model_parallel_size == 1:
|
148 |
+
trainer = Trainer(cfg, task, model, criterion, quantizer)
|
149 |
+
else:
|
150 |
+
trainer = MegatronTrainer(cfg, task, model, criterion)
|
151 |
+
logger.info(
|
152 |
+
"training on {} devices (GPUs/TPUs)".format(
|
153 |
+
cfg.distributed_training.distributed_world_size
|
154 |
+
)
|
155 |
+
)
|
156 |
+
logger.info(
|
157 |
+
"max tokens per device = {} and max sentences per device = {}".format(
|
158 |
+
cfg.dataset.max_tokens,
|
159 |
+
cfg.dataset.batch_size,
|
160 |
+
)
|
161 |
+
)
|
162 |
+
|
163 |
+
# Load the latest checkpoint if one is available and restore the
|
164 |
+
# corresponding train iterator
|
165 |
+
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
|
166 |
+
cfg.checkpoint,
|
167 |
+
trainer,
|
168 |
+
# don't cache epoch iterators for sharded datasets
|
169 |
+
disable_iterator_cache=task.has_sharded_data("train"),
|
170 |
+
)
|
171 |
+
if cfg.common.tpu:
|
172 |
+
import torch_xla.core.xla_model as xm
|
173 |
+
|
174 |
+
xm.rendezvous("load_checkpoint") # wait for all workers
|
175 |
+
|
176 |
+
max_epoch = cfg.optimization.max_epoch or math.inf
|
177 |
+
lr = trainer.get_lr()
|
178 |
+
|
179 |
+
# TODO: a dry run on validation set to pin the memory
|
180 |
+
valid_subsets = cfg.dataset.valid_subset.split(",")
|
181 |
+
if not cfg.dataset.disable_validation:
|
182 |
+
for subset in valid_subsets:
|
183 |
+
logger.info('begin dry-run validation on "{}" subset'.format(subset))
|
184 |
+
itr = trainer.get_valid_iterator(subset).next_epoch_itr(
|
185 |
+
shuffle=False, set_dataset_epoch=False # use a fixed valid set
|
186 |
+
)
|
187 |
+
if cfg.common.tpu:
|
188 |
+
itr = utils.tpu_data_loader(itr)
|
189 |
+
for _ in itr:
|
190 |
+
pass
|
191 |
+
# TODO: end of dry run section
|
192 |
+
|
193 |
+
train_meter = meters.StopwatchMeter()
|
194 |
+
train_meter.start()
|
195 |
+
while epoch_itr.next_epoch_idx <= max_epoch:
|
196 |
+
if lr <= cfg.optimization.stop_min_lr:
|
197 |
+
logger.info(
|
198 |
+
f"stopping training because current learning rate ({lr}) is smaller "
|
199 |
+
"than or equal to minimum learning rate "
|
200 |
+
f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
|
201 |
+
)
|
202 |
+
break
|
203 |
+
|
204 |
+
# train for one epoch
|
205 |
+
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
|
206 |
+
if should_stop:
|
207 |
+
break
|
208 |
+
|
209 |
+
# only use first validation loss to update the learning rate
|
210 |
+
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
|
211 |
+
|
212 |
+
epoch_itr = trainer.get_train_iterator(
|
213 |
+
epoch_itr.next_epoch_idx,
|
214 |
+
# sharded data: get train iterator for next epoch
|
215 |
+
load_dataset=task.has_sharded_data("train"),
|
216 |
+
# don't cache epoch iterators for sharded datasets
|
217 |
+
disable_iterator_cache=task.has_sharded_data("train"),
|
218 |
+
)
|
219 |
+
train_meter.stop()
|
220 |
+
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
|
221 |
+
|
222 |
+
# ioPath implementation to wait for all asynchronous file writes to complete.
|
223 |
+
if cfg.checkpoint.write_checkpoints_asynchronously:
|
224 |
+
logger.info(
|
225 |
+
"ioPath PathManager waiting for all asynchronous checkpoint "
|
226 |
+
"writes to finish."
|
227 |
+
)
|
228 |
+
PathManager.async_close()
|
229 |
+
logger.info("ioPath PathManager finished waiting.")
|
230 |
+
|
231 |
+
|
232 |
+
def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
|
233 |
+
# skip check if no validation was done in the current epoch
|
234 |
+
if valid_loss is None:
|
235 |
+
return False
|
236 |
+
if cfg.checkpoint.patience <= 0:
|
237 |
+
return False
|
238 |
+
|
239 |
+
def is_better(a, b):
|
240 |
+
return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
|
241 |
+
|
242 |
+
prev_best = getattr(should_stop_early, "best", None)
|
243 |
+
if prev_best is None or is_better(valid_loss, prev_best):
|
244 |
+
should_stop_early.best = valid_loss
|
245 |
+
should_stop_early.num_runs = 0
|
246 |
+
return False
|
247 |
+
else:
|
248 |
+
should_stop_early.num_runs += 1
|
249 |
+
if should_stop_early.num_runs >= cfg.checkpoint.patience:
|
250 |
+
logger.info(
|
251 |
+
"early stop since valid performance hasn't improved for last {} runs".format(
|
252 |
+
cfg.checkpoint.patience
|
253 |
+
)
|
254 |
+
)
|
255 |
+
return True
|
256 |
+
else:
|
257 |
+
return False
|
258 |
+
|
259 |
+
|
260 |
+
@metrics.aggregate("train")
|
261 |
+
def train(
|
262 |
+
cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
|
263 |
+
) -> Tuple[List[Optional[float]], bool]:
|
264 |
+
"""Train the model for one epoch and return validation losses."""
|
265 |
+
# Initialize data iterator
|
266 |
+
itr = epoch_itr.next_epoch_itr(
|
267 |
+
fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
|
268 |
+
shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
|
269 |
+
)
|
270 |
+
update_freq = (
|
271 |
+
cfg.optimization.update_freq[epoch_itr.epoch - 1]
|
272 |
+
if epoch_itr.epoch <= len(cfg.optimization.update_freq)
|
273 |
+
else cfg.optimization.update_freq[-1]
|
274 |
+
)
|
275 |
+
itr = iterators.GroupedIterator(
|
276 |
+
itr,
|
277 |
+
update_freq,
|
278 |
+
skip_remainder_batch=cfg.optimization.skip_remainder_batch,
|
279 |
+
)
|
280 |
+
if cfg.common.tpu:
|
281 |
+
itr = utils.tpu_data_loader(itr)
|
282 |
+
progress = progress_bar.progress_bar(
|
283 |
+
itr,
|
284 |
+
log_format=cfg.common.log_format,
|
285 |
+
log_file=cfg.common.log_file,
|
286 |
+
log_interval=cfg.common.log_interval,
|
287 |
+
epoch=epoch_itr.epoch,
|
288 |
+
aim_repo=(
|
289 |
+
cfg.common.aim_repo
|
290 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
291 |
+
else None
|
292 |
+
),
|
293 |
+
aim_run_hash=(
|
294 |
+
cfg.common.aim_run_hash
|
295 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
296 |
+
else None
|
297 |
+
),
|
298 |
+
aim_param_checkpoint_dir=cfg.checkpoint.save_dir,
|
299 |
+
tensorboard_logdir=(
|
300 |
+
cfg.common.tensorboard_logdir
|
301 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
302 |
+
else None
|
303 |
+
),
|
304 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
305 |
+
wandb_project=(
|
306 |
+
cfg.common.wandb_project
|
307 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
308 |
+
else None
|
309 |
+
),
|
310 |
+
wandb_run_name=os.environ.get(
|
311 |
+
"WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
|
312 |
+
),
|
313 |
+
azureml_logging=(
|
314 |
+
cfg.common.azureml_logging
|
315 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
316 |
+
else False
|
317 |
+
),
|
318 |
+
)
|
319 |
+
progress.update_config(_flatten_config(cfg))
|
320 |
+
|
321 |
+
trainer.begin_epoch(epoch_itr.epoch)
|
322 |
+
|
323 |
+
valid_subsets = cfg.dataset.valid_subset.split(",")
|
324 |
+
should_stop = False
|
325 |
+
num_updates = trainer.get_num_updates()
|
326 |
+
logger.info("Start iterating over samples")
|
327 |
+
for i, samples in enumerate(progress):
|
328 |
+
with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
|
329 |
+
"train_step-%d" % i
|
330 |
+
):
|
331 |
+
log_output = trainer.train_step(samples)
|
332 |
+
|
333 |
+
if log_output is not None: # not OOM, overflow, ...
|
334 |
+
# log mid-epoch stats
|
335 |
+
num_updates = trainer.get_num_updates()
|
336 |
+
if num_updates % cfg.common.log_interval == 0:
|
337 |
+
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
|
338 |
+
progress.log(stats, tag="train_inner", step=num_updates)
|
339 |
+
|
340 |
+
# reset mid-epoch stats after each log interval
|
341 |
+
# the end-of-epoch stats will still be preserved
|
342 |
+
metrics.reset_meters("train_inner")
|
343 |
+
|
344 |
+
end_of_epoch = not itr.has_next()
|
345 |
+
valid_losses, should_stop = validate_and_save(
|
346 |
+
cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
|
347 |
+
)
|
348 |
+
|
349 |
+
if should_stop:
|
350 |
+
break
|
351 |
+
|
352 |
+
# log end-of-epoch stats
|
353 |
+
logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
|
354 |
+
stats = get_training_stats(metrics.get_smoothed_values("train"))
|
355 |
+
progress.print(stats, tag="train", step=num_updates)
|
356 |
+
|
357 |
+
# reset epoch-level meters
|
358 |
+
metrics.reset_meters("train")
|
359 |
+
return valid_losses, should_stop
|
360 |
+
|
361 |
+
|
362 |
+
def _flatten_config(cfg: DictConfig):
|
363 |
+
config = OmegaConf.to_container(cfg)
|
364 |
+
# remove any legacy Namespaces and replace with a single "args"
|
365 |
+
namespace = None
|
366 |
+
for k, v in list(config.items()):
|
367 |
+
if isinstance(v, argparse.Namespace):
|
368 |
+
namespace = v
|
369 |
+
del config[k]
|
370 |
+
if namespace is not None:
|
371 |
+
config["args"] = vars(namespace)
|
372 |
+
return config
|
373 |
+
|
374 |
+
|
375 |
+
def validate_and_save(
|
376 |
+
cfg: DictConfig,
|
377 |
+
trainer: Trainer,
|
378 |
+
task: tasks.FairseqTask,
|
379 |
+
epoch_itr,
|
380 |
+
valid_subsets: List[str],
|
381 |
+
end_of_epoch: bool,
|
382 |
+
) -> Tuple[List[Optional[float]], bool]:
|
383 |
+
num_updates = trainer.get_num_updates()
|
384 |
+
max_update = cfg.optimization.max_update or math.inf
|
385 |
+
|
386 |
+
# Stopping conditions (and an additional one based on validation loss later
|
387 |
+
# on)
|
388 |
+
should_stop = False
|
389 |
+
if num_updates >= max_update:
|
390 |
+
should_stop = True
|
391 |
+
logger.info(
|
392 |
+
f"Stopping training due to "
|
393 |
+
f"num_updates: {num_updates} >= max_update: {max_update}"
|
394 |
+
)
|
395 |
+
|
396 |
+
training_time_hours = trainer.cumulative_training_time() / (60 * 60)
|
397 |
+
if (
|
398 |
+
cfg.optimization.stop_time_hours > 0
|
399 |
+
and training_time_hours > cfg.optimization.stop_time_hours
|
400 |
+
):
|
401 |
+
should_stop = True
|
402 |
+
logger.info(
|
403 |
+
f"Stopping training due to "
|
404 |
+
f"cumulative_training_time: {training_time_hours} > "
|
405 |
+
f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
|
406 |
+
)
|
407 |
+
|
408 |
+
do_save = (
|
409 |
+
(end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
|
410 |
+
or should_stop
|
411 |
+
or (
|
412 |
+
cfg.checkpoint.save_interval_updates > 0
|
413 |
+
and num_updates > 0
|
414 |
+
and num_updates % cfg.checkpoint.save_interval_updates == 0
|
415 |
+
and num_updates >= cfg.dataset.validate_after_updates
|
416 |
+
)
|
417 |
+
)
|
418 |
+
do_validate = (
|
419 |
+
(
|
420 |
+
(not end_of_epoch and do_save) # validate during mid-epoch saves
|
421 |
+
or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
|
422 |
+
or should_stop
|
423 |
+
or (
|
424 |
+
cfg.dataset.validate_interval_updates > 0
|
425 |
+
and num_updates > 0
|
426 |
+
and num_updates % cfg.dataset.validate_interval_updates == 0
|
427 |
+
)
|
428 |
+
)
|
429 |
+
and not cfg.dataset.disable_validation
|
430 |
+
and num_updates >= cfg.dataset.validate_after_updates
|
431 |
+
)
|
432 |
+
|
433 |
+
# Validate
|
434 |
+
valid_losses = [None]
|
435 |
+
if do_validate:
|
436 |
+
valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
|
437 |
+
|
438 |
+
should_stop |= should_stop_early(cfg, valid_losses[0])
|
439 |
+
|
440 |
+
# Save checkpoint
|
441 |
+
if do_save or should_stop:
|
442 |
+
cp_path = checkpoint_utils.save_checkpoint(
|
443 |
+
cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
|
444 |
+
)
|
445 |
+
if cp_path is not None and hasattr(task, "post_save"):
|
446 |
+
task.post_save(cp_path, num_updates)
|
447 |
+
|
448 |
+
return valid_losses, should_stop
|
449 |
+
|
450 |
+
|
451 |
+
def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
|
452 |
+
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
|
453 |
+
return stats
|
454 |
+
|
455 |
+
|
456 |
+
def validate(
|
457 |
+
cfg: DictConfig,
|
458 |
+
trainer: Trainer,
|
459 |
+
task: tasks.FairseqTask,
|
460 |
+
epoch_itr,
|
461 |
+
subsets: List[str],
|
462 |
+
) -> List[Optional[float]]:
|
463 |
+
"""Evaluate the model on the validation set(s) and return the losses."""
|
464 |
+
|
465 |
+
if cfg.dataset.fixed_validation_seed is not None:
|
466 |
+
# set fixed seed for every validation
|
467 |
+
utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
|
468 |
+
|
469 |
+
trainer.begin_valid_epoch(epoch_itr.epoch)
|
470 |
+
valid_losses = []
|
471 |
+
for subset_idx, subset in enumerate(subsets):
|
472 |
+
logger.info('begin validation on "{}" subset'.format(subset))
|
473 |
+
|
474 |
+
# Initialize data iterator
|
475 |
+
itr = trainer.get_valid_iterator(subset).next_epoch_itr(
|
476 |
+
shuffle=False, set_dataset_epoch=False # use a fixed valid set
|
477 |
+
)
|
478 |
+
if cfg.common.tpu:
|
479 |
+
itr = utils.tpu_data_loader(itr)
|
480 |
+
progress = progress_bar.progress_bar(
|
481 |
+
itr,
|
482 |
+
log_format=cfg.common.log_format,
|
483 |
+
log_interval=cfg.common.log_interval,
|
484 |
+
epoch=epoch_itr.epoch,
|
485 |
+
prefix=f"valid on '{subset}' subset",
|
486 |
+
aim_repo=(
|
487 |
+
cfg.common.aim_repo
|
488 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
489 |
+
else None
|
490 |
+
),
|
491 |
+
aim_run_hash=(
|
492 |
+
cfg.common.aim_run_hash
|
493 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
494 |
+
else None
|
495 |
+
),
|
496 |
+
aim_param_checkpoint_dir=cfg.checkpoint.save_dir,
|
497 |
+
tensorboard_logdir=(
|
498 |
+
cfg.common.tensorboard_logdir
|
499 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
500 |
+
else None
|
501 |
+
),
|
502 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
503 |
+
wandb_project=(
|
504 |
+
cfg.common.wandb_project
|
505 |
+
if distributed_utils.is_master(cfg.distributed_training)
|
506 |
+
else None
|
507 |
+
),
|
508 |
+
wandb_run_name=os.environ.get(
|
509 |
+
"WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
|
510 |
+
),
|
511 |
+
)
|
512 |
+
|
513 |
+
# create a new root metrics aggregator so validation metrics
|
514 |
+
# don't pollute other aggregators (e.g., train meters)
|
515 |
+
with metrics.aggregate(new_root=True) as agg:
|
516 |
+
for i, sample in enumerate(progress):
|
517 |
+
if (
|
518 |
+
cfg.dataset.max_valid_steps is not None
|
519 |
+
and i > cfg.dataset.max_valid_steps
|
520 |
+
):
|
521 |
+
break
|
522 |
+
trainer.valid_step(sample)
|
523 |
+
|
524 |
+
# log validation stats
|
525 |
+
# only tracking the best metric on the 1st validation subset
|
526 |
+
tracking_best = subset_idx == 0
|
527 |
+
stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values(), tracking_best)
|
528 |
+
|
529 |
+
if hasattr(task, "post_validate"):
|
530 |
+
task.post_validate(trainer.get_model(), stats, agg)
|
531 |
+
|
532 |
+
progress.print(stats, tag=subset, step=trainer.get_num_updates())
|
533 |
+
|
534 |
+
valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
|
535 |
+
return valid_losses
|
536 |
+
|
537 |
+
|
538 |
+
def get_valid_stats(
|
539 |
+
cfg: DictConfig,
|
540 |
+
trainer: Trainer,
|
541 |
+
stats: Dict[str, Any],
|
542 |
+
tracking_best: bool,
|
543 |
+
) -> Dict[str, Any]:
|
544 |
+
stats["num_updates"] = trainer.get_num_updates()
|
545 |
+
if tracking_best and hasattr(checkpoint_utils.save_checkpoint, "best"):
|
546 |
+
key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
|
547 |
+
best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
|
548 |
+
stats[key] = best_function(
|
549 |
+
checkpoint_utils.save_checkpoint.best,
|
550 |
+
stats[cfg.checkpoint.best_checkpoint_metric],
|
551 |
+
)
|
552 |
+
return stats
|
553 |
+
|
554 |
+
|
555 |
+
def cli_main(
|
556 |
+
modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
|
557 |
+
) -> None:
|
558 |
+
parser = options.get_training_parser()
|
559 |
+
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
|
560 |
+
|
561 |
+
cfg = convert_namespace_to_omegaconf(args)
|
562 |
+
|
563 |
+
if cfg.common.use_plasma_view:
|
564 |
+
server = PlasmaStore(path=cfg.common.plasma_path)
|
565 |
+
logger.info(
|
566 |
+
f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}"
|
567 |
+
)
|
568 |
+
|
569 |
+
if args.profile:
|
570 |
+
with torch.cuda.profiler.profile():
|
571 |
+
with torch.autograd.profiler.emit_nvtx():
|
572 |
+
distributed_utils.call_main(cfg, main)
|
573 |
+
else:
|
574 |
+
distributed_utils.call_main(cfg, main)
|
575 |
+
|
576 |
+
# if cfg.common.use_plasma_view:
|
577 |
+
# server.server.kill()
|
578 |
+
|
579 |
+
|
580 |
+
if __name__ == "__main__":
|
581 |
+
cli_main()
|
fairseq/fairseq_cli/validate.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from argparse import Namespace
|
11 |
+
from itertools import chain
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from omegaconf import DictConfig
|
15 |
+
|
16 |
+
from fairseq import checkpoint_utils, distributed_utils, options, utils
|
17 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
18 |
+
from fairseq.logging import metrics, progress_bar
|
19 |
+
from fairseq.utils import reset_logging
|
20 |
+
|
21 |
+
logging.basicConfig(
|
22 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
23 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
24 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
25 |
+
stream=sys.stdout,
|
26 |
+
)
|
27 |
+
logger = logging.getLogger("fairseq_cli.validate")
|
28 |
+
|
29 |
+
|
30 |
+
def main(cfg: DictConfig, override_args=None):
|
31 |
+
if isinstance(cfg, Namespace):
|
32 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
33 |
+
|
34 |
+
utils.import_user_module(cfg.common)
|
35 |
+
|
36 |
+
reset_logging()
|
37 |
+
|
38 |
+
assert (
|
39 |
+
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
40 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
41 |
+
|
42 |
+
use_fp16 = cfg.common.fp16
|
43 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
44 |
+
|
45 |
+
if use_cuda:
|
46 |
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
47 |
+
|
48 |
+
if cfg.distributed_training.distributed_world_size > 1:
|
49 |
+
data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
|
50 |
+
data_parallel_rank = distributed_utils.get_data_parallel_rank()
|
51 |
+
else:
|
52 |
+
data_parallel_world_size = 1
|
53 |
+
data_parallel_rank = 0
|
54 |
+
|
55 |
+
if override_args is not None:
|
56 |
+
overrides = vars(override_args)
|
57 |
+
overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
|
58 |
+
else:
|
59 |
+
overrides = None
|
60 |
+
|
61 |
+
# Load ensemble
|
62 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
63 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
64 |
+
[cfg.common_eval.path],
|
65 |
+
arg_overrides=overrides,
|
66 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
67 |
+
)
|
68 |
+
model = models[0]
|
69 |
+
|
70 |
+
# Move models to GPU
|
71 |
+
for model in models:
|
72 |
+
model.eval()
|
73 |
+
if use_fp16:
|
74 |
+
model.half()
|
75 |
+
if use_cuda:
|
76 |
+
model.cuda()
|
77 |
+
|
78 |
+
# Print args
|
79 |
+
logger.info(saved_cfg)
|
80 |
+
|
81 |
+
# Build criterion
|
82 |
+
criterion = task.build_criterion(saved_cfg.criterion)
|
83 |
+
criterion.eval()
|
84 |
+
|
85 |
+
for subset in cfg.dataset.valid_subset.split(","):
|
86 |
+
try:
|
87 |
+
task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task)
|
88 |
+
dataset = task.dataset(subset)
|
89 |
+
except KeyError:
|
90 |
+
raise Exception("Cannot find dataset: " + subset)
|
91 |
+
|
92 |
+
# Initialize data iterator
|
93 |
+
itr = task.get_batch_iterator(
|
94 |
+
dataset=dataset,
|
95 |
+
max_tokens=cfg.dataset.max_tokens,
|
96 |
+
max_sentences=cfg.dataset.batch_size,
|
97 |
+
max_positions=utils.resolve_max_positions(
|
98 |
+
task.max_positions(),
|
99 |
+
*[m.max_positions() for m in models],
|
100 |
+
),
|
101 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
102 |
+
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
103 |
+
seed=cfg.common.seed,
|
104 |
+
num_shards=data_parallel_world_size,
|
105 |
+
shard_id=data_parallel_rank,
|
106 |
+
num_workers=cfg.dataset.num_workers,
|
107 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
108 |
+
).next_epoch_itr(shuffle=False)
|
109 |
+
progress = progress_bar.progress_bar(
|
110 |
+
itr,
|
111 |
+
log_format=cfg.common.log_format,
|
112 |
+
log_interval=cfg.common.log_interval,
|
113 |
+
prefix=f"valid on '{subset}' subset",
|
114 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
115 |
+
)
|
116 |
+
|
117 |
+
log_outputs = []
|
118 |
+
for i, sample in enumerate(progress):
|
119 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
120 |
+
_loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
|
121 |
+
progress.log(log_output, step=i)
|
122 |
+
log_outputs.append(log_output)
|
123 |
+
|
124 |
+
if data_parallel_world_size > 1:
|
125 |
+
log_outputs = distributed_utils.all_gather_list(
|
126 |
+
log_outputs,
|
127 |
+
max_size=cfg.common.all_gather_list_size,
|
128 |
+
group=distributed_utils.get_data_parallel_group(),
|
129 |
+
)
|
130 |
+
log_outputs = list(chain.from_iterable(log_outputs))
|
131 |
+
|
132 |
+
with metrics.aggregate() as agg:
|
133 |
+
task.reduce_metrics(log_outputs, criterion)
|
134 |
+
log_output = agg.get_smoothed_values()
|
135 |
+
|
136 |
+
progress.print(log_output, tag=subset, step=i)
|
137 |
+
|
138 |
+
|
139 |
+
def cli_main():
|
140 |
+
parser = options.get_validation_parser()
|
141 |
+
args = options.parse_args_and_arch(parser)
|
142 |
+
|
143 |
+
# only override args that are explicitly given on the command line
|
144 |
+
override_parser = options.get_validation_parser()
|
145 |
+
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
|
146 |
+
|
147 |
+
distributed_utils.call_main(
|
148 |
+
convert_namespace_to_omegaconf(args), main, override_args=override_args
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
cli_main()
|
fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
|
3 |
+
__version__ = "0.1"
|
fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
|
4 |
+
from hydra.core.config_store import ConfigStore
|
5 |
+
|
6 |
+
from hydra_plugins.hydra_submitit_launcher.config import SlurmQueueConf
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class DependencySubmititConf(SlurmQueueConf):
|
11 |
+
"""Slurm configuration overrides and specific parameters"""
|
12 |
+
|
13 |
+
_target_: str = (
|
14 |
+
"hydra_plugins.dependency_submitit_launcher.launcher.DependencySubmititLauncher"
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
ConfigStore.instance().store(
|
19 |
+
group="hydra/launcher",
|
20 |
+
name="dependency_submitit_slurm",
|
21 |
+
node=DependencySubmititConf(),
|
22 |
+
provider="dependency_submitit_slurm",
|
23 |
+
)
|
fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, List, Sequence
|
7 |
+
|
8 |
+
from hydra.core.singleton import Singleton
|
9 |
+
from hydra.core.utils import JobReturn, filter_overrides
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
log = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
from .config import DependencySubmititConf
|
15 |
+
from hydra_plugins.hydra_submitit_launcher.submitit_launcher import BaseSubmititLauncher
|
16 |
+
|
17 |
+
|
18 |
+
class DependencySubmititLauncher(BaseSubmititLauncher):
|
19 |
+
_EXECUTOR = "slurm"
|
20 |
+
|
21 |
+
def launch(
|
22 |
+
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
|
23 |
+
) -> Sequence[JobReturn]:
|
24 |
+
|
25 |
+
# lazy import to ensure plugin discovery remains fast
|
26 |
+
import submitit
|
27 |
+
|
28 |
+
assert self.config is not None
|
29 |
+
|
30 |
+
num_jobs = len(job_overrides)
|
31 |
+
assert num_jobs > 0
|
32 |
+
|
33 |
+
next_script = None
|
34 |
+
|
35 |
+
for jo in job_overrides:
|
36 |
+
if next_script is None:
|
37 |
+
for item in jo:
|
38 |
+
if "next_script=" in item:
|
39 |
+
next_script = item
|
40 |
+
break
|
41 |
+
assert (
|
42 |
+
next_script is not None
|
43 |
+
), "job overrides must contain +next_script=path/to/next/script"
|
44 |
+
jo.remove(next_script)
|
45 |
+
|
46 |
+
idx = next_script.find("=")
|
47 |
+
next_script = next_script[idx + 1 :]
|
48 |
+
|
49 |
+
params = self.params
|
50 |
+
# build executor
|
51 |
+
init_params = {"folder": self.params["submitit_folder"]}
|
52 |
+
specific_init_keys = {"max_num_timeout"}
|
53 |
+
|
54 |
+
init_params.update(
|
55 |
+
**{
|
56 |
+
f"{self._EXECUTOR}_{x}": y
|
57 |
+
for x, y in params.items()
|
58 |
+
if x in specific_init_keys
|
59 |
+
}
|
60 |
+
)
|
61 |
+
init_keys = specific_init_keys | {"submitit_folder"}
|
62 |
+
executor = submitit.AutoExecutor(cluster=self._EXECUTOR, **init_params)
|
63 |
+
|
64 |
+
# specify resources/parameters
|
65 |
+
baseparams = set(OmegaConf.structured(DependencySubmititConf).keys())
|
66 |
+
params = {
|
67 |
+
x if x in baseparams else f"{self._EXECUTOR}_{x}": y
|
68 |
+
for x, y in params.items()
|
69 |
+
if x not in init_keys
|
70 |
+
}
|
71 |
+
executor.update_parameters(**params)
|
72 |
+
|
73 |
+
log.info(
|
74 |
+
f"Submitit '{self._EXECUTOR}' sweep output dir : "
|
75 |
+
f"{self.config.hydra.sweep.dir}"
|
76 |
+
)
|
77 |
+
sweep_dir = Path(str(self.config.hydra.sweep.dir))
|
78 |
+
sweep_dir.mkdir(parents=True, exist_ok=True)
|
79 |
+
if "mode" in self.config.hydra.sweep:
|
80 |
+
mode = int(str(self.config.hydra.sweep.mode), 8)
|
81 |
+
os.chmod(sweep_dir, mode=mode)
|
82 |
+
|
83 |
+
job_params: List[Any] = []
|
84 |
+
for idx, overrides in enumerate(job_overrides):
|
85 |
+
idx = initial_job_idx + idx
|
86 |
+
lst = " ".join(filter_overrides(overrides))
|
87 |
+
log.info(f"\t#{idx} : {lst}")
|
88 |
+
job_params.append(
|
89 |
+
(
|
90 |
+
list(overrides),
|
91 |
+
"hydra.sweep.dir",
|
92 |
+
idx,
|
93 |
+
f"job_id_for_{idx}",
|
94 |
+
Singleton.get_state(),
|
95 |
+
)
|
96 |
+
)
|
97 |
+
|
98 |
+
jobs = executor.map_array(self, *zip(*job_params))
|
99 |
+
|
100 |
+
for j, jp in zip(jobs, job_params):
|
101 |
+
job_id = str(j.job_id)
|
102 |
+
task_id = "0" if "_" not in job_id else job_id.split("_")[1]
|
103 |
+
sweep_config = self.config_loader.load_sweep_config(self.config, jp[0])
|
104 |
+
dir = sweep_config.hydra.sweep.dir
|
105 |
+
|
106 |
+
dir = (
|
107 |
+
dir.replace("[", "")
|
108 |
+
.replace("]", "")
|
109 |
+
.replace("{", "")
|
110 |
+
.replace("}", "")
|
111 |
+
.replace(",", "_")
|
112 |
+
.replace("'", "")
|
113 |
+
.replace('"', "")
|
114 |
+
)
|
115 |
+
|
116 |
+
subprocess.call(
|
117 |
+
[next_script, job_id, task_id, dir],
|
118 |
+
shell=False,
|
119 |
+
)
|
120 |
+
|
121 |
+
return [j.results()[0] for j in jobs]
|
fairseq/hydra_plugins/dependency_submitit_launcher/setup.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
# type: ignore
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from read_version import read_version
|
6 |
+
from setuptools import find_namespace_packages, setup
|
7 |
+
|
8 |
+
setup(
|
9 |
+
name="dependency-submitit-launcher",
|
10 |
+
version=read_version("hydra_plugins/dependency_submitit_launcher", "__init__.py"),
|
11 |
+
author="Alexei Baevski",
|
12 |
+
author_email="abaevski@fb.com",
|
13 |
+
description="Dependency-supporting Submitit Launcher for Hydra apps",
|
14 |
+
packages=find_namespace_packages(include=["hydra_plugins.*"]),
|
15 |
+
classifiers=[
|
16 |
+
"License :: OSI Approved :: MIT License",
|
17 |
+
"Programming Language :: Python :: 3.7",
|
18 |
+
"Programming Language :: Python :: 3.8",
|
19 |
+
"Programming Language :: Python :: 3.9",
|
20 |
+
"Operating System :: MacOS",
|
21 |
+
"Operating System :: POSIX :: Linux",
|
22 |
+
"Development Status :: 4 - Beta",
|
23 |
+
],
|
24 |
+
install_requires=[
|
25 |
+
"hydra-core>=1.0.4",
|
26 |
+
"submitit>=1.0.0",
|
27 |
+
],
|
28 |
+
include_package_data=True,
|
29 |
+
)
|
fairseq/scripts/__init__.py
ADDED
File without changes
|
fairseq/scripts/average_checkpoints.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import collections
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from fairseq.file_io import PathManager
|
15 |
+
|
16 |
+
|
17 |
+
def average_checkpoints(inputs):
|
18 |
+
"""Loads checkpoints from inputs and returns a model with averaged weights.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
inputs: An iterable of string paths of checkpoints to load from.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
A dict of string keys mapping to various values. The 'model' key
|
25 |
+
from the returned dict should correspond to an OrderedDict mapping
|
26 |
+
string parameter names to torch Tensors.
|
27 |
+
"""
|
28 |
+
params_dict = collections.OrderedDict()
|
29 |
+
params_keys = None
|
30 |
+
new_state = None
|
31 |
+
num_models = len(inputs)
|
32 |
+
|
33 |
+
for fpath in inputs:
|
34 |
+
with PathManager.open(fpath, "rb") as f:
|
35 |
+
state = torch.load(
|
36 |
+
f,
|
37 |
+
map_location=(
|
38 |
+
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
39 |
+
),
|
40 |
+
)
|
41 |
+
# Copies over the settings from the first checkpoint
|
42 |
+
if new_state is None:
|
43 |
+
new_state = state
|
44 |
+
|
45 |
+
model_params = state["model"]
|
46 |
+
|
47 |
+
model_params_keys = list(model_params.keys())
|
48 |
+
if params_keys is None:
|
49 |
+
params_keys = model_params_keys
|
50 |
+
elif params_keys != model_params_keys:
|
51 |
+
raise KeyError(
|
52 |
+
"For checkpoint {}, expected list of params: {}, "
|
53 |
+
"but found: {}".format(f, params_keys, model_params_keys)
|
54 |
+
)
|
55 |
+
|
56 |
+
for k in params_keys:
|
57 |
+
p = model_params[k]
|
58 |
+
if isinstance(p, torch.HalfTensor):
|
59 |
+
p = p.float()
|
60 |
+
if k not in params_dict:
|
61 |
+
params_dict[k] = p.clone()
|
62 |
+
# NOTE: clone() is needed in case of p is a shared parameter
|
63 |
+
else:
|
64 |
+
params_dict[k] += p
|
65 |
+
|
66 |
+
averaged_params = collections.OrderedDict()
|
67 |
+
for k, v in params_dict.items():
|
68 |
+
averaged_params[k] = v
|
69 |
+
if averaged_params[k].is_floating_point():
|
70 |
+
averaged_params[k].div_(num_models)
|
71 |
+
else:
|
72 |
+
averaged_params[k] //= num_models
|
73 |
+
new_state["model"] = averaged_params
|
74 |
+
return new_state
|
75 |
+
|
76 |
+
|
77 |
+
def last_n_checkpoints(paths, n, update_based, upper_bound=None):
|
78 |
+
assert len(paths) == 1
|
79 |
+
path = paths[0]
|
80 |
+
if update_based:
|
81 |
+
pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt")
|
82 |
+
else:
|
83 |
+
pt_regexp = re.compile(r"checkpoint(\d+)\.pt")
|
84 |
+
files = PathManager.ls(path)
|
85 |
+
|
86 |
+
entries = []
|
87 |
+
for f in files:
|
88 |
+
m = pt_regexp.fullmatch(f)
|
89 |
+
if m is not None:
|
90 |
+
sort_key = int(m.group(1))
|
91 |
+
if upper_bound is None or sort_key <= upper_bound:
|
92 |
+
entries.append((sort_key, m.group(0)))
|
93 |
+
if len(entries) < n:
|
94 |
+
raise Exception(
|
95 |
+
"Found {} checkpoint files but need at least {}", len(entries), n
|
96 |
+
)
|
97 |
+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
|
98 |
+
|
99 |
+
|
100 |
+
def main():
|
101 |
+
parser = argparse.ArgumentParser(
|
102 |
+
description="Tool to average the params of input checkpoints to "
|
103 |
+
"produce a new checkpoint",
|
104 |
+
)
|
105 |
+
# fmt: off
|
106 |
+
parser.add_argument('--inputs', required=True, nargs='+',
|
107 |
+
help='Input checkpoint file paths.')
|
108 |
+
parser.add_argument('--output', required=True, metavar='FILE',
|
109 |
+
help='Write the new checkpoint containing the averaged weights to this path.')
|
110 |
+
num_group = parser.add_mutually_exclusive_group()
|
111 |
+
num_group.add_argument('--num-epoch-checkpoints', type=int,
|
112 |
+
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the '
|
113 |
+
'path specified by input, and average last this many of them.')
|
114 |
+
num_group.add_argument('--num-update-checkpoints', type=int,
|
115 |
+
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by'
|
116 |
+
' input, and average last this many of them.')
|
117 |
+
num_group.add_argument('--num-best-checkpoints', type=int, default=0,
|
118 |
+
help='if set, will try to find checkpoints with names checkpoint_best_ee_xx.pt in the path specified by'
|
119 |
+
' input, and average last this many of them.')
|
120 |
+
parser.add_argument('--checkpoint-upper-bound', type=int,
|
121 |
+
help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
|
122 |
+
'when using --num-update-checkpoints, this will set an upper bound on which update to use'
|
123 |
+
'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be'
|
124 |
+
' averaged.'
|
125 |
+
'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would'
|
126 |
+
' be averaged assuming --save-interval-updates 500'
|
127 |
+
)
|
128 |
+
# fmt: on
|
129 |
+
args = parser.parse_args()
|
130 |
+
print(args)
|
131 |
+
|
132 |
+
num = None
|
133 |
+
is_update_based = False
|
134 |
+
if args.num_update_checkpoints is not None:
|
135 |
+
num = args.num_update_checkpoints
|
136 |
+
is_update_based = True
|
137 |
+
elif args.num_epoch_checkpoints is not None:
|
138 |
+
num = args.num_epoch_checkpoints
|
139 |
+
|
140 |
+
assert args.checkpoint_upper_bound is None or (
|
141 |
+
args.num_epoch_checkpoints is not None
|
142 |
+
or args.num_update_checkpoints is not None
|
143 |
+
), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints"
|
144 |
+
assert (
|
145 |
+
args.num_epoch_checkpoints is None or args.num_update_checkpoints is None
|
146 |
+
), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints"
|
147 |
+
|
148 |
+
if num is not None:
|
149 |
+
args.inputs = last_n_checkpoints(
|
150 |
+
args.inputs,
|
151 |
+
num,
|
152 |
+
is_update_based,
|
153 |
+
upper_bound=args.checkpoint_upper_bound,
|
154 |
+
)
|
155 |
+
print("averaging checkpoints: ", args.inputs)
|
156 |
+
|
157 |
+
if args.num_best_checkpoints > 0:
|
158 |
+
args.inputs = list(
|
159 |
+
sorted(
|
160 |
+
args.inputs,
|
161 |
+
key=lambda x: float(
|
162 |
+
os.path.basename(x).split("_")[-1].replace(".pt", "")
|
163 |
+
),
|
164 |
+
)
|
165 |
+
)
|
166 |
+
args.inputs = args.inputs[: args.num_best_checkpoints]
|
167 |
+
for path in args.inputs:
|
168 |
+
print(os.path.basename(path))
|
169 |
+
new_state = average_checkpoints(args.inputs)
|
170 |
+
with PathManager.open(args.output, "wb") as f:
|
171 |
+
torch.save(new_state, f)
|
172 |
+
print("Finished writing averaged checkpoint to {}".format(args.output))
|
173 |
+
|
174 |
+
|
175 |
+
if __name__ == "__main__":
|
176 |
+
main()
|
fairseq/scripts/build_sym_alignment.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""
|
6 |
+
Use this script in order to build symmetric alignments for your translation
|
7 |
+
dataset.
|
8 |
+
This script depends on fast_align and mosesdecoder tools. You will need to
|
9 |
+
build those before running the script.
|
10 |
+
fast_align:
|
11 |
+
github: http://github.com/clab/fast_align
|
12 |
+
instructions: follow the instructions in README.md
|
13 |
+
mosesdecoder:
|
14 |
+
github: http://github.com/moses-smt/mosesdecoder
|
15 |
+
instructions: http://www.statmt.org/moses/?n=Development.GetStarted
|
16 |
+
The script produces the following files under --output_dir:
|
17 |
+
text.joined - concatenation of lines from the source_file and the
|
18 |
+
target_file.
|
19 |
+
align.forward - forward pass of fast_align.
|
20 |
+
align.backward - backward pass of fast_align.
|
21 |
+
aligned.sym_heuristic - symmetrized alignment.
|
22 |
+
"""
|
23 |
+
|
24 |
+
import argparse
|
25 |
+
import os
|
26 |
+
from itertools import zip_longest
|
27 |
+
|
28 |
+
|
29 |
+
def main():
|
30 |
+
parser = argparse.ArgumentParser(description="symmetric alignment builer")
|
31 |
+
# fmt: off
|
32 |
+
parser.add_argument('--fast_align_dir',
|
33 |
+
help='path to fast_align build directory')
|
34 |
+
parser.add_argument('--mosesdecoder_dir',
|
35 |
+
help='path to mosesdecoder root directory')
|
36 |
+
parser.add_argument('--sym_heuristic',
|
37 |
+
help='heuristic to use for symmetrization',
|
38 |
+
default='grow-diag-final-and')
|
39 |
+
parser.add_argument('--source_file',
|
40 |
+
help='path to a file with sentences '
|
41 |
+
'in the source language')
|
42 |
+
parser.add_argument('--target_file',
|
43 |
+
help='path to a file with sentences '
|
44 |
+
'in the target language')
|
45 |
+
parser.add_argument('--output_dir',
|
46 |
+
help='output directory')
|
47 |
+
# fmt: on
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
fast_align_bin = os.path.join(args.fast_align_dir, "fast_align")
|
51 |
+
symal_bin = os.path.join(args.mosesdecoder_dir, "bin", "symal")
|
52 |
+
sym_fast_align_bin = os.path.join(
|
53 |
+
args.mosesdecoder_dir, "scripts", "ems", "support", "symmetrize-fast-align.perl"
|
54 |
+
)
|
55 |
+
|
56 |
+
# create joined file
|
57 |
+
joined_file = os.path.join(args.output_dir, "text.joined")
|
58 |
+
with open(args.source_file, "r", encoding="utf-8") as src, open(
|
59 |
+
args.target_file, "r", encoding="utf-8"
|
60 |
+
) as tgt:
|
61 |
+
with open(joined_file, "w", encoding="utf-8") as joined:
|
62 |
+
for s, t in zip_longest(src, tgt):
|
63 |
+
print("{} ||| {}".format(s.strip(), t.strip()), file=joined)
|
64 |
+
|
65 |
+
bwd_align_file = os.path.join(args.output_dir, "align.backward")
|
66 |
+
|
67 |
+
# run forward alignment
|
68 |
+
fwd_align_file = os.path.join(args.output_dir, "align.forward")
|
69 |
+
fwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v > {FWD}".format(
|
70 |
+
FASTALIGN=fast_align_bin, JOINED=joined_file, FWD=fwd_align_file
|
71 |
+
)
|
72 |
+
assert os.system(fwd_fast_align_cmd) == 0
|
73 |
+
|
74 |
+
# run backward alignment
|
75 |
+
bwd_align_file = os.path.join(args.output_dir, "align.backward")
|
76 |
+
bwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}".format(
|
77 |
+
FASTALIGN=fast_align_bin, JOINED=joined_file, BWD=bwd_align_file
|
78 |
+
)
|
79 |
+
assert os.system(bwd_fast_align_cmd) == 0
|
80 |
+
|
81 |
+
# run symmetrization
|
82 |
+
sym_out_file = os.path.join(args.output_dir, "aligned")
|
83 |
+
sym_cmd = "{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}".format(
|
84 |
+
SYMFASTALIGN=sym_fast_align_bin,
|
85 |
+
FWD=fwd_align_file,
|
86 |
+
BWD=bwd_align_file,
|
87 |
+
SRC=args.source_file,
|
88 |
+
TGT=args.target_file,
|
89 |
+
OUT=sym_out_file,
|
90 |
+
HEURISTIC=args.sym_heuristic,
|
91 |
+
SYMAL=symal_bin,
|
92 |
+
)
|
93 |
+
assert os.system(sym_cmd) == 0
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
main()
|
fairseq/scripts/check_installation.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
|
4 |
+
cwd = Path(".").resolve()
|
5 |
+
print("running 'check_installation.py' from:", cwd)
|
6 |
+
|
7 |
+
# Old versions of numpy/torch can prevent loading the .so files
|
8 |
+
import torch
|
9 |
+
|
10 |
+
print("torch:", torch.__version__)
|
11 |
+
import numpy
|
12 |
+
|
13 |
+
print("numpy:", numpy.__version__)
|
14 |
+
|
15 |
+
import fairseq
|
16 |
+
|
17 |
+
print("Fairseq installed at:", fairseq.__file__)
|
18 |
+
import fairseq.criterions
|
19 |
+
import fairseq.dataclass.configs
|
20 |
+
|
21 |
+
import _imp
|
22 |
+
|
23 |
+
print("Should load following .so suffixes:", _imp.extension_suffixes())
|
24 |
+
|
25 |
+
so_files = list(Path(fairseq.__file__).parent.glob("*.so"))
|
26 |
+
so_files.extend(Path(fairseq.__file__).parent.glob("data/*.so"))
|
27 |
+
print("Found following .so files:")
|
28 |
+
for so_file in so_files:
|
29 |
+
print(f"- {so_file}")
|
30 |
+
|
31 |
+
from fairseq import libbleu
|
32 |
+
|
33 |
+
print("Found libbleu at", libbleu.__file__)
|
34 |
+
from fairseq.data import data_utils_fast
|
35 |
+
|
36 |
+
print("Found data_utils_fast at", data_utils_fast.__file__)
|
fairseq/scripts/compare_namespaces.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""Helper script to compare two argparse.Namespace objects."""
|
3 |
+
|
4 |
+
from argparse import Namespace # noqa
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
|
9 |
+
ns1 = eval(input("Namespace 1: "))
|
10 |
+
ns2 = eval(input("Namespace 2: "))
|
11 |
+
|
12 |
+
def keys(ns):
|
13 |
+
ks = set()
|
14 |
+
for k in dir(ns):
|
15 |
+
if not k.startswith("_"):
|
16 |
+
ks.add(k)
|
17 |
+
return ks
|
18 |
+
|
19 |
+
k1 = keys(ns1)
|
20 |
+
k2 = keys(ns2)
|
21 |
+
|
22 |
+
def print_keys(ks, ns1, ns2=None):
|
23 |
+
for k in ks:
|
24 |
+
if ns2 is None:
|
25 |
+
print("{}\t{}".format(k, getattr(ns1, k, None)))
|
26 |
+
else:
|
27 |
+
print(
|
28 |
+
"{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None))
|
29 |
+
)
|
30 |
+
|
31 |
+
print("Keys unique to namespace 1:")
|
32 |
+
print_keys(k1 - k2, ns1)
|
33 |
+
print()
|
34 |
+
|
35 |
+
print("Keys unique to namespace 2:")
|
36 |
+
print_keys(k2 - k1, ns2)
|
37 |
+
print()
|
38 |
+
|
39 |
+
print("Overlapping keys with different values:")
|
40 |
+
ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")]
|
41 |
+
print_keys(ks, ns1, ns2)
|
42 |
+
print()
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
main()
|
fairseq/scripts/compound_split_bleu.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
if [ $# -ne 1 ]; then
|
4 |
+
echo "usage: $0 GENERATE_PY_OUTPUT"
|
5 |
+
exit 1
|
6 |
+
fi
|
7 |
+
|
8 |
+
GEN=$1
|
9 |
+
|
10 |
+
SYS=$GEN.sys
|
11 |
+
REF=$GEN.ref
|
12 |
+
|
13 |
+
if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then
|
14 |
+
echo "not done generating"
|
15 |
+
exit
|
16 |
+
fi
|
17 |
+
|
18 |
+
grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
|
19 |
+
grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
|
20 |
+
fairseq-score --sys $SYS --ref $REF
|
fairseq/scripts/constraints/extract.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
#
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
"""Extracts random constraints from reference files."""
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import random
|
12 |
+
import sys
|
13 |
+
|
14 |
+
|
15 |
+
def get_phrase(words, index, length):
|
16 |
+
assert index < len(words) - length + 1
|
17 |
+
phr = " ".join(words[index : index + length])
|
18 |
+
for i in range(index, index + length):
|
19 |
+
words.pop(index)
|
20 |
+
return phr
|
21 |
+
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
|
25 |
+
if args.seed:
|
26 |
+
random.seed(args.seed)
|
27 |
+
|
28 |
+
for line in sys.stdin:
|
29 |
+
constraints = []
|
30 |
+
|
31 |
+
def add_constraint(constraint):
|
32 |
+
constraints.append(constraint)
|
33 |
+
|
34 |
+
source = line.rstrip()
|
35 |
+
if "\t" in line:
|
36 |
+
source, target = line.split("\t")
|
37 |
+
if args.add_sos:
|
38 |
+
target = f"<s> {target}"
|
39 |
+
if args.add_eos:
|
40 |
+
target = f"{target} </s>"
|
41 |
+
|
42 |
+
if len(target.split()) >= args.len:
|
43 |
+
words = [target]
|
44 |
+
|
45 |
+
num = args.number
|
46 |
+
|
47 |
+
choices = {}
|
48 |
+
for i in range(num):
|
49 |
+
if len(words) == 0:
|
50 |
+
break
|
51 |
+
segmentno = random.choice(range(len(words)))
|
52 |
+
segment = words.pop(segmentno)
|
53 |
+
tokens = segment.split()
|
54 |
+
phrase_index = random.choice(range(len(tokens)))
|
55 |
+
choice = " ".join(
|
56 |
+
tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
|
57 |
+
)
|
58 |
+
for j in range(
|
59 |
+
phrase_index, min(len(tokens), phrase_index + args.len)
|
60 |
+
):
|
61 |
+
tokens.pop(phrase_index)
|
62 |
+
if phrase_index > 0:
|
63 |
+
words.append(" ".join(tokens[0:phrase_index]))
|
64 |
+
if phrase_index + 1 < len(tokens):
|
65 |
+
words.append(" ".join(tokens[phrase_index:]))
|
66 |
+
choices[target.find(choice)] = choice
|
67 |
+
|
68 |
+
# mask out with spaces
|
69 |
+
target = target.replace(choice, " " * len(choice), 1)
|
70 |
+
|
71 |
+
for key in sorted(choices.keys()):
|
72 |
+
add_constraint(choices[key])
|
73 |
+
|
74 |
+
print(source, *constraints, sep="\t")
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
parser = argparse.ArgumentParser()
|
79 |
+
parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
|
80 |
+
parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
|
81 |
+
parser.add_argument(
|
82 |
+
"--add-sos", default=False, action="store_true", help="add <s> token"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--add-eos", default=False, action="store_true", help="add </s> token"
|
86 |
+
)
|
87 |
+
parser.add_argument("--seed", "-s", default=0, type=int)
|
88 |
+
args = parser.parse_args()
|
89 |
+
|
90 |
+
main(args)
|
fairseq/scripts/constraints/validate.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
#
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import sys
|
9 |
+
|
10 |
+
|
11 |
+
"""Reads in a fairseq output file, and verifies that the constraints
|
12 |
+
(C- lines) are present in the output (the first H- line). Assumes that
|
13 |
+
constraints are listed prior to the first hypothesis.
|
14 |
+
"""
|
15 |
+
|
16 |
+
constraints = []
|
17 |
+
found = 0
|
18 |
+
total = 0
|
19 |
+
for line in sys.stdin:
|
20 |
+
if line.startswith("C-"):
|
21 |
+
constraints.append(line.rstrip().split("\t")[1])
|
22 |
+
elif line.startswith("H-"):
|
23 |
+
text = line.split("\t")[2]
|
24 |
+
|
25 |
+
for constraint in constraints:
|
26 |
+
total += 1
|
27 |
+
if constraint in text:
|
28 |
+
found += 1
|
29 |
+
else:
|
30 |
+
print(f"No {constraint} in {text}", file=sys.stderr)
|
31 |
+
|
32 |
+
constraints = []
|
33 |
+
|
34 |
+
print(f"Found {found} / {total} = {100 * found / total:.1f}%")
|
fairseq/scripts/convert_dictionary.lua
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-- Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
--
|
3 |
+
-- This source code is licensed under the MIT license found in the
|
4 |
+
-- LICENSE file in the root directory of this source tree.
|
5 |
+
--
|
6 |
+
-- Usage: convert_dictionary.lua <dict.th7>
|
7 |
+
require 'fairseq'
|
8 |
+
require 'torch'
|
9 |
+
require 'paths'
|
10 |
+
|
11 |
+
if #arg < 1 then
|
12 |
+
print('usage: convert_dictionary.lua <dict.th7>')
|
13 |
+
os.exit(1)
|
14 |
+
end
|
15 |
+
if not paths.filep(arg[1]) then
|
16 |
+
print('error: file does not exit: ' .. arg[1])
|
17 |
+
os.exit(1)
|
18 |
+
end
|
19 |
+
|
20 |
+
dict = torch.load(arg[1])
|
21 |
+
dst = paths.basename(arg[1]):gsub('.th7', '.txt')
|
22 |
+
assert(dst:match('.txt$'))
|
23 |
+
|
24 |
+
f = io.open(dst, 'w')
|
25 |
+
for idx, symbol in ipairs(dict.index_to_symbol) do
|
26 |
+
if idx > dict.cutoff then
|
27 |
+
break
|
28 |
+
end
|
29 |
+
f:write(symbol)
|
30 |
+
f:write(' ')
|
31 |
+
f:write(dict.index_to_freq[idx])
|
32 |
+
f:write('\n')
|
33 |
+
end
|
34 |
+
f:close()
|
fairseq/scripts/convert_model.lua
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-- Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
--
|
3 |
+
-- This source code is licensed under the MIT license found in the
|
4 |
+
-- LICENSE file in the root directory of this source tree.
|
5 |
+
--
|
6 |
+
-- Usage: convert_model.lua <model_epoch1.th7>
|
7 |
+
require 'torch'
|
8 |
+
local fairseq = require 'fairseq'
|
9 |
+
|
10 |
+
model = torch.load(arg[1])
|
11 |
+
|
12 |
+
function find_weight_norm(container, module)
|
13 |
+
for _, wn in ipairs(container:listModules()) do
|
14 |
+
if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then
|
15 |
+
return wn
|
16 |
+
end
|
17 |
+
end
|
18 |
+
end
|
19 |
+
|
20 |
+
function push_state(dict, key, module)
|
21 |
+
if torch.type(module) == 'nn.Linear' then
|
22 |
+
local wn = find_weight_norm(model.module, module)
|
23 |
+
assert(wn)
|
24 |
+
dict[key .. '.weight_v'] = wn.v:float()
|
25 |
+
dict[key .. '.weight_g'] = wn.g:float()
|
26 |
+
elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then
|
27 |
+
local wn = find_weight_norm(model.module, module)
|
28 |
+
assert(wn)
|
29 |
+
local v = wn.v:float():view(wn.viewOut):transpose(2, 3)
|
30 |
+
dict[key .. '.weight_v'] = v
|
31 |
+
dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1)
|
32 |
+
else
|
33 |
+
dict[key .. '.weight'] = module.weight:float()
|
34 |
+
end
|
35 |
+
if module.bias then
|
36 |
+
dict[key .. '.bias'] = module.bias:float()
|
37 |
+
end
|
38 |
+
end
|
39 |
+
|
40 |
+
encoder_dict = {}
|
41 |
+
decoder_dict = {}
|
42 |
+
combined_dict = {}
|
43 |
+
|
44 |
+
function encoder_state(encoder)
|
45 |
+
luts = encoder:findModules('nn.LookupTable')
|
46 |
+
push_state(encoder_dict, 'embed_tokens', luts[1])
|
47 |
+
push_state(encoder_dict, 'embed_positions', luts[2])
|
48 |
+
|
49 |
+
fcs = encoder:findModules('nn.Linear')
|
50 |
+
assert(#fcs >= 2)
|
51 |
+
local nInputPlane = fcs[1].weight:size(1)
|
52 |
+
push_state(encoder_dict, 'fc1', table.remove(fcs, 1))
|
53 |
+
push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs))
|
54 |
+
|
55 |
+
for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do
|
56 |
+
push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module)
|
57 |
+
if nInputPlane ~= module.weight:size(3) / 2 then
|
58 |
+
push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
|
59 |
+
end
|
60 |
+
nInputPlane = module.weight:size(3) / 2
|
61 |
+
end
|
62 |
+
assert(#fcs == 0)
|
63 |
+
end
|
64 |
+
|
65 |
+
function decoder_state(decoder)
|
66 |
+
luts = decoder:findModules('nn.LookupTable')
|
67 |
+
push_state(decoder_dict, 'embed_tokens', luts[1])
|
68 |
+
push_state(decoder_dict, 'embed_positions', luts[2])
|
69 |
+
|
70 |
+
fcs = decoder:findModules('nn.Linear')
|
71 |
+
local nInputPlane = fcs[1].weight:size(1)
|
72 |
+
push_state(decoder_dict, 'fc1', table.remove(fcs, 1))
|
73 |
+
push_state(decoder_dict, 'fc2', fcs[#fcs - 1])
|
74 |
+
push_state(decoder_dict, 'fc3', fcs[#fcs])
|
75 |
+
|
76 |
+
table.remove(fcs, #fcs)
|
77 |
+
table.remove(fcs, #fcs)
|
78 |
+
|
79 |
+
for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do
|
80 |
+
if nInputPlane ~= module.weight:size(3) / 2 then
|
81 |
+
push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
|
82 |
+
end
|
83 |
+
nInputPlane = module.weight:size(3) / 2
|
84 |
+
|
85 |
+
local prefix = 'attention.' .. tostring(i - 1)
|
86 |
+
push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1))
|
87 |
+
push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1))
|
88 |
+
push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module)
|
89 |
+
end
|
90 |
+
assert(#fcs == 0)
|
91 |
+
end
|
92 |
+
|
93 |
+
|
94 |
+
_encoder = model.module.modules[2]
|
95 |
+
_decoder = model.module.modules[3]
|
96 |
+
|
97 |
+
encoder_state(_encoder)
|
98 |
+
decoder_state(_decoder)
|
99 |
+
|
100 |
+
for k, v in pairs(encoder_dict) do
|
101 |
+
combined_dict['encoder.' .. k] = v
|
102 |
+
end
|
103 |
+
for k, v in pairs(decoder_dict) do
|
104 |
+
combined_dict['decoder.' .. k] = v
|
105 |
+
end
|
106 |
+
|
107 |
+
|
108 |
+
torch.save('state_dict.t7', combined_dict)
|
fairseq/scripts/count_docs.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Count the number of documents and average number of lines and tokens per
|
8 |
+
document in a large file. Documents should be separated by a single empty line.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import gzip
|
13 |
+
import sys
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("input")
|
21 |
+
parser.add_argument("--gzip", action="store_true")
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
def gopen():
|
25 |
+
if args.gzip:
|
26 |
+
return gzip.open(args.input, "r")
|
27 |
+
else:
|
28 |
+
return open(args.input, "r", encoding="utf-8")
|
29 |
+
|
30 |
+
num_lines = []
|
31 |
+
num_toks = []
|
32 |
+
with gopen() as h:
|
33 |
+
num_docs = 1
|
34 |
+
num_lines_in_doc = 0
|
35 |
+
num_toks_in_doc = 0
|
36 |
+
for i, line in enumerate(h):
|
37 |
+
if len(line.strip()) == 0: # empty line indicates new document
|
38 |
+
num_docs += 1
|
39 |
+
num_lines.append(num_lines_in_doc)
|
40 |
+
num_toks.append(num_toks_in_doc)
|
41 |
+
num_lines_in_doc = 0
|
42 |
+
num_toks_in_doc = 0
|
43 |
+
else:
|
44 |
+
num_lines_in_doc += 1
|
45 |
+
num_toks_in_doc += len(line.rstrip().split())
|
46 |
+
if i % 1000000 == 0:
|
47 |
+
print(i, file=sys.stderr, end="", flush=True)
|
48 |
+
elif i % 100000 == 0:
|
49 |
+
print(".", file=sys.stderr, end="", flush=True)
|
50 |
+
print(file=sys.stderr, flush=True)
|
51 |
+
|
52 |
+
print("found {} docs".format(num_docs))
|
53 |
+
print("average num lines per doc: {}".format(np.mean(num_lines)))
|
54 |
+
print("average num toks per doc: {}".format(np.mean(num_toks)))
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
main()
|
fairseq/scripts/read_binarized.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
from fairseq.data import Dictionary, data_utils, indexed_dataset
|
10 |
+
|
11 |
+
|
12 |
+
def get_parser():
|
13 |
+
parser = argparse.ArgumentParser(
|
14 |
+
description="writes text from binarized file to stdout"
|
15 |
+
)
|
16 |
+
# fmt: off
|
17 |
+
parser.add_argument('--dataset-impl', help='dataset implementation',
|
18 |
+
choices=indexed_dataset.get_available_dataset_impl())
|
19 |
+
parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
|
20 |
+
parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
|
21 |
+
# fmt: on
|
22 |
+
|
23 |
+
return parser
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
parser = get_parser()
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
dictionary = Dictionary.load(args.dict) if args.dict is not None else None
|
31 |
+
dataset = data_utils.load_indexed_dataset(
|
32 |
+
args.input,
|
33 |
+
dictionary,
|
34 |
+
dataset_impl=args.dataset_impl,
|
35 |
+
default="lazy",
|
36 |
+
)
|
37 |
+
|
38 |
+
for tensor_line in dataset:
|
39 |
+
if dictionary is None:
|
40 |
+
line = " ".join([str(int(x)) for x in tensor_line])
|
41 |
+
else:
|
42 |
+
line = dictionary.string(tensor_line)
|
43 |
+
|
44 |
+
print(line)
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
main()
|
fairseq/scripts/rm_pt.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import re
|
10 |
+
import shutil
|
11 |
+
import sys
|
12 |
+
|
13 |
+
|
14 |
+
pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt")
|
15 |
+
pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt")
|
16 |
+
pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt")
|
17 |
+
|
18 |
+
|
19 |
+
def parse_checkpoints(files):
|
20 |
+
entries = []
|
21 |
+
for f in files:
|
22 |
+
m = pt_regexp_epoch_based.fullmatch(f)
|
23 |
+
if m is not None:
|
24 |
+
entries.append((int(m.group(1)), m.group(0)))
|
25 |
+
else:
|
26 |
+
m = pt_regexp_update_based.fullmatch(f)
|
27 |
+
if m is not None:
|
28 |
+
entries.append((int(m.group(1)), m.group(0)))
|
29 |
+
return entries
|
30 |
+
|
31 |
+
|
32 |
+
def last_n_checkpoints(files, n):
|
33 |
+
entries = parse_checkpoints(files)
|
34 |
+
return [x[1] for x in sorted(entries, reverse=True)[:n]]
|
35 |
+
|
36 |
+
|
37 |
+
def every_n_checkpoints(files, n):
|
38 |
+
entries = parse_checkpoints(files)
|
39 |
+
return [x[1] for x in sorted(sorted(entries)[::-n])]
|
40 |
+
|
41 |
+
|
42 |
+
def main():
|
43 |
+
parser = argparse.ArgumentParser(
|
44 |
+
description=(
|
45 |
+
"Recursively delete checkpoint files from `root_dir`, "
|
46 |
+
"but preserve checkpoint_best.pt and checkpoint_last.pt"
|
47 |
+
)
|
48 |
+
)
|
49 |
+
parser.add_argument("root_dirs", nargs="*")
|
50 |
+
parser.add_argument(
|
51 |
+
"--save-last", type=int, default=0, help="number of last checkpoints to save"
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--save-every", type=int, default=0, help="interval of checkpoints to save"
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--preserve-test",
|
58 |
+
action="store_true",
|
59 |
+
help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--delete-best", action="store_true", help="delete checkpoint_best.pt"
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--delete-last", action="store_true", help="delete checkpoint_last.pt"
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--no-dereference", action="store_true", help="don't dereference symlinks"
|
69 |
+
)
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
files_to_desymlink = []
|
73 |
+
files_to_preserve = []
|
74 |
+
files_to_delete = []
|
75 |
+
for root_dir in args.root_dirs:
|
76 |
+
for root, _subdirs, files in os.walk(root_dir):
|
77 |
+
if args.save_last > 0:
|
78 |
+
to_save = last_n_checkpoints(files, args.save_last)
|
79 |
+
else:
|
80 |
+
to_save = []
|
81 |
+
if args.save_every > 0:
|
82 |
+
to_save += every_n_checkpoints(files, args.save_every)
|
83 |
+
for file in files:
|
84 |
+
if not pt_regexp.fullmatch(file):
|
85 |
+
continue
|
86 |
+
full_path = os.path.join(root, file)
|
87 |
+
if (
|
88 |
+
not os.path.basename(root).startswith("test_") or args.preserve_test
|
89 |
+
) and (
|
90 |
+
(file == "checkpoint_last.pt" and not args.delete_last)
|
91 |
+
or (file == "checkpoint_best.pt" and not args.delete_best)
|
92 |
+
or file in to_save
|
93 |
+
):
|
94 |
+
if os.path.islink(full_path) and not args.no_dereference:
|
95 |
+
files_to_desymlink.append(full_path)
|
96 |
+
else:
|
97 |
+
files_to_preserve.append(full_path)
|
98 |
+
else:
|
99 |
+
files_to_delete.append(full_path)
|
100 |
+
|
101 |
+
if len(files_to_desymlink) == 0 and len(files_to_delete) == 0:
|
102 |
+
print("Nothing to do.")
|
103 |
+
sys.exit(0)
|
104 |
+
|
105 |
+
files_to_desymlink = sorted(files_to_desymlink)
|
106 |
+
files_to_preserve = sorted(files_to_preserve)
|
107 |
+
files_to_delete = sorted(files_to_delete)
|
108 |
+
|
109 |
+
print("Operations to perform (in order):")
|
110 |
+
if len(files_to_desymlink) > 0:
|
111 |
+
for file in files_to_desymlink:
|
112 |
+
print(" - preserve (and dereference symlink): " + file)
|
113 |
+
if len(files_to_preserve) > 0:
|
114 |
+
for file in files_to_preserve:
|
115 |
+
print(" - preserve: " + file)
|
116 |
+
if len(files_to_delete) > 0:
|
117 |
+
for file in files_to_delete:
|
118 |
+
print(" - delete: " + file)
|
119 |
+
while True:
|
120 |
+
resp = input("Continue? (Y/N): ")
|
121 |
+
if resp.strip().lower() == "y":
|
122 |
+
break
|
123 |
+
elif resp.strip().lower() == "n":
|
124 |
+
sys.exit(0)
|
125 |
+
|
126 |
+
print("Executing...")
|
127 |
+
if len(files_to_desymlink) > 0:
|
128 |
+
for file in files_to_desymlink:
|
129 |
+
realpath = os.path.realpath(file)
|
130 |
+
print("rm " + file)
|
131 |
+
os.remove(file)
|
132 |
+
print("cp {} {}".format(realpath, file))
|
133 |
+
shutil.copyfile(realpath, file)
|
134 |
+
if len(files_to_delete) > 0:
|
135 |
+
for file in files_to_delete:
|
136 |
+
print("rm " + file)
|
137 |
+
os.remove(file)
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
main()
|
fairseq/scripts/sacrebleu.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
if [ $# -ne 4 ]; then
|
4 |
+
echo "usage: $0 TESTSET SRCLANG TGTLANG GEN"
|
5 |
+
exit 1
|
6 |
+
fi
|
7 |
+
|
8 |
+
TESTSET=$1
|
9 |
+
SRCLANG=$2
|
10 |
+
TGTLANG=$3
|
11 |
+
|
12 |
+
GEN=$4
|
13 |
+
|
14 |
+
if ! command -v sacremoses &> /dev/null
|
15 |
+
then
|
16 |
+
echo "sacremoses could not be found, please install with: pip install sacremoses"
|
17 |
+
exit
|
18 |
+
fi
|
19 |
+
|
20 |
+
grep ^H $GEN \
|
21 |
+
| sed 's/^H\-//' \
|
22 |
+
| sort -n -k 1 \
|
23 |
+
| cut -f 3 \
|
24 |
+
| sacremoses detokenize \
|
25 |
+
> $GEN.sorted.detok
|
26 |
+
|
27 |
+
sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok
|
fairseq/scripts/shard_docs.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Split a large file into shards while respecting document boundaries. Documents
|
8 |
+
should be separated by a single empty line.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import contextlib
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument("input")
|
18 |
+
parser.add_argument("--num-shards", type=int)
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
assert args.num_shards is not None and args.num_shards > 1
|
22 |
+
|
23 |
+
with open(args.input, "r", encoding="utf-8") as h:
|
24 |
+
with contextlib.ExitStack() as stack:
|
25 |
+
outputs = [
|
26 |
+
stack.enter_context(
|
27 |
+
open(args.input + ".shard" + str(i), "w", encoding="utf-8")
|
28 |
+
)
|
29 |
+
for i in range(args.num_shards)
|
30 |
+
]
|
31 |
+
|
32 |
+
doc = []
|
33 |
+
first_doc = [True] * args.num_shards
|
34 |
+
|
35 |
+
def output_doc(i):
|
36 |
+
if not first_doc[i]:
|
37 |
+
outputs[i].write("\n")
|
38 |
+
first_doc[i] = False
|
39 |
+
for line in doc:
|
40 |
+
outputs[i].write(line)
|
41 |
+
doc.clear()
|
42 |
+
|
43 |
+
num_docs = 0
|
44 |
+
for line in h:
|
45 |
+
if line.strip() == "": # empty line indicates new document
|
46 |
+
output_doc(num_docs % args.num_shards)
|
47 |
+
num_docs += 1
|
48 |
+
else:
|
49 |
+
doc.append(line)
|
50 |
+
output_doc(num_docs % args.num_shards)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
main()
|
fairseq/scripts/split_train_valid_docs.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Split a large file into a train and valid set while respecting document
|
8 |
+
boundaries. Documents should be separated by a single empty line.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import random
|
13 |
+
import sys
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("input")
|
19 |
+
parser.add_argument("sample_output", help="train output file")
|
20 |
+
parser.add_argument("remainder_output", help="valid output file")
|
21 |
+
parser.add_argument("-k", type=int, help="remainder size")
|
22 |
+
parser.add_argument(
|
23 |
+
"--lines", action="store_true", help="split lines instead of docs"
|
24 |
+
)
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
assert args.k is not None
|
28 |
+
|
29 |
+
sample = []
|
30 |
+
remainder = []
|
31 |
+
num_docs = [0]
|
32 |
+
|
33 |
+
def update_sample(doc):
|
34 |
+
if len(sample) < args.k:
|
35 |
+
sample.append(doc.copy())
|
36 |
+
else:
|
37 |
+
i = num_docs[0]
|
38 |
+
j = random.randrange(i + 1)
|
39 |
+
if j < args.k:
|
40 |
+
remainder.append(sample[j])
|
41 |
+
sample[j] = doc.copy()
|
42 |
+
else:
|
43 |
+
remainder.append(doc.copy())
|
44 |
+
num_docs[0] += 1
|
45 |
+
doc.clear()
|
46 |
+
|
47 |
+
with open(args.input, "r", encoding="utf-8") as h:
|
48 |
+
doc = []
|
49 |
+
for i, line in enumerate(h):
|
50 |
+
if line.strip() == "": # empty line indicates new document
|
51 |
+
update_sample(doc)
|
52 |
+
else:
|
53 |
+
doc.append(line)
|
54 |
+
if args.lines:
|
55 |
+
update_sample(doc)
|
56 |
+
if i % 1000000 == 0:
|
57 |
+
print(i, file=sys.stderr, end="", flush=True)
|
58 |
+
elif i % 100000 == 0:
|
59 |
+
print(".", file=sys.stderr, end="", flush=True)
|
60 |
+
if len(doc) > 0:
|
61 |
+
update_sample(doc)
|
62 |
+
print(file=sys.stderr, flush=True)
|
63 |
+
|
64 |
+
assert len(sample) == args.k
|
65 |
+
|
66 |
+
with open(args.sample_output, "w", encoding="utf-8") as out:
|
67 |
+
first = True
|
68 |
+
for doc in sample:
|
69 |
+
if not first and not args.lines:
|
70 |
+
out.write("\n")
|
71 |
+
first = False
|
72 |
+
for line in doc:
|
73 |
+
out.write(line)
|
74 |
+
|
75 |
+
with open(args.remainder_output, "w", encoding="utf-8") as out:
|
76 |
+
first = True
|
77 |
+
for doc in remainder:
|
78 |
+
if not first and not args.lines:
|
79 |
+
out.write("\n")
|
80 |
+
first = False
|
81 |
+
for line in doc:
|
82 |
+
out.write(line)
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
main()
|
fairseq/scripts/spm_decode.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
import sentencepiece as spm
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument(
|
18 |
+
"--model", required=True, help="sentencepiece model to use for decoding"
|
19 |
+
)
|
20 |
+
parser.add_argument("--input", required=True, help="input file to decode")
|
21 |
+
parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
sp = spm.SentencePieceProcessor()
|
25 |
+
sp.Load(args.model)
|
26 |
+
|
27 |
+
if args.input_format == "piece":
|
28 |
+
|
29 |
+
def decode(input):
|
30 |
+
return "".join(sp.DecodePieces(input))
|
31 |
+
|
32 |
+
elif args.input_format == "id":
|
33 |
+
|
34 |
+
def decode(input):
|
35 |
+
return "".join(sp.DecodeIds(input))
|
36 |
+
|
37 |
+
else:
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
def tok2int(tok):
|
41 |
+
# remap reference-side <unk> (represented as <<unk>>) to 0
|
42 |
+
return int(tok) if tok != "<<unk>>" else 0
|
43 |
+
|
44 |
+
with open(args.input, "r", encoding="utf-8") as h:
|
45 |
+
for line in h:
|
46 |
+
if args.input_format == "id":
|
47 |
+
print(decode(list(map(tok2int, line.rstrip().split()))))
|
48 |
+
elif args.input_format == "piece":
|
49 |
+
print(decode(line.rstrip().split()))
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
main()
|
fairseq/scripts/spm_encode.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import contextlib
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import sentencepiece as spm
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument(
|
20 |
+
"--model", required=True, help="sentencepiece model to use for encoding"
|
21 |
+
)
|
22 |
+
parser.add_argument(
|
23 |
+
"--inputs", nargs="+", default=["-"], help="input files to filter/encode"
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--outputs", nargs="+", default=["-"], help="path to save encoded outputs"
|
27 |
+
)
|
28 |
+
parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
|
29 |
+
parser.add_argument(
|
30 |
+
"--min-len",
|
31 |
+
type=int,
|
32 |
+
metavar="N",
|
33 |
+
help="filter sentence pairs with fewer than N tokens",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--max-len",
|
37 |
+
type=int,
|
38 |
+
metavar="N",
|
39 |
+
help="filter sentence pairs with more than N tokens",
|
40 |
+
)
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
assert len(args.inputs) == len(
|
44 |
+
args.outputs
|
45 |
+
), "number of input and output paths should match"
|
46 |
+
|
47 |
+
sp = spm.SentencePieceProcessor()
|
48 |
+
sp.Load(args.model)
|
49 |
+
|
50 |
+
if args.output_format == "piece":
|
51 |
+
|
52 |
+
def encode(input):
|
53 |
+
return sp.EncodeAsPieces(input)
|
54 |
+
|
55 |
+
elif args.output_format == "id":
|
56 |
+
|
57 |
+
def encode(input):
|
58 |
+
return list(map(str, sp.EncodeAsIds(input)))
|
59 |
+
|
60 |
+
else:
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
if args.min_len is not None or args.max_len is not None:
|
64 |
+
|
65 |
+
def valid(line):
|
66 |
+
return (args.min_len is None or len(line) >= args.min_len) and (
|
67 |
+
args.max_len is None or len(line) <= args.max_len
|
68 |
+
)
|
69 |
+
|
70 |
+
else:
|
71 |
+
|
72 |
+
def valid(lines):
|
73 |
+
return True
|
74 |
+
|
75 |
+
with contextlib.ExitStack() as stack:
|
76 |
+
inputs = [
|
77 |
+
stack.enter_context(open(input, "r", encoding="utf-8"))
|
78 |
+
if input != "-"
|
79 |
+
else sys.stdin
|
80 |
+
for input in args.inputs
|
81 |
+
]
|
82 |
+
outputs = [
|
83 |
+
stack.enter_context(open(output, "w", encoding="utf-8"))
|
84 |
+
if output != "-"
|
85 |
+
else sys.stdout
|
86 |
+
for output in args.outputs
|
87 |
+
]
|
88 |
+
|
89 |
+
stats = {
|
90 |
+
"num_empty": 0,
|
91 |
+
"num_filtered": 0,
|
92 |
+
}
|
93 |
+
|
94 |
+
def encode_line(line):
|
95 |
+
line = line.strip()
|
96 |
+
if len(line) > 0:
|
97 |
+
line = encode(line)
|
98 |
+
if valid(line):
|
99 |
+
return line
|
100 |
+
else:
|
101 |
+
stats["num_filtered"] += 1
|
102 |
+
else:
|
103 |
+
stats["num_empty"] += 1
|
104 |
+
return None
|
105 |
+
|
106 |
+
for i, lines in enumerate(zip(*inputs), start=1):
|
107 |
+
enc_lines = list(map(encode_line, lines))
|
108 |
+
if not any(enc_line is None for enc_line in enc_lines):
|
109 |
+
for enc_line, output_h in zip(enc_lines, outputs):
|
110 |
+
print(" ".join(enc_line), file=output_h)
|
111 |
+
if i % 10000 == 0:
|
112 |
+
print("processed {} lines".format(i), file=sys.stderr)
|
113 |
+
|
114 |
+
print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
|
115 |
+
print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
main()
|
fairseq/scripts/spm_train.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
9 |
+
|
10 |
+
import sys
|
11 |
+
|
12 |
+
import sentencepiece as spm
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
|
fairseq/scripts/test_fsdp.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
rm -rf fsdp_dummy
|
3 |
+
mkdir -p fsdp_dummy
|
4 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
|
5 |
+
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
|
6 |
+
--cpu-offload --checkpoint-activations \
|
7 |
+
--task language_modeling --tokens-per-sample 256 --batch-size 8 \
|
8 |
+
--arch transformer_lm_gpt2_tiny \
|
9 |
+
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
|
10 |
+
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
|
11 |
+
--max-update 5 --log-format json --log-interval 1 \
|
12 |
+
--save-interval-updates 5 --save-dir fsdp_dummy --disable-validation \
|
13 |
+
--restore-file x.pt "$@"
|
14 |
+
|
15 |
+
# Now we try to load the checkpoint
|
16 |
+
CUDA_VISIBLE_DEVICES=0,1 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
|
17 |
+
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
|
18 |
+
--cpu-offload --checkpoint-activations \
|
19 |
+
--task language_modeling --tokens-per-sample 256 --batch-size 8 \
|
20 |
+
--arch transformer_lm_gpt2_tiny \
|
21 |
+
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
|
22 |
+
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
|
23 |
+
--max-update 2 --log-format json --log-interval 1 \
|
24 |
+
--save-interval-updates 2 --save-dir fsdp_dummy
|
fairseq/tests/__init__.py
ADDED
File without changes
|
fairseq/tests/tasks/test_masked_lm.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from tempfile import TemporaryDirectory
|
9 |
+
|
10 |
+
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
|
11 |
+
from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
|
12 |
+
from tests.utils import build_vocab, make_data
|
13 |
+
|
14 |
+
|
15 |
+
class TestMaskedLM(unittest.TestCase):
|
16 |
+
def test_masks_tokens(self):
|
17 |
+
with TemporaryDirectory() as dirname:
|
18 |
+
|
19 |
+
# prep input file
|
20 |
+
raw_file = os.path.join(dirname, "raw")
|
21 |
+
data = make_data(out_file=raw_file)
|
22 |
+
vocab = build_vocab(data)
|
23 |
+
|
24 |
+
# binarize
|
25 |
+
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
|
26 |
+
split = "train"
|
27 |
+
bin_file = os.path.join(dirname, split)
|
28 |
+
FileBinarizer.multiprocess_dataset(
|
29 |
+
input_file=raw_file,
|
30 |
+
binarizer=binarizer,
|
31 |
+
dataset_impl="mmap",
|
32 |
+
vocab_size=len(vocab),
|
33 |
+
output_prefix=bin_file,
|
34 |
+
)
|
35 |
+
|
36 |
+
# setup task
|
37 |
+
cfg = MaskedLMConfig(
|
38 |
+
data=dirname,
|
39 |
+
seed=42,
|
40 |
+
mask_prob=0.5, # increasing the odds of masking
|
41 |
+
random_token_prob=0, # avoiding random tokens for exact match
|
42 |
+
leave_unmasked_prob=0, # always masking for exact match
|
43 |
+
)
|
44 |
+
task = MaskedLMTask(cfg, binarizer.dict)
|
45 |
+
|
46 |
+
original_dataset = task._load_dataset_split(bin_file, 1, False)
|
47 |
+
|
48 |
+
# load datasets
|
49 |
+
task.load_dataset(split)
|
50 |
+
masked_dataset = task.dataset(split)
|
51 |
+
|
52 |
+
mask_index = task.source_dictionary.index("<mask>")
|
53 |
+
iterator = task.get_batch_iterator(
|
54 |
+
dataset=masked_dataset,
|
55 |
+
max_tokens=65_536,
|
56 |
+
max_positions=4_096,
|
57 |
+
).next_epoch_itr(shuffle=False)
|
58 |
+
for batch in iterator:
|
59 |
+
for sample in range(len(batch)):
|
60 |
+
net_input = batch["net_input"]
|
61 |
+
masked_src_tokens = net_input["src_tokens"][sample]
|
62 |
+
masked_src_length = net_input["src_lengths"][sample]
|
63 |
+
masked_tgt_tokens = batch["target"][sample]
|
64 |
+
|
65 |
+
sample_id = batch["id"][sample]
|
66 |
+
original_tokens = original_dataset[sample_id]
|
67 |
+
original_tokens = original_tokens.masked_select(
|
68 |
+
masked_src_tokens[:masked_src_length] == mask_index
|
69 |
+
)
|
70 |
+
masked_tokens = masked_tgt_tokens.masked_select(
|
71 |
+
masked_tgt_tokens != task.source_dictionary.pad()
|
72 |
+
)
|
73 |
+
|
74 |
+
assert masked_tokens.equal(original_tokens)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
unittest.main()
|
fairseq/tests/tasks/test_span_masked_lm.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from tempfile import TemporaryDirectory
|
9 |
+
|
10 |
+
from fairseq import options
|
11 |
+
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
|
12 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
13 |
+
from fairseq.tasks.span_masked_lm import SpanMaskedLMTask
|
14 |
+
from tests.utils import build_vocab, make_data
|
15 |
+
|
16 |
+
|
17 |
+
class TestSpanMaskedLM(unittest.TestCase):
|
18 |
+
def test_masks_token_spans(self):
|
19 |
+
with TemporaryDirectory() as dirname:
|
20 |
+
|
21 |
+
# prep input file
|
22 |
+
raw_file = os.path.join(dirname, "raw")
|
23 |
+
data = make_data(out_file=raw_file)
|
24 |
+
vocab = build_vocab(data)
|
25 |
+
|
26 |
+
# binarize
|
27 |
+
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
|
28 |
+
split = "train"
|
29 |
+
bin_file = os.path.join(dirname, split)
|
30 |
+
dataset_impl = "mmap"
|
31 |
+
|
32 |
+
FileBinarizer.multiprocess_dataset(
|
33 |
+
input_file=raw_file,
|
34 |
+
binarizer=binarizer,
|
35 |
+
dataset_impl=dataset_impl,
|
36 |
+
vocab_size=len(vocab),
|
37 |
+
output_prefix=bin_file,
|
38 |
+
)
|
39 |
+
|
40 |
+
# adding sentinel tokens
|
41 |
+
for i in range(100):
|
42 |
+
vocab.add_symbol(f"<extra_id_{i}>")
|
43 |
+
|
44 |
+
# setup task
|
45 |
+
train_args = options.parse_args_and_arch(
|
46 |
+
options.get_training_parser(),
|
47 |
+
[
|
48 |
+
"--task",
|
49 |
+
"span_masked_lm",
|
50 |
+
"--arch",
|
51 |
+
"bart_base",
|
52 |
+
"--seed",
|
53 |
+
"42",
|
54 |
+
dirname,
|
55 |
+
],
|
56 |
+
)
|
57 |
+
cfg = convert_namespace_to_omegaconf(train_args)
|
58 |
+
task = SpanMaskedLMTask(cfg.task, binarizer.dict)
|
59 |
+
|
60 |
+
# load datasets
|
61 |
+
original_dataset = task._load_dataset_split(bin_file, 1, False)
|
62 |
+
task.load_dataset(split)
|
63 |
+
masked_dataset = task.dataset(split)
|
64 |
+
|
65 |
+
iterator = task.get_batch_iterator(
|
66 |
+
dataset=masked_dataset,
|
67 |
+
max_tokens=65_536,
|
68 |
+
max_positions=4_096,
|
69 |
+
).next_epoch_itr(shuffle=False)
|
70 |
+
num_tokens = len(vocab)
|
71 |
+
for batch in iterator:
|
72 |
+
for sample in range(len(batch)):
|
73 |
+
sample_id = batch["id"][sample]
|
74 |
+
original_tokens = original_dataset[sample_id]
|
75 |
+
masked_src_tokens = batch["net_input"]["src_tokens"][sample]
|
76 |
+
masked_src_length = batch["net_input"]["src_lengths"][sample]
|
77 |
+
masked_tgt_tokens = batch["target"][sample]
|
78 |
+
|
79 |
+
original_offset = 0
|
80 |
+
masked_tgt_offset = 0
|
81 |
+
extra_id_token = len(vocab) - 1
|
82 |
+
for masked_src_token in masked_src_tokens[:masked_src_length]:
|
83 |
+
if masked_src_token == extra_id_token:
|
84 |
+
assert (
|
85 |
+
masked_src_token == masked_tgt_tokens[masked_tgt_offset]
|
86 |
+
)
|
87 |
+
extra_id_token -= 1
|
88 |
+
masked_tgt_offset += 1
|
89 |
+
while (
|
90 |
+
original_offset < len(original_tokens)
|
91 |
+
and masked_tgt_tokens[masked_tgt_offset]
|
92 |
+
!= extra_id_token
|
93 |
+
):
|
94 |
+
assert (
|
95 |
+
original_tokens[original_offset]
|
96 |
+
== masked_tgt_tokens[masked_tgt_offset]
|
97 |
+
)
|
98 |
+
original_offset += 1
|
99 |
+
masked_tgt_offset += 1
|
100 |
+
else:
|
101 |
+
assert original_tokens[original_offset] == masked_src_token
|
102 |
+
original_offset += 1
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
unittest.main()
|
fairseq/tests/test_activation_checkpointing.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
|
14 |
+
class Model(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
torch.manual_seed(0)
|
20 |
+
self.use_pytorch_checkpoint = use_pytorch_checkpoint
|
21 |
+
self.ffn = nn.Sequential(
|
22 |
+
nn.Linear(32, 128),
|
23 |
+
# add a Dropout layer to test RNG save/restore
|
24 |
+
nn.Dropout(p=0.5),
|
25 |
+
nn.Linear(128, 32),
|
26 |
+
)
|
27 |
+
if use_fairseq_checkpoint:
|
28 |
+
self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
|
29 |
+
self.out = nn.Linear(32, 1)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
if self.use_pytorch_checkpoint:
|
33 |
+
x = checkpoint(self.ffn, x)
|
34 |
+
else:
|
35 |
+
x = self.ffn(x)
|
36 |
+
return self.out(x)
|
37 |
+
|
38 |
+
|
39 |
+
class TestActivationCheckpointing(unittest.TestCase):
|
40 |
+
def _test_checkpoint_wrapper(self, device, log_memory_usage=False):
|
41 |
+
def get_loss_and_gnorm(model):
|
42 |
+
torch.manual_seed(1)
|
43 |
+
input = torch.rand(2, 16, 32).requires_grad_(True).to(device)
|
44 |
+
model.zero_grad()
|
45 |
+
loss = model(input).sum()
|
46 |
+
loss.backward()
|
47 |
+
gnorm = torch.norm(
|
48 |
+
torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])
|
49 |
+
)
|
50 |
+
return {"loss": loss, "gnorm": gnorm}
|
51 |
+
|
52 |
+
model = Model().to(device)
|
53 |
+
no_cpt = get_loss_and_gnorm(model)
|
54 |
+
|
55 |
+
model = Model(use_pytorch_checkpoint=True).to(device)
|
56 |
+
pyt_cpt = get_loss_and_gnorm(model)
|
57 |
+
torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"])
|
58 |
+
torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"])
|
59 |
+
|
60 |
+
model = Model(use_fairseq_checkpoint=True).to(device)
|
61 |
+
fairseq_cpt = get_loss_and_gnorm(model)
|
62 |
+
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
|
63 |
+
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"])
|
64 |
+
|
65 |
+
model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
|
66 |
+
fairseq_cpt_offload = get_loss_and_gnorm(model)
|
67 |
+
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"])
|
68 |
+
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"])
|
69 |
+
|
70 |
+
def test_checkpoint_wrapper_cpu(self):
|
71 |
+
self._test_checkpoint_wrapper(device=torch.device("cpu"))
|
72 |
+
|
73 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
74 |
+
def test_checkpoint_wrapper_cuda(self):
|
75 |
+
self._test_checkpoint_wrapper(device=torch.device("cuda"))
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
unittest.main()
|
fairseq/tests/test_amp_optimizer.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import copy
|
8 |
+
import unittest
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.cuda.amp import GradScaler, autocast
|
12 |
+
|
13 |
+
from fairseq.optim import build_optimizer
|
14 |
+
|
15 |
+
|
16 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
17 |
+
class TestGradientScalingAMP(unittest.TestCase):
|
18 |
+
def setUp(self):
|
19 |
+
self.x = torch.tensor([2.0]).cuda().half()
|
20 |
+
weight = 3.0
|
21 |
+
bias = 5.0
|
22 |
+
self.error = 1.0
|
23 |
+
self.target = torch.tensor([self.x * weight + bias + self.error]).cuda()
|
24 |
+
self.loss_fn = torch.nn.L1Loss()
|
25 |
+
|
26 |
+
self.model = torch.nn.Linear(1, 1)
|
27 |
+
self.model.weight.data = torch.tensor([[weight]])
|
28 |
+
self.model.bias.data = torch.tensor([bias])
|
29 |
+
self.model.cuda()
|
30 |
+
self.params = list(self.model.parameters())
|
31 |
+
|
32 |
+
self.namespace_dls = argparse.Namespace(
|
33 |
+
optimizer="adam",
|
34 |
+
lr=[0.1],
|
35 |
+
adam_betas="(0.9, 0.999)",
|
36 |
+
adam_eps=1e-8,
|
37 |
+
weight_decay=0.0,
|
38 |
+
threshold_loss_scale=1,
|
39 |
+
min_loss_scale=1e-4,
|
40 |
+
)
|
41 |
+
self.scaler = GradScaler(
|
42 |
+
init_scale=1,
|
43 |
+
growth_interval=1,
|
44 |
+
)
|
45 |
+
|
46 |
+
def run_iter(self, model, params, optimizer):
|
47 |
+
optimizer.zero_grad()
|
48 |
+
with autocast():
|
49 |
+
y = model(self.x)
|
50 |
+
loss = self.loss_fn(y, self.target)
|
51 |
+
self.scaler.scale(loss).backward()
|
52 |
+
self.assertEqual(loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16))
|
53 |
+
|
54 |
+
self.scaler.unscale_(optimizer)
|
55 |
+
grad_norm = optimizer.clip_grad_norm(0)
|
56 |
+
self.assertAlmostEqual(grad_norm.item(), 2.2361, 4)
|
57 |
+
|
58 |
+
self.scaler.step(optimizer)
|
59 |
+
self.scaler.update()
|
60 |
+
self.assertEqual(
|
61 |
+
model.weight,
|
62 |
+
torch.tensor([[3.1]], device="cuda:0", requires_grad=True),
|
63 |
+
)
|
64 |
+
self.assertEqual(
|
65 |
+
model.bias,
|
66 |
+
torch.tensor([5.1], device="cuda:0", requires_grad=True),
|
67 |
+
)
|
68 |
+
self.assertEqual(self.scaler.get_scale(), 2.0)
|
69 |
+
|
70 |
+
def test_automatic_mixed_precision(self):
|
71 |
+
model = copy.deepcopy(self.model)
|
72 |
+
params = list(model.parameters())
|
73 |
+
optimizer = build_optimizer(self.namespace_dls, params)
|
74 |
+
|
75 |
+
self.run_iter(model, params, optimizer)
|
fairseq/tests/test_average_checkpoints.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import collections
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import tempfile
|
10 |
+
import unittest
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from scripts.average_checkpoints import average_checkpoints
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
class ModelWithSharedParameter(nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super(ModelWithSharedParameter, self).__init__()
|
21 |
+
self.embedding = nn.Embedding(1000, 200)
|
22 |
+
self.FC1 = nn.Linear(200, 200)
|
23 |
+
self.FC2 = nn.Linear(200, 200)
|
24 |
+
# tie weight in FC2 to FC1
|
25 |
+
self.FC2.weight = nn.Parameter(self.FC1.weight)
|
26 |
+
self.FC2.bias = nn.Parameter(self.FC1.bias)
|
27 |
+
|
28 |
+
self.relu = nn.ReLU()
|
29 |
+
|
30 |
+
def forward(self, input):
|
31 |
+
return self.FC2(self.ReLU(self.FC1(input))) + self.FC1(input)
|
32 |
+
|
33 |
+
|
34 |
+
class TestAverageCheckpoints(unittest.TestCase):
|
35 |
+
def test_average_checkpoints(self):
|
36 |
+
params_0 = collections.OrderedDict(
|
37 |
+
[
|
38 |
+
("a", torch.DoubleTensor([100.0])),
|
39 |
+
("b", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
|
40 |
+
("c", torch.IntTensor([7, 8, 9])),
|
41 |
+
]
|
42 |
+
)
|
43 |
+
params_1 = collections.OrderedDict(
|
44 |
+
[
|
45 |
+
("a", torch.DoubleTensor([1.0])),
|
46 |
+
("b", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
|
47 |
+
("c", torch.IntTensor([2, 2, 2])),
|
48 |
+
]
|
49 |
+
)
|
50 |
+
params_avg = collections.OrderedDict(
|
51 |
+
[
|
52 |
+
("a", torch.DoubleTensor([50.5])),
|
53 |
+
("b", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
|
54 |
+
# We expect truncation for integer division
|
55 |
+
("c", torch.IntTensor([4, 5, 5])),
|
56 |
+
]
|
57 |
+
)
|
58 |
+
|
59 |
+
fd_0, path_0 = tempfile.mkstemp()
|
60 |
+
fd_1, path_1 = tempfile.mkstemp()
|
61 |
+
torch.save(collections.OrderedDict([("model", params_0)]), path_0)
|
62 |
+
torch.save(collections.OrderedDict([("model", params_1)]), path_1)
|
63 |
+
|
64 |
+
output = average_checkpoints([path_0, path_1])["model"]
|
65 |
+
|
66 |
+
os.close(fd_0)
|
67 |
+
os.remove(path_0)
|
68 |
+
os.close(fd_1)
|
69 |
+
os.remove(path_1)
|
70 |
+
|
71 |
+
for (k_expected, v_expected), (k_out, v_out) in zip(
|
72 |
+
params_avg.items(), output.items()
|
73 |
+
):
|
74 |
+
self.assertEqual(
|
75 |
+
k_expected,
|
76 |
+
k_out,
|
77 |
+
"Key mismatch - expected {} but found {}. "
|
78 |
+
"(Expected list of keys: {} vs actual list of keys: {})".format(
|
79 |
+
k_expected, k_out, params_avg.keys(), output.keys()
|
80 |
+
),
|
81 |
+
)
|
82 |
+
np.testing.assert_allclose(
|
83 |
+
v_expected.numpy(),
|
84 |
+
v_out.numpy(),
|
85 |
+
err_msg="Tensor value mismatch for key {}".format(k_expected),
|
86 |
+
)
|
87 |
+
|
88 |
+
def test_average_checkpoints_with_shared_parameters(self):
|
89 |
+
def _construct_model_with_shared_parameters(path, value):
|
90 |
+
m = ModelWithSharedParameter()
|
91 |
+
nn.init.constant_(m.FC1.weight, value)
|
92 |
+
torch.save({"model": m.state_dict()}, path)
|
93 |
+
return m
|
94 |
+
|
95 |
+
tmpdir = tempfile.mkdtemp()
|
96 |
+
paths = []
|
97 |
+
path = os.path.join(tmpdir, "m1.pt")
|
98 |
+
m1 = _construct_model_with_shared_parameters(path, 1.0)
|
99 |
+
paths.append(path)
|
100 |
+
|
101 |
+
path = os.path.join(tmpdir, "m2.pt")
|
102 |
+
m2 = _construct_model_with_shared_parameters(path, 2.0)
|
103 |
+
paths.append(path)
|
104 |
+
|
105 |
+
path = os.path.join(tmpdir, "m3.pt")
|
106 |
+
m3 = _construct_model_with_shared_parameters(path, 3.0)
|
107 |
+
paths.append(path)
|
108 |
+
|
109 |
+
new_model = average_checkpoints(paths)
|
110 |
+
self.assertTrue(
|
111 |
+
torch.equal(
|
112 |
+
new_model["model"]["embedding.weight"],
|
113 |
+
(m1.embedding.weight + m2.embedding.weight + m3.embedding.weight) / 3.0,
|
114 |
+
)
|
115 |
+
)
|
116 |
+
|
117 |
+
self.assertTrue(
|
118 |
+
torch.equal(
|
119 |
+
new_model["model"]["FC1.weight"],
|
120 |
+
(m1.FC1.weight + m2.FC1.weight + m3.FC1.weight) / 3.0,
|
121 |
+
)
|
122 |
+
)
|
123 |
+
|
124 |
+
self.assertTrue(
|
125 |
+
torch.equal(
|
126 |
+
new_model["model"]["FC2.weight"],
|
127 |
+
(m1.FC2.weight + m2.FC2.weight + m3.FC2.weight) / 3.0,
|
128 |
+
)
|
129 |
+
)
|
130 |
+
shutil.rmtree(tmpdir)
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
unittest.main()
|
fairseq/tests/test_backtranslation_dataset.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
|
8 |
+
import tests.utils as test_utils
|
9 |
+
import torch
|
10 |
+
from fairseq.data import (
|
11 |
+
BacktranslationDataset,
|
12 |
+
LanguagePairDataset,
|
13 |
+
TransformEosDataset,
|
14 |
+
)
|
15 |
+
from fairseq.sequence_generator import SequenceGenerator
|
16 |
+
|
17 |
+
|
18 |
+
class TestBacktranslationDataset(unittest.TestCase):
|
19 |
+
def setUp(self):
|
20 |
+
(
|
21 |
+
self.tgt_dict,
|
22 |
+
self.w1,
|
23 |
+
self.w2,
|
24 |
+
self.src_tokens,
|
25 |
+
self.src_lengths,
|
26 |
+
self.model,
|
27 |
+
) = test_utils.sequence_generator_setup()
|
28 |
+
|
29 |
+
dummy_src_samples = self.src_tokens
|
30 |
+
|
31 |
+
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
|
32 |
+
self.cuda = torch.cuda.is_available()
|
33 |
+
|
34 |
+
def _backtranslation_dataset_helper(
|
35 |
+
self,
|
36 |
+
remove_eos_from_input_src,
|
37 |
+
remove_eos_from_output_src,
|
38 |
+
):
|
39 |
+
tgt_dataset = LanguagePairDataset(
|
40 |
+
src=self.tgt_dataset,
|
41 |
+
src_sizes=self.tgt_dataset.sizes,
|
42 |
+
src_dict=self.tgt_dict,
|
43 |
+
tgt=None,
|
44 |
+
tgt_sizes=None,
|
45 |
+
tgt_dict=None,
|
46 |
+
)
|
47 |
+
|
48 |
+
generator = SequenceGenerator(
|
49 |
+
[self.model],
|
50 |
+
tgt_dict=self.tgt_dict,
|
51 |
+
max_len_a=0,
|
52 |
+
max_len_b=200,
|
53 |
+
beam_size=2,
|
54 |
+
unk_penalty=0,
|
55 |
+
)
|
56 |
+
|
57 |
+
backtranslation_dataset = BacktranslationDataset(
|
58 |
+
tgt_dataset=TransformEosDataset(
|
59 |
+
dataset=tgt_dataset,
|
60 |
+
eos=self.tgt_dict.eos(),
|
61 |
+
# remove eos from the input src
|
62 |
+
remove_eos_from_src=remove_eos_from_input_src,
|
63 |
+
),
|
64 |
+
src_dict=self.tgt_dict,
|
65 |
+
backtranslation_fn=(
|
66 |
+
lambda sample: generator.generate([self.model], sample)
|
67 |
+
),
|
68 |
+
output_collater=TransformEosDataset(
|
69 |
+
dataset=tgt_dataset,
|
70 |
+
eos=self.tgt_dict.eos(),
|
71 |
+
# if we remove eos from the input src, then we need to add it
|
72 |
+
# back to the output tgt
|
73 |
+
append_eos_to_tgt=remove_eos_from_input_src,
|
74 |
+
remove_eos_from_src=remove_eos_from_output_src,
|
75 |
+
).collater,
|
76 |
+
cuda=self.cuda,
|
77 |
+
)
|
78 |
+
dataloader = torch.utils.data.DataLoader(
|
79 |
+
backtranslation_dataset,
|
80 |
+
batch_size=2,
|
81 |
+
collate_fn=backtranslation_dataset.collater,
|
82 |
+
)
|
83 |
+
backtranslation_batch_result = next(iter(dataloader))
|
84 |
+
|
85 |
+
eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2
|
86 |
+
|
87 |
+
# Note that we sort by src_lengths and add left padding, so actually
|
88 |
+
# ids will look like: [1, 0]
|
89 |
+
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
|
90 |
+
if remove_eos_from_output_src:
|
91 |
+
expected_src = expected_src[:, :-1]
|
92 |
+
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
|
93 |
+
generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
|
94 |
+
tgt_tokens = backtranslation_batch_result["target"]
|
95 |
+
|
96 |
+
self.assertTensorEqual(expected_src, generated_src)
|
97 |
+
self.assertTensorEqual(expected_tgt, tgt_tokens)
|
98 |
+
|
99 |
+
def test_backtranslation_dataset_no_eos_in_output_src(self):
|
100 |
+
self._backtranslation_dataset_helper(
|
101 |
+
remove_eos_from_input_src=False,
|
102 |
+
remove_eos_from_output_src=True,
|
103 |
+
)
|
104 |
+
|
105 |
+
def test_backtranslation_dataset_with_eos_in_output_src(self):
|
106 |
+
self._backtranslation_dataset_helper(
|
107 |
+
remove_eos_from_input_src=False,
|
108 |
+
remove_eos_from_output_src=False,
|
109 |
+
)
|
110 |
+
|
111 |
+
def test_backtranslation_dataset_no_eos_in_input_src(self):
|
112 |
+
self._backtranslation_dataset_helper(
|
113 |
+
remove_eos_from_input_src=True,
|
114 |
+
remove_eos_from_output_src=False,
|
115 |
+
)
|
116 |
+
|
117 |
+
def assertTensorEqual(self, t1, t2):
|
118 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
119 |
+
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
unittest.main()
|