Spaces:
Runtime error
Runtime error
[[Back]](..) | |
# Joint Speech Text Training for the MuST-C English to German Speech Translation task | |
Joint Training Baseline: it is based on paper ["A general multi-task learning framework to leverage text data for speech to text tasks"](https://arxiv.org/pdf/2010.11338.pdf) | |
Enhanced Joint Training: the joint training is enhanced with pre-trained models, cross attentive regularization and online knowledge distillation based on paper ["Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task"](https://research.fb.com/publications/improving-speech-translation-by-understanding-and-learning-from-the-auxiliary-text-translation-task) | |
## Prepare Data | |
#### Download files | |
- Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/spm.model) | |
- Dictionary [dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/dict.txt) | |
- config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/config.yaml) | |
#### Prepare MuST-C data set | |
- [Please follow the data preparation in the S2T example](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mustc_example.md) | |
- Append src_text in the tsv file with phoneme representation. | |
```bash | |
python examples/speech_text_joint_to_text/scripts/g2p_encode.py \ | |
--lower-case --do-filter --use-word-start --no-punc \ | |
--reserve-word examples/speech_text_joint_to_text/configs/mustc_noise.list \ | |
--data-path ${must_c_en_de_src_text} \ | |
--out-path ${must_c_en_de_src_text_pho} | |
``` | |
- Update tsv data with src_text generated above and save to $MANIFEST_ROOT | |
- Prepare phoneme dictionary and save to $MANIFEST_ROOT as [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/src_dict.txt) | |
#### Prepare WMT text data | |
- [Download wmt data](https://github.com/pytorch/fairseq/blob/main/examples/translation/prepare-wmt14en2de.sh) | |
- Convert source text (English) into phoneme representation as above | |
- Generate binary parallel file for training (as translation example) and save data in $parallel_text_data | |
## Training | |
The model is trained with 8 v100 GPUs. | |
#### Download pretrained models | |
- [pretrain_encoder](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_asr_transformer_m.pt) | |
- [pretrain_nmt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_mt.pt) | |
#### Training scripts | |
- Jointly trained model from scratch | |
```bash | |
python train.py ${MANIFEST_ROOT} \ | |
--save-dir ${save_dir} \ | |
--num-workers 8 \ | |
--task speech_text_joint_to_text \ | |
--arch dualinputs2ttransformer_s \ | |
--user-dir examples/speech_text_joint_to_text \ | |
--max-epoch 100 --update-mix-data \ | |
--optimizer adam --lr-scheduler inverse_sqrt \ | |
--lr 0.001 --update-freq 4 --clip-norm 10.0 \ | |
--criterion guided_label_smoothed_cross_entropy_with_accuracy \ | |
--label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \ | |
--max-positions-text 400 --seed 2 --speech-encoder-layers 12 \ | |
--text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \ | |
--dropout 0.1 --warmup-updates 20000 \ | |
--text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \ | |
--text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \ | |
--log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \ | |
--mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \ | |
--log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \ | |
--keep-last-epochs 10 | |
``` | |
- Jointly trained model with good initialization, cross attentive loss and online knowledge distillation | |
```bash | |
python train.py ${MANIFEST_ROOT} \ | |
--save-dir ${save_dir} \ | |
--num-workers 8 \ | |
--task speech_text_joint_to_text \ | |
--arch dualinputs2ttransformer_m \ | |
--user-dir examples/speech_text_joint_to_text \ | |
--max-epoch 100 --update-mix-data \ | |
--optimizer adam --lr-scheduler inverse_sqrt \ | |
--lr 0.002 --update-freq 4 --clip-norm 10.0 \ | |
--criterion guided_label_smoothed_cross_entropy_with_accuracy \ | |
--guide-alpha 0.8 --disable-text-guide-update-num 5000 \ | |
--label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \ | |
--max-positions-text 400 --seed 2 --speech-encoder-layers 12 \ | |
--text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \ | |
--dropout 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \ | |
--text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \ | |
--text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \ | |
--log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \ | |
--mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \ | |
--log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \ | |
--load-pretrain-speech-encoder ${pretrain_encoder} \ | |
--load-pretrain-decoder ${pretrain_nmt} \ | |
--load-pretrain-text-encoder-last ${pretrain_nmt} \ | |
--keep-last-epochs 10 | |
``` | |
## Evaluation | |
```bash | |
python ./fairseq_cli/generate.py \ | |
${MANIFEST_ROOT} \ | |
--task speech_text_joint_to_text \ | |
--max-tokens 25000 \ | |
--nbest 1 \ | |
--results-path ${infer_results} \ | |
--batch-size 512 \ | |
--path ${model} \ | |
--gen-subset tst-COMMON \ | |
--config-yaml config_spm.yaml \ | |
--scoring sacrebleu \ | |
--beam 5 --lenpen 1.0 \ | |
--user-dir examples/speech_text_joint_to_text \ | |
--load-speech-only | |
``` | |
## Results (Joint training with initialization + CAR + online KD) | |
|Direction|En-De | En-Es | En-Fr | | |
|---|---|---|---| | |
|BLEU|27.4| 31.2 | 37.6 | | |
|checkpoint | [link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_ave_10.pt) |[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_es/checkpoint_ave_10.pt)|[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_fr/checkpoint_ave_10.pt)| | |