JustinLin610's picture
first commit
ee21b96
|
raw
history blame
No virus
6.01 kB

[Back]

Joint Speech Text Training for the MuST-C English to German Speech Translation task

Joint Training Baseline: it is based on paper "A general multi-task learning framework to leverage text data for speech to text tasks"

Enhanced Joint Training: the joint training is enhanced with pre-trained models, cross attentive regularization and online knowledge distillation based on paper "Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task"

Prepare Data

Download files

Prepare MuST-C data set

    python examples/speech_text_joint_to_text/scripts/g2p_encode.py \
        --lower-case --do-filter --use-word-start --no-punc \
        --reserve-word examples/speech_text_joint_to_text/configs/mustc_noise.list \
        --data-path ${must_c_en_de_src_text} \
        --out-path ${must_c_en_de_src_text_pho}
  • Update tsv data with src_text generated above and save to $MANIFEST_ROOT
  • Prepare phoneme dictionary and save to $MANIFEST_ROOT as src_dict.txt

Prepare WMT text data

  • Download wmt data
  • Convert source text (English) into phoneme representation as above
  • Generate binary parallel file for training (as translation example) and save data in $parallel_text_data

Training

The model is trained with 8 v100 GPUs.

Download pretrained models

Training scripts

  • Jointly trained model from scratch
python train.py ${MANIFEST_ROOT} \
    --save-dir ${save_dir} \
    --num-workers 8 \
    --task speech_text_joint_to_text \
    --arch dualinputs2ttransformer_s \
    --user-dir examples/speech_text_joint_to_text \
    --max-epoch 100 --update-mix-data \
    --optimizer adam --lr-scheduler inverse_sqrt \
    --lr 0.001 --update-freq 4 --clip-norm 10.0 \
    --criterion guided_label_smoothed_cross_entropy_with_accuracy \
    --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
    --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
    --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
    --dropout 0.1 --warmup-updates 20000  \
    --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
    --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
    --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
    --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
    --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
    --keep-last-epochs 10
  • Jointly trained model with good initialization, cross attentive loss and online knowledge distillation
python train.py ${MANIFEST_ROOT} \
    --save-dir ${save_dir} \
    --num-workers 8 \
    --task speech_text_joint_to_text \
    --arch dualinputs2ttransformer_m \
    --user-dir examples/speech_text_joint_to_text \
    --max-epoch 100 --update-mix-data \
    --optimizer adam --lr-scheduler inverse_sqrt \
    --lr 0.002 --update-freq 4 --clip-norm 10.0 \
    --criterion guided_label_smoothed_cross_entropy_with_accuracy \
    --guide-alpha 0.8 --disable-text-guide-update-num 5000 \
    --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
    --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
    --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
    --dropout 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \
    --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
    --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
    --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
    --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
    --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
    --load-pretrain-speech-encoder ${pretrain_encoder} \
    --load-pretrain-decoder ${pretrain_nmt} \
    --load-pretrain-text-encoder-last ${pretrain_nmt} \
    --keep-last-epochs 10

Evaluation

python ./fairseq_cli/generate.py \
        ${MANIFEST_ROOT} \
        --task speech_text_joint_to_text \
        --max-tokens 25000 \
        --nbest 1 \
        --results-path ${infer_results} \
        --batch-size 512 \
        --path ${model} \
        --gen-subset tst-COMMON \
        --config-yaml config_spm.yaml \
        --scoring sacrebleu \
        --beam 5 --lenpen 1.0 \
        --user-dir examples/speech_text_joint_to_text \
        --load-speech-only

Results (Joint training with initialization + CAR + online KD)

Direction En-De En-Es En-Fr
BLEU 27.4 31.2 37.6
checkpoint link link link