Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py +101 -0
- fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py +224 -0
- fairseq/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py +318 -0
- fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md +118 -0
- fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md +76 -0
- fairseq/examples/speech_text_joint_to_text/docs/pre-training.md +192 -0
- fairseq/examples/speech_text_joint_to_text/models/__init__.py +8 -0
- fairseq/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py +698 -0
- fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py +1093 -0
- fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py +526 -0
- fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py +584 -0
- fairseq/examples/speech_text_joint_to_text/scripts/convert_model.py +71 -0
- fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py +191 -0
- fairseq/examples/speech_text_joint_to_text/tasks/__init__.py +8 -0
- fairseq/examples/speech_text_joint_to_text/tasks/pair_denoising.py +447 -0
- fairseq/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py +654 -0
- fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py +377 -0
- fairseq/examples/speech_to_speech/README.md +7 -0
- fairseq/examples/speech_to_speech/__init__.py +6 -0
- fairseq/examples/speech_to_speech/asr_bleu/README.md +34 -0
- fairseq/examples/speech_to_speech/asr_bleu/__init__.py +0 -0
- fairseq/examples/speech_to_speech/asr_bleu/asr_model_cfgs.json +198 -0
- fairseq/examples/speech_to_speech/asr_bleu/compute_asr_bleu.py +244 -0
- fairseq/examples/speech_to_speech/asr_bleu/requirements.txt +7 -0
- fairseq/examples/speech_to_speech/asr_bleu/utils.py +306 -0
- fairseq/examples/speech_to_speech/benchmarking/README.md +31 -0
- fairseq/examples/speech_to_speech/benchmarking/configs/2StageS2ST.yaml +19 -0
- fairseq/examples/speech_to_speech/benchmarking/configs/3StageS2ST.yaml +28 -0
- fairseq/examples/speech_to_speech/benchmarking/configs/DirectS2U.yaml +22 -0
- fairseq/examples/speech_to_speech/benchmarking/configs/S2T.yaml +13 -0
- fairseq/examples/speech_to_speech/benchmarking/core.py +487 -0
- fairseq/examples/speech_to_speech/benchmarking/data_utils.py +264 -0
- fairseq/examples/speech_to_speech/benchmarking/get_metrics.py +162 -0
- fairseq/examples/speech_to_speech/docs/data_augmentation.md +435 -0
- fairseq/examples/speech_to_speech/docs/direct_s2st_discrete_units.md +181 -0
- fairseq/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md +125 -0
- fairseq/examples/speech_to_speech/docs/textless_s2st_real_data.md +89 -0
- fairseq/examples/speech_to_speech/generate_waveform_from_code.py +116 -0
- fairseq/examples/speech_to_speech/preprocessing/__init__.py +4 -0
- fairseq/examples/speech_to_speech/preprocessing/data_utils.py +88 -0
- fairseq/examples/speech_to_speech/preprocessing/prep_s2spect_data.py +169 -0
- fairseq/examples/speech_to_speech/preprocessing/prep_s2ut_data.py +114 -0
- fairseq/examples/speech_to_speech/preprocessing/prep_sn_data.py +88 -0
- fairseq/examples/speech_to_speech/preprocessing/prep_sn_output_data.py +58 -0
- fairseq/examples/speech_to_speech/unity/__init__.py +7 -0
- fairseq/examples/speech_to_speech/unity/sequence_generator.py +626 -0
- fairseq/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py +267 -0
- fairseq/examples/speech_to_text/README.md +77 -0
- fairseq/examples/speech_to_text/data_utils.py +383 -0
- fairseq/examples/speech_to_text/docs/covost_example.md +140 -0
fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from fairseq import utils
|
8 |
+
from fairseq.criterions import register_criterion
|
9 |
+
from fairseq.criterions.label_smoothed_cross_entropy import (
|
10 |
+
LabelSmoothedCrossEntropyCriterion,
|
11 |
+
LabelSmoothedCrossEntropyCriterionConfig,
|
12 |
+
label_smoothed_nll_loss,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
@register_criterion(
|
17 |
+
"speech_text_pretrain_cross_entropy",
|
18 |
+
dataclass=LabelSmoothedCrossEntropyCriterionConfig,
|
19 |
+
)
|
20 |
+
class SpeechTextPreTrainCrossEntCriterion(LabelSmoothedCrossEntropyCriterion):
|
21 |
+
def __init__(self, task, sentence_avg, label_smoothing, report_accuracy=False):
|
22 |
+
super().__init__(
|
23 |
+
task, sentence_avg, label_smoothing, report_accuracy=report_accuracy
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, model, sample, reduce=True):
|
27 |
+
net_output = model(**sample["net_input"])
|
28 |
+
loss, nll_loss, nsentences, ntokens, n_correct = self.compute_loss(
|
29 |
+
model, net_output, sample, reduce=reduce
|
30 |
+
)
|
31 |
+
sample_size = nsentences if self.sentence_avg else ntokens
|
32 |
+
logging_output = {
|
33 |
+
"loss": loss.data,
|
34 |
+
"nll_loss": nll_loss.data,
|
35 |
+
"ntokens": ntokens,
|
36 |
+
"nsentences": nsentences,
|
37 |
+
"sample_size": sample_size,
|
38 |
+
}
|
39 |
+
if self.report_accuracy:
|
40 |
+
logging_output["n_correct"] = utils.item(n_correct)
|
41 |
+
logging_output["total"] = utils.item(ntokens)
|
42 |
+
return loss, sample_size, logging_output
|
43 |
+
|
44 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
45 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
46 |
+
target = model.get_targets(sample, net_output)
|
47 |
+
assert self.ignore_prefix_size == 0
|
48 |
+
if self.ignore_prefix_size > 0:
|
49 |
+
if getattr(lprobs, "batch_first", False):
|
50 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
51 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
52 |
+
else:
|
53 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
54 |
+
target = target[self.ignore_prefix_size :, :].contiguous()
|
55 |
+
return lprobs, target
|
56 |
+
|
57 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
58 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
59 |
+
n_correct = 0
|
60 |
+
if isinstance(target, dict):
|
61 |
+
t_lprobs = target["target_logprobs"]
|
62 |
+
|
63 |
+
if not lprobs.batch_first:
|
64 |
+
lprobs = lprobs.transpose(0, 1)
|
65 |
+
t_lprobs = t_lprobs.transpose(0, 1)
|
66 |
+
nsentences, seq_len = lprobs.size()[:2]
|
67 |
+
ntokens = nsentences * seq_len
|
68 |
+
t_probs = t_lprobs.exp()
|
69 |
+
mask_indices = (
|
70 |
+
net_output[1]["mask_indices"][0]
|
71 |
+
if len(net_output[1]["mask_indices"]) > 0
|
72 |
+
else None
|
73 |
+
)
|
74 |
+
|
75 |
+
# mask_indices is True for those masking frames
|
76 |
+
if mask_indices is not None: # B X T
|
77 |
+
t_probs = t_probs.masked_fill(mask_indices.eq(False).unsqueeze(-1), 0)
|
78 |
+
ntokens = mask_indices.int().sum()
|
79 |
+
t_probs = t_probs.detach()
|
80 |
+
t_lprobs = t_lprobs.detach()
|
81 |
+
loss = (
|
82 |
+
-(t_probs * (lprobs - t_lprobs)).sum()
|
83 |
+
if reduce
|
84 |
+
else -(t_probs * (lprobs - t_lprobs)).sum(-1, keepdim=True)
|
85 |
+
)
|
86 |
+
nll_loss = loss
|
87 |
+
else:
|
88 |
+
nsentences = target.size(0)
|
89 |
+
mask = target.ne(self.padding_idx)
|
90 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
91 |
+
lprobs.view(-1, lprobs.size(-1)),
|
92 |
+
target.view(-1),
|
93 |
+
self.eps,
|
94 |
+
ignore_index=self.padding_idx,
|
95 |
+
reduce=reduce,
|
96 |
+
)
|
97 |
+
n_correct = torch.sum(
|
98 |
+
lprobs.argmax(-1).masked_select(mask).eq(target.masked_select(mask))
|
99 |
+
)
|
100 |
+
ntokens = torch.sum(mask)
|
101 |
+
return loss, nll_loss, nsentences, ntokens, n_correct
|
fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from fairseq import utils
|
10 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
11 |
+
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
|
12 |
+
from fairseq.logging import metrics
|
13 |
+
|
14 |
+
|
15 |
+
@register_criterion("guided_label_smoothed_cross_entropy_with_accuracy")
|
16 |
+
class GuidedCrossEntAccCriterion(FairseqCriterion):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
task,
|
20 |
+
sentence_avg,
|
21 |
+
guide_alpha,
|
22 |
+
text_input_cost_ratio,
|
23 |
+
label_smoothing,
|
24 |
+
disable_text_guide_update_num=0,
|
25 |
+
attentive_cost_regularization=0,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
guide_alpha: alpha to inteplate nll and kd loss
|
29 |
+
text_input_cost_ratio: loss ratio for text only input data
|
30 |
+
label_smoothing: label smoothing ratio
|
31 |
+
disable_text_guide_update_num: only use nll loss for the first N updates
|
32 |
+
attentive_cost_regularization: ratio fo attentive cost
|
33 |
+
"""
|
34 |
+
super().__init__(task)
|
35 |
+
self.alpha = guide_alpha
|
36 |
+
self.attn_beta = attentive_cost_regularization
|
37 |
+
self.sentence_avg = sentence_avg
|
38 |
+
self.eps = label_smoothing
|
39 |
+
self.text_input_cost_ratio = text_input_cost_ratio
|
40 |
+
self.disable_update_num = disable_text_guide_update_num
|
41 |
+
assert self.alpha >= 0 and self.alpha <= 1.0
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def add_args(parser):
|
45 |
+
"""Add criterion-specific arguments to the parser."""
|
46 |
+
# fmt: off
|
47 |
+
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
|
48 |
+
help='epsilon for label smoothing, 0 means no label smoothing')
|
49 |
+
# fmt: off
|
50 |
+
parser.add_argument('--guide-alpha', default=0., type=float, metavar='D',
|
51 |
+
help='alpha to merge kd cost from text to speech input with ce loss')
|
52 |
+
# fmt: off
|
53 |
+
parser.add_argument('--disable-text-guide-update-num', default=0, type=int, metavar='D',
|
54 |
+
help='disable guided target from text for the first N updates.')
|
55 |
+
parser.add_argument("--attentive-cost-regularization", default=0.0, type=float, metavar='D',
|
56 |
+
help="use encoder attentive loss regularization with cost ratio D")
|
57 |
+
parser.add_argument("--attentive-cost-without-normalize", action='store_true',
|
58 |
+
help="Don't do normalization during attentive cost computation")
|
59 |
+
|
60 |
+
def forward(self, model, sample, reduce=True):
|
61 |
+
reduction = 'sum' if reduce else 'none'
|
62 |
+
net_input = sample["net_input"]
|
63 |
+
net_output = model(**net_input)
|
64 |
+
attn_cost = None
|
65 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
66 |
+
is_dual_input = True if net_input['src_tokens'] is not None and net_input.get('src_txt_tokens') is not None else False
|
67 |
+
target = model.get_targets(sample, net_output)
|
68 |
+
src_token_num = 0
|
69 |
+
if is_dual_input:
|
70 |
+
# lprobs_spch from speech encoder and lprobs_text from text encoder
|
71 |
+
lprobs_spch, lprobs_text = torch.chunk(lprobs, 2)
|
72 |
+
lprobs_spch.batch_first = lprobs.batch_first
|
73 |
+
lprobs_text.batch_first = lprobs.batch_first
|
74 |
+
|
75 |
+
speech_loss, speech_nll_loss, speech_correct, speech_total = \
|
76 |
+
self.guide_loss_and_acc(model, lprobs_spch, lprobs_text, target, reduce=(reduction == 'sum'))
|
77 |
+
text_loss, text_nll_loss, text_correct, text_total = self.compute_loss_and_acc(model, lprobs_text, target, reduction=reduction)
|
78 |
+
loss = (speech_loss + text_loss)
|
79 |
+
nll_loss = (speech_nll_loss + text_nll_loss)
|
80 |
+
correct = speech_correct + text_correct
|
81 |
+
total = speech_total + text_total
|
82 |
+
|
83 |
+
attn_cost = net_output[1].get('attn_cost')
|
84 |
+
if attn_cost is not None:
|
85 |
+
# attn_cost is batch_first and padding tokens have been masked already
|
86 |
+
src_token_num = attn_cost.ne(0).sum()
|
87 |
+
attn_cost = attn_cost.sum()
|
88 |
+
loss = loss + attn_cost * self.attn_beta
|
89 |
+
else:
|
90 |
+
attn_cost = 0
|
91 |
+
else:
|
92 |
+
loss, nll_loss, correct, total = self.compute_loss_and_acc(model, lprobs, target, reduction=reduction)
|
93 |
+
if sample["net_input"]['src_tokens'] is None: # text input only
|
94 |
+
loss = loss * self.text_input_cost_ratio
|
95 |
+
speech_loss = None
|
96 |
+
speech_nll_loss = None
|
97 |
+
|
98 |
+
sample_size, logging_output = self.get_logging_output(
|
99 |
+
sample, loss, nll_loss, correct, total, src_token_num, speech_loss, speech_nll_loss, attn_cost, is_dual_input
|
100 |
+
)
|
101 |
+
return loss, sample_size, logging_output
|
102 |
+
|
103 |
+
def compute_loss_and_acc(self, model, lprobs, target, reduction='sum'):
|
104 |
+
if not lprobs.batch_first:
|
105 |
+
lprobs = lprobs.transpose(0, 1)
|
106 |
+
lprobs = lprobs.view(-1, lprobs.size(-1)) # -> (B x T) x C
|
107 |
+
target = target.view(-1)
|
108 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
109 |
+
lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=(reduction == 'sum'),
|
110 |
+
)
|
111 |
+
|
112 |
+
mask = target.ne(self.padding_idx)
|
113 |
+
correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
|
114 |
+
total = torch.sum(mask)
|
115 |
+
return loss, nll_loss, correct, total
|
116 |
+
|
117 |
+
def guide_loss_and_acc(self, model, lprobs, lprobs_teacher, target, reduce=True):
|
118 |
+
""" lprobs_teacher is used as guide for lprobs """
|
119 |
+
if self.alpha == 0.0 or model.num_updates < self.disable_update_num:
|
120 |
+
return self.compute_loss_and_acc(model, lprobs, target, reduction=('sum' if reduce else 'none'))
|
121 |
+
if not lprobs.batch_first:
|
122 |
+
lprobs = lprobs.transpose(0, 1)
|
123 |
+
lprobs_teacher = lprobs_teacher.transpose(0, 1)
|
124 |
+
|
125 |
+
lprobs = lprobs.view(-1, lprobs.size(-1)).float() # -> (B x T) x C
|
126 |
+
lprobs_teacher = lprobs_teacher.view(-1, lprobs_teacher.size(-1)).float() # -> (B x T) x C
|
127 |
+
target = target.view(-1)
|
128 |
+
loss = F.nll_loss(lprobs, target, ignore_index=self.padding_idx, reduction='sum' if reduce else 'none')
|
129 |
+
nll_loss = loss
|
130 |
+
probs_teacher = lprobs_teacher.exp().masked_fill_(target.unsqueeze(-1).eq(self.padding_idx), 0)
|
131 |
+
probs_teacher = probs_teacher.detach()
|
132 |
+
guide_loss = -(probs_teacher*lprobs).sum() if reduce else -(probs_teacher*lprobs).sum(-1, keepdim=True)
|
133 |
+
loss = self.alpha*guide_loss + (1.0 - self.alpha)*loss
|
134 |
+
|
135 |
+
mask = target.ne(self.padding_idx)
|
136 |
+
correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
|
137 |
+
total = torch.sum(mask)
|
138 |
+
return loss, nll_loss, correct, total
|
139 |
+
|
140 |
+
def get_logging_output(
|
141 |
+
self,
|
142 |
+
sample,
|
143 |
+
loss,
|
144 |
+
nll_loss,
|
145 |
+
correct,
|
146 |
+
total,
|
147 |
+
src_token_num=0,
|
148 |
+
speech_loss=None,
|
149 |
+
speech_nll_loss=None,
|
150 |
+
attn_cost=None,
|
151 |
+
is_dual_input=False,
|
152 |
+
):
|
153 |
+
|
154 |
+
sample_size = (
|
155 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
156 |
+
)
|
157 |
+
mul_size = 2 if is_dual_input else 1
|
158 |
+
|
159 |
+
logging_output = {
|
160 |
+
"loss": utils.item(loss.data), # * sample['ntokens'],
|
161 |
+
"nll_loss": utils.item(nll_loss.data), # * sample['ntokens'],
|
162 |
+
"ntokens": sample["ntokens"]*mul_size,
|
163 |
+
"nsentences": sample["target"].size(0)*mul_size,
|
164 |
+
"sample_size": sample_size*mul_size,
|
165 |
+
"correct": utils.item(correct.data),
|
166 |
+
"total": utils.item(total.data),
|
167 |
+
"src_token_num": utils.item(src_token_num.data) if src_token_num > 0 else 0,
|
168 |
+
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
|
169 |
+
}
|
170 |
+
|
171 |
+
if speech_loss is not None:
|
172 |
+
logging_output["speech_loss"] = utils.item(speech_loss.data)
|
173 |
+
logging_output["speech_nll_loss"] = utils.item(speech_nll_loss.data)
|
174 |
+
logging_output["sample_size_speech_cost"] = sample_size
|
175 |
+
logging_output["speech_attn_loss"] = attn_cost
|
176 |
+
|
177 |
+
return sample_size*mul_size, logging_output
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def aggregate_logging_outputs(logging_outputs):
|
181 |
+
"""Aggregate logging outputs from data parallel training."""
|
182 |
+
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
|
183 |
+
total_sum = sum(log.get("total", 0) for log in logging_outputs)
|
184 |
+
src_token_sum = sum(log.get("src_token_num", 0) for log in logging_outputs)
|
185 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
186 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
187 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
188 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
189 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
190 |
+
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
|
191 |
+
speech_loss_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
|
192 |
+
speech_nll_loss_sum = sum(log.get("speech_nll_loss", 0) for log in logging_outputs)
|
193 |
+
speech_attn_loss_sum = sum(log.get("speech_attn_loss", 0) for log in logging_outputs)
|
194 |
+
sample_size_speech = sum(log.get("sample_size_speech_cost", 0) for log in logging_outputs)
|
195 |
+
|
196 |
+
agg_output = {
|
197 |
+
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
198 |
+
"nll_loss": nll_loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
199 |
+
# if args.sentence_avg, then sample_size is nsentences, and loss
|
200 |
+
# is per-sentence loss; else sample_size is ntokens, and the loss
|
201 |
+
# becomes per-output token loss
|
202 |
+
"speech_loss": speech_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
|
203 |
+
"speech_nll_loss": speech_nll_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
|
204 |
+
"speech_attn_loss": speech_attn_loss_sum / src_token_sum / math.log(2) if src_token_sum > 0 else 0.0,
|
205 |
+
"ntokens": ntokens,
|
206 |
+
"nsentences": nsentences,
|
207 |
+
"nframes": nframes,
|
208 |
+
"sample_size": sample_size,
|
209 |
+
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
|
210 |
+
"correct": correct_sum,
|
211 |
+
"total": total_sum,
|
212 |
+
"src_token_num": src_token_sum,
|
213 |
+
# total is the number of validate tokens
|
214 |
+
}
|
215 |
+
return agg_output
|
216 |
+
|
217 |
+
@classmethod
|
218 |
+
def reduce_metrics(cls, logging_outputs):
|
219 |
+
"""Aggregate logging outputs from data parallel training."""
|
220 |
+
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
|
221 |
+
for k, v in agg_logging_outputs.items():
|
222 |
+
if k in {'nsentences', 'ntokens', 'sample_size'}:
|
223 |
+
continue
|
224 |
+
metrics.log_scalar(k, v, round=3)
|
fairseq/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import math
|
8 |
+
import re
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from fairseq.data import data_utils
|
13 |
+
from fairseq.data.language_pair_dataset import LanguagePairDataset
|
14 |
+
|
15 |
+
|
16 |
+
# Part of the code is modified from DenoisingDataset
|
17 |
+
# compared with DenoisingDataset, no permute_sentences or documents (rotate_ratio, permute_sentence_ratio)
|
18 |
+
class LanguagePairDenoisingDataset(LanguagePairDataset):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
src,
|
22 |
+
src_sizes,
|
23 |
+
src_dict,
|
24 |
+
tgt,
|
25 |
+
tgt_sizes,
|
26 |
+
tgt_dict,
|
27 |
+
mask_idx,
|
28 |
+
mask_whole_words,
|
29 |
+
seed,
|
30 |
+
args,
|
31 |
+
left_pad_source=True,
|
32 |
+
left_pad_target=False,
|
33 |
+
shuffle=True,
|
34 |
+
input_feeding=True,
|
35 |
+
remove_eos_from_source=False,
|
36 |
+
append_eos_to_target=False,
|
37 |
+
align_dataset=None,
|
38 |
+
constraints=None,
|
39 |
+
append_bos=False,
|
40 |
+
eos=None,
|
41 |
+
num_buckets=0,
|
42 |
+
src_lang_id=None,
|
43 |
+
tgt_lang_id=None,
|
44 |
+
pad_to_multiple=1,
|
45 |
+
):
|
46 |
+
super().__init__(
|
47 |
+
src,
|
48 |
+
src_sizes,
|
49 |
+
src_dict,
|
50 |
+
tgt,
|
51 |
+
tgt_sizes,
|
52 |
+
tgt_dict,
|
53 |
+
left_pad_source,
|
54 |
+
left_pad_target,
|
55 |
+
shuffle,
|
56 |
+
input_feeding,
|
57 |
+
remove_eos_from_source,
|
58 |
+
append_eos_to_target,
|
59 |
+
align_dataset,
|
60 |
+
constraints,
|
61 |
+
append_bos,
|
62 |
+
eos,
|
63 |
+
num_buckets,
|
64 |
+
src_lang_id,
|
65 |
+
tgt_lang_id,
|
66 |
+
pad_to_multiple,
|
67 |
+
)
|
68 |
+
|
69 |
+
self.mask_idx = mask_idx
|
70 |
+
self.mask_whole_word = mask_whole_words
|
71 |
+
self.mask_ratio = args.mask
|
72 |
+
self.random_ratio = args.mask_random
|
73 |
+
self.insert_ratio = args.insert
|
74 |
+
|
75 |
+
self.replace_length = args.replace_length
|
76 |
+
|
77 |
+
if self.replace_length not in [-1, 0, 1]:
|
78 |
+
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
|
79 |
+
if args.mask_length not in ["subword", "word", "span-poisson"]:
|
80 |
+
raise ValueError(f"invalid arg: mask-length={args.mask_length}")
|
81 |
+
if args.mask_length == "subword" and args.replace_length not in [0, 1]:
|
82 |
+
raise ValueError("if using subwords, use replace-length=1 or 0")
|
83 |
+
|
84 |
+
self.mask_span_distribution = None
|
85 |
+
if args.mask_length == "span-poisson":
|
86 |
+
# Text infilling: "A number of text spans are sampled, with span lengths drawn from a Poisson distribution (λ = 3). Each span is replaced with a single [MASK] token. 0-length spans correspond to the insertion of [MASK] tokens."
|
87 |
+
_lambda = args.poisson_lambda
|
88 |
+
|
89 |
+
lambda_to_the_k = 1
|
90 |
+
e_to_the_minus_lambda = math.exp(-_lambda)
|
91 |
+
k_factorial = 1
|
92 |
+
ps = []
|
93 |
+
for k in range(0, 128):
|
94 |
+
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
95 |
+
lambda_to_the_k *= _lambda
|
96 |
+
k_factorial *= k + 1
|
97 |
+
if ps[-1] < 0.0000001:
|
98 |
+
break
|
99 |
+
ps = torch.FloatTensor(ps)
|
100 |
+
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
101 |
+
|
102 |
+
self.epoch = 0
|
103 |
+
self.seed = seed
|
104 |
+
|
105 |
+
def _is_phoneme(x):
|
106 |
+
if re.search("<lang:", x) or x in (
|
107 |
+
"<mask>",
|
108 |
+
"<sil>",
|
109 |
+
"<pad>",
|
110 |
+
"<s>",
|
111 |
+
"</s>",
|
112 |
+
"<unk>",
|
113 |
+
):
|
114 |
+
return False
|
115 |
+
return True
|
116 |
+
|
117 |
+
self.voc_valid_ids = torch.LongTensor(
|
118 |
+
[i for i, x in enumerate(self.src_dict.symbols) if _is_phoneme(x)]
|
119 |
+
)
|
120 |
+
self.voc_valid_size = self.voc_valid_ids.size(0)
|
121 |
+
|
122 |
+
@property
|
123 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
124 |
+
return False
|
125 |
+
|
126 |
+
def set_epoch(self, epoch, **unused):
|
127 |
+
self.epoch = epoch
|
128 |
+
|
129 |
+
def __getitem__(self, index):
|
130 |
+
tgt_item = self.tgt[index] if self.tgt is not None else None
|
131 |
+
src_item = copy.deepcopy(self.src[index])
|
132 |
+
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
133 |
+
source = src_item
|
134 |
+
assert source[-1] == self.eos
|
135 |
+
if self.mask_ratio > 0:
|
136 |
+
source = self.add_whole_word_mask(source, self.mask_ratio)
|
137 |
+
|
138 |
+
if self.insert_ratio > 0:
|
139 |
+
source = self.add_insertion_noise(source, self.insert_ratio)
|
140 |
+
src_item = source
|
141 |
+
|
142 |
+
if self.append_eos_to_target:
|
143 |
+
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
|
144 |
+
if self.tgt and self.tgt[index][-1] != eos:
|
145 |
+
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
|
146 |
+
|
147 |
+
if self.append_bos:
|
148 |
+
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
|
149 |
+
if self.tgt and self.tgt[index][0] != bos:
|
150 |
+
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
|
151 |
+
|
152 |
+
bos = self.src_dict.bos()
|
153 |
+
if src_item[0] != bos:
|
154 |
+
src_item = torch.cat([torch.LongTensor([bos]), src_item])
|
155 |
+
|
156 |
+
if self.remove_eos_from_source:
|
157 |
+
eos = self.src_dict.eos()
|
158 |
+
if src_item[-1] == eos:
|
159 |
+
src_item = src_item[:-1]
|
160 |
+
|
161 |
+
example = {
|
162 |
+
"id": index,
|
163 |
+
"source": src_item,
|
164 |
+
"target": tgt_item,
|
165 |
+
}
|
166 |
+
if self.align_dataset is not None:
|
167 |
+
example["alignment"] = self.align_dataset[index]
|
168 |
+
if self.constraints is not None:
|
169 |
+
example["constraints"] = self.constraints[index]
|
170 |
+
if self.src_lang_id is not None:
|
171 |
+
example["src_lang_id"] = self.src_lang_id
|
172 |
+
if self.tgt_lang_id is not None:
|
173 |
+
example["tgt_lang_id"] = self.tgt_lang_id
|
174 |
+
return example
|
175 |
+
|
176 |
+
# following functions are borrowed from denoising_dataset
|
177 |
+
def word_starts(self, source):
|
178 |
+
if self.mask_whole_word is not None:
|
179 |
+
is_word_start = self.mask_whole_word.gather(0, source)
|
180 |
+
else:
|
181 |
+
is_word_start = torch.ones(source.size())
|
182 |
+
is_word_start[0] = 0
|
183 |
+
is_word_start[-1] = 0
|
184 |
+
return is_word_start
|
185 |
+
|
186 |
+
def add_whole_word_mask(self, source, p):
|
187 |
+
is_word_start = self.word_starts(source)
|
188 |
+
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
|
189 |
+
num_inserts = 0
|
190 |
+
if num_to_mask == 0:
|
191 |
+
return source
|
192 |
+
|
193 |
+
if self.mask_span_distribution is not None:
|
194 |
+
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
195 |
+
|
196 |
+
# Make sure we have enough to mask
|
197 |
+
cum_length = torch.cumsum(lengths, 0)
|
198 |
+
while cum_length[-1] < num_to_mask:
|
199 |
+
lengths = torch.cat(
|
200 |
+
[
|
201 |
+
lengths,
|
202 |
+
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
203 |
+
],
|
204 |
+
dim=0,
|
205 |
+
)
|
206 |
+
cum_length = torch.cumsum(lengths, 0)
|
207 |
+
|
208 |
+
# Trim to masking budget
|
209 |
+
i = 0
|
210 |
+
while cum_length[i] < num_to_mask:
|
211 |
+
i += 1
|
212 |
+
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
213 |
+
num_to_mask = i + 1
|
214 |
+
lengths = lengths[:num_to_mask]
|
215 |
+
|
216 |
+
# Handle 0-length mask (inserts) separately
|
217 |
+
lengths = lengths[lengths > 0]
|
218 |
+
num_inserts = num_to_mask - lengths.size(0)
|
219 |
+
num_to_mask -= num_inserts
|
220 |
+
if num_to_mask == 0:
|
221 |
+
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
222 |
+
|
223 |
+
assert (lengths > 0).all()
|
224 |
+
else:
|
225 |
+
lengths = torch.ones((num_to_mask,)).long()
|
226 |
+
assert is_word_start[-1] == 0
|
227 |
+
word_starts = is_word_start.nonzero(as_tuple=False)
|
228 |
+
indices = word_starts[
|
229 |
+
torch.randperm(word_starts.size(0))[:num_to_mask]
|
230 |
+
].squeeze(1)
|
231 |
+
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
232 |
+
|
233 |
+
source_length = source.size(0)
|
234 |
+
assert source_length - 1 not in indices
|
235 |
+
to_keep = torch.ones(source_length, dtype=torch.bool)
|
236 |
+
is_word_start[
|
237 |
+
-1
|
238 |
+
] = 255 # acts as a long length, so spans don't go over the end of doc
|
239 |
+
if self.replace_length == 0:
|
240 |
+
to_keep[indices] = 0
|
241 |
+
else:
|
242 |
+
# keep index, but replace it with [MASK]
|
243 |
+
source[indices] = self.mask_idx
|
244 |
+
source[indices[mask_random]] = self.voc_valid_ids[
|
245 |
+
torch.randint(0, self.voc_valid_size - 1, size=(mask_random.sum(),))
|
246 |
+
]
|
247 |
+
|
248 |
+
if self.mask_span_distribution is not None:
|
249 |
+
assert len(lengths.size()) == 1
|
250 |
+
assert lengths.size() == indices.size()
|
251 |
+
lengths -= 1
|
252 |
+
while indices.size(0) > 0:
|
253 |
+
assert lengths.size() == indices.size()
|
254 |
+
lengths -= is_word_start[indices + 1].long()
|
255 |
+
uncompleted = lengths >= 0
|
256 |
+
indices = indices[uncompleted] + 1
|
257 |
+
mask_random = mask_random[uncompleted]
|
258 |
+
lengths = lengths[uncompleted]
|
259 |
+
if self.replace_length != -1:
|
260 |
+
# delete token
|
261 |
+
to_keep[indices] = 0
|
262 |
+
else:
|
263 |
+
# keep index, but replace it with [MASK]
|
264 |
+
source[indices] = self.mask_idx
|
265 |
+
source[indices[mask_random]] = self.voc_valid_ids[
|
266 |
+
torch.randint(
|
267 |
+
0, self.voc_valid_size - 1, size=(mask_random.sum(),)
|
268 |
+
)
|
269 |
+
]
|
270 |
+
else:
|
271 |
+
# A bit faster when all lengths are 1
|
272 |
+
while indices.size(0) > 0:
|
273 |
+
uncompleted = is_word_start[indices + 1] == 0
|
274 |
+
indices = indices[uncompleted] + 1
|
275 |
+
mask_random = mask_random[uncompleted]
|
276 |
+
if self.replace_length != -1:
|
277 |
+
# delete token
|
278 |
+
to_keep[indices] = 0
|
279 |
+
else:
|
280 |
+
# keep index, but replace it with [MASK]
|
281 |
+
source[indices] = self.mask_idx
|
282 |
+
source[indices[mask_random]] = self.voc_valid_ids[
|
283 |
+
torch.randint(
|
284 |
+
0, self.voc_valid_size - 1, size=(mask_random.sum(),)
|
285 |
+
)
|
286 |
+
]
|
287 |
+
|
288 |
+
assert source_length - 1 not in indices
|
289 |
+
|
290 |
+
source = source[to_keep]
|
291 |
+
|
292 |
+
if num_inserts > 0:
|
293 |
+
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
294 |
+
|
295 |
+
return source
|
296 |
+
|
297 |
+
def add_insertion_noise(self, tokens, p):
|
298 |
+
if p == 0.0:
|
299 |
+
return tokens
|
300 |
+
|
301 |
+
num_tokens = len(tokens)
|
302 |
+
n = int(math.ceil(num_tokens * p))
|
303 |
+
|
304 |
+
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
305 |
+
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
306 |
+
noise_mask[noise_indices] = 1
|
307 |
+
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
308 |
+
|
309 |
+
num_random = int(math.ceil(n * self.random_ratio))
|
310 |
+
result[noise_indices[num_random:]] = self.mask_idx
|
311 |
+
result[noise_indices[:num_random]] = self.voc_valid_ids[
|
312 |
+
torch.randint(0, self.voc_valid_size - 1, size=(num_random,))
|
313 |
+
]
|
314 |
+
|
315 |
+
result[~noise_mask] = tokens
|
316 |
+
|
317 |
+
assert (result >= 0).all()
|
318 |
+
return result
|
fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[Back]](..)
|
2 |
+
|
3 |
+
# Joint Speech Text Training for the MuST-C English to German Speech Translation task
|
4 |
+
|
5 |
+
Joint Training Baseline: it is based on paper ["A general multi-task learning framework to leverage text data for speech to text tasks"](https://arxiv.org/pdf/2010.11338.pdf)
|
6 |
+
|
7 |
+
Enhanced Joint Training: the joint training is enhanced with pre-trained models, cross attentive regularization and online knowledge distillation based on paper ["Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task"](https://research.fb.com/publications/improving-speech-translation-by-understanding-and-learning-from-the-auxiliary-text-translation-task)
|
8 |
+
|
9 |
+
## Prepare Data
|
10 |
+
#### Download files
|
11 |
+
- Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/spm.model)
|
12 |
+
- Dictionary [dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/dict.txt)
|
13 |
+
- config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/config.yaml)
|
14 |
+
#### Prepare MuST-C data set
|
15 |
+
- Please follow the data preparation in the [S2T example](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mustc_example.md)
|
16 |
+
- Convert source text under the "src_text" column in the tsv file into phoneme representation.
|
17 |
+
```bash
|
18 |
+
python examples/speech_text_joint_to_text/scripts/g2p_encode.py \
|
19 |
+
--lower-case --do-filter --use-word-start --no-punc \
|
20 |
+
--reserve-word examples/speech_text_joint_to_text/configs/mustc_noise.list \
|
21 |
+
--data-path ${must_c_en_de_src_text} \
|
22 |
+
--out-path ${must_c_en_de_src_text_pho}
|
23 |
+
```
|
24 |
+
- Replace the source text under the "src_text" column in the tsv file with the corresponding phoneme reprentation generated in the step above.
|
25 |
+
Below is the snapshot for the MuST-C en-de dev tsv
|
26 |
+
```
|
27 |
+
id audio n_frames tgt_text src_text speaker
|
28 |
+
ted_767_0 en-de/flac.zip:10071514743:48445 56160 Heute spreche ich zu Ihnen über Energie und Klima. ▁AY1 M ▁G OW1 IH0 NG ▁T UW1 ▁T AO1 K ▁T AH0 D EY1 ▁AH0 B AW1 T ▁EH1 N ER0 JH IY0 ▁AH0 N D ▁K L AY1 M AH0 T spk.767_
|
29 |
+
ted_767_1 en-de/flac.zip:1214217978:205678 226080 Und das überrascht vielleicht etwas, weil sich meine Vollzeitbeschäftigung bei der Stiftung hauptsächlich um Impfstoffe und Saatgut dreht, um die Dinge, die wir erfinden und liefern müssen um den ärmsten 2 Milliarden ein besseres Leben zu ermöglichen. ▁AH0 N D ▁DH AE1 T ▁M AY1 T ▁S IY1 M ▁AH0 ▁B IH1 T ▁S ER0 P R AY1 Z IH0 NG ▁B IH0 K AO1 Z ▁M AY1 ▁F UH1 L ▁T AY1 M ▁W ER1 K ▁AE1 T ▁DH AH0 ▁F AW0 N D EY1 SH AH0 N ▁IH1 Z ▁M OW1 S T L IY0 ▁AH0 B AW1 T ▁V AE2 K S IY1 N Z ▁AH0 N D ▁S IY1 D Z ▁AH0 B AW1 T ▁DH AH0 ▁TH IH1 NG Z ▁DH AE1 T ▁W IY1 ▁N IY1 D ▁T UW1 ▁IH0 N V EH1 N T ▁AH0 N D ▁D IH0 L IH1 V ER0 ▁T UW1 ▁HH EH1 L P ▁DH AH0 ▁P UH1 R IH0 S T ▁T UW1 ▁B IH1 L Y AH0 N ▁L AY1 V ▁B EH1 T ER0 ▁L IH1 V Z spk.767_
|
30 |
+
```
|
31 |
+
- Prepare phoneme dictionary and save to $MANIFEST_ROOT as [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/src_dict.txt)
|
32 |
+
#### Prepare WMT text data
|
33 |
+
- [Download wmt data](https://github.com/pytorch/fairseq/blob/main/examples/translation/prepare-wmt14en2de.sh)
|
34 |
+
- Convert source text (English) into phoneme representation as above
|
35 |
+
- Generate binary parallel files with "fairseq-preprocess" from fairseq for training and validation. The source input is English phoneme representation and the target input is German sentencepiece token . The output is saved under $parallel_text_data
|
36 |
+
|
37 |
+
## Training
|
38 |
+
The model is trained with 8 v100 GPUs.
|
39 |
+
|
40 |
+
#### Download pretrained models
|
41 |
+
- [pretrain_encoder](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_transformer_m.pt)
|
42 |
+
- [pretrain_nmt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_mt.pt)
|
43 |
+
|
44 |
+
#### Training scripts
|
45 |
+
- Jointly trained model from scratch
|
46 |
+
```bash
|
47 |
+
python train.py ${MANIFEST_ROOT} \
|
48 |
+
--save-dir ${save_dir} \
|
49 |
+
--num-workers 8 \
|
50 |
+
--task speech_text_joint_to_text \
|
51 |
+
--arch dualinputs2ttransformer_s \
|
52 |
+
--user-dir examples/speech_text_joint_to_text \
|
53 |
+
--max-epoch 100 --update-mix-data \
|
54 |
+
--optimizer adam --lr-scheduler inverse_sqrt \
|
55 |
+
--lr 0.001 --update-freq 4 --clip-norm 10.0 \
|
56 |
+
--criterion guided_label_smoothed_cross_entropy_with_accuracy \
|
57 |
+
--label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
|
58 |
+
--max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
|
59 |
+
--text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
|
60 |
+
--dropout 0.1 --warmup-updates 20000 \
|
61 |
+
--text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
|
62 |
+
--text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
|
63 |
+
--log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
|
64 |
+
--mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
|
65 |
+
--log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
|
66 |
+
--keep-last-epochs 10
|
67 |
+
```
|
68 |
+
- Jointly trained model with good initialization, cross attentive loss and online knowledge distillation
|
69 |
+
```bash
|
70 |
+
python train.py ${MANIFEST_ROOT} \
|
71 |
+
--save-dir ${save_dir} \
|
72 |
+
--num-workers 8 \
|
73 |
+
--task speech_text_joint_to_text \
|
74 |
+
--arch dualinputs2ttransformer_m \
|
75 |
+
--user-dir examples/speech_text_joint_to_text \
|
76 |
+
--max-epoch 100 --update-mix-data \
|
77 |
+
--optimizer adam --lr-scheduler inverse_sqrt \
|
78 |
+
--lr 0.002 --update-freq 4 --clip-norm 10.0 \
|
79 |
+
--criterion guided_label_smoothed_cross_entropy_with_accuracy \
|
80 |
+
--guide-alpha 0.8 --disable-text-guide-update-num 5000 \
|
81 |
+
--label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
|
82 |
+
--max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
|
83 |
+
--text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
|
84 |
+
--dropout 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \
|
85 |
+
--text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
|
86 |
+
--text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
|
87 |
+
--log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
|
88 |
+
--mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
|
89 |
+
--log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
|
90 |
+
--load-pretrain-speech-encoder ${pretrain_encoder} \
|
91 |
+
--load-pretrain-decoder ${pretrain_nmt} \
|
92 |
+
--load-pretrain-text-encoder-last ${pretrain_nmt} \
|
93 |
+
--keep-last-epochs 10
|
94 |
+
```
|
95 |
+
|
96 |
+
## Evaluation
|
97 |
+
```bash
|
98 |
+
python ./fairseq_cli/generate.py \
|
99 |
+
${MANIFEST_ROOT} \
|
100 |
+
--task speech_text_joint_to_text \
|
101 |
+
--max-tokens 25000 \
|
102 |
+
--nbest 1 \
|
103 |
+
--results-path ${infer_results} \
|
104 |
+
--batch-size 512 \
|
105 |
+
--path ${model} \
|
106 |
+
--gen-subset tst-COMMON_st \
|
107 |
+
--config-yaml config.yaml \
|
108 |
+
--scoring sacrebleu \
|
109 |
+
--beam 5 --lenpen 1.0 \
|
110 |
+
--user-dir examples/speech_text_joint_to_text \
|
111 |
+
--load-speech-only
|
112 |
+
```
|
113 |
+
|
114 |
+
## Results (Joint training with initialization + CAR + online KD)
|
115 |
+
|Direction|En-De | En-Es | En-Fr |
|
116 |
+
|---|---|---|---|
|
117 |
+
|BLEU|27.4| 31.2 | 37.6 |
|
118 |
+
|checkpoint | [link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_ave_10.pt) |[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_es/checkpoint_ave_10.pt)|[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_fr/checkpoint_ave_10.pt)|
|
fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[Back]](..)
|
2 |
+
|
3 |
+
# Joint Speech Text Training for the 2021 IWSLT multilingual speech translation
|
4 |
+
|
5 |
+
This directory contains the code from paper ["FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task"](https://arxiv.org/pdf/2107.06959.pdf).
|
6 |
+
|
7 |
+
## Prepare Data
|
8 |
+
#### Download files
|
9 |
+
- Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/spm.model)
|
10 |
+
- Dictionary [tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/dict.txt)
|
11 |
+
- Config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/config.yaml)
|
12 |
+
|
13 |
+
#### Prepare
|
14 |
+
- Please follow the data preparation in [speech-to-text](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mtedx_example.md) with option "--use-audio-input" for raw audio tsv files.
|
15 |
+
- Prepare tsv files with phoneme based source text (under column 'src_text') as [MuST-C](ende-mustc.md) example.
|
16 |
+
|
17 |
+
|
18 |
+
## Training
|
19 |
+
|
20 |
+
#### Download pretrained models
|
21 |
+
- [Pretrained mbart model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/mbart.pt)
|
22 |
+
- [Pretrained w2v model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/xlsr_53_56k.pt)
|
23 |
+
|
24 |
+
|
25 |
+
#### Training scripts
|
26 |
+
|
27 |
+
```bash
|
28 |
+
python train.py ${MANIFEST_ROOT} \
|
29 |
+
--save-dir ${save_dir} \
|
30 |
+
--user-dir examples/speech_text_joint_to_text \
|
31 |
+
--train-subset train_es_en_tedx,train_es_es_tedx,train_fr_en_tedx,train_fr_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_en_tedx,train_pt_pt_tedx \
|
32 |
+
--valid-subset valid_es_en_tedx,valid_es_es_tedx,valid_es_fr_tedx,valid_es_it_tedx,valid_es_pt_tedx,valid_fr_en_tedx,valid_fr_es_tedx,valid_fr_fr_tedx,valid_fr_pt_tedx,valid_it_en_tedx,valid_it_es_tedx,valid_it_it_tedx,valid_pt_en_tedx,valid_pt_es_tedx,valid_pt_pt_tedx \
|
33 |
+
--config-yaml config.yaml --ddp-backend no_c10d \
|
34 |
+
--num-workers 2 --task speech_text_joint_to_text \
|
35 |
+
--criterion guided_label_smoothed_cross_entropy_with_accuracy \
|
36 |
+
--label-smoothing 0.3 --guide-alpha 0.8 \
|
37 |
+
--disable-text-guide-update-num 5000 --arch dualinputxmtransformer_base \
|
38 |
+
--max-tokens 500000 --max-sentences 3 --max-tokens-valid 800000 \
|
39 |
+
--max-source-positions 800000 --enc-grad-mult 2.0 \
|
40 |
+
--attentive-cost-regularization 0.02 --optimizer adam \
|
41 |
+
--clip-norm 1.0 --log-format simple --log-interval 200 \
|
42 |
+
--keep-last-epochs 5 --seed 1 \
|
43 |
+
--w2v-path ${w2v_path} \
|
44 |
+
--load-pretrained-mbart-from ${mbart_path} \
|
45 |
+
--max-update 1000000 --update-freq 4 \
|
46 |
+
--skip-invalid-size-inputs-valid-test \
|
47 |
+
--skip-encoder-projection --save-interval 1 \
|
48 |
+
--attention-dropout 0.3 --mbart-dropout 0.3 \
|
49 |
+
--finetune-w2v-params all --finetune-mbart-decoder-params all \
|
50 |
+
--finetune-mbart-encoder-params all --stack-w2v-mbart-encoder \
|
51 |
+
--drop-w2v-layers 12 --normalize \
|
52 |
+
--lr 5e-05 --lr-scheduler inverse_sqrt --warmup-updates 5000
|
53 |
+
```
|
54 |
+
|
55 |
+
## Evaluation
|
56 |
+
```bash
|
57 |
+
python ./fairseq_cli/generate.py
|
58 |
+
${MANIFEST_ROOT} \
|
59 |
+
--task speech_text_joint_to_text \
|
60 |
+
--user-dir ./examples/speech_text_joint_to_text \
|
61 |
+
--load-speech-only --gen-subset test_es_en_tedx \
|
62 |
+
--path ${model} \
|
63 |
+
--max-source-positions 800000 \
|
64 |
+
--skip-invalid-size-inputs-valid-test \
|
65 |
+
--config-yaml config.yaml \
|
66 |
+
--infer-target-lang en \
|
67 |
+
--max-tokens 800000 \
|
68 |
+
--beam 5 \
|
69 |
+
--results-path ${RESULTS_DIR} \
|
70 |
+
--scoring sacrebleu
|
71 |
+
```
|
72 |
+
The trained model can be downloaded [here](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/checkpoint17.pt)
|
73 |
+
|
74 |
+
|direction|es_en|fr_en|pt_en|it_en|fr_es|pt_es|it_es|es_es|fr_fr|pt_pt|it_it|
|
75 |
+
|---|---|---|---|---|---|---|---|---|---|---|---|
|
76 |
+
|BLEU|31.62|36.93|35.07|27.12|38.87|35.57|34.13|74.59|74.64|70.84|69.76|
|
fairseq/examples/speech_text_joint_to_text/docs/pre-training.md
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[Back]](..)
|
2 |
+
|
3 |
+
# Unified Speech-Text Pre-training for Speech Translation and Recognition
|
4 |
+
|
5 |
+
This directory contains the pre-training recipes from paper ["Unified Speech-Text Pre-training for Speech Translation and Recognition"](https://arxiv.org/abs/2204.05409).
|
6 |
+
|
7 |
+
## Librispeech ASR Pre-training
|
8 |
+
### Prepare Data
|
9 |
+
#### Download files
|
10 |
+
#### Prepare pre-training data
|
11 |
+
- Text to text task (T2T): prepare the binary data following the similar steps in [EN_DE Joint training](./ende-mustc.md). The source data is presented as phomeme token sequence and the target data is coded as subword tokens via SentencePiece. The text data is downloaded from [openslr](https://www.openslr.org/12)
|
12 |
+
- Self-supervised speech learning task (SSL): The data is prepared as [wav2vec 2.0](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec/README.md)
|
13 |
+
- Speech to phoneme classification task (S2P): The tsv file contains 5 fields: "id", "audio", "n_frames", "tgt_text", and "align". The tgt_text field is corresponding to the phoneme based representation of the speech data. "align" field contains the alignment information. The phoneme level forced alignment for the labelled speech data (i.e. Librispeech) can be obtained via [kaldi](http://kaldi-asr.org) or [MFA](https://montrealcorpustools.github.io/Montreal-Forced-Aligner/). The segmentation information is normalized to 0$\sim$1 for the whole utterance. The snapshot of the tsv file is below:
|
14 |
+
```
|
15 |
+
id audio n_frames tgt_text align
|
16 |
+
116-288045-0000 /librispeech/dev-other/116/288045/116-288045-0000.flac 170400 <sil> ▁AE1 Z AY1 ▁AH0 P R OW1 CH T ▁DH AH1 ▁S IH1 T IY0 <sil> AY1 ▁HH ER1 D ▁B EH1 L Z ▁R IH1 NG IH0 NG <sil> ▁AE1 N D AH0 ▁L IH1 T AH0 L ▁L EY1 T ER0 AY1 ▁F AW1 N D ▁DH AH0 ▁S T R IY1 T S ▁AH0 S T IH1 R ▁W IH0 TH ▁TH R AO1 NG Z ▁AH0 V ▁W EH1 L ▁D R EH1 S T ▁P IY1 P AH0 L ▁IH1 N ▁F AE1 M L IY0 ▁G R UW1 P S <sil> ▁W EH1 N D IH0 NG ▁DH EH1 R ▁W EY1 <sil> ▁HH IH1 DH ER0 ▁AH0 N D ▁TH IH1 DH ER0 <sil> 0.047977 0.056444 0.064911 0.075259 0.081844 0.089370 0.095014 0.104421 0.109125 0.111947 0.115710 0.120414 0.134525 0.141110 0.143932 0.174036 0.176858 0.190028 0.199436 0.207902 0.218250 0.224835 0.231421 0.242709 0.251176 0.257761 0.263405 0.268109 0.270931 0.290687 0.342427 0.349953 0.353716 0.356538 0.360301 0.363123 0.365945 0.368768 0.371590 0.376294 0.384760 0.394167 0.401693 0.409219 0.419567 0.430856 0.441204 0.444026 0.446849 0.449671 0.456256 0.463782 0.471308 0.477893 0.486359 0.491063 0.494826 0.501411 0.512700 0.517404 0.520226 0.534337 0.540922 0.545626 0.550329 0.559737 0.568203 0.583255 0.592662 0.600188 0.603951 0.611477 0.619003 0.624647 0.634055 0.639699 0.646284 0.653810 0.659454 0.664158 0.670743 0.682032 0.687676 0.692380 0.708373 0.713076 0.719661 0.729069 0.740357 0.744120 0.748824 0.752587 0.761994 0.770461 0.781750 0.790216 0.805268 0.808090 0.823142 0.832549 0.836312 0.840075 0.843838 0.851364 0.854186 0.857008 0.862653 0.878645 0.898401 0.901223 0.906867 0.913452 0.920038 0.926623 0.934149 0.939793 0.942615 0.945437 0.952023 0.957667 0.977422 1.000000
|
17 |
+
|
18 |
+
```
|
19 |
+
- Speech to text task (S2T): The data preparation follow the steps in [EN_DE Joint training](./ende-mustc.md).
|
20 |
+
|
21 |
+
#### Prepare fine-tuning data:
|
22 |
+
We re-use the data from T2T and S2T tasks in the fine-tuning stage.
|
23 |
+
|
24 |
+
### Model Build
|
25 |
+
#### Pre-training
|
26 |
+
```
|
27 |
+
python train.py $T2T_DATA \
|
28 |
+
--save-dir $SAVE_PRE_PATH --user-dir examples/speech_text_joint_to_text --task speech_text_joint_denoising \
|
29 |
+
--criterion speech_text_pretrain_cross_entropy --optimizer adam --weight-decay 0.01 --config-yaml config_s2p.yaml --config-s2s-yaml config.yaml --ddp-backend no_c10d \
|
30 |
+
--lang-pairs pho-wrd --num-workers 4 --log-interval 500 --save-interval-updates 5000 --keep-interval-updates 1 --no-emb-update-unsup --report-accuracy --lr 0.001 --end-learning-rate 1e-06 \
|
31 |
+
--lr-scheduler polynomial_decay --warmup-updates 10000 --total-num-update 800000 --update-freq 6 --validate-interval-updates 10000 --train-subset train \
|
32 |
+
--valid-subset valid,valid_sup_speech,valid_sup_speech_s2s,valid_unsup_speech --dataset-impl mmap \
|
33 |
+
--sup-speech-data $S2P_DATA_PATH --sup-speech-train-subset train_960.ali --sup-speech-valid-subset dev-clean.ali --sup-speech-s2s-data $S2T_DATA_PATH \
|
34 |
+
--sup-speech-s2s-train-subset train --sup-speech-s2s-valid-subset dev-clean --unsup-speech-train-data $SSL_DATA_PATH/train.tsv --unsup-speech-valid-data $SSL_DATA_PATH/valid.tsv \
|
35 |
+
--batch-size 200 --batch-size-valid 150 --max-source-positions 1024 --max-target-positions 1024 --max-text-tokens 3072 --max-speech-positions 600000 \
|
36 |
+
--max-sample-size 750000 --min-sample-size 64000 --max-speech-tokens 750000 --max-tokens-valid 750000 --skip-invalid-size-inputs-valid-test \
|
37 |
+
--unsupervised-speech-sample-ratio 3.0 --supervised-speech-sample-ratio 5 --supervised-speech-s2s-sample-ratio 5 --text-sample-ratio 1.0 --mask 0.3 --mask-random 0.1 \
|
38 |
+
--mask-length span-poisson --speech-sup-mask-prob 0.3 --speech-unsup-mask-prob 0.7 --use-mask-whole-words --arch speech_text_pretrain_bart_base_stack \
|
39 |
+
--no-scale-feature --activation-fn gelu --speech-extractor-mode default --stacked-encoder all --encoder-normalize-before --decoder-normalize-before \
|
40 |
+
--encoder-learned-pos --decoder-learned-pos --dropout 0.1 --load-pretrained-mbart-encoder-from $BART --load-pretrained-mbart-decoder-from $BART
|
41 |
+
```
|
42 |
+
The current implementation also supports model pre-training without the forced alignment supervised data. In this case, CTC is used to optimize the S2P task. We need to do following changes for the setting:
|
43 |
+
1. options to be added
|
44 |
+
```
|
45 |
+
--use-sup-speech-ctc --criterion speech_text_pretrain_compound
|
46 |
+
```
|
47 |
+
2. options to be deleted
|
48 |
+
```
|
49 |
+
--same-data-update --criterion speech_text_pretrain_cross_entropy
|
50 |
+
```
|
51 |
+
However, we find the CTC based pre-training is still worse than the forced alignment based setting. It could be partially due to the inferior pre-training setting that we re-use the forced alignment based pre-training setting for the CTC based pre-training.
|
52 |
+
|
53 |
+
#### Fine-tuning
|
54 |
+
```
|
55 |
+
python train.py $S2T_DATA_PATH \
|
56 |
+
--save-dir $SAVE_FT_PATH --num-workers 8 --task speech_text_joint_to_text --arch dualinputs2twavtransformer_base_stack \
|
57 |
+
--user-dir examples/speech_text_joint_to_text --max-update 100000 --optimizer adam --lr-scheduler inverse_sqrt --lr 0.0003 --update-freq 3 --clip-norm 10.0 \
|
58 |
+
--criterion guided_label_smoothed_cross_entropy_with_accuracy --guide-alpha 0.8 --label-smoothing 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \
|
59 |
+
--enc-grad-mult 2.0 --max-tokens 800000 --max-source-positions 800000 --max-tokens-text 10000 --max-positions-text 1024 --max-target-positions 1024 --no-scale-feature \
|
60 |
+
--activation-fn gelu --load-pretrained-speech-text-encoder $SAVE_PRE_PATH/checkpoint_last.pt --load-pretrained-speech-text-decoder $SAVE_PRE_PATH/checkpoint_last.pt \
|
61 |
+
--encoder-normalize-before --decoder-normalize-before --speech-extractor-mode default --speech-mask-channel-length 64 --speech-mask-channel-prob 0.5 \
|
62 |
+
--speech-mask-length 10 --speech-mask-prob 0.65 --text-sample-ratio 0.25 --mask-text-ratio 0.3 --mask-text-type random --parallel-text-data text_bin \
|
63 |
+
--text-input-cost-ratio 0.5 --langpairs pho-wrd --update-mix-data --log-format json --max-tokens-valid 800000 --ddp-backend no_c10d --log-interval 500 \
|
64 |
+
--config-yaml config.yaml --skip-invalid-size-inputs-valid-test --keep-last-epochs 50 --layernorm-embedding --encoder-learned-pos --decoder-learned-pos
|
65 |
+
```
|
66 |
+
|
67 |
+
### Evaluation
|
68 |
+
The last 10 epoch models from fine-tuning is conducted model average to get $FINAL_MODEL
|
69 |
+
```
|
70 |
+
python ./fairseq_cli/generate.py \
|
71 |
+
$S2T_DATA_PATH \
|
72 |
+
--task speech_text_joint_to_text \
|
73 |
+
--max-tokens 800000 \
|
74 |
+
--max-source-positions 800000 \
|
75 |
+
--nbest 1 \
|
76 |
+
--results-path $RESULTS_LOG \
|
77 |
+
--batch-size 512 \
|
78 |
+
--path $FINAL_MODEL \
|
79 |
+
--gen-subset $SUBSET \
|
80 |
+
--config-yaml config.yaml \
|
81 |
+
--scoring wer \
|
82 |
+
--beam 10 --lenpen 1.0 examples/speech_text_joint_to_text \
|
83 |
+
--user-dir examples/speech_text_joint_to_text --load-speech-only \
|
84 |
+
--model-overrides {'load_pretrained_speech_text_decoder':'','load_pretrained_speech_text_encoder':''}
|
85 |
+
```
|
86 |
+
|
87 |
+
### Results and models
|
88 |
+
| | dev-clean | dev-other | test-clean | test-other |
|
89 |
+
|---|---|---|---|---|
|
90 |
+
| WER| 2.0 | 4.4 | 2.1 |4.6 |
|
91 |
+
|
92 |
+
**Model Links**:
|
93 |
+
- [config_s2p.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/config_s2p.yaml): Config for S2P
|
94 |
+
- [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/spm.model): Sentence Piece model
|
95 |
+
- [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/src_dict.txt): Source Phoneme Dictionary
|
96 |
+
- [tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/tgt_dict.txt): Target Sentence Piece Dictionary
|
97 |
+
- [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/config.yaml): Config for S2T
|
98 |
+
- [BART](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/bart.pt): trained from Librispeech text data
|
99 |
+
- [Joint Pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/checkpoint6.pt): model pre-trained with 960 hours Librispeech data (S2P, S2T) Librispeech text training data (T2T) and Librilight data (SSL)
|
100 |
+
- [Fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/checkpoint_ave_10.pt): the pre-trained model is fined one 960 hours Librispeech speech and text data. (S2T + T2T)
|
101 |
+
|
102 |
+
## MuST-C
|
103 |
+
### Prepare Data
|
104 |
+
Compared with the ASR Librispeech ASR recipe, the differences are below:
|
105 |
+
- Replace the speech data with corresponding MuST-C data
|
106 |
+
- Parallel text data from WMT is replaced the Librispeech text data
|
107 |
+
|
108 |
+
### Model Build
|
109 |
+
#### Pre-training
|
110 |
+
EN-DE is used as an example
|
111 |
+
```
|
112 |
+
python train.py $TXT_DATA \
|
113 |
+
--save-dir $SAVE_PRE_PATH --user-dir examples/speech_text_joint_to_text --task speech_text_joint_denoising --criterion speech_text_pretrain_cross_entropy --optimizer adam --weight-decay 0.01 \
|
114 |
+
--config-yaml config_s2p.yaml --config-s2s-yaml config.yaml --ddp-backend no_c10d --lang-pairs-bitext en-fr --num-workers 4 --log-interval 500 --save-interval-updates 5000 --keep-interval-updates 1 \
|
115 |
+
--no-emb-update-unsup --use-decoder-output-proj --report-accuracy --lr 0.001 --end-learning-rate 1e-06 --lr-scheduler polynomial_decay --warmup-updates 10000 --total-num-update 800000 \
|
116 |
+
--update-freq 8 --validate-interval-updates 10000 --train-subset train --valid-subset valid_sup_speech,valid_sup_speech_s2s,valid_unsup_speech --dataset-impl mmap \
|
117 |
+
--sup-speech-data $S2P_DATA_PATH --sup-speech-train-subset train --sup-speech-valid-subset dev --sup-speech-s2s-data $S2T_DATA_PATH --sup-speech-s2s-train-subset train \
|
118 |
+
--sup-speech-s2s-valid-subset dev --unsup-speech-train-data $SSL_DATA_PATH/train.tsv --unsup-speech-valid-data $SSL_DATA_PATH/valid.tsv --batch-size 200 --batch-size-valid 100 \
|
119 |
+
--max-source-positions 1024 --max-target-positions 1024 --max-text-tokens 2048 --max-speech-positions 600000 --max-sample-size 600000 --min-sample-size 64000 \
|
120 |
+
--max-speech-tokens 600000 --max-tokens-valid 600000 --skip-invalid-size-inputs-valid-test --unsupervised-speech-sample-ratio 1.2 --supervised-speech-sample-ratio 10 \
|
121 |
+
--supervised-speech-s2s-sample-ratio 10 --bitext-sample-ratio 0.5 --mask 0.3 --mask-random 0.1 --mask-length span-poisson --speech-sup-mask-prob 0.3 \
|
122 |
+
--speech-unsup-mask-prob 0.7 --use-mask-whole-words --arch speech_text_pretrain_bart_base_stack --no-scale-feature --activation-fn gelu --speech-extractor-mode default \
|
123 |
+
--stacked-encoder s2s --encoder-normalize-before --decoder-normalize-before --encoder-learned-pos --decoder-learned-pos --dropout 0.1 \
|
124 |
+
--load-pretrained-mbart-encoder-from $EN_FR_NMT --load-pretrained-mbart-decoder-from $EN_FR_NMT
|
125 |
+
```
|
126 |
+
#### Fine-tuning
|
127 |
+
```
|
128 |
+
python train.py $S2T_DATA_PATH \
|
129 |
+
--save-dir $SAVE_FT_PATH --num-workers 8 --task speech_text_joint_to_text --arch dualinputs2twavtransformer_base_stack --user-dir examples/speech_text_joint_to_text \
|
130 |
+
--max-epoch 25 --update-mix-data --optimizer adam --lr-scheduler inverse_sqrt --lr 0.0003 --update-freq 4 --clip-norm 10.0 --warmup-updates 20000 \
|
131 |
+
--criterion guided_label_smoothed_cross_entropy_with_accuracy --guide-alpha 0.8 --attentive-cost-regularization 0.02 --enc-grad-mult 2.0 --label-smoothing 0.1 \
|
132 |
+
--max-tokens 800000 --max-source-positions 800000 --max-tokens-text 10000 --max-positions-text 1024 --load-pretrained-speech-text-encoder $SAVE_PRE_PATH/checkpoint_last.pt \
|
133 |
+
--load-pretrained-speech-text-decoder $SAVE_PRE_PATH/checkpoint_last.pt --speech-mask-channel-length 64 --speech-mask-channel-prob 0.5 --speech-mask-length 10 \
|
134 |
+
--speech-mask-prob 0.65 --text-sample-ratio 0.05 --mask-text-ratio 0.3 --mask-text-type random --parallel-text-data data-bin-wt --text-input-cost-ratio 0.5 \
|
135 |
+
--langpairs en-fr --log-format json --max-tokens-valid 800000 --ddp-backend no_c10d --log-interval 100 --config-yaml config.yaml --skip-invalid-size-inputs-valid-test \
|
136 |
+
--noise-token '▁NOISE' --keep-last-epochs 40 --layernorm-embedding --encoder-learned-pos --decoder-learned-pos --activation-fn gelu \
|
137 |
+
--speech-extractor-mode default --max-target-positions 1024 --encoder-normalize-before --decoder-normalize-before
|
138 |
+
```
|
139 |
+
|
140 |
+
### Evaluation
|
141 |
+
The last 10 epoch models from fine-tuning is conducted model average to get $FINAL_MODEL
|
142 |
+
```
|
143 |
+
python fairseq_cli/generate.py \
|
144 |
+
$S2T_DATA_PATH \
|
145 |
+
--task speech_text_joint_to_text \
|
146 |
+
--nbest 1 \
|
147 |
+
--max-tokens 800000 \
|
148 |
+
--max-source-positions 800000 \
|
149 |
+
--results-path $RESULTS_LOG \
|
150 |
+
--batch-size 512 \
|
151 |
+
--path $FINAL_MODEL \
|
152 |
+
--gen-subset $SUBSET \
|
153 |
+
--config-yaml config.yaml \
|
154 |
+
--scoring sacrebleu \
|
155 |
+
--beam 10 --lenpen 1.0 examples/speech_text_joint_to_text \
|
156 |
+
--user-dir examples/speech_text_joint_to_text --load-speech-only \
|
157 |
+
--model-overrides {'load_pretrained_speech_text_decoder':'','load_pretrained_speech_text_encoder':''}
|
158 |
+
```
|
159 |
+
|
160 |
+
|
161 |
+
### Results and models
|
162 |
+
| | en-fr | en-es | en-de |
|
163 |
+
|---|---|---|---|
|
164 |
+
| BLEU| 39.7 | 33.2 |29.2 |
|
165 |
+
|
166 |
+
|
167 |
+
**Model Links**:
|
168 |
+
1. DE
|
169 |
+
- [de config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/config.yaml)
|
170 |
+
- [de src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/src_dict.txt)
|
171 |
+
- [de tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/tgt_dict.txt)
|
172 |
+
- [de spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/spm.model)
|
173 |
+
- [de pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/nmt.pt)
|
174 |
+
- [de pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/checkpoint_pretraing.pt)
|
175 |
+
- [de fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/checkpoint_finetune_ave10.pt)
|
176 |
+
2. ES
|
177 |
+
- [es config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/config.yaml)
|
178 |
+
- [es src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/src_dict.txt)
|
179 |
+
- [es tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/tgt_dict.txt)
|
180 |
+
- [es spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/spm.model)
|
181 |
+
- [es pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/nmt.pt)
|
182 |
+
- [es pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/checkpoint_pretraing.pt)
|
183 |
+
- [es fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/checkpoint_finetune_ave10.pt)
|
184 |
+
3. FR
|
185 |
+
- [fr config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/config.yaml)
|
186 |
+
- [fr src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/src_dict.txt)
|
187 |
+
- [fr tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/tgt_dict.txt)
|
188 |
+
- [fr spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/spm.model)
|
189 |
+
- [fr pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/nmt.pt)
|
190 |
+
- [fr pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/checkpoint_pretraing.pt)
|
191 |
+
- [fr fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/checkpoint_finetune_ave10.pt)
|
192 |
+
4. [config_s2p.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/config_s2p.yaml)
|
fairseq/examples/speech_text_joint_to_text/models/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import importlib
|
7 |
+
import os
|
8 |
+
|
fairseq/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from collections import OrderedDict, namedtuple
|
5 |
+
from typing import Dict, Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
from fairseq import checkpoint_utils, utils
|
13 |
+
from fairseq.file_io import PathManager
|
14 |
+
from fairseq.models import (
|
15 |
+
FairseqDecoder,
|
16 |
+
FairseqEncoderDecoderModel,
|
17 |
+
register_model,
|
18 |
+
register_model_architecture,
|
19 |
+
)
|
20 |
+
from fairseq.models.speech_to_text import (
|
21 |
+
MultiInputDecoder,
|
22 |
+
MultiModalityEncoder,
|
23 |
+
SpeechWavTransformerEncoder,
|
24 |
+
StackedSpeechWavTransformerEncoder,
|
25 |
+
)
|
26 |
+
from fairseq.models.transformer import (
|
27 |
+
TransformerDecoder,
|
28 |
+
TransformerEncoder,
|
29 |
+
TransformerModel,
|
30 |
+
)
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
class SpeechTextPreTrainEncoder(MultiModalityEncoder):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
dictionary,
|
39 |
+
sup_speech_encoder,
|
40 |
+
sup_s2s_speech_encoder,
|
41 |
+
unsup_speech_encoder,
|
42 |
+
text_encoder,
|
43 |
+
):
|
44 |
+
super().__init__(dictionary)
|
45 |
+
self.sup_speech_encoder = sup_speech_encoder
|
46 |
+
self.sup_s2s_speech_encoder = sup_s2s_speech_encoder
|
47 |
+
self.unsup_speech_encoder = unsup_speech_encoder
|
48 |
+
self.text_encoder = text_encoder
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def update_transformer_encoder_cfg(cls, args, update_dict):
|
52 |
+
cfg = dict(args._get_kwargs())
|
53 |
+
for fkey in update_dict.keys():
|
54 |
+
cfg[fkey] = update_dict[fkey]
|
55 |
+
cfg.pop("_name", None) # remove keys start with _
|
56 |
+
model_args = namedtuple("args", cfg.keys())(*cfg.values())
|
57 |
+
return model_args
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def build_text_encoder(cls, args, src_dictionary):
|
61 |
+
enc_emb = nn.Embedding(
|
62 |
+
len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad()
|
63 |
+
)
|
64 |
+
model_args = cls.update_transformer_encoder_cfg(
|
65 |
+
args, {"encoder_layers": args.text_encoder_layers}
|
66 |
+
)
|
67 |
+
text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
|
68 |
+
return text_encoder
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def build_speech_encoder(cls, args):
|
72 |
+
model_args = cls.update_transformer_encoder_cfg(
|
73 |
+
args,
|
74 |
+
{
|
75 |
+
"encoder_layers": args.speech_encoder_layers,
|
76 |
+
"speech_mask_prob": args.speech_sup_mask_prob,
|
77 |
+
},
|
78 |
+
)
|
79 |
+
speech_encoder = SpeechWavTransformerEncoder(model_args)
|
80 |
+
return speech_encoder
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def share_layers(cls, src_layers, tgt_layers): # share layer but not dropout
|
84 |
+
# share parameters in src_layers with tgt_layers
|
85 |
+
assert len(src_layers) == len(tgt_layers)
|
86 |
+
for i, ly in enumerate(src_layers):
|
87 |
+
tly = tgt_layers[i]
|
88 |
+
tly.self_attn = ly.self_attn
|
89 |
+
tly.self_attn_layer_norm = ly.self_attn_layer_norm
|
90 |
+
tly.activation_fn = ly.activation_fn
|
91 |
+
tly.normalize_before = ly.normalize_before
|
92 |
+
tly.fc1 = ly.fc1
|
93 |
+
tly.fc2 = ly.fc2
|
94 |
+
tly.final_layer_norm = ly.final_layer_norm
|
95 |
+
if hasattr(tly, "encoder_attn"):
|
96 |
+
tly.encoder_attn = ly.encoder_attn
|
97 |
+
tly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm
|
98 |
+
return tgt_layers
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def build_unsup_speech_encoder(cls, args, sup_speech_encoder):
|
102 |
+
model_args = cls.update_transformer_encoder_cfg(
|
103 |
+
args,
|
104 |
+
{
|
105 |
+
"encoder_layers": args.speech_encoder_layers,
|
106 |
+
"speech_mask_prob": args.speech_unsup_mask_prob,
|
107 |
+
"encoder_layerdrop": 0.0,
|
108 |
+
"decoder_layerdrop": 0.0,
|
109 |
+
"dropout": args.speech_unsup_dropout,
|
110 |
+
"activation_dropout": args.speech_unsup_dropout,
|
111 |
+
"attention_dropout": 0.0,
|
112 |
+
"dropout_features": args.speech_unsup_feature_dropout,
|
113 |
+
"dropout_input": args.speech_unsup_feature_dropout,
|
114 |
+
},
|
115 |
+
)
|
116 |
+
|
117 |
+
unsup_speech_encoder = SpeechWavTransformerEncoder(model_args, alway_mask=True)
|
118 |
+
unsup_speech_encoder.layer_norm = sup_speech_encoder.layer_norm
|
119 |
+
unsup_speech_encoder.layers = cls.share_layers(
|
120 |
+
sup_speech_encoder.layers, unsup_speech_encoder.layers
|
121 |
+
)
|
122 |
+
unsup_speech_encoder.mask_emb = sup_speech_encoder.mask_emb
|
123 |
+
unsup_speech_encoder.embed_positions = sup_speech_encoder.embed_positions
|
124 |
+
unsup_speech_encoder.feat_layer_norm = sup_speech_encoder.feat_layer_norm
|
125 |
+
unsup_speech_encoder.feat_proj = sup_speech_encoder.feat_proj
|
126 |
+
unsup_speech_encoder.subsample = sup_speech_encoder.subsample
|
127 |
+
return unsup_speech_encoder
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def build_encoder(cls, args, dictionary):
|
131 |
+
text_encoder = cls.build_text_encoder(args, dictionary)
|
132 |
+
if getattr(args, "load_pretrained_mbart_encoder_from", None):
|
133 |
+
text_encoder = checkpoint_utils.load_pretrained_component_from_model(
|
134 |
+
component=text_encoder,
|
135 |
+
checkpoint=args.load_pretrained_mbart_encoder_from,
|
136 |
+
)
|
137 |
+
speech_encoder = cls.build_speech_encoder(args)
|
138 |
+
if getattr(args, "load_pretrained_feature_extractor_from", None):
|
139 |
+
|
140 |
+
def load_feature_extractor(component, checkpoint):
|
141 |
+
if not PathManager.exists(checkpoint):
|
142 |
+
raise IOError("Model file not found: {}".format(checkpoint))
|
143 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
|
144 |
+
component_state_dict = OrderedDict()
|
145 |
+
|
146 |
+
component_prefix = "feature_extractor"
|
147 |
+
for key in state["model"].keys():
|
148 |
+
if key.startswith(component_prefix):
|
149 |
+
component_subkey = key[len(component_prefix) + 1 :]
|
150 |
+
component_state_dict[component_subkey] = state["model"][key]
|
151 |
+
component.load_state_dict(component_state_dict, strict=True)
|
152 |
+
return component
|
153 |
+
|
154 |
+
speech_encoder.subsample = load_feature_extractor(
|
155 |
+
speech_encoder.subsample, args.load_pretrained_feature_extractor_from
|
156 |
+
)
|
157 |
+
speech_s2s_encoder = speech_encoder
|
158 |
+
unsup_speech_encoder = cls.build_unsup_speech_encoder(args, speech_encoder)
|
159 |
+
if getattr(args, "stacked_encoder", "none") != "none":
|
160 |
+
if args.encoder_shared_text_layers_from_begin > 0:
|
161 |
+
raise ValueError(
|
162 |
+
"We can not stack encoders and share encoders at the same time!"
|
163 |
+
)
|
164 |
+
speech_s2s_encoder = StackedSpeechWavTransformerEncoder(
|
165 |
+
speech_encoder, text_encoder.layers, text_encoder.layer_norm
|
166 |
+
)
|
167 |
+
if args.stacked_encoder == "all":
|
168 |
+
speech_encoder = speech_s2s_encoder
|
169 |
+
unsup_speech_encoder = StackedSpeechWavTransformerEncoder(
|
170 |
+
unsup_speech_encoder, text_encoder.layers, text_encoder.layer_norm
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
cls.share_speech_text_encoder(
|
174 |
+
speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin
|
175 |
+
)
|
176 |
+
return SpeechTextPreTrainEncoder(
|
177 |
+
dictionary,
|
178 |
+
speech_encoder,
|
179 |
+
speech_s2s_encoder,
|
180 |
+
unsup_speech_encoder,
|
181 |
+
text_encoder,
|
182 |
+
)
|
183 |
+
|
184 |
+
@classmethod
|
185 |
+
def share_speech_text_encoder(
|
186 |
+
cls, speech_encoder, text_encoder, shared_layers_from_begin
|
187 |
+
):
|
188 |
+
if shared_layers_from_begin > 0:
|
189 |
+
num_text_encoder_layers = len(text_encoder.layers)
|
190 |
+
assert len(speech_encoder.layers) >= shared_layers_from_begin
|
191 |
+
assert num_text_encoder_layers >= shared_layers_from_begin
|
192 |
+
assert len(speech_encoder.layers) >= num_text_encoder_layers
|
193 |
+
for i, ly in enumerate(
|
194 |
+
speech_encoder.layers[
|
195 |
+
-num_text_encoder_layers : -num_text_encoder_layers
|
196 |
+
+ shared_layers_from_begin
|
197 |
+
]
|
198 |
+
):
|
199 |
+
assert isinstance(text_encoder.layers[i], type(ly))
|
200 |
+
text_encoder.layers[i] = ly
|
201 |
+
|
202 |
+
def select_encoder(self, mode, **kwargs):
|
203 |
+
if mode in ("speech", "sup_speech_ctc", "sup_speech_ali", "sup_speech_s2s"):
|
204 |
+
kwargs["features_only"] = True
|
205 |
+
if mode == "sup_speech_s2s":
|
206 |
+
return self.sup_s2s_speech_encoder, kwargs
|
207 |
+
return self.sup_speech_encoder, kwargs
|
208 |
+
elif mode == "unsup_speech":
|
209 |
+
kwargs["features_only"] = False
|
210 |
+
return self.unsup_speech_encoder, kwargs
|
211 |
+
elif mode in ("text", "bitext"):
|
212 |
+
return self.text_encoder, kwargs
|
213 |
+
else:
|
214 |
+
raise NotImplementedError(f"{mode} is not supported")
|
215 |
+
return None, kwargs
|
216 |
+
|
217 |
+
def forward(self, src_tokens, src_lengths=None, mode="", alignment=None, **kwargs):
|
218 |
+
return super().forward(src_tokens, src_lengths, mode, **kwargs)
|
219 |
+
|
220 |
+
|
221 |
+
# SpeechDummyDecoder works as an extension of encoder, so we could fit encoder only training into seq2seq training
|
222 |
+
class SpeechDummyDecoder(FairseqDecoder):
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
dictionary,
|
226 |
+
output_embedding,
|
227 |
+
no_emb_update_unsup=False,
|
228 |
+
use_output_proj=False,
|
229 |
+
):
|
230 |
+
super().__init__(dictionary)
|
231 |
+
self.output_embedding = output_embedding
|
232 |
+
num_embedding, num_dim = self.output_embedding.weight.size()
|
233 |
+
self.out_proj = (
|
234 |
+
None if use_output_proj is False else nn.Linear(num_dim, num_dim)
|
235 |
+
)
|
236 |
+
self.no_emb_update_unsup = no_emb_update_unsup
|
237 |
+
|
238 |
+
def extend_alignment(self, alignment, src_lengths, prev_output_tokens):
|
239 |
+
# alignment: B X N
|
240 |
+
# src_lengths: B X T
|
241 |
+
# prev_output_tokens: B X (N + 1)
|
242 |
+
tgt_tokens = prev_output_tokens[
|
243 |
+
:, 1:
|
244 |
+
] # remove the leading start of sentence token
|
245 |
+
ext_alignment = (
|
246 |
+
torch.ones(len(src_lengths), src_lengths.max(), device=src_lengths.device)
|
247 |
+
.long()
|
248 |
+
.fill_(self.dictionary.pad())
|
249 |
+
)
|
250 |
+
for bs in range(src_lengths.size(0)):
|
251 |
+
tgt_length = tgt_tokens[bs].ne(self.dictionary.pad()).sum().item()
|
252 |
+
assert tgt_length == sum(alignment[bs].ne(1)) + 1
|
253 |
+
src_st = 0
|
254 |
+
for i in range(tgt_length):
|
255 |
+
tok = tgt_tokens[bs][i]
|
256 |
+
src_ed = (alignment[bs][i] * src_lengths[bs]).int().item()
|
257 |
+
ext_alignment[bs][src_st:src_ed].fill_(tok)
|
258 |
+
src_st = src_ed
|
259 |
+
return ext_alignment
|
260 |
+
|
261 |
+
def forward(
|
262 |
+
self,
|
263 |
+
prev_output_tokens,
|
264 |
+
encoder_out,
|
265 |
+
incremental_state=None,
|
266 |
+
mode="speech",
|
267 |
+
alignment=None,
|
268 |
+
**kwargs,
|
269 |
+
):
|
270 |
+
"""
|
271 |
+
Args:
|
272 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
273 |
+
`(batch, tgt_len)`, for teacher forcing
|
274 |
+
encoder_out (optional): output from the encoder, used for
|
275 |
+
encoder-side attention
|
276 |
+
incremental_state (dict): dictionary used for storing state during
|
277 |
+
:ref:`Incremental decoding`
|
278 |
+
features_only (bool, optional): only return features without
|
279 |
+
applying output layer (default: False).
|
280 |
+
full_context_alignment (bool, optional): don't apply
|
281 |
+
auto-regressive mask to self-attention (default: False).
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
sup_speech_ctc:
|
285 |
+
dictionary{"logits": logits, "padding_mask": padding_mask}
|
286 |
+
sup_speech_ali and unsup_speech:
|
287 |
+
tuple:
|
288 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
289 |
+
- a dictionary with any model-specific outputs
|
290 |
+
"""
|
291 |
+
emb_weight = self.output_embedding.weight
|
292 |
+
if (
|
293 |
+
mode == "unsup_speech" and self.no_emb_update_unsup
|
294 |
+
): # no gradient for embedding here
|
295 |
+
emb_weight = emb_weight.detach()
|
296 |
+
enc_out = (
|
297 |
+
encoder_out["encoder_out"][0]
|
298 |
+
if self.out_proj is None
|
299 |
+
else self.out_proj(encoder_out["encoder_out"][0])
|
300 |
+
)
|
301 |
+
logits = F.linear(enc_out, emb_weight, None).transpose(0, 1) # B X T X C
|
302 |
+
others = None
|
303 |
+
if mode in (
|
304 |
+
"speech",
|
305 |
+
"sup_speech_ctc",
|
306 |
+
): # speech data with label, do forcealignment
|
307 |
+
if len(encoder_out["encoder_padding_mask"]) > 0:
|
308 |
+
padding_mask = encoder_out["encoder_padding_mask"][0]
|
309 |
+
logits = logits.masked_fill(padding_mask, float("-inf"))
|
310 |
+
else:
|
311 |
+
seq_len, bsz = encoder_out["encoder_out"][0].size()[:2]
|
312 |
+
padding_mask = torch.zeros(
|
313 |
+
bsz, seq_len, device=encoder_out["encoder_out"][0].device
|
314 |
+
).bool()
|
315 |
+
return {"x": logits, "padding_mask": padding_mask}
|
316 |
+
elif mode == "sup_speech_ali":
|
317 |
+
src_lengths = None
|
318 |
+
if len(encoder_out["encoder_padding_mask"]) > 0:
|
319 |
+
src_lengths = (1 - encoder_out["encoder_padding_mask"][0].long()).sum(
|
320 |
+
-1
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
seq_len, bsz = encoder_out["encoder_out"][0].size()[:2]
|
324 |
+
src_lengths = (
|
325 |
+
torch.ones(bsz, device=encoder_out["encoder_out"][0].device).long()
|
326 |
+
* seq_len
|
327 |
+
)
|
328 |
+
assert alignment is not None
|
329 |
+
alignment = self.extend_alignment(
|
330 |
+
alignment, src_lengths, prev_output_tokens
|
331 |
+
)
|
332 |
+
others = {"pseudo_target_tokens": alignment}
|
333 |
+
elif mode == "unsup_speech":
|
334 |
+
enc_out_ori = (
|
335 |
+
encoder_out["encoder_unmasked_out"][0]
|
336 |
+
if self.out_proj is None
|
337 |
+
else self.out_proj(encoder_out["encoder_unmasked_out"][0])
|
338 |
+
)
|
339 |
+
logits_ori = F.linear(enc_out_ori, emb_weight, None).transpose(0, 1)
|
340 |
+
if len(encoder_out["encoder_padding_mask"]) > 0:
|
341 |
+
encoder_padding_mask = encoder_out["encoder_padding_mask"][0]
|
342 |
+
logits_ori = logits_ori.masked_fill(encoder_padding_mask, float("-inf"))
|
343 |
+
pseudo_labels = utils.log_softmax(logits_ori, dim=-1)
|
344 |
+
others = {
|
345 |
+
"pseudo_target_logprobs": pseudo_labels,
|
346 |
+
"padding_mask": encoder_out["encoder_padding_mask"], # B X T
|
347 |
+
"mask_indices": encoder_out[
|
348 |
+
"mask_indices"
|
349 |
+
], # True for masked frames B X T
|
350 |
+
}
|
351 |
+
return logits, others
|
352 |
+
|
353 |
+
def get_normalized_probs(
|
354 |
+
self,
|
355 |
+
net_output: Dict[str, Tensor],
|
356 |
+
log_probs: bool,
|
357 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
358 |
+
):
|
359 |
+
return self.get_normalized_probs_scriptable(
|
360 |
+
(net_output["x"], None), log_probs, sample
|
361 |
+
)
|
362 |
+
|
363 |
+
|
364 |
+
class SpeechTextPreTrainDecoder(MultiInputDecoder):
|
365 |
+
def __init__(self, dictionary, speech_decoder, text_decoder):
|
366 |
+
super().__init__(dictionary)
|
367 |
+
self.speech_decoder = speech_decoder
|
368 |
+
self.text_decoder = text_decoder
|
369 |
+
|
370 |
+
def select_decoder(self, mode, **kwargs):
|
371 |
+
if mode == "unsup_speech":
|
372 |
+
kwargs["mode"] = mode
|
373 |
+
return self.speech_decoder, kwargs
|
374 |
+
if mode in ("text", "bitext"):
|
375 |
+
return self.text_decoder, kwargs
|
376 |
+
if mode in ("speech", "sup_speech_ctc", "sup_speech_ali"):
|
377 |
+
kwargs["mode"] = mode
|
378 |
+
return self.speech_decoder, kwargs
|
379 |
+
if mode in ("speech", "sup_speech_s2s"):
|
380 |
+
if "alignment" in kwargs:
|
381 |
+
del kwargs["alignment"]
|
382 |
+
return self.text_decoder, kwargs
|
383 |
+
|
384 |
+
raise NotImplementedError(f"{mode} is not supported")
|
385 |
+
return None, kwargs
|
386 |
+
|
387 |
+
def get_normalized_probs(
|
388 |
+
self,
|
389 |
+
net_output,
|
390 |
+
log_probs,
|
391 |
+
sample=None,
|
392 |
+
):
|
393 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
394 |
+
if isinstance(net_output, dict):
|
395 |
+
return self.speech_decoder.get_normalized_probs(
|
396 |
+
net_output, log_probs, sample
|
397 |
+
)
|
398 |
+
return self.text_decoder.get_normalized_probs(net_output, log_probs, sample)
|
399 |
+
|
400 |
+
@classmethod
|
401 |
+
def build_text_decoder(cls, args, tgt_dictionary, dec_emb_share=None):
|
402 |
+
dec_emb = (
|
403 |
+
nn.Embedding(
|
404 |
+
len(tgt_dictionary), args.decoder_embed_dim, tgt_dictionary.pad()
|
405 |
+
)
|
406 |
+
if dec_emb_share is None
|
407 |
+
else dec_emb_share
|
408 |
+
)
|
409 |
+
text_decoder = TransformerDecoder(args, tgt_dictionary, dec_emb)
|
410 |
+
return text_decoder
|
411 |
+
|
412 |
+
@classmethod
|
413 |
+
def build_dummy_speech_decoder(cls, args, dictionary, dec_emb_share=None):
|
414 |
+
dec_emb = (
|
415 |
+
nn.Embedding(len(dictionary), args.decoder_embed_dim, dictionary.pad())
|
416 |
+
if dec_emb_share is None
|
417 |
+
else dec_emb_share
|
418 |
+
)
|
419 |
+
speech_decoder = SpeechDummyDecoder(
|
420 |
+
dictionary,
|
421 |
+
dec_emb,
|
422 |
+
no_emb_update_unsup=getattr(args, "no_emb_update_unsup", False),
|
423 |
+
use_output_proj=getattr(args, "use_decoder_output_proj", False),
|
424 |
+
)
|
425 |
+
return speech_decoder
|
426 |
+
|
427 |
+
@classmethod
|
428 |
+
def build_decoder(
|
429 |
+
cls, args, text_dictionary, speech_dictionary, speech_output_embedding
|
430 |
+
):
|
431 |
+
text_decoder = cls.build_text_decoder(args, text_dictionary)
|
432 |
+
speech_decoder = cls.build_dummy_speech_decoder(
|
433 |
+
args, speech_dictionary, speech_output_embedding
|
434 |
+
)
|
435 |
+
if getattr(args, "load_pretrained_mbart_decoder_from", None):
|
436 |
+
text_decoder = checkpoint_utils.load_pretrained_component_from_model(
|
437 |
+
component=text_decoder,
|
438 |
+
checkpoint=args.load_pretrained_mbart_decoder_from,
|
439 |
+
)
|
440 |
+
return SpeechTextPreTrainDecoder(text_dictionary, speech_decoder, text_decoder)
|
441 |
+
|
442 |
+
|
443 |
+
@register_model("speech_text_pretrain_bart")
|
444 |
+
class SpeechTextPreTrainModel(FairseqEncoderDecoderModel):
|
445 |
+
def __init__(self, encoder, decoder):
|
446 |
+
super().__init__(encoder, decoder)
|
447 |
+
self.num_updates = 0
|
448 |
+
|
449 |
+
def forward(
|
450 |
+
self, src_tokens, src_lengths, prev_output_tokens, src_lang_ids=None, **kwargs
|
451 |
+
):
|
452 |
+
if src_lang_ids is not None:
|
453 |
+
encoder_out = self.encoder(
|
454 |
+
src_tokens, src_lengths=src_lengths, src_lang_ids=src_lang_ids, **kwargs
|
455 |
+
)
|
456 |
+
else:
|
457 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
458 |
+
decoder_out = self.decoder(
|
459 |
+
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
460 |
+
)
|
461 |
+
return decoder_out
|
462 |
+
|
463 |
+
def max_positions(self):
|
464 |
+
return None # it is provided in task
|
465 |
+
|
466 |
+
def get_targets(self, sample, net_output):
|
467 |
+
mode = sample["net_input"]["mode"]
|
468 |
+
if mode == "unsup_speech":
|
469 |
+
return {"target_logprobs": net_output[1]["pseudo_target_logprobs"]}
|
470 |
+
if mode == "sup_speech_ali":
|
471 |
+
return net_output[1]["pseudo_target_tokens"]
|
472 |
+
return sample["target"]
|
473 |
+
|
474 |
+
def get_normalized_probs(
|
475 |
+
self,
|
476 |
+
net_output,
|
477 |
+
log_probs,
|
478 |
+
sample=None,
|
479 |
+
):
|
480 |
+
# net_output['encoder_out'] is a (B, T, D) tensor
|
481 |
+
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
482 |
+
lprobs.batch_first = True
|
483 |
+
return lprobs
|
484 |
+
|
485 |
+
@staticmethod
|
486 |
+
def add_args(parser):
|
487 |
+
TransformerModel.add_args(parser)
|
488 |
+
SpeechWavTransformerEncoder.add_args(parser)
|
489 |
+
parser.add_argument(
|
490 |
+
"--speech-sup-mask-prob",
|
491 |
+
type=float,
|
492 |
+
help="probability of replacing a token with mask (sup-speech)",
|
493 |
+
)
|
494 |
+
parser.add_argument(
|
495 |
+
"--speech-unsup-mask-prob",
|
496 |
+
type=float,
|
497 |
+
help="probability of replacing a token with mask (unsup-speech)",
|
498 |
+
)
|
499 |
+
parser.add_argument(
|
500 |
+
"--load-pretrained-mbart-encoder-from",
|
501 |
+
type=str,
|
502 |
+
metavar="STR",
|
503 |
+
help="model to take text encoder weights from (for initialization)",
|
504 |
+
)
|
505 |
+
|
506 |
+
parser.add_argument(
|
507 |
+
"--load-pretrained-mbart-decoder-from",
|
508 |
+
type=str,
|
509 |
+
metavar="STR",
|
510 |
+
help="model to take text decoder weights from (for initialization)",
|
511 |
+
)
|
512 |
+
|
513 |
+
parser.add_argument(
|
514 |
+
"--load-pretrained-feature-extractor-from",
|
515 |
+
type=str,
|
516 |
+
metavar="STR",
|
517 |
+
help="model to take feature extractor weights from (for initialization)",
|
518 |
+
)
|
519 |
+
|
520 |
+
parser.add_argument(
|
521 |
+
"--speech-unsup-dropout",
|
522 |
+
type=float,
|
523 |
+
default=0,
|
524 |
+
help="dropout for unsupervised speech encoder",
|
525 |
+
)
|
526 |
+
|
527 |
+
parser.add_argument(
|
528 |
+
"--speech-unsup-feature-dropout",
|
529 |
+
type=float,
|
530 |
+
default=0,
|
531 |
+
help="dropout for unsupervised speech feature encoder",
|
532 |
+
)
|
533 |
+
|
534 |
+
parser.add_argument(
|
535 |
+
"--encoder-shared-text-layers-from-begin",
|
536 |
+
type=int,
|
537 |
+
help="number of text encoder layers shared with speech encoder (from first layer)",
|
538 |
+
)
|
539 |
+
|
540 |
+
parser.add_argument(
|
541 |
+
"--stacked-encoder",
|
542 |
+
default="none",
|
543 |
+
choices=["none", "s2s", "all"],
|
544 |
+
help="stack speech and text encoders",
|
545 |
+
)
|
546 |
+
|
547 |
+
parser.add_argument("--use-decoder-output-proj", action="store_true")
|
548 |
+
|
549 |
+
@classmethod
|
550 |
+
def build_model(cls, args, task):
|
551 |
+
encoder = SpeechTextPreTrainEncoder.build_encoder(args, task.src_dict)
|
552 |
+
decoder = SpeechTextPreTrainDecoder.build_decoder(
|
553 |
+
args, task.tgt_dict, task.src_dict, encoder.text_encoder.embed_tokens
|
554 |
+
)
|
555 |
+
model = SpeechTextPreTrainModel(encoder, decoder)
|
556 |
+
return model
|
557 |
+
|
558 |
+
def upgrade_state_dict(self, state_dict):
|
559 |
+
"""Upgrade old state dicts to work with newer code."""
|
560 |
+
if "decoder.speech_decoder.output_projection.weight" in state_dict:
|
561 |
+
del state_dict["decoder.speech_decoder.output_projection.weight"]
|
562 |
+
self.upgrade_state_dict_named(state_dict, "")
|
563 |
+
|
564 |
+
|
565 |
+
@register_model_architecture(
|
566 |
+
"speech_text_pretrain_bart", "speech_text_pretrain_bart_base"
|
567 |
+
)
|
568 |
+
def speech_text_pretrain_bart_base(args):
|
569 |
+
# speech masking
|
570 |
+
args.dropout_input = getattr(args, "dropout_input", 0)
|
571 |
+
args.dropout_features = getattr(args, "dropout_features", 0)
|
572 |
+
args.speech_mask_length = getattr(args, "speech_mask_length", 10)
|
573 |
+
args.speech_mask_prob = getattr(args, "speech_mask_prob", 0.65)
|
574 |
+
args.speech_sup_mask_prob = getattr(args, "speech_sup_mask_prob", 0.3)
|
575 |
+
args.speech_unsup_mask_prob = getattr(
|
576 |
+
args, "speech_unsup_mask_prob", args.speech_mask_prob
|
577 |
+
)
|
578 |
+
args.speech_mask_selection = getattr(args, "speech_mask_selection", "static")
|
579 |
+
args.speech_mask_other = getattr(args, "speech_mask_other", 0)
|
580 |
+
args.speech_mask_min_space = getattr(args, "speech_mask_min_space", 1)
|
581 |
+
args.speech_no_mask_overlap = getattr(args, "speech_no_mask_overlap", False)
|
582 |
+
|
583 |
+
args.speech_mask_channel_length = getattr(args, "speech_mask_channel_length", 10)
|
584 |
+
args.speech_mask_channel_prob = getattr(args, "speech_mask_channel_prob", 0.0)
|
585 |
+
args.speech_mask_channel_selection = getattr(
|
586 |
+
args, "speech_mask_channel_selection", "static"
|
587 |
+
)
|
588 |
+
args.speech_mask_channel_other = getattr(args, "speech_mask_channel_other", 0)
|
589 |
+
args.speech_mask_channel_min_space = getattr(
|
590 |
+
args, "speech_mask_channel_min_space", 1
|
591 |
+
)
|
592 |
+
args.speech_no_mask_channel_overlap = getattr(
|
593 |
+
args, "speech_no_mask_channel_overlap", False
|
594 |
+
)
|
595 |
+
args.no_scale_feature = getattr(args, "", False)
|
596 |
+
args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) # 0.1
|
597 |
+
|
598 |
+
# Transformer
|
599 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
600 |
+
args.encoder_ffn_embed_dim = getattr(
|
601 |
+
args, "encoder_ffn_embed_dim", args.encoder_embed_dim * 4
|
602 |
+
)
|
603 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
604 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
605 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
606 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
607 |
+
args.speech_conv_bias = getattr(args, "speech_conv_bias", False)
|
608 |
+
|
609 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
610 |
+
args.decoder_ffn_embed_dim = getattr(
|
611 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
612 |
+
)
|
613 |
+
args.decoder_attention_heads = getattr(
|
614 |
+
args, "decoder_attention_heads", args.encoder_attention_heads
|
615 |
+
)
|
616 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
617 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
618 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
619 |
+
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
|
620 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
621 |
+
args.activation_fn = getattr(args, "activation_fn", "relu") # gelu?
|
622 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
623 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
624 |
+
|
625 |
+
args.speech_unsup_dropout = getattr(args, "speech_unsup_dropout", 0)
|
626 |
+
args.speech_unsup_feature_dropout = getattr(args, "speech_unsup_feature_dropout", 0)
|
627 |
+
|
628 |
+
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
629 |
+
args.share_decoder_input_output_embed = getattr(
|
630 |
+
args, "share_decoder_input_output_embed", False
|
631 |
+
)
|
632 |
+
args.no_token_positional_embeddings = getattr(
|
633 |
+
args, "no_token_positional_embeddings", False
|
634 |
+
)
|
635 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
636 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
637 |
+
args.decoder_output_dim = getattr(
|
638 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
639 |
+
)
|
640 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
641 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
642 |
+
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
643 |
+
|
644 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
|
645 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
|
646 |
+
args.encoder_shared_text_layers_from_begin = getattr(
|
647 |
+
args, "encoder_shared_text_layers_from_begin", 6
|
648 |
+
)
|
649 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
650 |
+
|
651 |
+
args.no_emb_update_unsup = getattr(args, "no_emb_update_unsup", False)
|
652 |
+
|
653 |
+
|
654 |
+
@register_model_architecture(
|
655 |
+
"speech_text_pretrain_bart", "speech_text_pretrain_bart_base_stack"
|
656 |
+
)
|
657 |
+
def speech_text_pretrain_bart_base_stack(args):
|
658 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
|
659 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
|
660 |
+
args.encoder_shared_text_layers_from_begin = getattr(
|
661 |
+
args, "encoder_shared_text_layers_from_begin", 0
|
662 |
+
)
|
663 |
+
args.stacked_encoder = getattr(args, "stacked_encoder", "all")
|
664 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
665 |
+
speech_text_pretrain_bart_base(args)
|
666 |
+
|
667 |
+
|
668 |
+
@register_model_architecture(
|
669 |
+
"speech_text_pretrain_bart", "speech_text_pretrain_bart_large"
|
670 |
+
)
|
671 |
+
def speech_text_pretrain_bart_large(args):
|
672 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
673 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
674 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 24)
|
675 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
|
676 |
+
args.encoder_shared_text_layers_from_begin = getattr(
|
677 |
+
args, "encoder_shared_text_layers_from_begin", 12
|
678 |
+
)
|
679 |
+
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
680 |
+
args.dropout = getattr(args, "dropout", 0.3)
|
681 |
+
speech_text_pretrain_bart_base(args)
|
682 |
+
|
683 |
+
|
684 |
+
@register_model_architecture(
|
685 |
+
"speech_text_pretrain_bart", "speech_text_pretrain_bart_large_stack"
|
686 |
+
)
|
687 |
+
def speech_text_pretrain_bart_large_stack(args):
|
688 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
689 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
690 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
|
691 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
|
692 |
+
args.encoder_shared_text_layers_from_begin = getattr(
|
693 |
+
args, "encoder_shared_text_layers_from_begin", 0
|
694 |
+
)
|
695 |
+
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
696 |
+
args.stacked_encoder = getattr(args, "stacked_encoder", "s2s")
|
697 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
698 |
+
speech_text_pretrain_bart_base(args)
|
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py
ADDED
@@ -0,0 +1,1093 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from collections import namedtuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from fairseq import checkpoint_utils
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.models import (
|
14 |
+
FairseqEncoder,
|
15 |
+
FairseqDecoder,
|
16 |
+
FairseqEncoderDecoderModel,
|
17 |
+
register_model,
|
18 |
+
register_model_architecture,
|
19 |
+
)
|
20 |
+
from fairseq.models.fairseq_encoder import EncoderOut
|
21 |
+
from fairseq.models.speech_to_text import (
|
22 |
+
TransformerDecoder,
|
23 |
+
S2TTransformerEncoder,
|
24 |
+
)
|
25 |
+
from fairseq.models.transformer import TransformerEncoder
|
26 |
+
from fairseq.modules import (
|
27 |
+
TransformerEncoderLayer,
|
28 |
+
GradMultiply,
|
29 |
+
LayerNorm,
|
30 |
+
)
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
class SpeechEoSEncoder(FairseqEncoder):
|
36 |
+
def __init__(self, encoder, eos_num, feat_dim, adapter_type="None", adapter_dim=0):
|
37 |
+
super().__init__(None)
|
38 |
+
self.encoder = encoder
|
39 |
+
self.eos_num = eos_num # downsampling rate for speech input feature
|
40 |
+
self.eos_emb = (
|
41 |
+
nn.Parameter(torch.zeros(1, feat_dim), requires_grad=True)
|
42 |
+
if eos_num > 0
|
43 |
+
else None
|
44 |
+
)
|
45 |
+
self.adapter = self.add_adapter(adapter_type, adapter_dim)
|
46 |
+
|
47 |
+
def add_adapter(self, adapter_type, adapter_dim):
|
48 |
+
def _make_identity(linear, eps=1e-5):
|
49 |
+
assert isinstance(linear, nn.Linear)
|
50 |
+
linear.weight.data.mul_(eps)
|
51 |
+
linear.weight.data.fill_diagonal_(1.0)
|
52 |
+
if linear.bias is not None:
|
53 |
+
linear.bias.data.mul_(eps)
|
54 |
+
|
55 |
+
adapter = None
|
56 |
+
if adapter_type == "Linear":
|
57 |
+
assert adapter_dim > 0
|
58 |
+
adapter = nn.Sequential(
|
59 |
+
nn.Linear(adapter_dim, adapter_dim), LayerNorm(adapter_dim)
|
60 |
+
)
|
61 |
+
# initialize the adapter as identity matrix first
|
62 |
+
_make_identity(adapter[0])
|
63 |
+
|
64 |
+
elif adapter_type == "MLP":
|
65 |
+
assert adapter_dim > 0
|
66 |
+
# assume the model is pre-norm model
|
67 |
+
adapter = nn.Sequential(
|
68 |
+
nn.Linear(adapter_dim, 2 * adapter_dim),
|
69 |
+
nn.ReLU(),
|
70 |
+
nn.Linear(2 * adapter_dim, adapter_dim),
|
71 |
+
LayerNorm(adapter_dim),
|
72 |
+
)
|
73 |
+
_make_identity(adapter[0])
|
74 |
+
_make_identity(adapter[2])
|
75 |
+
return adapter
|
76 |
+
|
77 |
+
def add_eos(self, src_tokens, src_lengths):
|
78 |
+
bsz, max_seq_len, fdim = src_tokens.size()
|
79 |
+
if self.eos_num > 0:
|
80 |
+
src_token_eos = torch.zeros(
|
81 |
+
[bsz, max_seq_len + self.eos_num, fdim],
|
82 |
+
dtype=src_tokens.dtype,
|
83 |
+
device=src_tokens.device,
|
84 |
+
)
|
85 |
+
src_token_eos[:, :max_seq_len] = src_tokens
|
86 |
+
for bi in range(bsz):
|
87 |
+
src_token_eos[bi][
|
88 |
+
src_lengths[bi] : src_lengths[bi] + self.eos_num
|
89 |
+
] = self.eos_emb.expand(self.eos_num, fdim)
|
90 |
+
src_lengths = src_lengths + self.eos_num
|
91 |
+
src_tokens = src_token_eos
|
92 |
+
return src_tokens, src_lengths
|
93 |
+
|
94 |
+
def apply_adapter(self, enc_out):
|
95 |
+
if self.adapter is None:
|
96 |
+
return enc_out
|
97 |
+
rst = self.adapter(enc_out.encoder_out)
|
98 |
+
if enc_out.encoder_padding_mask is not None:
|
99 |
+
rst.masked_fill_(
|
100 |
+
enc_out.encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0
|
101 |
+
)
|
102 |
+
return EncoderOut(
|
103 |
+
encoder_out=rst,
|
104 |
+
encoder_padding_mask=enc_out.encoder_padding_mask,
|
105 |
+
encoder_embedding=enc_out.encoder_embedding,
|
106 |
+
encoder_states=enc_out.encoder_states,
|
107 |
+
src_tokens=enc_out.src_tokens,
|
108 |
+
src_lengths=enc_out.src_lengths,
|
109 |
+
)
|
110 |
+
|
111 |
+
def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs):
|
112 |
+
"""
|
113 |
+
src_tokens: padded tensor (B, T, C * feat)
|
114 |
+
src_lengths: tensor of original lengths of input utterances (B,)
|
115 |
+
"""
|
116 |
+
src_tokens, src_lengths = self.add_eos(src_tokens, src_lengths)
|
117 |
+
enc_out = self.encoder(src_tokens, src_lengths, return_all_hiddens)
|
118 |
+
enc_out = self.apply_adapter(enc_out)
|
119 |
+
return enc_out
|
120 |
+
|
121 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
122 |
+
return self.encoder.reorder_encoder_out(encoder_out, new_order)
|
123 |
+
|
124 |
+
|
125 |
+
class DualInputEncoder(FairseqEncoder):
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
args,
|
129 |
+
spch_encoder,
|
130 |
+
text_encoder,
|
131 |
+
dictionary,
|
132 |
+
cross_attentive_loss_before_last_layer=-1,
|
133 |
+
):
|
134 |
+
super().__init__(dictionary)
|
135 |
+
|
136 |
+
self.spch_encoder = spch_encoder
|
137 |
+
self.text_encoder = text_encoder
|
138 |
+
self.enc_grad_mult = args.enc_grad_mult
|
139 |
+
self.cross_attentive_loss_before_last_layer = (
|
140 |
+
cross_attentive_loss_before_last_layer
|
141 |
+
)
|
142 |
+
self.use_cross_attentive_loss = (
|
143 |
+
False if cross_attentive_loss_before_last_layer <= -1 else True
|
144 |
+
)
|
145 |
+
self.enc2_along_grad_mult = args.enc2_along_grad_mult
|
146 |
+
|
147 |
+
@classmethod
|
148 |
+
def set_shared_layer(cls, share_level, src_layer, tgt_layer):
|
149 |
+
"""
|
150 |
+
share parameters from tgt_layer to src_layer
|
151 |
+
share_level:
|
152 |
+
0: share everything
|
153 |
+
1: share everything but different model
|
154 |
+
2: share weight but not bias, layernorm
|
155 |
+
"""
|
156 |
+
if share_level == 0:
|
157 |
+
return tgt_layer
|
158 |
+
if isinstance(src_layer, nn.Linear):
|
159 |
+
return tgt_layer
|
160 |
+
if isinstance(src_layer, TransformerEncoderLayer):
|
161 |
+
assert src_layer.embed_dim == tgt_layer.embed_dim
|
162 |
+
assert src_layer.normalize_before == tgt_layer.normalize_before
|
163 |
+
if share_level == 1:
|
164 |
+
src_layer.fc1 = tgt_layer.fc1
|
165 |
+
src_layer.fc2 = tgt_layer.fc2
|
166 |
+
src_layer.self_attn = tgt_layer.self_attn
|
167 |
+
src_layer.final_layer_norm = tgt_layer.final_layer_norm
|
168 |
+
src_layer.self_attn_layer_norm = tgt_layer.self_attn_layer_norm
|
169 |
+
src_layer.layernorm_embedding = tgt_layer.layernorm_embedding
|
170 |
+
else:
|
171 |
+
src_layer.fc1.weight = tgt_layer.fc1.weight
|
172 |
+
src_layer.fc2.weight = tgt_layer.fc2.weight
|
173 |
+
src_layer.self_attn.k_proj.weight = tgt_layer.self_attn.k_proj.weight
|
174 |
+
src_layer.self_attn.v_proj.weight = tgt_layer.self_attn.v_proj.weight
|
175 |
+
src_layer.self_attn.q_proj.weight = tgt_layer.self_attn.q_proj.weight
|
176 |
+
src_layer.self_attn.out_proj.weight = (
|
177 |
+
tgt_layer.self_attn.out_proj.weight
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
if share_level == 1:
|
181 |
+
return tgt_layer
|
182 |
+
return src_layer
|
183 |
+
|
184 |
+
@classmethod
|
185 |
+
def build_spch_encoder(cls, args):
|
186 |
+
cfg = {
|
187 |
+
"input_feat_per_channel": args.input_feat_per_channel,
|
188 |
+
"input_channels": args.input_channels,
|
189 |
+
"conv_kernel_sizes": args.conv_kernel_sizes,
|
190 |
+
"conv_channels": args.conv_channels,
|
191 |
+
"encoder_embed_dim": args.encoder_embed_dim,
|
192 |
+
"encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
|
193 |
+
"encoder_layers": args.speech_encoder_layers,
|
194 |
+
"encoder_layerdrop": args.encoder_layerdrop,
|
195 |
+
"encoder_attention_heads": args.encoder_attention_heads,
|
196 |
+
"max_source_positions": args.max_source_positions,
|
197 |
+
"dropout": args.dropout,
|
198 |
+
"encoder_normalize_before": args.encoder_normalize_before,
|
199 |
+
"activation_dropout": args.activation_dropout,
|
200 |
+
"attention_dropout": args.attention_dropout,
|
201 |
+
"activation_fn": args.activation_fn,
|
202 |
+
"layernorm_embedding": args.layernorm_embedding,
|
203 |
+
"no_token_positional_embeddings": args.no_token_positional_embeddings,
|
204 |
+
"no_scale_embedding": args.no_scale_embedding,
|
205 |
+
"quant_noise_pq": args.quant_noise_pq,
|
206 |
+
"encoder_freezing_updates": 0,
|
207 |
+
}
|
208 |
+
model_args = namedtuple("args", cfg.keys())(*cfg.values())
|
209 |
+
spch_encoder = S2TTransformerEncoder(model_args)
|
210 |
+
if args.add_speech_eos:
|
211 |
+
spch_encoder = SpeechEoSEncoder(
|
212 |
+
spch_encoder,
|
213 |
+
2 * len(args.conv_kernel_sizes.split(",")),
|
214 |
+
args.input_feat_per_channel,
|
215 |
+
adapter_type=getattr(args, "speech_encoder_adapter_type", "None"),
|
216 |
+
adapter_dim=args.encoder_embed_dim,
|
217 |
+
)
|
218 |
+
return spch_encoder
|
219 |
+
|
220 |
+
@classmethod
|
221 |
+
def build_text_encoder(cls, args, src_dictionary, spch_encoder):
|
222 |
+
if args.encoder_shared_layers > 0:
|
223 |
+
mx_shared_layers = (
|
224 |
+
args.speech_encoder_layers
|
225 |
+
if args.speech_encoder_layers < args.text_encoder_layers
|
226 |
+
else args.text_encoder_layers
|
227 |
+
)
|
228 |
+
args.encoder_shared_layers = (
|
229 |
+
args.encoder_shared_layers
|
230 |
+
if args.encoder_shared_layers <= mx_shared_layers
|
231 |
+
else mx_shared_layers
|
232 |
+
)
|
233 |
+
cfg = {
|
234 |
+
"encoder_embed_dim": args.encoder_text_embed_dim,
|
235 |
+
"encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
|
236 |
+
"encoder_layers": args.text_encoder_layers,
|
237 |
+
"encoder_layerdrop": args.encoder_layerdrop,
|
238 |
+
"encoder_attention_heads": args.encoder_attention_heads,
|
239 |
+
"encoder_learned_pos": args.encoder_learned_pos,
|
240 |
+
"max_source_positions": args.max_source_positions,
|
241 |
+
"dropout": args.dropout,
|
242 |
+
"encoder_normalize_before": args.encoder_normalize_before,
|
243 |
+
"activation_dropout": args.activation_dropout,
|
244 |
+
"attention_dropout": args.attention_dropout,
|
245 |
+
"activation_fn": args.activation_fn,
|
246 |
+
"adaptive_input": args.adaptive_input,
|
247 |
+
"no_token_positional_embeddings": args.no_token_positional_embeddings,
|
248 |
+
"no_scale_embedding": args.no_scale_embedding,
|
249 |
+
"quant_noise_pq": args.quant_noise_pq,
|
250 |
+
}
|
251 |
+
model_args = namedtuple("args", cfg.keys())(*cfg.values())
|
252 |
+
enc_emb = nn.Embedding(
|
253 |
+
len(src_dictionary), model_args.encoder_embed_dim, src_dictionary.pad()
|
254 |
+
)
|
255 |
+
text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
|
256 |
+
if args.add_speech_eos:
|
257 |
+
spch_encoder = spch_encoder.encoder
|
258 |
+
if args.encoder_shared_layers > 0:
|
259 |
+
text_encoder.layer_norm = cls.set_shared_layer(
|
260 |
+
args.encoder_shared_layer_level,
|
261 |
+
text_encoder.layer_norm,
|
262 |
+
spch_encoder.layer_norm,
|
263 |
+
)
|
264 |
+
for i, ly in enumerate(
|
265 |
+
spch_encoder.transformer_layers[-args.encoder_shared_layers :]
|
266 |
+
):
|
267 |
+
ly_id = i + args.text_encoder_layers - args.encoder_shared_layers
|
268 |
+
if not isinstance(text_encoder.layers[ly_id], type(ly)):
|
269 |
+
if text_encoder.layers[ly_id]._get_name() not in ('TransformerEncoderLayerBase', 'TransformerEncoderLayer'):
|
270 |
+
raise ValueError("The shared layers are expected from the same class")
|
271 |
+
text_encoder.layers[ly_id] = cls.set_shared_layer(
|
272 |
+
args.encoder_shared_layer_level,
|
273 |
+
text_encoder.layers[ly_id],
|
274 |
+
ly,
|
275 |
+
)
|
276 |
+
return text_encoder
|
277 |
+
|
278 |
+
def mult_rst_grad(self, rst, ratio):
|
279 |
+
assert isinstance(rst, dict) # instead of EncoderOut
|
280 |
+
assert len(rst["encoder_out"]) == 1
|
281 |
+
rst["encoder_out"][0] = GradMultiply.apply(rst["encoder_out"][0], ratio)
|
282 |
+
return rst
|
283 |
+
|
284 |
+
def process_attentive_loss_states(self, rst, interstates):
|
285 |
+
assert isinstance(rst, dict) # instead of EncoderOut
|
286 |
+
rst["encoder_states"] = interstates
|
287 |
+
return rst
|
288 |
+
|
289 |
+
def forward(
|
290 |
+
self,
|
291 |
+
src_tokens,
|
292 |
+
src_lengths=None,
|
293 |
+
src_txt_tokens=None,
|
294 |
+
src_txt_lengths=None,
|
295 |
+
**kwargs
|
296 |
+
):
|
297 |
+
"""
|
298 |
+
Args:
|
299 |
+
src_tokens: padded tensor (B, T, C * feat)
|
300 |
+
src_lengths: tensor of original lengths of input utterances (speech) (B,)
|
301 |
+
src_txt_tokens: padded tensor (B, T)
|
302 |
+
src_txt_lengths: tensor of original lengths of input utterances (text) (B,)
|
303 |
+
"""
|
304 |
+
# src_tokens only: inference
|
305 |
+
# src_tokens, src_lengths: speech only training
|
306 |
+
# src_txt_tokens, src_txt_lengths: text only training
|
307 |
+
# all valid: speech + text training
|
308 |
+
|
309 |
+
if src_tokens is None and src_txt_tokens is None:
|
310 |
+
raise ValueError(
|
311 |
+
"src_tokens and src_txt_tokens cannot be None at the same time"
|
312 |
+
)
|
313 |
+
ret1 = None
|
314 |
+
ret2 = None
|
315 |
+
return_all_hiddens = False
|
316 |
+
if src_tokens is not None:
|
317 |
+
if (
|
318 |
+
self.use_cross_attentive_loss and src_txt_tokens is not None
|
319 |
+
): # remove self.training so we can get attn score during validation step
|
320 |
+
return_all_hiddens = True
|
321 |
+
ret1 = self.spch_encoder(
|
322 |
+
src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
|
323 |
+
)
|
324 |
+
|
325 |
+
if self.use_cross_attentive_loss and src_txt_tokens is not None:
|
326 |
+
assert self.cross_attentive_loss_before_last_layer < len(
|
327 |
+
ret1["encoder_states"]
|
328 |
+
)
|
329 |
+
ret1 = self.process_attentive_loss_states(
|
330 |
+
ret1,
|
331 |
+
ret1["encoder_states"][
|
332 |
+
-self.cross_attentive_loss_before_last_layer - 1
|
333 |
+
],
|
334 |
+
)
|
335 |
+
|
336 |
+
if src_txt_tokens is not None:
|
337 |
+
ret2 = self.text_encoder(
|
338 |
+
src_txt_tokens, src_txt_lengths, return_all_hiddens=return_all_hiddens
|
339 |
+
)
|
340 |
+
if return_all_hiddens:
|
341 |
+
if self.cross_attentive_loss_before_last_layer == len(
|
342 |
+
self.text_encoder.layers
|
343 |
+
):
|
344 |
+
text_embedding, _ = self.text_encoder.forward_embedding(
|
345 |
+
src_txt_tokens
|
346 |
+
)
|
347 |
+
text_embedding = text_embedding.transpose(0, 1)
|
348 |
+
ret2 = self.process_attentive_loss_states(ret2, text_embedding)
|
349 |
+
else:
|
350 |
+
assert self.cross_attentive_loss_before_last_layer < len(
|
351 |
+
self.text_encoder.layers
|
352 |
+
)
|
353 |
+
ret2 = self.process_attentive_loss_states(
|
354 |
+
ret2,
|
355 |
+
ret2["encoder_states"][
|
356 |
+
-self.cross_attentive_loss_before_last_layer - 1
|
357 |
+
],
|
358 |
+
)
|
359 |
+
|
360 |
+
def merge_output(rst1, rst2):
|
361 |
+
if rst1 is None:
|
362 |
+
if not (self.enc2_along_grad_mult == 1.0 or self.training):
|
363 |
+
rst2 = self.mult_rst_grad(rst2, self.enc2_along_grad_mult)
|
364 |
+
return rst2
|
365 |
+
if rst2 is None:
|
366 |
+
return rst1
|
367 |
+
if self.enc_grad_mult != 1.0 and self.training:
|
368 |
+
rst1 = self.mult_rst_grad(rst1, self.enc_grad_mult)
|
369 |
+
rst2 = self.mult_rst_grad(rst2, self.enc_grad_mult)
|
370 |
+
rst = (rst1, rst2)
|
371 |
+
return rst
|
372 |
+
|
373 |
+
return merge_output(ret1, ret2)
|
374 |
+
|
375 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
376 |
+
assert self.training is False # used for inference only
|
377 |
+
return self.spch_encoder.reorder_encoder_out(encoder_out, new_order)
|
378 |
+
|
379 |
+
|
380 |
+
# TransformerMultiInputDecoder: take one or two encoder inputs
|
381 |
+
class TransformerMultiInputDecoder(FairseqDecoder):
|
382 |
+
def __init__(
|
383 |
+
self,
|
384 |
+
dictionary,
|
385 |
+
spch_decoder,
|
386 |
+
text_decoder,
|
387 |
+
compute_cross_attentive_loss=False,
|
388 |
+
cross_attentive_loss_with_norm=True,
|
389 |
+
cross_attentive_loss_reverse=False,
|
390 |
+
):
|
391 |
+
|
392 |
+
super().__init__(dictionary)
|
393 |
+
self.spch_decoder = spch_decoder
|
394 |
+
self.text_decoder = text_decoder
|
395 |
+
self.compute_cross_attentive_loss = compute_cross_attentive_loss
|
396 |
+
self.cross_attentive_loss_with_norm = cross_attentive_loss_with_norm
|
397 |
+
self.cross_attentive_loss_reverse = cross_attentive_loss_reverse
|
398 |
+
|
399 |
+
@classmethod
|
400 |
+
def share_spchdecoder(cls, task_args, text_decoder, spch_decoder):
|
401 |
+
if task_args.decoder_shared_layer_level == 0:
|
402 |
+
return text_decoder
|
403 |
+
assert text_decoder.embed_tokens == spch_decoder.embed_tokens
|
404 |
+
spch_decoder.project_in_dim = text_decoder.project_in_dim
|
405 |
+
spch_decoder.embed_positions = text_decoder.embed_positions
|
406 |
+
spch_decoder.layernorm_embedding = text_decoder.layernorm_embedding
|
407 |
+
spch_decoder.project_out_dim = text_decoder.project_out_dim
|
408 |
+
spch_decoder.adaptive_softmax = text_decoder.adaptive_softmax
|
409 |
+
if task_args.decoder_shared_layer_level == 1:
|
410 |
+
spch_decoder.output_projection = text_decoder.output_projection
|
411 |
+
spch_decoder.layer_norm = text_decoder.layer_norm
|
412 |
+
else: # 2
|
413 |
+
spch_decoder.output_projection.weight = (
|
414 |
+
text_decoder.output_projection.weight
|
415 |
+
)
|
416 |
+
for i, ly in enumerate(text_decoder.layers):
|
417 |
+
sly = spch_decoder.layers[i]
|
418 |
+
sly.self_attn = ly.self_attn
|
419 |
+
sly.self_attn_layer_norm = ly.self_attn_layer_norm
|
420 |
+
# sly.encoder_attn = ly.encoder_attn
|
421 |
+
if (
|
422 |
+
task_args.decoder_shared_layer_level == 1
|
423 |
+
): # share everything, but under different models
|
424 |
+
sly.encoder_attn = ly.encoder_attn
|
425 |
+
sly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm
|
426 |
+
sly.fc1 = ly.fc1
|
427 |
+
sly.fc2 = ly.fc2
|
428 |
+
sly.final_layer_norm = ly.final_layer_norm
|
429 |
+
else: # task_args.decoder_shared_layer_level == 2: #separated encoder_attn_layer_norm and bias
|
430 |
+
sly.encoder_attn.k_proj.weight = ly.encoder_attn.k_proj.weight
|
431 |
+
sly.encoder_attn.v_proj.weight = ly.encoder_attn.v_proj.weight
|
432 |
+
sly.encoder_attn.q_proj.weight = ly.encoder_attn.q_proj.weight
|
433 |
+
sly.encoder_attn.out_proj.weight = ly.encoder_attn.out_proj.weight
|
434 |
+
sly.fc1.weight = ly.fc1.weight
|
435 |
+
sly.fc2.weight = ly.fc2.weight
|
436 |
+
|
437 |
+
return spch_decoder
|
438 |
+
|
439 |
+
def cross_attentive_loss(
|
440 |
+
self, teacher_states, student_states, teacher_masking, student_masking, eps=1e-6
|
441 |
+
):
|
442 |
+
x = teacher_states.transpose(0, 1) # from T X B X D to B X T X D
|
443 |
+
y = student_states.transpose(0, 1)
|
444 |
+
if self.cross_attentive_loss_with_norm:
|
445 |
+
x = x / (x.norm(dim=2, keepdim=True) + eps)
|
446 |
+
y = y / (y.norm(dim=2, keepdim=True) + eps)
|
447 |
+
dim = x.size(-1)
|
448 |
+
# lengths: batch X seqLen
|
449 |
+
sim_scores_xy = torch.bmm(x, y.transpose(1, 2)) # batch X lenx X leny ]
|
450 |
+
if y.dtype == torch.float16:
|
451 |
+
sim_scores_xy = sim_scores_xy.float()
|
452 |
+
y = y.float()
|
453 |
+
x = x.float()
|
454 |
+
if teacher_masking != []:
|
455 |
+
assert len(teacher_masking) == 1
|
456 |
+
sim_scores_xy = sim_scores_xy.masked_fill(
|
457 |
+
teacher_masking[0].unsqueeze(-1), float("-inf")
|
458 |
+
)
|
459 |
+
if student_masking != []:
|
460 |
+
sim_scores_xy = sim_scores_xy.masked_fill(
|
461 |
+
student_masking[0].unsqueeze(1), float("-inf")
|
462 |
+
)
|
463 |
+
# do masking
|
464 |
+
y_weights = utils.softmax(sim_scores_xy, dim=-1)
|
465 |
+
if teacher_masking != []:
|
466 |
+
y_weights = y_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0)
|
467 |
+
x_reconstruct_from_y = torch.bmm(y_weights, y)
|
468 |
+
|
469 |
+
sim_scores_xx = torch.bmm(x, x.transpose(1, 2)) # batch X lenx X lenx ]
|
470 |
+
x_weights = utils.softmax(sim_scores_xx, dim=-1)
|
471 |
+
if teacher_masking != []:
|
472 |
+
x_weights = x_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0)
|
473 |
+
|
474 |
+
# no gradient for teacher state
|
475 |
+
x_reconstruct_from_x = torch.bmm(x_weights, x).detach()
|
476 |
+
cost = (x_reconstruct_from_x - x_reconstruct_from_y).norm(dim=2)
|
477 |
+
if teacher_masking != []:
|
478 |
+
cost = cost.masked_fill(teacher_masking[0], 0)
|
479 |
+
|
480 |
+
if not self.cross_attentive_loss_with_norm:
|
481 |
+
cost = cost / dim
|
482 |
+
return cost
|
483 |
+
|
484 |
+
def forward(
|
485 |
+
self,
|
486 |
+
prev_output_tokens,
|
487 |
+
encoder_out,
|
488 |
+
incremental_state=None,
|
489 |
+
has_txt_input=False,
|
490 |
+
**kwargs
|
491 |
+
):
|
492 |
+
"""
|
493 |
+
Args:
|
494 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
495 |
+
`(batch, tgt_len)`, for input feeding/teacher forcing. If there are
|
496 |
+
two or more input during training, they will share the same prev_output_tokens
|
497 |
+
encoder_out (tuple[Tensor]): output from the encoder, used for
|
498 |
+
encoder-side attention. It will be tuple if there are more inputs, but a tensor
|
499 |
+
if only one input
|
500 |
+
incremental_state ([dict]): dictionary used for storing state during
|
501 |
+
:ref:`Incremental decoding`. It is only valid for inference, only from single
|
502 |
+
input
|
503 |
+
Returns:
|
504 |
+
tuple:
|
505 |
+
- the last decoder layer's output of shape `(batch, tgt_len,
|
506 |
+
vocab)`. If there are N inputs, batch will be N bigger than a single input
|
507 |
+
- the last decoder layer's attention weights of shape `(batch,
|
508 |
+
tgt_len, src_len)`
|
509 |
+
"""
|
510 |
+
assert not isinstance(encoder_out, EncoderOut)
|
511 |
+
if isinstance(encoder_out, tuple): # training with mulitple input
|
512 |
+
rst = []
|
513 |
+
assert len(encoder_out) == 2
|
514 |
+
for i, eo in enumerate(encoder_out):
|
515 |
+
assert incremental_state is None
|
516 |
+
if i == 0:
|
517 |
+
rst.append(
|
518 |
+
self.spch_decoder(prev_output_tokens, eo, incremental_state)
|
519 |
+
)
|
520 |
+
else:
|
521 |
+
rst.append(
|
522 |
+
self.text_decoder(prev_output_tokens, eo, incremental_state)
|
523 |
+
)
|
524 |
+
dec_out = torch.cat([r[0] for r in rst], dim=0)
|
525 |
+
attn_cost = None
|
526 |
+
if self.compute_cross_attentive_loss:
|
527 |
+
assert isinstance(encoder_out[0], dict)
|
528 |
+
if self.cross_attentive_loss_reverse:
|
529 |
+
attn_cost = self.cross_attentive_loss(
|
530 |
+
teacher_states=encoder_out[1]["encoder_states"], # text_states
|
531 |
+
student_states=encoder_out[0]["encoder_states"], # spch_states
|
532 |
+
teacher_masking=encoder_out[1]["encoder_padding_mask"],
|
533 |
+
student_masking=encoder_out[0]["encoder_padding_mask"],
|
534 |
+
)
|
535 |
+
else:
|
536 |
+
attn_cost = self.cross_attentive_loss(
|
537 |
+
teacher_states=encoder_out[0]["encoder_states"], # spch_states
|
538 |
+
student_states=encoder_out[1]["encoder_states"], # text_states
|
539 |
+
teacher_masking=encoder_out[0]["encoder_padding_mask"],
|
540 |
+
student_masking=encoder_out[1]["encoder_padding_mask"],
|
541 |
+
)
|
542 |
+
|
543 |
+
return (dec_out, {"attn_cost": attn_cost})
|
544 |
+
else: # inference or training with one input
|
545 |
+
if has_txt_input:
|
546 |
+
return self.text_decoder(
|
547 |
+
prev_output_tokens, encoder_out, incremental_state
|
548 |
+
)
|
549 |
+
return self.spch_decoder(prev_output_tokens, encoder_out, incremental_state)
|
550 |
+
|
551 |
+
|
552 |
+
# Note:
|
553 |
+
# dual input transformer:
|
554 |
+
# encoder: S2TTransformerEncoder for speech + TransformerEncoder for text
|
555 |
+
# decoder: TransformerDecoder for text
|
556 |
+
@register_model("dual_input_s2t_transformer")
|
557 |
+
class DualInputS2TTransformerModel(FairseqEncoderDecoderModel):
|
558 |
+
def __init__(self, encoder, decoder):
|
559 |
+
super().__init__(encoder, decoder)
|
560 |
+
self.num_updates = 0
|
561 |
+
|
562 |
+
def max_positions(self):
|
563 |
+
return None # it is provided in task
|
564 |
+
|
565 |
+
@staticmethod
|
566 |
+
def add_args(parser):
|
567 |
+
"""Add model-specific arguments to the parser."""
|
568 |
+
# encoder 1: S2TTransformerEncoder for speech
|
569 |
+
parser.add_argument(
|
570 |
+
"--conv-kernel-sizes",
|
571 |
+
type=str,
|
572 |
+
metavar="N",
|
573 |
+
help="kernel sizes of Conv1d subsampling layers",
|
574 |
+
)
|
575 |
+
parser.add_argument(
|
576 |
+
"--conv-channels",
|
577 |
+
type=int,
|
578 |
+
metavar="N",
|
579 |
+
help="# of channels in Conv1d subsampling layers",
|
580 |
+
)
|
581 |
+
parser.add_argument(
|
582 |
+
"--enc-output-dim",
|
583 |
+
type=int,
|
584 |
+
metavar="N",
|
585 |
+
help="""
|
586 |
+
encoder output dimension, can be None. If specified, projecting the
|
587 |
+
transformer output to the specified dimension""",
|
588 |
+
)
|
589 |
+
# standard Transformer
|
590 |
+
parser.add_argument(
|
591 |
+
"--activation-fn",
|
592 |
+
type=str,
|
593 |
+
default="relu",
|
594 |
+
choices=utils.get_available_activation_fns(),
|
595 |
+
help="activation function to use",
|
596 |
+
)
|
597 |
+
parser.add_argument(
|
598 |
+
"--dropout", type=float, metavar="D", help="dropout probability"
|
599 |
+
)
|
600 |
+
parser.add_argument(
|
601 |
+
"--attention-dropout",
|
602 |
+
type=float,
|
603 |
+
metavar="D",
|
604 |
+
help="dropout probability for attention weights",
|
605 |
+
)
|
606 |
+
parser.add_argument(
|
607 |
+
"--activation-dropout",
|
608 |
+
"--relu-dropout",
|
609 |
+
type=float,
|
610 |
+
metavar="D",
|
611 |
+
help="dropout probability after activation in FFN.",
|
612 |
+
)
|
613 |
+
parser.add_argument(
|
614 |
+
"--encoder-embed-dim",
|
615 |
+
type=int,
|
616 |
+
metavar="N",
|
617 |
+
help="encoder embedding dimension",
|
618 |
+
)
|
619 |
+
parser.add_argument(
|
620 |
+
"--encoder-text-embed-dim",
|
621 |
+
type=int,
|
622 |
+
metavar="N",
|
623 |
+
help="encoder text embedding dimension",
|
624 |
+
)
|
625 |
+
parser.add_argument(
|
626 |
+
"--encoder-ffn-embed-dim",
|
627 |
+
type=int,
|
628 |
+
metavar="N",
|
629 |
+
help="encoder embedding dimension for FFN",
|
630 |
+
)
|
631 |
+
parser.add_argument(
|
632 |
+
"--encoder-attention-heads",
|
633 |
+
type=int,
|
634 |
+
metavar="N",
|
635 |
+
help="num encoder attention heads",
|
636 |
+
)
|
637 |
+
parser.add_argument(
|
638 |
+
"--decoder-embed-dim",
|
639 |
+
type=int,
|
640 |
+
metavar="N",
|
641 |
+
help="decoder embedding dimension",
|
642 |
+
)
|
643 |
+
parser.add_argument(
|
644 |
+
"--decoder-ffn-embed-dim",
|
645 |
+
type=int,
|
646 |
+
metavar="N",
|
647 |
+
help="decoder embedding dimension for FFN",
|
648 |
+
)
|
649 |
+
parser.add_argument(
|
650 |
+
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
|
651 |
+
)
|
652 |
+
parser.add_argument(
|
653 |
+
"--decoder-attention-heads",
|
654 |
+
type=int,
|
655 |
+
metavar="N",
|
656 |
+
help="num decoder attention heads",
|
657 |
+
)
|
658 |
+
parser.add_argument(
|
659 |
+
"--layernorm-embedding",
|
660 |
+
action="store_true",
|
661 |
+
help="add layernorm to embedding",
|
662 |
+
)
|
663 |
+
parser.add_argument(
|
664 |
+
"--no-scale-embedding",
|
665 |
+
action="store_true",
|
666 |
+
help="if True, dont scale embeddings",
|
667 |
+
)
|
668 |
+
# non-standard transformer parameters
|
669 |
+
parser.add_argument(
|
670 |
+
"--speech-encoder-layers",
|
671 |
+
type=int,
|
672 |
+
metavar="N",
|
673 |
+
help="num speech encoder layers",
|
674 |
+
)
|
675 |
+
parser.add_argument(
|
676 |
+
"--text-encoder-layers",
|
677 |
+
type=int,
|
678 |
+
metavar="N",
|
679 |
+
help="num text encoder layers",
|
680 |
+
)
|
681 |
+
parser.add_argument(
|
682 |
+
"--encoder-shared-layers",
|
683 |
+
type=int,
|
684 |
+
metavar="N",
|
685 |
+
help="num shared encoder layers",
|
686 |
+
)
|
687 |
+
parser.add_argument(
|
688 |
+
"--encoder-shared-layer-level",
|
689 |
+
type=int,
|
690 |
+
metavar="N",
|
691 |
+
default=0,
|
692 |
+
choices=[0, 1, 2],
|
693 |
+
help="share layer level 0: all share 1: all share with separate model 2: share weight but not bias and layernorm",
|
694 |
+
)
|
695 |
+
|
696 |
+
parser.add_argument(
|
697 |
+
"--decoder-shared-layer-level",
|
698 |
+
default=0,
|
699 |
+
choices=[0, 1, 2],
|
700 |
+
type=int,
|
701 |
+
metavar="N",
|
702 |
+
help="0: share everything; 1: share everything with different model 2: no share layer_norm and bias",
|
703 |
+
)
|
704 |
+
###
|
705 |
+
parser.add_argument(
|
706 |
+
"--text-input-cost-ratio",
|
707 |
+
type=float,
|
708 |
+
default=1.0,
|
709 |
+
metavar="V",
|
710 |
+
help="text input cost ratio relative to speech input cost",
|
711 |
+
)
|
712 |
+
parser.add_argument(
|
713 |
+
"--init-scale",
|
714 |
+
type=float,
|
715 |
+
default=1.0,
|
716 |
+
metavar="V",
|
717 |
+
help="scale the initial weight by given factor",
|
718 |
+
)
|
719 |
+
parser.add_argument(
|
720 |
+
"--enc-grad-mult",
|
721 |
+
type=float,
|
722 |
+
metavar="V",
|
723 |
+
default=1.0,
|
724 |
+
help="multiply enc1 and enc2 gradient by V",
|
725 |
+
)
|
726 |
+
parser.add_argument(
|
727 |
+
"--enc2-along-grad-mult",
|
728 |
+
type=float,
|
729 |
+
metavar="V",
|
730 |
+
default=1.0,
|
731 |
+
help="multiply enc2 gradient by V if only enc2 is used",
|
732 |
+
)
|
733 |
+
parser.add_argument(
|
734 |
+
"--load-pretrain-encoder",
|
735 |
+
type=str,
|
736 |
+
default="",
|
737 |
+
metavar="EXPR",
|
738 |
+
help=""" path to the pretrained encoder """,
|
739 |
+
)
|
740 |
+
parser.add_argument(
|
741 |
+
"--load-pretrain-speech-encoder",
|
742 |
+
type=str,
|
743 |
+
default="",
|
744 |
+
metavar="EXPR",
|
745 |
+
help=""" path to the pretrained speech encoder """,
|
746 |
+
)
|
747 |
+
parser.add_argument(
|
748 |
+
"--load-pretrain-text-encoder",
|
749 |
+
type=str,
|
750 |
+
default="",
|
751 |
+
metavar="EXPR",
|
752 |
+
help=""" path to the pretrained text encoder """,
|
753 |
+
)
|
754 |
+
parser.add_argument(
|
755 |
+
"--load-pretrain-text-encoder-last",
|
756 |
+
type=str,
|
757 |
+
default="",
|
758 |
+
metavar="EXPR",
|
759 |
+
help=""" path to the pretrained text encoder """,
|
760 |
+
)
|
761 |
+
parser.add_argument(
|
762 |
+
"--load-pretrain-decoder",
|
763 |
+
type=str,
|
764 |
+
metavar="EXPR",
|
765 |
+
default="",
|
766 |
+
help=""" path to the pretrained encoder """,
|
767 |
+
)
|
768 |
+
parser.add_argument(
|
769 |
+
"--add-speech-eos",
|
770 |
+
action="store_true",
|
771 |
+
help="add eos token at the end of input feature",
|
772 |
+
)
|
773 |
+
parser.add_argument(
|
774 |
+
"--speech-encoder-adapter-type",
|
775 |
+
type=str,
|
776 |
+
metavar="EXPR",
|
777 |
+
default="None",
|
778 |
+
choices=["None", "Linear", "MLP"],
|
779 |
+
help="add speech encoder adapter",
|
780 |
+
)
|
781 |
+
|
782 |
+
@classmethod
|
783 |
+
def build_encoder(cls, args, task):
|
784 |
+
spch_encoder = DualInputEncoder.build_spch_encoder(args)
|
785 |
+
text_encoder = DualInputEncoder.build_text_encoder(
|
786 |
+
args, task.src_dict, spch_encoder
|
787 |
+
)
|
788 |
+
cross_attentive_loss_before_last_layer = (
|
789 |
+
0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
|
790 |
+
)
|
791 |
+
encoder = DualInputEncoder(
|
792 |
+
args,
|
793 |
+
spch_encoder,
|
794 |
+
text_encoder,
|
795 |
+
task.src_dict,
|
796 |
+
cross_attentive_loss_before_last_layer,
|
797 |
+
)
|
798 |
+
if args.init_scale != 1.0:
|
799 |
+
with torch.no_grad():
|
800 |
+
for param in encoder.parameters():
|
801 |
+
param.data.mul_(args.init_scale)
|
802 |
+
if args.load_pretrain_text_encoder != "":
|
803 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
804 |
+
text_encoder, args.load_pretrain_text_encoder
|
805 |
+
)
|
806 |
+
if args.load_pretrain_speech_encoder != "":
|
807 |
+
if hasattr(spch_encoder, "encoder"):
|
808 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
809 |
+
spch_encoder.encoder, args.load_pretrain_speech_encoder
|
810 |
+
)
|
811 |
+
else:
|
812 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
813 |
+
spch_encoder, args.load_pretrain_speech_encoder
|
814 |
+
)
|
815 |
+
if (
|
816 |
+
args.load_pretrain_text_encoder_last != ""
|
817 |
+
): # if share encoder, speech encoder parameters will be used.
|
818 |
+
# It provides a chance to use pre-trained mt encoder instead
|
819 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
820 |
+
text_encoder, args.load_pretrain_text_encoder_last
|
821 |
+
)
|
822 |
+
|
823 |
+
if args.load_pretrain_encoder != "":
|
824 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
825 |
+
encoder, args.load_pretrain_encoder
|
826 |
+
)
|
827 |
+
return encoder
|
828 |
+
|
829 |
+
@classmethod
|
830 |
+
def build_decoder(cls, args, task):
|
831 |
+
dec_cfg = {
|
832 |
+
"decoder_layerdrop": args.decoder_layerdrop,
|
833 |
+
"share_decoder_input_output_embed": args.share_decoder_input_output_embed,
|
834 |
+
"decoder_embed_dim": args.decoder_embed_dim,
|
835 |
+
"max_target_positions": args.max_target_positions,
|
836 |
+
"dropout": args.dropout,
|
837 |
+
"encoder_learned_pos": args.encoder_learned_pos,
|
838 |
+
"decoder_learned_pos": args.decoder_learned_pos,
|
839 |
+
"layernorm_embedding": args.layernorm_embedding,
|
840 |
+
"decoder_normalize_before": args.decoder_normalize_before,
|
841 |
+
"activation_dropout": args.activation_dropout,
|
842 |
+
"attention_dropout": args.attention_dropout,
|
843 |
+
"decoder_ffn_embed_dim": args.decoder_ffn_embed_dim,
|
844 |
+
"decoder_layers": args.decoder_layers,
|
845 |
+
"decoder_attention_heads": args.decoder_attention_heads,
|
846 |
+
"decoder_output_dim": args.decoder_embed_dim,
|
847 |
+
"no_scale_embedding": args.no_scale_embedding,
|
848 |
+
"adaptive_input": args.adaptive_input,
|
849 |
+
"quant_noise_pq": args.quant_noise_pq,
|
850 |
+
"adaptive_softmax_cutoff": args.adaptive_softmax_cutoff,
|
851 |
+
"tie_adaptive_weights": args.tie_adaptive_weights,
|
852 |
+
"no_token_positional_embeddings": args.no_token_positional_embeddings,
|
853 |
+
"encoder": {"embed_dim":args.encoder_embed_dim}
|
854 |
+
}
|
855 |
+
dec_cfg = namedtuple("args", dec_cfg.keys())(*dec_cfg.values())
|
856 |
+
dec_emb = nn.Embedding(
|
857 |
+
len(task.target_dictionary),
|
858 |
+
args.decoder_embed_dim,
|
859 |
+
task.target_dictionary.pad(),
|
860 |
+
)
|
861 |
+
compute_cross_attentive_loss = (
|
862 |
+
True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
|
863 |
+
)
|
864 |
+
cross_attentive_loss_without_norm = getattr(
|
865 |
+
args, "attentive_cost_without_normalize", False
|
866 |
+
)
|
867 |
+
cross_attentive_loss_reverse = (
|
868 |
+
False # getattr(args, "attentive_cost_reverse", False)
|
869 |
+
)
|
870 |
+
|
871 |
+
text_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb)
|
872 |
+
spch_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb)
|
873 |
+
spch_decoder = TransformerMultiInputDecoder.share_spchdecoder(
|
874 |
+
args, text_decoder, spch_decoder
|
875 |
+
)
|
876 |
+
decoder = TransformerMultiInputDecoder(
|
877 |
+
dictionary=task.target_dictionary,
|
878 |
+
spch_decoder=spch_decoder,
|
879 |
+
text_decoder=text_decoder,
|
880 |
+
compute_cross_attentive_loss=compute_cross_attentive_loss,
|
881 |
+
cross_attentive_loss_with_norm=True
|
882 |
+
if not cross_attentive_loss_without_norm
|
883 |
+
else False,
|
884 |
+
cross_attentive_loss_reverse=cross_attentive_loss_reverse,
|
885 |
+
)
|
886 |
+
if args.init_scale != 1.0:
|
887 |
+
with torch.no_grad():
|
888 |
+
for param in decoder.parameters():
|
889 |
+
param.data.mul_(args.init_scale)
|
890 |
+
if args.load_pretrain_decoder != "":
|
891 |
+
try:
|
892 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
893 |
+
decoder, args.load_pretrain_decoder
|
894 |
+
)
|
895 |
+
except RuntimeError:
|
896 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
897 |
+
decoder.text_decoder, args.load_pretrain_decoder
|
898 |
+
)
|
899 |
+
if args.decoder_shared_layer_level > 0:
|
900 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
901 |
+
decoder.spch_decoder, args.load_pretrain_decoder
|
902 |
+
)
|
903 |
+
|
904 |
+
return decoder
|
905 |
+
|
906 |
+
@classmethod
|
907 |
+
def build_model(cls, args, task):
|
908 |
+
"""Build a new model instance."""
|
909 |
+
# make sure that all args are properly defaulted
|
910 |
+
# (in case there are any new ones)
|
911 |
+
dualinputs2ttransformer_base(args)
|
912 |
+
|
913 |
+
encoder = cls.build_encoder(args, task)
|
914 |
+
decoder = cls.build_decoder(args, task)
|
915 |
+
return cls(encoder, decoder)
|
916 |
+
|
917 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
918 |
+
# net_output['encoder_out'] is a (B, T, D) tensor
|
919 |
+
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
|
920 |
+
lprobs.batch_first = True
|
921 |
+
return lprobs
|
922 |
+
|
923 |
+
def set_num_updates(self, num_updates):
|
924 |
+
"""Set the number of parameters updates."""
|
925 |
+
super().set_num_updates(num_updates)
|
926 |
+
self.num_updates = num_updates
|
927 |
+
|
928 |
+
def forward(
|
929 |
+
self,
|
930 |
+
src_tokens,
|
931 |
+
src_lengths,
|
932 |
+
prev_output_tokens,
|
933 |
+
use_encoder_outputs=False,
|
934 |
+
src_txt_tokens=None,
|
935 |
+
src_txt_lengths=None,
|
936 |
+
mode="sup_speech",
|
937 |
+
**kwargs
|
938 |
+
):
|
939 |
+
"""
|
940 |
+
Run the forward pass for an encoder-decoder model.
|
941 |
+
|
942 |
+
First feed a batch of source tokens through the encoder. Then, feed the
|
943 |
+
encoder output and previous decoder outputs (i.e., teacher forcing) to
|
944 |
+
the decoder to produce the next outputs::
|
945 |
+
|
946 |
+
encoder_out = self.encoder(src_tokens, src_lengths)
|
947 |
+
return self.decoder(prev_output_tokens, encoder_out)
|
948 |
+
|
949 |
+
Args:
|
950 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
951 |
+
`(batch, src_len)`
|
952 |
+
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
|
953 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
954 |
+
`(batch, tgt_len)`, for teacher forcing
|
955 |
+
mode = 'sup_speech' or 'text'
|
956 |
+
|
957 |
+
Returns:
|
958 |
+
tuple:
|
959 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
960 |
+
- a dictionary with any model-specific outputs
|
961 |
+
"""
|
962 |
+
if mode == "text":
|
963 |
+
assert src_txt_tokens is None
|
964 |
+
src_txt_tokens = src_tokens
|
965 |
+
src_txt_lengths = src_lengths
|
966 |
+
src_tokens = None
|
967 |
+
src_lengths = None
|
968 |
+
encoder_out = self.encoder(
|
969 |
+
src_tokens,
|
970 |
+
src_lengths=src_lengths,
|
971 |
+
src_txt_tokens=src_txt_tokens,
|
972 |
+
src_txt_lengths=src_txt_lengths,
|
973 |
+
**kwargs
|
974 |
+
)
|
975 |
+
has_txt_input = True if src_txt_tokens is not None else False
|
976 |
+
decoder_out = self.decoder(
|
977 |
+
prev_output_tokens,
|
978 |
+
encoder_out=encoder_out,
|
979 |
+
has_txt_input=has_txt_input,
|
980 |
+
**kwargs
|
981 |
+
)
|
982 |
+
if use_encoder_outputs:
|
983 |
+
return decoder_out, encoder_out
|
984 |
+
return decoder_out
|
985 |
+
|
986 |
+
|
987 |
+
@register_model_architecture(
|
988 |
+
"dual_input_s2t_transformer", "dualinputs2ttransformer_base"
|
989 |
+
)
|
990 |
+
def dualinputs2ttransformer_base(args):
|
991 |
+
args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0)
|
992 |
+
# Convolutional subsampler
|
993 |
+
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
994 |
+
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
|
995 |
+
args.conv_channels = getattr(args, "conv_channels", 1024)
|
996 |
+
# Transformer
|
997 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
998 |
+
args.encoder_text_embed_dim = getattr(
|
999 |
+
args, "encoder_text_embed_dim", args.encoder_embed_dim
|
1000 |
+
)
|
1001 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
1002 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
1003 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
1004 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
1005 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
1006 |
+
|
1007 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
1008 |
+
args.decoder_ffn_embed_dim = getattr(
|
1009 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
1010 |
+
)
|
1011 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
1012 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
1013 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
1014 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
1015 |
+
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
|
1016 |
+
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
|
1017 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
1018 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
1019 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
1020 |
+
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
1021 |
+
args.share_decoder_input_output_embed = getattr(
|
1022 |
+
args, "share_decoder_input_output_embed", False
|
1023 |
+
)
|
1024 |
+
args.no_token_positional_embeddings = getattr(
|
1025 |
+
args, "no_token_positional_embeddings", False
|
1026 |
+
)
|
1027 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
1028 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
1029 |
+
args.decoder_output_dim = getattr(
|
1030 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
1031 |
+
)
|
1032 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
1033 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
1034 |
+
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
1035 |
+
|
1036 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10)
|
1037 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
|
1038 |
+
args.encoder_shared_layers = getattr(args, "encoder_shared_layers", 0)
|
1039 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
1040 |
+
|
1041 |
+
args.add_speech_eos = getattr(args, "add_speech_eos", False)
|
1042 |
+
|
1043 |
+
|
1044 |
+
@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_s")
|
1045 |
+
def dualinputs2ttransformer_s(args):
|
1046 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
|
1047 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
|
1048 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
1049 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
1050 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
1051 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 7)
|
1052 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 7)
|
1053 |
+
args.decoder_layers = getattr(args, "decoder_layers", 7)
|
1054 |
+
dualinputs2ttransformer_base(args)
|
1055 |
+
|
1056 |
+
|
1057 |
+
@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_m")
|
1058 |
+
def dualinputs2ttransformer_m(args):
|
1059 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
1060 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
|
1061 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
1062 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
1063 |
+
args.dropout = getattr(args, "dropout", 0.15)
|
1064 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10)
|
1065 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
|
1066 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
1067 |
+
dualinputs2ttransformer_base(args)
|
1068 |
+
|
1069 |
+
|
1070 |
+
@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_b")
|
1071 |
+
def dualinputs2ttransformer_b(args):
|
1072 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
1073 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
|
1074 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
1075 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
|
1076 |
+
args.dropout = getattr(args, "dropout", 0.15)
|
1077 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
|
1078 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
|
1079 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
1080 |
+
dualinputs2ttransformer_base(args)
|
1081 |
+
|
1082 |
+
|
1083 |
+
@register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_l")
|
1084 |
+
def dualinputs2ttransformer_l(args):
|
1085 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
1086 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
|
1087 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
1088 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
1089 |
+
args.dropout = getattr(args, "dropout", 0.2)
|
1090 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
|
1091 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
|
1092 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
1093 |
+
dualinputs2ttransformer_base(args)
|
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from collections import OrderedDict, namedtuple
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from fairseq import checkpoint_utils, utils
|
12 |
+
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
13 |
+
from fairseq.file_io import PathManager
|
14 |
+
from fairseq.models import register_model, register_model_architecture
|
15 |
+
from fairseq.models.speech_to_text import (
|
16 |
+
SpeechWavTransformerEncoder,
|
17 |
+
StackedSpeechWavTransformerEncoder,
|
18 |
+
TransformerDecoder,
|
19 |
+
)
|
20 |
+
from fairseq.models.transformer import TransformerEncoder
|
21 |
+
|
22 |
+
from .s2t_dualinputtransformer import (
|
23 |
+
DualInputEncoder,
|
24 |
+
DualInputS2TTransformerModel,
|
25 |
+
TransformerMultiInputDecoder,
|
26 |
+
)
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
@register_model("dual_input_wav_transformer")
|
32 |
+
class DualInputWavTransformerModel(DualInputS2TTransformerModel):
|
33 |
+
def __init__(self, encoder, decoder):
|
34 |
+
super().__init__(encoder, decoder)
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def add_args(parser):
|
38 |
+
def add_transformer_args(parser):
|
39 |
+
# We can't use TransformerModel.add_args(parser), since it defines max-source-positions which is duplicated with tasks/speech_to_text.py
|
40 |
+
# Transformer
|
41 |
+
parser.add_argument(
|
42 |
+
"--activation-fn",
|
43 |
+
type=str,
|
44 |
+
default="relu",
|
45 |
+
choices=utils.get_available_activation_fns(),
|
46 |
+
help="activation function to use",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--dropout", type=float, metavar="D", help="dropout probability"
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--attention-dropout",
|
53 |
+
type=float,
|
54 |
+
metavar="D",
|
55 |
+
help="dropout probability for attention weights",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--activation-dropout",
|
59 |
+
"--relu-dropout",
|
60 |
+
type=float,
|
61 |
+
metavar="D",
|
62 |
+
help="dropout probability after activation in FFN.",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--encoder-embed-dim",
|
66 |
+
type=int,
|
67 |
+
metavar="N",
|
68 |
+
help="encoder embedding dimension",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--encoder-ffn-embed-dim",
|
72 |
+
type=int,
|
73 |
+
metavar="N",
|
74 |
+
help="encoder embedding dimension for FFN",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--encoder-attention-heads",
|
81 |
+
type=int,
|
82 |
+
metavar="N",
|
83 |
+
help="num encoder attention heads",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--encoder-normalize-before",
|
87 |
+
action="store_true",
|
88 |
+
help="apply layernorm before each encoder block",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--decoder-embed-dim",
|
92 |
+
type=int,
|
93 |
+
metavar="N",
|
94 |
+
help="decoder embedding dimension",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--decoder-ffn-embed-dim",
|
98 |
+
type=int,
|
99 |
+
metavar="N",
|
100 |
+
help="decoder embedding dimension for FFN",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--decoder-attention-heads",
|
107 |
+
type=int,
|
108 |
+
metavar="N",
|
109 |
+
help="num decoder attention heads",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--decoder-normalize-before",
|
113 |
+
action="store_true",
|
114 |
+
help="apply layernorm before each decoder block",
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--share-decoder-input-output-embed",
|
118 |
+
action="store_true",
|
119 |
+
help="share decoder input and output embeddings",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--layernorm-embedding",
|
123 |
+
action="store_true",
|
124 |
+
help="add layernorm to embedding",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--no-scale-embedding",
|
128 |
+
action="store_true",
|
129 |
+
help="if True, dont scale embeddings",
|
130 |
+
)
|
131 |
+
|
132 |
+
parser.add_argument(
|
133 |
+
"--encoder-learned-pos",
|
134 |
+
action="store_true",
|
135 |
+
help="use learned positional embeddings",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--decoder-learned-pos",
|
139 |
+
action="store_true",
|
140 |
+
help="use learned positional embeddings",
|
141 |
+
)
|
142 |
+
|
143 |
+
add_transformer_args(parser)
|
144 |
+
SpeechWavTransformerEncoder.add_args(parser)
|
145 |
+
parser.add_argument(
|
146 |
+
"--load-pretrained-speech-text-encoder",
|
147 |
+
type=str,
|
148 |
+
default="",
|
149 |
+
metavar="EXPR",
|
150 |
+
help=""" path to the pretrained speech text encoder from SpeechTextPreTrainModel """,
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--load-pretrained-wav2vec-encoder",
|
154 |
+
type=str,
|
155 |
+
default="",
|
156 |
+
metavar="EXPR",
|
157 |
+
help=""" path to the pretrained speech text encoder from wav2vec """,
|
158 |
+
)
|
159 |
+
|
160 |
+
parser.add_argument(
|
161 |
+
"--load-pretrained-speech-text-decoder",
|
162 |
+
type=str,
|
163 |
+
default="",
|
164 |
+
metavar="EXPR",
|
165 |
+
help=""" path to the pretrained speech text decoder from SpeechTextPreTrainModel """,
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--load-pretrained-text-decoder",
|
169 |
+
type=str,
|
170 |
+
default="",
|
171 |
+
metavar="EXPR",
|
172 |
+
help=""" path to the pretrained text decoder """,
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--load-init-encoder",
|
176 |
+
type=str,
|
177 |
+
default="",
|
178 |
+
metavar="EXPR",
|
179 |
+
help=""" path to load seed encoder model """,
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--load-init-decoder",
|
183 |
+
type=str,
|
184 |
+
default="",
|
185 |
+
metavar="EXPR",
|
186 |
+
help=""" path to load seed decoder model """,
|
187 |
+
)
|
188 |
+
|
189 |
+
parser.add_argument(
|
190 |
+
"--text-input-cost-ratio",
|
191 |
+
type=float,
|
192 |
+
default=1.0,
|
193 |
+
metavar="V",
|
194 |
+
help="text input cost ratio relative to speech input cost",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--enc-grad-mult",
|
198 |
+
type=float,
|
199 |
+
metavar="V",
|
200 |
+
default=1.0,
|
201 |
+
help="multiply enc1 and enc2 gradient by V",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--enc2-along-grad-mult",
|
205 |
+
type=float,
|
206 |
+
metavar="V",
|
207 |
+
default=1.0,
|
208 |
+
help="multiply enc2 gradient by V if only enc2 is used",
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--no-strict-check-pretrain-model",
|
212 |
+
action="store_true",
|
213 |
+
help="Don't apply strict model check for the pretrained model",
|
214 |
+
)
|
215 |
+
|
216 |
+
parser.add_argument(
|
217 |
+
"--stacked-encoder",
|
218 |
+
action="store_true",
|
219 |
+
help="stack speech and text encoders",
|
220 |
+
)
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def update_transformer_encoder_cfg(cls, args, update_dict):
|
224 |
+
cfg = dict(args._get_kwargs())
|
225 |
+
for fkey in update_dict.keys():
|
226 |
+
cfg[fkey] = update_dict[fkey]
|
227 |
+
cfg.pop("_name", None) # remove keys start with _
|
228 |
+
model_args = namedtuple("args", cfg.keys())(*cfg.values())
|
229 |
+
return model_args
|
230 |
+
|
231 |
+
@classmethod
|
232 |
+
def build_text_encoder(cls, args, src_dictionary):
|
233 |
+
enc_emb = nn.Embedding(
|
234 |
+
len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad()
|
235 |
+
)
|
236 |
+
model_args = cls.update_transformer_encoder_cfg(
|
237 |
+
args,
|
238 |
+
{
|
239 |
+
"encoder_layers": args.text_encoder_layers,
|
240 |
+
"max_source_positions": args.max_positions_text,
|
241 |
+
},
|
242 |
+
)
|
243 |
+
text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
|
244 |
+
return text_encoder
|
245 |
+
|
246 |
+
@classmethod
|
247 |
+
def build_speech_encoder(cls, args):
|
248 |
+
model_args = cls.update_transformer_encoder_cfg(
|
249 |
+
args, {"encoder_layers": args.speech_encoder_layers}
|
250 |
+
)
|
251 |
+
speech_encoder = SpeechWavTransformerEncoder(model_args)
|
252 |
+
return speech_encoder
|
253 |
+
|
254 |
+
@classmethod
|
255 |
+
def check_args(cls, condition, is_strict, msg):
|
256 |
+
if condition:
|
257 |
+
return
|
258 |
+
if is_strict:
|
259 |
+
raise ValueError(msg)
|
260 |
+
logger.warn(msg)
|
261 |
+
|
262 |
+
@classmethod
|
263 |
+
def build_encoder(cls, args, task):
|
264 |
+
# text_encoder = cls.build_text_encoder(args, task.source_dictionary )
|
265 |
+
text_encoder = cls.build_text_encoder(args, task.src_dict)
|
266 |
+
speech_encoder = cls.build_speech_encoder(args)
|
267 |
+
if args.load_pretrained_wav2vec_encoder:
|
268 |
+
component_pairs = (
|
269 |
+
("feature_extractor", speech_encoder.subsample),
|
270 |
+
("post_extract_proj", speech_encoder.feat_proj),
|
271 |
+
("layer_norm", speech_encoder.feat_layer_norm),
|
272 |
+
("encoder.pos_conv", speech_encoder.embed_positions),
|
273 |
+
("encoder.layers", speech_encoder.layers),
|
274 |
+
("encoder.layer_norm", speech_encoder.layer_norm),
|
275 |
+
("mask_emb", speech_encoder.mask_emb),
|
276 |
+
)
|
277 |
+
state = cls.load_pretrained_speech_text_components(
|
278 |
+
args.load_pretrained_wav2vec_encoder, component_pairs
|
279 |
+
)
|
280 |
+
cls.check_args(
|
281 |
+
args.encoder_normalize_before
|
282 |
+
== state["cfg"]["model"]["layer_norm_first"],
|
283 |
+
not args.no_strict_check_pretrain_model,
|
284 |
+
f"encoder_normalize_before {args.encoder_normalize_before} doesn't match with the pretrained model",
|
285 |
+
)
|
286 |
+
cls.check_args(
|
287 |
+
args.activation_fn == state["cfg"]["model"]["activation_fn"],
|
288 |
+
not args.no_strict_check_pretrain_model,
|
289 |
+
f"activation_fn {args.activation_fn} doesn't match with the pretrained model",
|
290 |
+
)
|
291 |
+
|
292 |
+
if getattr(args, "stacked_encoder", False):
|
293 |
+
if args.encoder_shared_text_layers_from_begin > 0:
|
294 |
+
raise ValueError(
|
295 |
+
"We can not stack encoders and share encoders at the same time!"
|
296 |
+
)
|
297 |
+
speech_encoder = StackedSpeechWavTransformerEncoder(
|
298 |
+
speech_encoder, text_encoder.layers, text_encoder.layer_norm
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
cls.share_speech_text_encoder(
|
302 |
+
speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin
|
303 |
+
)
|
304 |
+
|
305 |
+
cross_attentive_loss_before_last_layer = (
|
306 |
+
0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
|
307 |
+
)
|
308 |
+
encoder = DualInputEncoder(
|
309 |
+
args,
|
310 |
+
speech_encoder,
|
311 |
+
text_encoder,
|
312 |
+
task.src_dict,
|
313 |
+
cross_attentive_loss_before_last_layer,
|
314 |
+
)
|
315 |
+
if args.load_pretrained_speech_text_encoder:
|
316 |
+
component_pairs = (
|
317 |
+
("encoder.sup_s2s_speech_encoder", encoder.spch_encoder),
|
318 |
+
("encoder.text_encoder", encoder.text_encoder),
|
319 |
+
)
|
320 |
+
cls.load_pretrained_speech_text_components(
|
321 |
+
args.load_pretrained_speech_text_encoder, component_pairs
|
322 |
+
)
|
323 |
+
if getattr(args, "load_init_encoder", "") != "":
|
324 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
325 |
+
encoder, args.load_init_encoder
|
326 |
+
)
|
327 |
+
return encoder
|
328 |
+
|
329 |
+
@classmethod
|
330 |
+
def build_text_decoder(cls, args, tgt_dictionary, dec_emb_share=None):
|
331 |
+
dec_emb = (
|
332 |
+
nn.Embedding(
|
333 |
+
len(tgt_dictionary), args.decoder_embed_dim, tgt_dictionary.pad()
|
334 |
+
)
|
335 |
+
if dec_emb_share is None
|
336 |
+
else dec_emb_share
|
337 |
+
)
|
338 |
+
text_decoder = TransformerDecoder(args, tgt_dictionary, dec_emb)
|
339 |
+
return text_decoder
|
340 |
+
|
341 |
+
@classmethod
|
342 |
+
def build_decoder(cls, args, task):
|
343 |
+
text_decoder = cls.build_text_decoder(args, task.target_dictionary)
|
344 |
+
compute_cross_attentive_loss = (
|
345 |
+
True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
|
346 |
+
)
|
347 |
+
cross_attentive_loss_without_norm = getattr(
|
348 |
+
args, "attentive_cost_without_normalize", False
|
349 |
+
)
|
350 |
+
cross_attentive_loss_reverse = (
|
351 |
+
False # getattr(args, "attentive_cost_reverse", False)
|
352 |
+
)
|
353 |
+
if getattr(args, "load_pretrained_text_decoder", "") != "":
|
354 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
355 |
+
text_decoder, args.load_pretrained_text_decoder
|
356 |
+
)
|
357 |
+
|
358 |
+
if args.load_pretrained_speech_text_decoder:
|
359 |
+
component_pairs = (("decoder.text_decoder", text_decoder),)
|
360 |
+
cls.load_pretrained_speech_text_components(
|
361 |
+
args.load_pretrained_speech_text_decoder, component_pairs
|
362 |
+
)
|
363 |
+
|
364 |
+
decoder = TransformerMultiInputDecoder(
|
365 |
+
dictionary=task.target_dictionary,
|
366 |
+
spch_decoder=text_decoder,
|
367 |
+
text_decoder=text_decoder,
|
368 |
+
compute_cross_attentive_loss=compute_cross_attentive_loss,
|
369 |
+
cross_attentive_loss_with_norm=True
|
370 |
+
if not cross_attentive_loss_without_norm
|
371 |
+
else False,
|
372 |
+
cross_attentive_loss_reverse=cross_attentive_loss_reverse,
|
373 |
+
)
|
374 |
+
if getattr(args, "load_init_decoder", "") != "":
|
375 |
+
checkpoint_utils.load_pretrained_component_from_model(
|
376 |
+
decoder, args.load_init_decoder
|
377 |
+
)
|
378 |
+
return decoder
|
379 |
+
|
380 |
+
@classmethod
|
381 |
+
def load_pretrained_speech_text_components(cls, checkpoint, component_pairs):
|
382 |
+
if not PathManager.exists(checkpoint):
|
383 |
+
raise IOError("Model file not found: {}".format(checkpoint))
|
384 |
+
state = load_checkpoint_to_cpu(checkpoint)
|
385 |
+
for component_type, component in component_pairs:
|
386 |
+
if isinstance(component, nn.parameter.Parameter):
|
387 |
+
component.data.copy_(state["model"][component_type])
|
388 |
+
else:
|
389 |
+
component_state_dict = OrderedDict()
|
390 |
+
for key in state["model"].keys():
|
391 |
+
if key.startswith(component_type):
|
392 |
+
component_subkey = key[len(component_type) + 1 :]
|
393 |
+
component_state_dict[component_subkey] = state["model"][key]
|
394 |
+
component.load_state_dict(component_state_dict, strict=True)
|
395 |
+
return state
|
396 |
+
|
397 |
+
@classmethod
|
398 |
+
def share_speech_text_encoder(
|
399 |
+
cls, speech_encoder, text_encoder, shared_layers_from_begin
|
400 |
+
):
|
401 |
+
if shared_layers_from_begin > 0:
|
402 |
+
num_text_encoder_layers = len(text_encoder.layers)
|
403 |
+
assert len(speech_encoder.layers) >= shared_layers_from_begin
|
404 |
+
assert num_text_encoder_layers >= shared_layers_from_begin
|
405 |
+
assert len(speech_encoder.layers) >= num_text_encoder_layers
|
406 |
+
for i, ly in enumerate(
|
407 |
+
speech_encoder.layers[
|
408 |
+
-num_text_encoder_layers : -num_text_encoder_layers
|
409 |
+
+ shared_layers_from_begin
|
410 |
+
]
|
411 |
+
):
|
412 |
+
assert isinstance(text_encoder.layers[i], type(ly))
|
413 |
+
text_encoder.layers[i] = ly
|
414 |
+
|
415 |
+
|
416 |
+
@register_model_architecture(
|
417 |
+
"dual_input_wav_transformer", "dualinputs2twavtransformer_base"
|
418 |
+
)
|
419 |
+
def dualinputs2twavtransformer_base(args):
|
420 |
+
# speech masking
|
421 |
+
args.dropout_input = getattr(args, "dropout_input", 0)
|
422 |
+
args.dropout_features = getattr(args, "dropout_features", 0)
|
423 |
+
args.speech_mask_length = getattr(args, "speech_mask_length", 10)
|
424 |
+
args.speech_mask_prob = getattr(args, "speech_mask_prob", 0.65)
|
425 |
+
args.speech_mask_selection = getattr(args, "speech_mask_selection", "static")
|
426 |
+
args.speech_mask_other = getattr(args, "speech_mask_other", 0)
|
427 |
+
args.speech_mask_min_space = getattr(args, "speech_mask_min_space", 1)
|
428 |
+
args.speech_no_mask_overlap = getattr(args, "speech_no_mask_overlap", False)
|
429 |
+
args.speech_conv_bias = getattr(args, "speech_conv_bias", False)
|
430 |
+
args.speech_extractor_mode = getattr(args, "speech_extractor_mode", "default")
|
431 |
+
args.no_strict_check_pretrain_model = getattr(
|
432 |
+
args, "no_strict_check_pretrain_model", False
|
433 |
+
)
|
434 |
+
|
435 |
+
args.speech_mask_channel_length = getattr(args, "speech_mask_channel_length", 10)
|
436 |
+
args.speech_mask_channel_prob = getattr(args, "speech_mask_channel_prob", 0.0)
|
437 |
+
args.speech_mask_channel_selection = getattr(
|
438 |
+
args, "speech_mask_channel_selection", "static"
|
439 |
+
)
|
440 |
+
args.speech_mask_channel_other = getattr(args, "speech_mask_channel_other", 0)
|
441 |
+
args.speech_mask_channel_min_space = getattr(
|
442 |
+
args, "speech_mask_channel_min_space", 1
|
443 |
+
)
|
444 |
+
args.speech_no_mask_channel_overlap = getattr(
|
445 |
+
args, "speech_no_mask_channel_overlap", False
|
446 |
+
)
|
447 |
+
args.no_scale_feature = getattr(args, "", False)
|
448 |
+
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0) # 0.1
|
449 |
+
|
450 |
+
# Transformer
|
451 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
452 |
+
args.encoder_ffn_embed_dim = getattr(
|
453 |
+
args, "encoder_ffn_embed_dim", args.encoder_embed_dim * 4
|
454 |
+
)
|
455 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
456 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
457 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
|
458 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
459 |
+
|
460 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
461 |
+
args.decoder_ffn_embed_dim = getattr(
|
462 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
463 |
+
)
|
464 |
+
args.decoder_attention_heads = getattr(
|
465 |
+
args, "decoder_attention_heads", args.encoder_attention_heads
|
466 |
+
)
|
467 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
468 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
469 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
470 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0)
|
471 |
+
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
|
472 |
+
args.activation_fn = getattr(args, "activation_fn", "relu") # gelu?
|
473 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
474 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
475 |
+
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
476 |
+
args.share_decoder_input_output_embed = getattr(
|
477 |
+
args, "share_decoder_input_output_embed", False
|
478 |
+
)
|
479 |
+
args.no_token_positional_embeddings = getattr(
|
480 |
+
args, "no_token_positional_embeddings", False
|
481 |
+
)
|
482 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
483 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
484 |
+
args.decoder_output_dim = getattr(
|
485 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
486 |
+
)
|
487 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
488 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
489 |
+
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
490 |
+
|
491 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
|
492 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
|
493 |
+
args.encoder_shared_text_layers_from_begin = getattr(
|
494 |
+
args, "encoder_shared_text_layers_from_begin", 6
|
495 |
+
)
|
496 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
497 |
+
|
498 |
+
|
499 |
+
@register_model_architecture(
|
500 |
+
"dual_input_wav_transformer", "dualinputs2twavtransformer_base_stack"
|
501 |
+
)
|
502 |
+
def dualinputs2twavtransformer_base_stack(args):
|
503 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
|
504 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
|
505 |
+
args.encoder_shared_text_layers_from_begin = getattr(
|
506 |
+
args, "encoder_shared_text_layers_from_begin", 0
|
507 |
+
)
|
508 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
509 |
+
args.stacked_encoder = getattr(args, "stacked_encoder", True)
|
510 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
511 |
+
dualinputs2twavtransformer_base(args)
|
512 |
+
|
513 |
+
|
514 |
+
@register_model_architecture(
|
515 |
+
"dual_input_wav_transformer", "dualinputs2twavtransformer_large"
|
516 |
+
)
|
517 |
+
def dualinputs2twavtransformer_large(args):
|
518 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
519 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
520 |
+
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 24)
|
521 |
+
args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
|
522 |
+
args.encoder_shared_text_layers_from_begin = getattr(
|
523 |
+
args, "encoder_shared_text_layers_from_begin", 12
|
524 |
+
)
|
525 |
+
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
526 |
+
dualinputs2twavtransformer_base(args)
|
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import copy
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
from fairseq import checkpoint_utils
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.data.data_utils import lengths_to_padding_mask
|
12 |
+
from fairseq.models import (
|
13 |
+
register_model,
|
14 |
+
register_model_architecture,
|
15 |
+
FairseqEncoder,
|
16 |
+
)
|
17 |
+
from fairseq.models.speech_to_text import Wav2VecEncoderWithAdaptor
|
18 |
+
from fairseq.models.speech_to_text.xm_transformer import (
|
19 |
+
set_default_adaptor_args,
|
20 |
+
set_default_w2v_encoder_args,
|
21 |
+
need_finetuning
|
22 |
+
)
|
23 |
+
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
|
24 |
+
from fairseq.models.wav2vec import TransformerSentenceEncoderLayer
|
25 |
+
from fairseq.utils import safe_hasattr
|
26 |
+
|
27 |
+
from .s2t_dualinputtransformer import (
|
28 |
+
DualInputS2TTransformerModel,
|
29 |
+
TransformerMultiInputDecoder,
|
30 |
+
DualInputEncoder,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class TransformerSentenceEncoderLayerStd(TransformerSentenceEncoderLayer):
|
35 |
+
def __init__(self, sent_enc_layer):
|
36 |
+
super(TransformerSentenceEncoderLayer, self).__init__()
|
37 |
+
self.embedding_dim = sent_enc_layer.embedding_dim
|
38 |
+
self.dropout = sent_enc_layer.dropout
|
39 |
+
self.activation_dropout = sent_enc_layer.activation_dropout
|
40 |
+
|
41 |
+
# Initialize blocks
|
42 |
+
self.activation_fn = sent_enc_layer.activation_fn
|
43 |
+
self.self_attn = sent_enc_layer.self_attn
|
44 |
+
|
45 |
+
self.dropout1 = sent_enc_layer.dropout1
|
46 |
+
self.dropout2 = sent_enc_layer.dropout2
|
47 |
+
self.dropout3 = sent_enc_layer.dropout3
|
48 |
+
|
49 |
+
self.layer_norm_first = sent_enc_layer.layer_norm_first
|
50 |
+
|
51 |
+
# layer norm associated with the self attention layer
|
52 |
+
self.self_attn_layer_norm = sent_enc_layer.self_attn_layer_norm
|
53 |
+
self.fc1 = sent_enc_layer.fc1
|
54 |
+
self.fc2 = sent_enc_layer.fc2
|
55 |
+
|
56 |
+
# layer norm associated with the position wise feed-forward NN
|
57 |
+
self.final_layer_norm = sent_enc_layer.final_layer_norm
|
58 |
+
|
59 |
+
def forward(
|
60 |
+
self,
|
61 |
+
x,
|
62 |
+
self_attn_mask=None,
|
63 |
+
self_attn_padding_mask=None,
|
64 |
+
need_weights=None,
|
65 |
+
att_args=None,
|
66 |
+
):
|
67 |
+
x, attn = super().forward(
|
68 |
+
x, self_attn_mask, self_attn_padding_mask, need_weights, att_args
|
69 |
+
)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
# TODO retire SharedEncoder
|
74 |
+
class SharedEncoder(FairseqEncoder):
|
75 |
+
def __init__(self, wav2vec_enc, mbart_enc, adaptor, shared_layers):
|
76 |
+
super().__init__(None)
|
77 |
+
self.w2v_encoder = wav2vec_enc
|
78 |
+
self.shared_layers = self.w2v_encoder.w2v_model.encoder.layers[-shared_layers:]
|
79 |
+
self.w2v_encoder.w2v_model.encoder.layers = (
|
80 |
+
self.w2v_encoder.w2v_model.encoder.layers[:-shared_layers]
|
81 |
+
)
|
82 |
+
self.adaptor = adaptor
|
83 |
+
if self.shared_layers[-1].layer_norm_first:
|
84 |
+
self.final_layer_norm = mbart_enc.layer_norm
|
85 |
+
else:
|
86 |
+
mbart_enc.layer_norm = None
|
87 |
+
self.final_layer_norm = None
|
88 |
+
shared_layer_from = len(mbart_enc.layers) - shared_layers
|
89 |
+
if shared_layer_from < 0:
|
90 |
+
shared_layer_from = 0
|
91 |
+
for layer_id, layer in enumerate(self.shared_layers):
|
92 |
+
mbart_enc.layers[
|
93 |
+
shared_layer_from + layer_id
|
94 |
+
] = TransformerSentenceEncoderLayerStd(layer)
|
95 |
+
|
96 |
+
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
97 |
+
padding_mask = lengths_to_padding_mask(src_lengths)
|
98 |
+
if not padding_mask.any():
|
99 |
+
padding_mask = None
|
100 |
+
|
101 |
+
out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True)
|
102 |
+
x = out["encoder_out"]
|
103 |
+
enc_padding_mask = None
|
104 |
+
if out["encoder_padding_mask"] is not None:
|
105 |
+
enc_padding_mask = out["encoder_padding_mask"].transpose(
|
106 |
+
0, 1
|
107 |
+
) # T X B --> B X T
|
108 |
+
|
109 |
+
x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
|
110 |
+
for layer in self.shared_layers:
|
111 |
+
x, _ = layer(x, enc_padding_mask)
|
112 |
+
if self.final_layer_norm is not None:
|
113 |
+
x = self.final_layer_norm(x)
|
114 |
+
|
115 |
+
return {
|
116 |
+
"encoder_out": [x], # T x B x C
|
117 |
+
"encoder_padding_mask": [enc_padding_mask]
|
118 |
+
if enc_padding_mask is not None
|
119 |
+
else [], # B x T
|
120 |
+
"encoder_embedding": [], # B x T x C
|
121 |
+
"encoder_states": [], # List[T x B x C]
|
122 |
+
"src_tokens": [],
|
123 |
+
"src_lengths": [],
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
class StackedWav2VecEncoderWithAdaptor(FairseqEncoder):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
wav2vec_enc,
|
131 |
+
mbart_enc_layers,
|
132 |
+
mbart_layer_norm,
|
133 |
+
adaptor,
|
134 |
+
drop_w2v_layers=0,
|
135 |
+
):
|
136 |
+
super().__init__(None)
|
137 |
+
self.w2v_encoder = wav2vec_enc
|
138 |
+
self.adaptor = adaptor
|
139 |
+
self.mbart_encoder_layers = mbart_enc_layers
|
140 |
+
self.final_layer_norm = mbart_layer_norm
|
141 |
+
if drop_w2v_layers > 0:
|
142 |
+
self.w2v_encoder.w2v_model.encoder.layers = (
|
143 |
+
self.w2v_encoder.w2v_model.encoder.layers[:-drop_w2v_layers]
|
144 |
+
)
|
145 |
+
|
146 |
+
def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs):
|
147 |
+
padding_mask = lengths_to_padding_mask(src_lengths)
|
148 |
+
if not padding_mask.any():
|
149 |
+
padding_mask = None
|
150 |
+
|
151 |
+
out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True)
|
152 |
+
x = out["encoder_out"]
|
153 |
+
enc_padding_mask = None
|
154 |
+
if out["padding_mask"] is not None:
|
155 |
+
enc_padding_mask = out["padding_mask"] # B X T
|
156 |
+
|
157 |
+
x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
|
158 |
+
encoder_states = []
|
159 |
+
for layer in self.mbart_encoder_layers:
|
160 |
+
x = layer(x, enc_padding_mask)
|
161 |
+
if return_all_hiddens:
|
162 |
+
encoder_states.append(x)
|
163 |
+
if self.final_layer_norm is not None:
|
164 |
+
x = self.final_layer_norm(x)
|
165 |
+
|
166 |
+
return {
|
167 |
+
"encoder_out": [x], # T x B x C
|
168 |
+
"encoder_padding_mask": [enc_padding_mask]
|
169 |
+
if enc_padding_mask is not None
|
170 |
+
else [], # B x T
|
171 |
+
"encoder_embedding": [], # B x T x C
|
172 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
173 |
+
"src_tokens": [],
|
174 |
+
"src_lengths": [],
|
175 |
+
}
|
176 |
+
|
177 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
178 |
+
new_encoder_out = (
|
179 |
+
[]
|
180 |
+
if len(encoder_out["encoder_out"]) == 0
|
181 |
+
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
|
182 |
+
)
|
183 |
+
|
184 |
+
new_encoder_padding_mask = (
|
185 |
+
[]
|
186 |
+
if len(encoder_out["encoder_padding_mask"]) == 0
|
187 |
+
else [
|
188 |
+
x.index_select(0, new_order)
|
189 |
+
for x in encoder_out["encoder_padding_mask"]
|
190 |
+
]
|
191 |
+
)
|
192 |
+
|
193 |
+
new_encoder_embedding = (
|
194 |
+
[]
|
195 |
+
if len(encoder_out["encoder_embedding"]) == 0
|
196 |
+
else [
|
197 |
+
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
|
198 |
+
]
|
199 |
+
)
|
200 |
+
|
201 |
+
encoder_states = encoder_out["encoder_states"]
|
202 |
+
if len(encoder_states) > 0:
|
203 |
+
for idx, state in enumerate(encoder_states):
|
204 |
+
encoder_states[idx] = state.index_select(1, new_order)
|
205 |
+
|
206 |
+
return {
|
207 |
+
"encoder_out": new_encoder_out, # T x B x C
|
208 |
+
"encoder_padding_mask": new_encoder_padding_mask, # B x T
|
209 |
+
"encoder_embedding": new_encoder_embedding, # B x T x C
|
210 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
211 |
+
"src_tokens": [], # B x T
|
212 |
+
"src_lengths": [], # B x 1
|
213 |
+
}
|
214 |
+
|
215 |
+
|
216 |
+
# Note:
|
217 |
+
# dual input transformer:
|
218 |
+
# encoder: wav2vec for speech + mbart encoder for text
|
219 |
+
# decoder: mbart decoder for text
|
220 |
+
@register_model("dual_input_xm_transformer")
|
221 |
+
class DualInputXMTransformerModel(DualInputS2TTransformerModel):
|
222 |
+
def __init__(self, encoder, decoder):
|
223 |
+
super().__init__(encoder, decoder)
|
224 |
+
|
225 |
+
@staticmethod
|
226 |
+
def add_args(parser):
|
227 |
+
"""Add model-specific arguments to the parser."""
|
228 |
+
# wav2vec encoder
|
229 |
+
Wav2VecEncoderWithAdaptor.add_args(parser)
|
230 |
+
# add_decoder_args(parser)
|
231 |
+
# mbart Transformer
|
232 |
+
parser.add_argument(
|
233 |
+
"--activation-fn",
|
234 |
+
type=str,
|
235 |
+
default="relu",
|
236 |
+
choices=utils.get_available_activation_fns(),
|
237 |
+
help="activation function to use",
|
238 |
+
)
|
239 |
+
|
240 |
+
parser.add_argument(
|
241 |
+
"--mbart-dropout", type=float, metavar="D", help="dropout probability"
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--mbart-attention-dropout",
|
245 |
+
type=float,
|
246 |
+
metavar="D",
|
247 |
+
help="dropout probability for attention weights",
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--mbart-activation-dropout",
|
251 |
+
type=float,
|
252 |
+
metavar="D",
|
253 |
+
help="dropout probability after activation in FFN.",
|
254 |
+
)
|
255 |
+
|
256 |
+
parser.add_argument(
|
257 |
+
"--encoder-embed-dim",
|
258 |
+
type=int,
|
259 |
+
metavar="N",
|
260 |
+
help="encoder embedding dimension",
|
261 |
+
)
|
262 |
+
parser.add_argument(
|
263 |
+
"--encoder-ffn-embed-dim",
|
264 |
+
type=int,
|
265 |
+
metavar="N",
|
266 |
+
help="encoder embedding dimension for FFN",
|
267 |
+
)
|
268 |
+
parser.add_argument(
|
269 |
+
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
270 |
+
)
|
271 |
+
parser.add_argument(
|
272 |
+
"--encoder-attention-heads",
|
273 |
+
type=int,
|
274 |
+
metavar="N",
|
275 |
+
help="num encoder attention heads",
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--encoder-normalize-before",
|
279 |
+
action="store_true",
|
280 |
+
help="apply layernorm before each encoder block",
|
281 |
+
)
|
282 |
+
|
283 |
+
parser.add_argument(
|
284 |
+
"--decoder-embed-dim",
|
285 |
+
type=int,
|
286 |
+
metavar="N",
|
287 |
+
help="decoder embedding dimension",
|
288 |
+
)
|
289 |
+
parser.add_argument(
|
290 |
+
"--decoder-ffn-embed-dim",
|
291 |
+
type=int,
|
292 |
+
metavar="N",
|
293 |
+
help="decoder embedding dimension for FFN",
|
294 |
+
)
|
295 |
+
parser.add_argument(
|
296 |
+
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--decoder-attention-heads",
|
300 |
+
type=int,
|
301 |
+
metavar="N",
|
302 |
+
help="num decoder attention heads",
|
303 |
+
)
|
304 |
+
parser.add_argument(
|
305 |
+
"--decoder-normalize-before",
|
306 |
+
action="store_true",
|
307 |
+
help="apply layernorm before each decoder block",
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--layernorm-embedding",
|
311 |
+
action="store_true",
|
312 |
+
help="add layernorm to embedding",
|
313 |
+
)
|
314 |
+
parser.add_argument(
|
315 |
+
"--no-scale-embedding",
|
316 |
+
action="store_true",
|
317 |
+
help="if True, dont scale embeddings",
|
318 |
+
)
|
319 |
+
parser.add_argument(
|
320 |
+
"--load-pretrained-mbart-from",
|
321 |
+
type=str,
|
322 |
+
metavar="STR",
|
323 |
+
help="model to take text encoder decoder weights from (for initialization)",
|
324 |
+
)
|
325 |
+
# parser.add_argument("--finetune-w2v-params", type=str, metavar="STR",
|
326 |
+
# help="comma-separated param strings to finetune.")
|
327 |
+
parser.add_argument(
|
328 |
+
"--finetune-mbart-decoder-params",
|
329 |
+
type=str,
|
330 |
+
metavar="STR",
|
331 |
+
help="comma-separated param strings to finetune.",
|
332 |
+
)
|
333 |
+
parser.add_argument(
|
334 |
+
"--finetune-mbart-encoder-params",
|
335 |
+
type=str,
|
336 |
+
metavar="STR",
|
337 |
+
help="comma-separated param strings to finetune.",
|
338 |
+
)
|
339 |
+
parser.add_argument(
|
340 |
+
"--skip-encoder-projection",
|
341 |
+
action="store_true",
|
342 |
+
help="skip the projection layer in encoder",
|
343 |
+
)
|
344 |
+
|
345 |
+
parser.add_argument(
|
346 |
+
"--enc-grad-mult",
|
347 |
+
type=float,
|
348 |
+
metavar="V",
|
349 |
+
default=1.0,
|
350 |
+
help="multiply enc1 and enc2 gradient by V",
|
351 |
+
)
|
352 |
+
parser.add_argument(
|
353 |
+
"--enc2-along-grad-mult",
|
354 |
+
type=float,
|
355 |
+
metavar="V",
|
356 |
+
default=1.0,
|
357 |
+
help="multiply enc2 gradient by V if only enc2 is used",
|
358 |
+
)
|
359 |
+
parser.add_argument(
|
360 |
+
"--text-input-cost-ratio",
|
361 |
+
type=float,
|
362 |
+
default=1.0,
|
363 |
+
metavar="V",
|
364 |
+
help="text input cost ratio relative to speech input cost",
|
365 |
+
)
|
366 |
+
parser.add_argument(
|
367 |
+
"--stack-w2v-mbart-encoder",
|
368 |
+
action="store_true",
|
369 |
+
help="stack w2v and mbart encoder",
|
370 |
+
)
|
371 |
+
parser.add_argument(
|
372 |
+
"--stack-w2v-mbart-nonorm-encoder",
|
373 |
+
action="store_true",
|
374 |
+
help="stack w2v and mbart encoder",
|
375 |
+
)
|
376 |
+
parser.add_argument(
|
377 |
+
"--no-final-norm-decoder", action="store_true", help="no layer norm"
|
378 |
+
)
|
379 |
+
parser.add_argument(
|
380 |
+
"--drop-w2v-layers",
|
381 |
+
type=int,
|
382 |
+
default=0,
|
383 |
+
metavar="N",
|
384 |
+
help="drop w2v encoder layers",
|
385 |
+
)
|
386 |
+
|
387 |
+
parser.add_argument(
|
388 |
+
"--share-w2v-text-encoder",
|
389 |
+
action="store_true",
|
390 |
+
help="share w2v encoder layers with text encoder",
|
391 |
+
)
|
392 |
+
parser.add_argument(
|
393 |
+
"--shared-w2v-layers",
|
394 |
+
type=int,
|
395 |
+
default=0,
|
396 |
+
metavar="N",
|
397 |
+
help="shared encoder layers from w2v encoder",
|
398 |
+
)
|
399 |
+
|
400 |
+
@classmethod
|
401 |
+
def build_encoder(cls, args, task):
|
402 |
+
_args = copy.deepcopy(args)
|
403 |
+
_args.dropout = args.mbart_dropout
|
404 |
+
_args.attention_dropout = args.mbart_attention_dropout
|
405 |
+
_args.activation_dropout = args.mbart_activation_dropout
|
406 |
+
_args.max_source_positions = 1024
|
407 |
+
enc_emb = nn.Embedding(
|
408 |
+
len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
|
409 |
+
)
|
410 |
+
text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
|
411 |
+
spch_encoder = Wav2VecEncoderWithAdaptor(args)
|
412 |
+
if getattr(args, "load_pretrained_mbart_from", None):
|
413 |
+
text_encoder = checkpoint_utils.load_pretrained_component_from_model(
|
414 |
+
component=text_encoder, checkpoint=args.load_pretrained_mbart_from
|
415 |
+
)
|
416 |
+
if getattr(args, "stack_w2v_mbart_encoder", False):
|
417 |
+
assert getattr(args, "share_w2v_text_encoder", False) is False
|
418 |
+
spch_encoder = StackedWav2VecEncoderWithAdaptor(
|
419 |
+
spch_encoder.w2v_encoder,
|
420 |
+
text_encoder.layers,
|
421 |
+
text_encoder.layer_norm,
|
422 |
+
spch_encoder.adaptor,
|
423 |
+
args.drop_w2v_layers,
|
424 |
+
)
|
425 |
+
elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
|
426 |
+
text_encoder.layer_norm = None
|
427 |
+
spch_encoder = StackedWav2VecEncoderWithAdaptor(
|
428 |
+
spch_encoder.w2v_encoder,
|
429 |
+
text_encoder.layers,
|
430 |
+
text_encoder.layer_norm,
|
431 |
+
spch_encoder.adaptor,
|
432 |
+
args.drop_w2v_layers,
|
433 |
+
)
|
434 |
+
elif getattr(args, "share_w2v_text_encoder", False):
|
435 |
+
spch_encoder = SharedEncoder(
|
436 |
+
spch_encoder.w2v_encoder,
|
437 |
+
text_encoder,
|
438 |
+
spch_encoder.adaptor,
|
439 |
+
args.shared_w2v_layers,
|
440 |
+
)
|
441 |
+
|
442 |
+
for k, p in spch_encoder.named_parameters():
|
443 |
+
# Freeze pretrained models by default
|
444 |
+
if safe_hasattr(
|
445 |
+
args, "finetune_w2v_params"
|
446 |
+
) and need_finetuning(args.finetune_w2v_params, k):
|
447 |
+
p.requires_grad = True
|
448 |
+
else:
|
449 |
+
p.requires_grad = False
|
450 |
+
for k, p in text_encoder.named_parameters():
|
451 |
+
# Freeze pretrained models by default
|
452 |
+
if safe_hasattr(
|
453 |
+
args, "finetune_mbart_encoder_params"
|
454 |
+
) and need_finetuning(
|
455 |
+
args.finetune_mbart_encoder_params, k
|
456 |
+
):
|
457 |
+
p.requires_grad = True
|
458 |
+
else:
|
459 |
+
p.requires_grad = False
|
460 |
+
cross_attentive_loss_before_last_layer = (
|
461 |
+
0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
|
462 |
+
)
|
463 |
+
encoder = DualInputEncoder(
|
464 |
+
args,
|
465 |
+
spch_encoder,
|
466 |
+
text_encoder,
|
467 |
+
task.src_dict,
|
468 |
+
cross_attentive_loss_before_last_layer,
|
469 |
+
)
|
470 |
+
return encoder
|
471 |
+
|
472 |
+
@classmethod
|
473 |
+
def build_decoder(cls, args, task):
|
474 |
+
_args = copy.deepcopy(args)
|
475 |
+
_args.dropout = args.mbart_dropout
|
476 |
+
_args.attention_dropout = args.mbart_attention_dropout
|
477 |
+
_args.activation_dropout = args.mbart_activation_dropout
|
478 |
+
_args.max_target_positions = 1024
|
479 |
+
dec_emb = nn.Embedding(
|
480 |
+
len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad()
|
481 |
+
)
|
482 |
+
decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb)
|
483 |
+
if getattr(args, "load_pretrained_mbart_from", None):
|
484 |
+
decoder = checkpoint_utils.load_pretrained_component_from_model(
|
485 |
+
component=decoder, checkpoint=args.load_pretrained_mbart_from
|
486 |
+
)
|
487 |
+
if getattr(args, "no_final_norm_decoder", False):
|
488 |
+
decoder.layer_norm = None
|
489 |
+
for k, p in decoder.named_parameters():
|
490 |
+
# Freeze pretrained models by default
|
491 |
+
if safe_hasattr(
|
492 |
+
args, "finetune_mbart_decoder_params"
|
493 |
+
) and need_finetuning(
|
494 |
+
args.finetune_mbart_decoder_params, k
|
495 |
+
):
|
496 |
+
p.requires_grad = True
|
497 |
+
else:
|
498 |
+
p.requires_grad = False
|
499 |
+
|
500 |
+
compute_cross_attentive_loss = (
|
501 |
+
True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
|
502 |
+
)
|
503 |
+
cross_attentive_loss_without_norm = getattr(
|
504 |
+
args, "attentive_cost_without_normalize", False
|
505 |
+
)
|
506 |
+
cross_attentive_loss_reverse = (
|
507 |
+
False # getattr(args, "attentive_cost_reverse", False)
|
508 |
+
)
|
509 |
+
decoder = TransformerMultiInputDecoder(
|
510 |
+
dictionary=task.target_dictionary,
|
511 |
+
spch_decoder=decoder,
|
512 |
+
text_decoder=decoder,
|
513 |
+
compute_cross_attentive_loss=compute_cross_attentive_loss,
|
514 |
+
cross_attentive_loss_with_norm=True
|
515 |
+
if not cross_attentive_loss_without_norm
|
516 |
+
else False,
|
517 |
+
cross_attentive_loss_reverse=cross_attentive_loss_reverse,
|
518 |
+
)
|
519 |
+
return decoder
|
520 |
+
|
521 |
+
@classmethod
|
522 |
+
def build_model(cls, args, task):
|
523 |
+
"""Build a new model instance."""
|
524 |
+
# make sure that all args are properly defaulted
|
525 |
+
# (in case there are any new ones)
|
526 |
+
dualinputxmtransformer_base(args)
|
527 |
+
|
528 |
+
encoder = cls.build_encoder(args, task)
|
529 |
+
decoder = cls.build_decoder(args, task)
|
530 |
+
return cls(encoder, decoder)
|
531 |
+
|
532 |
+
|
533 |
+
@register_model_architecture("dual_input_xm_transformer", "dualinputxmtransformer_base")
|
534 |
+
def dualinputxmtransformer_base(args):
|
535 |
+
# wav2vec encoder
|
536 |
+
set_default_w2v_encoder_args(args)
|
537 |
+
set_default_adaptor_args(args)
|
538 |
+
|
539 |
+
# mbart model
|
540 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
541 |
+
args.encoder_ffn_embed_dim = getattr(
|
542 |
+
args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim
|
543 |
+
)
|
544 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
545 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
546 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
547 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
548 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
549 |
+
|
550 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
551 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
552 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024)
|
553 |
+
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
554 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
555 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
556 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
|
557 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
558 |
+
|
559 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
560 |
+
|
561 |
+
args.mbart_attention_dropout = getattr(args, "mbart_attention_dropout", 0.0)
|
562 |
+
args.mbart_activation_dropout = getattr(args, "mbart_activation_dropout", 0.0)
|
563 |
+
args.mbart_dropout = getattr(args, "mbart_dropout", 0.1)
|
564 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
565 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
566 |
+
args.share_decoder_input_output_embed = getattr(
|
567 |
+
args, "share_decoder_input_output_embed", True
|
568 |
+
)
|
569 |
+
args.no_token_positional_embeddings = getattr(
|
570 |
+
args, "no_token_positional_embeddings", False
|
571 |
+
)
|
572 |
+
|
573 |
+
args.decoder_output_dim = getattr(
|
574 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
575 |
+
)
|
576 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
577 |
+
|
578 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
579 |
+
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
580 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
581 |
+
|
582 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
583 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
584 |
+
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
fairseq/examples/speech_text_joint_to_text/scripts/convert_model.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import re
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from fairseq.file_io import PathManager
|
14 |
+
|
15 |
+
|
16 |
+
def is_update(param_name, module_name):
|
17 |
+
if module_name in param_name:
|
18 |
+
return True
|
19 |
+
return False
|
20 |
+
|
21 |
+
|
22 |
+
def load_checkpoint(src_cpt):
|
23 |
+
|
24 |
+
with PathManager.open(src_cpt, "rb") as f:
|
25 |
+
state_src = torch.load(
|
26 |
+
f,
|
27 |
+
map_location=(
|
28 |
+
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
29 |
+
),
|
30 |
+
)
|
31 |
+
|
32 |
+
return state_src
|
33 |
+
|
34 |
+
|
35 |
+
def save_checkpoint(tgt_cpt, states):
|
36 |
+
|
37 |
+
with PathManager.open(tgt_cpt, "wb") as f:
|
38 |
+
torch.save(
|
39 |
+
states,
|
40 |
+
f,
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
# convert the pre-trained model into bart model
|
45 |
+
def main():
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
# fmt: off
|
48 |
+
parser.add_argument('--input-model', required=True,
|
49 |
+
help='Input checkpoint file path.')
|
50 |
+
parser.add_argument('--output-model', required=True,
|
51 |
+
help='output checkpoint file path.')
|
52 |
+
# fmt: on
|
53 |
+
args = parser.parse_args()
|
54 |
+
print(args)
|
55 |
+
|
56 |
+
states = load_checkpoint(args.input_model)
|
57 |
+
model = states["model"]
|
58 |
+
new_model = OrderedDict()
|
59 |
+
for key in model.keys():
|
60 |
+
if re.search("^encoder.text_encoder", key):
|
61 |
+
new_key = re.sub("encoder.text_encoder", "encoder", key)
|
62 |
+
new_model[new_key] = model[key]
|
63 |
+
elif re.search("^decoder.text_decoder", key):
|
64 |
+
new_key = re.sub("decoder.text_decoder", "decoder", key)
|
65 |
+
new_model[new_key] = model[key]
|
66 |
+
states["model"] = new_model
|
67 |
+
save_checkpoint(args.output_model, states)
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
main()
|
fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import itertools
|
8 |
+
import logging
|
9 |
+
import re
|
10 |
+
import time
|
11 |
+
|
12 |
+
from g2p_en import G2p
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
FAIL_SENT = "FAILED_SENTENCE"
|
17 |
+
|
18 |
+
|
19 |
+
def parse():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--data-path", type=str, required=True)
|
22 |
+
parser.add_argument("--out-path", type=str, required=True)
|
23 |
+
parser.add_argument("--lower-case", action="store_true")
|
24 |
+
parser.add_argument("--do-filter", action="store_true")
|
25 |
+
parser.add_argument("--use-word-start", action="store_true")
|
26 |
+
parser.add_argument("--dup-vowel", default=1, type=int)
|
27 |
+
parser.add_argument("--dup-consonant", default=1, type=int)
|
28 |
+
parser.add_argument("--no-punc", action="store_true")
|
29 |
+
parser.add_argument("--reserve-word", type=str, default="")
|
30 |
+
parser.add_argument(
|
31 |
+
"--reserve-first-column",
|
32 |
+
action="store_true",
|
33 |
+
help="first column is sentence id",
|
34 |
+
)
|
35 |
+
###
|
36 |
+
parser.add_argument("--parallel-process-num", default=1, type=int)
|
37 |
+
parser.add_argument("--logdir", default="")
|
38 |
+
args = parser.parse_args()
|
39 |
+
return args
|
40 |
+
|
41 |
+
|
42 |
+
def process_sent(sent, g2p, res_wrds, args):
|
43 |
+
sents = pre_process_sent(sent, args.do_filter, args.lower_case, res_wrds)
|
44 |
+
pho_seqs = [do_g2p(g2p, s, res_wrds, i == 0) for i, s in enumerate(sents)]
|
45 |
+
pho_seq = (
|
46 |
+
[FAIL_SENT]
|
47 |
+
if [FAIL_SENT] in pho_seqs
|
48 |
+
else list(itertools.chain.from_iterable(pho_seqs))
|
49 |
+
)
|
50 |
+
if args.no_punc:
|
51 |
+
pho_seq = remove_punc(pho_seq)
|
52 |
+
if args.dup_vowel > 1 or args.dup_consonant > 1:
|
53 |
+
pho_seq = dup_pho(pho_seq, args.dup_vowel, args.dup_consonant)
|
54 |
+
if args.use_word_start:
|
55 |
+
pho_seq = add_word_start(pho_seq)
|
56 |
+
return " ".join(pho_seq)
|
57 |
+
|
58 |
+
|
59 |
+
def remove_punc(sent):
|
60 |
+
ns = []
|
61 |
+
regex = re.compile("[^a-zA-Z0-9 ]")
|
62 |
+
for p in sent:
|
63 |
+
if (not regex.search(p)) or p == FAIL_SENT:
|
64 |
+
if p == " " and (len(ns) == 0 or ns[-1] == " "):
|
65 |
+
continue
|
66 |
+
ns.append(p)
|
67 |
+
return ns
|
68 |
+
|
69 |
+
|
70 |
+
def do_g2p(g2p, sent, res_wrds, is_first_sent):
|
71 |
+
if sent in res_wrds:
|
72 |
+
pho_seq = [res_wrds[sent]]
|
73 |
+
else:
|
74 |
+
pho_seq = g2p(sent)
|
75 |
+
if not is_first_sent:
|
76 |
+
pho_seq = [" "] + pho_seq # add space to separate
|
77 |
+
return pho_seq
|
78 |
+
|
79 |
+
|
80 |
+
def pre_process_sent(sent, do_filter, lower_case, res_wrds):
|
81 |
+
if do_filter:
|
82 |
+
sent = re.sub("-", " ", sent)
|
83 |
+
sent = re.sub("—", " ", sent)
|
84 |
+
if len(res_wrds) > 0:
|
85 |
+
wrds = sent.split()
|
86 |
+
wrds = ["SPLIT_ME " + w + " SPLIT_ME" if w in res_wrds else w for w in wrds]
|
87 |
+
sents = [x.strip() for x in " ".join(wrds).split("SPLIT_ME") if x.strip() != ""]
|
88 |
+
else:
|
89 |
+
sents = [sent]
|
90 |
+
if lower_case:
|
91 |
+
sents = [s.lower() if s not in res_wrds else s for s in sents]
|
92 |
+
return sents
|
93 |
+
|
94 |
+
|
95 |
+
def dup_pho(sent, dup_v_num, dup_c_num):
|
96 |
+
"""
|
97 |
+
duplicate phoneme defined as cmudict
|
98 |
+
http://www.speech.cs.cmu.edu/cgi-bin/cmudict
|
99 |
+
"""
|
100 |
+
if dup_v_num == 1 and dup_c_num == 1:
|
101 |
+
return sent
|
102 |
+
ns = []
|
103 |
+
for p in sent:
|
104 |
+
ns.append(p)
|
105 |
+
if re.search(r"\d$", p):
|
106 |
+
for i in range(1, dup_v_num):
|
107 |
+
ns.append(f"{p}-{i}P")
|
108 |
+
elif re.search(r"\w", p):
|
109 |
+
for i in range(1, dup_c_num):
|
110 |
+
ns.append(f"{p}-{i}P")
|
111 |
+
return ns
|
112 |
+
|
113 |
+
|
114 |
+
def add_word_start(sent):
|
115 |
+
ns = []
|
116 |
+
do_add = True
|
117 |
+
ws = "▁"
|
118 |
+
for p in sent:
|
119 |
+
if do_add:
|
120 |
+
p = ws + p
|
121 |
+
do_add = False
|
122 |
+
if p == " ":
|
123 |
+
do_add = True
|
124 |
+
else:
|
125 |
+
ns.append(p)
|
126 |
+
return ns
|
127 |
+
|
128 |
+
|
129 |
+
def load_reserve_word(reserve_word):
|
130 |
+
if reserve_word == "":
|
131 |
+
return []
|
132 |
+
with open(reserve_word, "r") as fp:
|
133 |
+
res_wrds = [x.strip().split() for x in fp.readlines() if x.strip() != ""]
|
134 |
+
assert sum([0 if len(x) == 2 else 1 for x in res_wrds]) == 0
|
135 |
+
res_wrds = dict(res_wrds)
|
136 |
+
return res_wrds
|
137 |
+
|
138 |
+
|
139 |
+
def process_sents(sents, args):
|
140 |
+
g2p = G2p()
|
141 |
+
out_sents = []
|
142 |
+
res_wrds = load_reserve_word(args.reserve_word)
|
143 |
+
for sent in sents:
|
144 |
+
col1 = ""
|
145 |
+
if args.reserve_first_column:
|
146 |
+
col1, sent = sent.split(None, 1)
|
147 |
+
sent = process_sent(sent, g2p, res_wrds, args)
|
148 |
+
if args.reserve_first_column and col1 != "":
|
149 |
+
sent = f"{col1} {sent}"
|
150 |
+
out_sents.append(sent)
|
151 |
+
return out_sents
|
152 |
+
|
153 |
+
|
154 |
+
def main():
|
155 |
+
args = parse()
|
156 |
+
out_sents = []
|
157 |
+
with open(args.data_path, "r") as fp:
|
158 |
+
sent_list = [x.strip() for x in fp.readlines()]
|
159 |
+
if args.parallel_process_num > 1:
|
160 |
+
try:
|
161 |
+
import submitit
|
162 |
+
except ImportError:
|
163 |
+
logger.warn(
|
164 |
+
"submitit is not found and only one job is used to process the data"
|
165 |
+
)
|
166 |
+
submitit = None
|
167 |
+
|
168 |
+
if args.parallel_process_num == 1 or submitit is None:
|
169 |
+
out_sents = process_sents(sent_list, args)
|
170 |
+
else:
|
171 |
+
# process sentences with parallel computation
|
172 |
+
lsize = len(sent_list) // args.parallel_process_num + 1
|
173 |
+
executor = submitit.AutoExecutor(folder=args.logdir)
|
174 |
+
executor.update_parameters(timeout_min=1000, cpus_per_task=4)
|
175 |
+
jobs = []
|
176 |
+
for i in range(args.parallel_process_num):
|
177 |
+
job = executor.submit(
|
178 |
+
process_sents, sent_list[lsize * i : lsize * (i + 1)], args
|
179 |
+
)
|
180 |
+
jobs.append(job)
|
181 |
+
is_running = True
|
182 |
+
while is_running:
|
183 |
+
time.sleep(5)
|
184 |
+
is_running = sum([job.done() for job in jobs]) < len(jobs)
|
185 |
+
out_sents = list(itertools.chain.from_iterable([job.result() for job in jobs]))
|
186 |
+
with open(args.out_path, "w") as fp:
|
187 |
+
fp.write("\n".join(out_sents) + "\n")
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
main()
|
fairseq/examples/speech_text_joint_to_text/tasks/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import importlib
|
7 |
+
import os
|
8 |
+
|
fairseq/examples/speech_text_joint_to_text/tasks/pair_denoising.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import re
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from examples.speech_text_joint_to_text.data.pair_denoising_dataset import (
|
15 |
+
LanguagePairDenoisingDataset,
|
16 |
+
)
|
17 |
+
from fairseq import utils
|
18 |
+
from fairseq.data import (
|
19 |
+
ConcatDataset,
|
20 |
+
Dictionary,
|
21 |
+
LanguagePairDataset,
|
22 |
+
ResamplingDataset,
|
23 |
+
TransformEosConcatLangPairDataset,
|
24 |
+
TransformEosLangPairDataset,
|
25 |
+
data_utils,
|
26 |
+
indexed_dataset,
|
27 |
+
)
|
28 |
+
from fairseq.data.encoders.utils import get_whole_word_mask
|
29 |
+
from fairseq.tasks import register_task
|
30 |
+
from fairseq.tasks.translation import TranslationTask
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def gen_whole_word_mask(args, dictionary):
|
36 |
+
def is_beginning_of_word(i):
|
37 |
+
if i < dictionary.nspecial:
|
38 |
+
# special elements are always considered beginnings
|
39 |
+
return True
|
40 |
+
tok = dictionary[i]
|
41 |
+
if tok.startswith("madeupword"):
|
42 |
+
return True
|
43 |
+
|
44 |
+
if tok in ["<unk>", "<s>", "</s>", "<pad>"]:
|
45 |
+
return True
|
46 |
+
return tok.startswith("\u2581")
|
47 |
+
|
48 |
+
if args.use_mask_whole_words:
|
49 |
+
mask_whole_words = torch.ByteTensor(
|
50 |
+
list(map(is_beginning_of_word, range(len(dictionary))))
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
# it will mask every token as word leading token, since no bpe model is loaded for phoneme tokens
|
54 |
+
return get_whole_word_mask(args, dictionary)
|
55 |
+
return mask_whole_words
|
56 |
+
|
57 |
+
|
58 |
+
@register_task("paired_denoising")
|
59 |
+
class PairedDenoisingTask(TranslationTask):
|
60 |
+
|
61 |
+
LANG_TAG_TEMPLATE = "<lang:{}>" # Tag for language (target)
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def add_args(parser):
|
65 |
+
TranslationTask.add_args(parser)
|
66 |
+
# bart setting
|
67 |
+
parser.add_argument(
|
68 |
+
"--mask",
|
69 |
+
default=0.0,
|
70 |
+
type=float,
|
71 |
+
help="fraction of words/subwords that will be masked",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--mask-random",
|
75 |
+
default=0.0,
|
76 |
+
type=float,
|
77 |
+
help="instead of using [MASK], use random token this often",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--insert",
|
81 |
+
default=0.0,
|
82 |
+
type=float,
|
83 |
+
help="insert this percentage of additional random tokens",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--poisson-lambda",
|
87 |
+
default=3.0,
|
88 |
+
type=float,
|
89 |
+
help="randomly shuffle sentences for this proportion of inputs",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--mask-length",
|
93 |
+
default="span-poisson",
|
94 |
+
type=str,
|
95 |
+
choices=["subword", "word", "span-poisson"],
|
96 |
+
help="mask length to choose",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--replace-length",
|
100 |
+
default=1,
|
101 |
+
type=int,
|
102 |
+
help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
|
103 |
+
)
|
104 |
+
|
105 |
+
# multi-lingual
|
106 |
+
parser.add_argument(
|
107 |
+
"--multilang-sampling-alpha",
|
108 |
+
type=float,
|
109 |
+
default=1.0,
|
110 |
+
help="smoothing alpha for sample ratios across multiple datasets",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--lang-pairs",
|
114 |
+
default="",
|
115 |
+
metavar="PAIRS",
|
116 |
+
help="comma-separated list of language pairs (in training order): phnen-en,phnfr-fr,phnit-it. Do masking",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--lang-pairs-bitext",
|
120 |
+
default="",
|
121 |
+
metavar="PAIRS",
|
122 |
+
help="comma-separated list of language pairs (in training order): en-de,en-fr,de-fr. No masking",
|
123 |
+
)
|
124 |
+
parser.add_argument("--add-src-lang-token", default=False, action="store_true")
|
125 |
+
parser.add_argument("--add-tgt-lang-token", default=False, action="store_true")
|
126 |
+
parser.add_argument(
|
127 |
+
"--no-whole-word-mask-langs",
|
128 |
+
type=str,
|
129 |
+
default="",
|
130 |
+
metavar="N",
|
131 |
+
help="languages without spacing between words dont support whole word masking",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--use-mask-whole-words", default=False, action="store_true"
|
135 |
+
)
|
136 |
+
|
137 |
+
@classmethod
|
138 |
+
def setup_task(cls, args, **kwargs):
|
139 |
+
"""Setup the task."""
|
140 |
+
paths = args.data.split(":")
|
141 |
+
assert len(paths) > 0
|
142 |
+
src_dict = Dictionary.load(
|
143 |
+
os.path.join(paths[0], "src_dict.txt")
|
144 |
+
) # assume all languages share a source dictionary
|
145 |
+
tgt_dict = Dictionary.load(
|
146 |
+
os.path.join(paths[0], "tgt_dict.txt")
|
147 |
+
) # assume all languages share a target dictionary
|
148 |
+
|
149 |
+
lang_pairs = args.lang_pairs + "," + args.lang_pairs_bitext
|
150 |
+
lang_pairs = re.sub(",$", "", re.sub("^,", "", lang_pairs))
|
151 |
+
src_langs = [lp.split("-")[0] for lp in lang_pairs.split(",")]
|
152 |
+
tgt_langs = [lp.split("-")[1] for lp in lang_pairs.split(",")]
|
153 |
+
|
154 |
+
if args.add_src_lang_token:
|
155 |
+
for lang in src_langs:
|
156 |
+
assert (
|
157 |
+
src_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
|
158 |
+
!= src_dict.unk()
|
159 |
+
)
|
160 |
+
if args.add_tgt_lang_token:
|
161 |
+
for lang in tgt_langs:
|
162 |
+
assert (
|
163 |
+
tgt_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
|
164 |
+
!= tgt_dict.unk()
|
165 |
+
)
|
166 |
+
|
167 |
+
logger.info("source dictionary: {} types".format(len(src_dict)))
|
168 |
+
logger.info("target dictionary: {} types".format(len(tgt_dict)))
|
169 |
+
if not hasattr(args, "shuffle_instance"):
|
170 |
+
args.shuffle_instance = False
|
171 |
+
return cls(args, src_dict, tgt_dict)
|
172 |
+
|
173 |
+
def __init__(self, args, src_dict, tgt_dict):
|
174 |
+
super().__init__(args, src_dict, tgt_dict)
|
175 |
+
# check mask token
|
176 |
+
self.mask_idx = self.src_dict.index("<mask>")
|
177 |
+
assert self.mask_idx != self.src_dict.unk()
|
178 |
+
self.lang_pairs = args.lang_pairs
|
179 |
+
self.lang_pairs_bitext = args.lang_pairs_bitext
|
180 |
+
self.args = args
|
181 |
+
|
182 |
+
@classmethod
|
183 |
+
def language_pair_denoising_dataset(
|
184 |
+
cls,
|
185 |
+
data_path,
|
186 |
+
do_mask,
|
187 |
+
split,
|
188 |
+
src,
|
189 |
+
src_dict,
|
190 |
+
tgt,
|
191 |
+
tgt_dict,
|
192 |
+
mask_idx,
|
193 |
+
mask_whole_words,
|
194 |
+
seed,
|
195 |
+
args,
|
196 |
+
dataset_impl,
|
197 |
+
combine=False,
|
198 |
+
left_pad_source=True,
|
199 |
+
left_pad_target=False,
|
200 |
+
max_source_positions=1024,
|
201 |
+
max_target_positions=1024,
|
202 |
+
shuffle=True,
|
203 |
+
src_lang_id=None,
|
204 |
+
tgt_lang_id=None,
|
205 |
+
):
|
206 |
+
def split_exists(split, src, tgt, lang, data_path):
|
207 |
+
filename = os.path.join(
|
208 |
+
data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
|
209 |
+
)
|
210 |
+
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
|
211 |
+
|
212 |
+
src_datasets = []
|
213 |
+
tgt_datasets = []
|
214 |
+
|
215 |
+
for k in itertools.count():
|
216 |
+
split_k = split + (str(k) if k > 0 else "")
|
217 |
+
|
218 |
+
# infer langcode
|
219 |
+
if split_exists(split_k, src, tgt, src, data_path):
|
220 |
+
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
|
221 |
+
elif split_exists(split_k, tgt, src, src, data_path):
|
222 |
+
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
|
223 |
+
else:
|
224 |
+
if k > 0:
|
225 |
+
break
|
226 |
+
else:
|
227 |
+
raise FileNotFoundError(
|
228 |
+
"Dataset not found: {} ({})".format(split, data_path)
|
229 |
+
)
|
230 |
+
|
231 |
+
src_dataset = data_utils.load_indexed_dataset(
|
232 |
+
prefix + src, src_dict, dataset_impl
|
233 |
+
)
|
234 |
+
src_datasets.append(src_dataset)
|
235 |
+
|
236 |
+
tgt_dataset = data_utils.load_indexed_dataset(
|
237 |
+
prefix + tgt, tgt_dict, dataset_impl
|
238 |
+
)
|
239 |
+
if tgt_dataset is not None:
|
240 |
+
tgt_datasets.append(tgt_dataset)
|
241 |
+
|
242 |
+
logger.info(
|
243 |
+
"{} {} {}-{} {} examples".format(
|
244 |
+
data_path, split_k, src, tgt, len(src_datasets[-1])
|
245 |
+
)
|
246 |
+
)
|
247 |
+
|
248 |
+
if not combine:
|
249 |
+
break
|
250 |
+
|
251 |
+
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
|
252 |
+
|
253 |
+
if len(src_datasets) == 1:
|
254 |
+
src_dataset = src_datasets[0]
|
255 |
+
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
|
256 |
+
else:
|
257 |
+
sample_ratios = [1] * len(src_datasets)
|
258 |
+
src_dataset = ConcatDataset(src_datasets, sample_ratios)
|
259 |
+
if len(tgt_datasets) > 0:
|
260 |
+
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
|
261 |
+
else:
|
262 |
+
tgt_dataset = None
|
263 |
+
|
264 |
+
eos = None
|
265 |
+
|
266 |
+
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
|
267 |
+
if not do_mask:
|
268 |
+
return LanguagePairDataset(
|
269 |
+
src_dataset,
|
270 |
+
src_dataset.sizes,
|
271 |
+
src_dict,
|
272 |
+
tgt_dataset,
|
273 |
+
tgt_dataset_sizes,
|
274 |
+
tgt_dict,
|
275 |
+
left_pad_source=left_pad_source,
|
276 |
+
left_pad_target=left_pad_target,
|
277 |
+
eos=eos,
|
278 |
+
shuffle=shuffle,
|
279 |
+
src_lang_id=src_lang_id,
|
280 |
+
tgt_lang_id=tgt_lang_id,
|
281 |
+
)
|
282 |
+
|
283 |
+
return LanguagePairDenoisingDataset(
|
284 |
+
src_dataset,
|
285 |
+
src_dataset.sizes,
|
286 |
+
src_dict,
|
287 |
+
tgt_dataset,
|
288 |
+
tgt_dataset_sizes,
|
289 |
+
tgt_dict,
|
290 |
+
mask_idx,
|
291 |
+
mask_whole_words,
|
292 |
+
seed,
|
293 |
+
args,
|
294 |
+
left_pad_source=left_pad_source,
|
295 |
+
left_pad_target=left_pad_target,
|
296 |
+
eos=eos,
|
297 |
+
shuffle=shuffle,
|
298 |
+
src_lang_id=src_lang_id,
|
299 |
+
tgt_lang_id=tgt_lang_id,
|
300 |
+
)
|
301 |
+
|
302 |
+
def _get_sample_prob(self, dataset_lens):
|
303 |
+
"""
|
304 |
+
Get smoothed sampling porbability by languages. This helps low resource
|
305 |
+
languages by upsampling them.
|
306 |
+
"""
|
307 |
+
prob = dataset_lens / dataset_lens.sum()
|
308 |
+
smoothed_prob = prob ** self.args.multilang_sampling_alpha
|
309 |
+
smoothed_prob = smoothed_prob / smoothed_prob.sum()
|
310 |
+
return smoothed_prob
|
311 |
+
|
312 |
+
def resample_datasets(self, lang_datasets, lang_pairs_all, epoch):
|
313 |
+
# For train subset, additionally up or down sample languages.
|
314 |
+
if self.args.multilang_sampling_alpha == 1.0:
|
315 |
+
return lang_datasets
|
316 |
+
|
317 |
+
dataset_lengths = np.array(
|
318 |
+
[len(d) for d in lang_datasets],
|
319 |
+
dtype=float,
|
320 |
+
)
|
321 |
+
sample_probs = self._get_sample_prob(dataset_lengths)
|
322 |
+
logger.info(
|
323 |
+
"Sample probability by language pair: {}".format(
|
324 |
+
{
|
325 |
+
lp: "{0:.4f}".format(sample_probs[id])
|
326 |
+
for id, lp in enumerate(lang_pairs_all)
|
327 |
+
}
|
328 |
+
)
|
329 |
+
)
|
330 |
+
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
|
331 |
+
logger.info(
|
332 |
+
"Up/Down Sampling ratio by language: {}".format(
|
333 |
+
{
|
334 |
+
lp: "{0:.2f}".format(size_ratio[id])
|
335 |
+
for id, lp in enumerate(lang_pairs_all)
|
336 |
+
}
|
337 |
+
)
|
338 |
+
)
|
339 |
+
|
340 |
+
resampled_lang_datasets = [
|
341 |
+
ResamplingDataset(
|
342 |
+
lang_datasets[i],
|
343 |
+
size_ratio=size_ratio[i],
|
344 |
+
seed=self.args.seed,
|
345 |
+
epoch=epoch,
|
346 |
+
replace=size_ratio[i] >= 1.0,
|
347 |
+
)
|
348 |
+
for i, d in enumerate(lang_datasets)
|
349 |
+
]
|
350 |
+
return resampled_lang_datasets
|
351 |
+
|
352 |
+
def load_dataset_only(
|
353 |
+
self, split, lang_pairs, do_mask=True, epoch=1, combine=False
|
354 |
+
):
|
355 |
+
paths = utils.split_paths(self.args.data)
|
356 |
+
assert len(paths) > 0
|
357 |
+
data_path = paths[(epoch - 1) % len(paths)]
|
358 |
+
|
359 |
+
# TODO unk token will be considered as first word too, though it might be an unknown phoneme within a word
|
360 |
+
# get_whole_word_mask returns a tensor (size V by 1 ) to indicate if a token is a word start token
|
361 |
+
mask_whole_src_words = gen_whole_word_mask(self.args, self.src_dict)
|
362 |
+
language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
|
363 |
+
lang_datasets = []
|
364 |
+
eos_bos = []
|
365 |
+
lang_pairs = lang_pairs.split(",") if lang_pairs != "" else []
|
366 |
+
assert len(lang_pairs) > 0
|
367 |
+
for lp in lang_pairs:
|
368 |
+
src, tgt = lp.split("-")
|
369 |
+
lang_mask_whole_src_words = (
|
370 |
+
mask_whole_src_words
|
371 |
+
if src not in language_without_segmentations
|
372 |
+
else None
|
373 |
+
)
|
374 |
+
|
375 |
+
end_token = (
|
376 |
+
self.source_dictionary.index(
|
377 |
+
PairedDenoisingTask.LANG_TAG_TEMPLATE.format(src)
|
378 |
+
)
|
379 |
+
if self.args.add_src_lang_token
|
380 |
+
else None
|
381 |
+
)
|
382 |
+
bos_token = (
|
383 |
+
self.target_dictionary.index(
|
384 |
+
PairedDenoisingTask.LANG_TAG_TEMPLATE.format(tgt)
|
385 |
+
)
|
386 |
+
if self.args.add_tgt_lang_token
|
387 |
+
else None
|
388 |
+
)
|
389 |
+
src_lang_id = None
|
390 |
+
|
391 |
+
if self.args.add_src_lang_token or self.args.add_tgt_lang_token:
|
392 |
+
eos_bos.append((end_token, bos_token))
|
393 |
+
|
394 |
+
dataset = PairedDenoisingTask.language_pair_denoising_dataset(
|
395 |
+
data_path,
|
396 |
+
do_mask,
|
397 |
+
split,
|
398 |
+
src,
|
399 |
+
self.source_dictionary,
|
400 |
+
tgt,
|
401 |
+
self.target_dictionary,
|
402 |
+
self.mask_idx,
|
403 |
+
lang_mask_whole_src_words,
|
404 |
+
self.args.seed,
|
405 |
+
self.args,
|
406 |
+
self.args.dataset_impl,
|
407 |
+
combine=combine,
|
408 |
+
left_pad_source=utils.eval_bool(self.args.left_pad_source),
|
409 |
+
left_pad_target=utils.eval_bool(self.args.left_pad_target),
|
410 |
+
max_source_positions=self.args.max_source_positions,
|
411 |
+
max_target_positions=self.args.max_target_positions,
|
412 |
+
src_lang_id=src_lang_id,
|
413 |
+
)
|
414 |
+
|
415 |
+
lang_datasets.append(dataset)
|
416 |
+
|
417 |
+
if len(lang_datasets) == 0:
|
418 |
+
return
|
419 |
+
elif len(lang_datasets) == 1:
|
420 |
+
dataset = lang_datasets[0]
|
421 |
+
if self.args.add_src_lang_token or self.args.add_tgt_lang_token:
|
422 |
+
end_token, bos_token = eos_bos[0]
|
423 |
+
dataset = TransformEosLangPairDataset(
|
424 |
+
dataset,
|
425 |
+
src_eos=self.source_dictionary.eos(),
|
426 |
+
new_src_eos=end_token,
|
427 |
+
tgt_bos=self.target_dictionary.eos(),
|
428 |
+
new_tgt_bos=bos_token,
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
end_tokens = [item[0] for item in eos_bos if item[0] is not None]
|
432 |
+
bos_tokens = [item[1] for item in eos_bos if item[1] is not None]
|
433 |
+
lang_datasets = self.resample_datasets(lang_datasets, lang_pairs, epoch)
|
434 |
+
dataset = TransformEosConcatLangPairDataset(
|
435 |
+
lang_datasets,
|
436 |
+
self.source_dictionary.eos(),
|
437 |
+
self.target_dictionary.eos(),
|
438 |
+
new_src_eos=end_tokens,
|
439 |
+
new_tgt_bos=bos_tokens,
|
440 |
+
)
|
441 |
+
return dataset
|
442 |
+
|
443 |
+
# split in (train, valid, test, ...)
|
444 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
445 |
+
self.datasets[split] = self.load_dataset_only(
|
446 |
+
split, self.lang_pairs, epoch=epoch, combine=combine
|
447 |
+
)
|
fairseq/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
from argparse import Namespace
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
from fairseq.data import ConcatDataset, Dictionary, encoders
|
12 |
+
from fairseq.data.audio.multi_modality_dataset import (
|
13 |
+
FileAudioDatasetWrapper,
|
14 |
+
ModalityDatasetItem,
|
15 |
+
MultiModalityDataset,
|
16 |
+
)
|
17 |
+
from fairseq.data.audio.speech_to_text_joint_dataset import (
|
18 |
+
S2TJointDataConfig,
|
19 |
+
SpeechToTextJointDatasetCreator,
|
20 |
+
)
|
21 |
+
from fairseq.data.iterators import GroupedEpochBatchIterator
|
22 |
+
from fairseq.tasks import register_task
|
23 |
+
|
24 |
+
from .pair_denoising import PairedDenoisingTask
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
@register_task("speech_text_joint_denoising")
|
30 |
+
class SpeechTextJointDenoisingPreTask(PairedDenoisingTask):
|
31 |
+
"""
|
32 |
+
Joint denoising training task for speech and text.
|
33 |
+
"""
|
34 |
+
|
35 |
+
SIL_TOKEN = "sil"
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def add_args(cls, parser):
|
39 |
+
PairedDenoisingTask.add_args(parser)
|
40 |
+
# set max tokens and position
|
41 |
+
parser.add_argument(
|
42 |
+
"--max-text-tokens",
|
43 |
+
type=int,
|
44 |
+
metavar="N",
|
45 |
+
default=1024,
|
46 |
+
help="maximum samples for encoder text input ",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--max-speech-tokens",
|
50 |
+
type=int,
|
51 |
+
metavar="N",
|
52 |
+
default=50000,
|
53 |
+
help="maximum samples for encoder speech input ",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--max-speech-positions",
|
57 |
+
type=int,
|
58 |
+
metavar="N",
|
59 |
+
default=400,
|
60 |
+
help="maximum tokens for per encoder text input ",
|
61 |
+
)
|
62 |
+
|
63 |
+
parser.add_argument(
|
64 |
+
"--max-sample-size",
|
65 |
+
type=int,
|
66 |
+
metavar="N",
|
67 |
+
default=32000,
|
68 |
+
help="max sample size to crop to for batching (unsupervised speech) ",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--min-sample-size",
|
72 |
+
type=int,
|
73 |
+
metavar="N",
|
74 |
+
default=4000,
|
75 |
+
help="min sample size to crop to for batching (unsupervised speech) ",
|
76 |
+
)
|
77 |
+
|
78 |
+
# set mini-batch ratio for different modalities/subtasks
|
79 |
+
# s2p
|
80 |
+
parser.add_argument(
|
81 |
+
"--supervised-speech-sample-ratio",
|
82 |
+
default="1",
|
83 |
+
type=str,
|
84 |
+
metavar="N",
|
85 |
+
help="Multiple Ratio for speech dataset with transcripts ",
|
86 |
+
)
|
87 |
+
# s2t
|
88 |
+
parser.add_argument(
|
89 |
+
"--supervised-speech-s2s-sample-ratio",
|
90 |
+
default="1",
|
91 |
+
type=str,
|
92 |
+
metavar="N",
|
93 |
+
help="Multiple Ratio for speech dataset with transcripts ",
|
94 |
+
)
|
95 |
+
# ssl
|
96 |
+
parser.add_argument(
|
97 |
+
"--unsupervised-speech-sample-ratio",
|
98 |
+
default="1",
|
99 |
+
type=str,
|
100 |
+
metavar="N",
|
101 |
+
help="Multiple Ratio for speech dataset without transcripts ",
|
102 |
+
)
|
103 |
+
# t2t with monolingual data (masking)
|
104 |
+
parser.add_argument(
|
105 |
+
"--text-sample-ratio",
|
106 |
+
default="1",
|
107 |
+
type=str,
|
108 |
+
metavar="N",
|
109 |
+
help="Multiple Ratio for text set ",
|
110 |
+
)
|
111 |
+
# t2t with parallel data (no masking)
|
112 |
+
parser.add_argument(
|
113 |
+
"--bitext-sample-ratio",
|
114 |
+
default="1",
|
115 |
+
type=str,
|
116 |
+
metavar="N",
|
117 |
+
help="Multiple Ratio for text set (bitext) ",
|
118 |
+
)
|
119 |
+
# train_subset = "train", 'valid' or so
|
120 |
+
# parallel data is loaded according to string lang_pairs and lang_pairs_no_mask from args.data
|
121 |
+
# (un)supervised speech is loaded from args.(un)sup_speech_{train,valid}_subset
|
122 |
+
parser.add_argument(
|
123 |
+
"--sup-speech-data", default="", help="path to supervised speech data"
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--sup-speech-train-subset",
|
127 |
+
default="",
|
128 |
+
help="supervised speech training subsets",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--sup-speech-valid-subset",
|
132 |
+
default="",
|
133 |
+
help="supervised speech validation subsets",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--config-yaml",
|
137 |
+
default="config.yaml",
|
138 |
+
help="supervised speech configuration yaml file",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--sup-speech-s2s-data", default="", help="path to supervised speech data"
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--sup-speech-s2s-train-subset",
|
145 |
+
default="",
|
146 |
+
help="supervised speech training subsets",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--sup-speech-s2s-valid-subset",
|
150 |
+
default="",
|
151 |
+
help="supervised speech validation subsets",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--config-s2s-yaml",
|
155 |
+
default="config.yaml",
|
156 |
+
help="supervised speech configuration yaml file",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--unsup-speech-train-data",
|
160 |
+
default="",
|
161 |
+
help="path to unsupervised speech training data (tsv)",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--unsup-speech-valid-data",
|
165 |
+
default="",
|
166 |
+
help="path to unsupervised speech valid data (tsv)",
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--sample-rate",
|
170 |
+
type=int,
|
171 |
+
metavar="N",
|
172 |
+
default=16000,
|
173 |
+
help="input audio sampling rate",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--no-emb-update-unsup",
|
177 |
+
default=False,
|
178 |
+
action="store_true",
|
179 |
+
help="no update for output embedding during unsupervised_speech mode",
|
180 |
+
)
|
181 |
+
parser.add_argument("--same-data-update", default=False, action="store_true")
|
182 |
+
|
183 |
+
# used for sup_speech_ali
|
184 |
+
parser.add_argument(
|
185 |
+
"--use-sup-speech-ctc",
|
186 |
+
default=False,
|
187 |
+
action="store_true",
|
188 |
+
help="use speech_sup_ctc instead of speech_sup_ali",
|
189 |
+
)
|
190 |
+
|
191 |
+
@classmethod
|
192 |
+
def setup_task(cls, args, **kwargs):
|
193 |
+
"""Setup the task."""
|
194 |
+
paths = args.data.split(":")
|
195 |
+
assert len(paths) > 0
|
196 |
+
src_dict = Dictionary.load(
|
197 |
+
os.path.join(paths[0], "src_dict.txt")
|
198 |
+
) # assume all languages share a source dictionary
|
199 |
+
tgt_dict = Dictionary.load(
|
200 |
+
os.path.join(paths[0], "tgt_dict.txt")
|
201 |
+
) # assume all languages share a target dictionary
|
202 |
+
|
203 |
+
lang_pairs = args.lang_pairs + "," + args.lang_pairs_bitext
|
204 |
+
lang_pairs = re.sub(",$", "", re.sub("^,", "", lang_pairs))
|
205 |
+
if lang_pairs != "":
|
206 |
+
src_langs = [lp.split("-")[0] for lp in lang_pairs.split(",")]
|
207 |
+
tgt_langs = [lp.split("-")[1] for lp in lang_pairs.split(",")]
|
208 |
+
else:
|
209 |
+
src_langs = []
|
210 |
+
tgt_langs = []
|
211 |
+
|
212 |
+
if args.add_src_lang_token:
|
213 |
+
for lang in src_langs:
|
214 |
+
assert (
|
215 |
+
src_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
|
216 |
+
!= src_dict.unk()
|
217 |
+
)
|
218 |
+
if args.add_tgt_lang_token:
|
219 |
+
for lang in tgt_langs:
|
220 |
+
assert (
|
221 |
+
tgt_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
|
222 |
+
!= tgt_dict.unk()
|
223 |
+
)
|
224 |
+
|
225 |
+
logger.info("source dictionary: {} types".format(len(src_dict)))
|
226 |
+
logger.info("target dictionary: {} types".format(len(tgt_dict)))
|
227 |
+
if not hasattr(args, "shuffle_instance"):
|
228 |
+
args.shuffle_instance = False
|
229 |
+
return cls(args, src_dict, tgt_dict)
|
230 |
+
|
231 |
+
def __init__(self, args, src_dict, tgt_dict):
|
232 |
+
super().__init__(args, src_dict, tgt_dict)
|
233 |
+
self.data_cfg = S2TJointDataConfig(
|
234 |
+
Path(args.sup_speech_data) / args.config_yaml
|
235 |
+
)
|
236 |
+
logger.info(
|
237 |
+
f"load supervised speech data configure from {Path(args.sup_speech_data) / args.config_yaml}"
|
238 |
+
)
|
239 |
+
self.data_s2s_cfg = (
|
240 |
+
S2TJointDataConfig(Path(args.sup_speech_s2s_data) / args.config_s2s_yaml)
|
241 |
+
if args.sup_speech_s2s_train_subset != ""
|
242 |
+
else None
|
243 |
+
)
|
244 |
+
if self.data_s2s_cfg is not None:
|
245 |
+
logger.info(
|
246 |
+
f"load supervised sequece to sequence speech data configure from {Path(args.sup_speech_s2s_data) / args.config_yaml}"
|
247 |
+
)
|
248 |
+
|
249 |
+
def parse_data_ratio(sample_ratio):
|
250 |
+
ratios = sample_ratio.split(",")
|
251 |
+
if len(ratios) == 1:
|
252 |
+
return [float(ratios[0])]
|
253 |
+
epoch_ratios = []
|
254 |
+
for item in ratios:
|
255 |
+
ep, r = item.split(":")
|
256 |
+
ep = int(ep)
|
257 |
+
r = float(r)
|
258 |
+
assert ep > 0 # epoch is 1 based
|
259 |
+
assert ep >= len(epoch_ratios)
|
260 |
+
|
261 |
+
if len(epoch_ratios) == 0:
|
262 |
+
epoch_ratios.append(
|
263 |
+
r
|
264 |
+
) # epoch_ratios[0] is not used, but we still set it to the first value to make thing simple.
|
265 |
+
while len(epoch_ratios) < ep:
|
266 |
+
epoch_ratios.append(epoch_ratios[-1])
|
267 |
+
epoch_ratios.append(r)
|
268 |
+
return epoch_ratios
|
269 |
+
|
270 |
+
self.sup_ratio = parse_data_ratio(args.supervised_speech_sample_ratio)
|
271 |
+
self.sup_s2s_ratio = parse_data_ratio(args.supervised_speech_s2s_sample_ratio)
|
272 |
+
self.text_ratio = parse_data_ratio(args.text_sample_ratio)
|
273 |
+
self.bitext_ratio = parse_data_ratio(args.bitext_sample_ratio)
|
274 |
+
self.unsup_ratio = parse_data_ratio(args.unsupervised_speech_sample_ratio)
|
275 |
+
self.sample_mode = None
|
276 |
+
|
277 |
+
def build_model(self, args):
|
278 |
+
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
|
279 |
+
args.input_channels = self.data_cfg.input_channels
|
280 |
+
return super().build_model(args)
|
281 |
+
|
282 |
+
def build_tokenizer(self, data_cfg, msg=""):
|
283 |
+
logger.info(f"pre-tokenizer {msg}: {data_cfg.pre_tokenizer}")
|
284 |
+
return encoders.build_tokenizer(Namespace(**data_cfg.pre_tokenizer))
|
285 |
+
|
286 |
+
def build_bpe(self, data_cfg, msg=""):
|
287 |
+
logger.info(f"tokenizer {msg}: {data_cfg.bpe_tokenizer}")
|
288 |
+
return encoders.build_bpe(Namespace(**data_cfg.bpe_tokenizer))
|
289 |
+
|
290 |
+
@classmethod
|
291 |
+
def resolve_data_type(cls, split, use_sup_speech_ctc):
|
292 |
+
if len(split.split("_")) == 1:
|
293 |
+
# default case, train or valid
|
294 |
+
is_train = split
|
295 |
+
dtype = "text"
|
296 |
+
else:
|
297 |
+
is_train, dtype = split.split("_", 1)
|
298 |
+
is_train = True if is_train == "train" else False
|
299 |
+
if dtype == "sup_speech":
|
300 |
+
dtype = "sup_speech_ctc" if use_sup_speech_ctc else "sup_speech_ali"
|
301 |
+
assert dtype in (
|
302 |
+
"text",
|
303 |
+
"bitext",
|
304 |
+
"sup_speech_ali",
|
305 |
+
"sup_speech_s2s",
|
306 |
+
"unsup_speech",
|
307 |
+
"sup_speech_ctc",
|
308 |
+
), f"failed resolving {split} (it resulted into: {dtype} ; is_train={is_train})"
|
309 |
+
return is_train, dtype
|
310 |
+
|
311 |
+
def create_modalitydatasetitem(self, dtype, dataset):
|
312 |
+
dsitem = None
|
313 |
+
if dtype in ("text", "bitext"):
|
314 |
+
dsitem = ModalityDatasetItem(
|
315 |
+
dtype,
|
316 |
+
dataset,
|
317 |
+
(self.args.max_source_positions, self.args.max_target_positions),
|
318 |
+
self.args.max_text_tokens,
|
319 |
+
self.args.batch_size,
|
320 |
+
)
|
321 |
+
elif dtype in ("sup_speech_ctc", "sup_speech_ali", "sup_speech_s2s"):
|
322 |
+
dsitem = ModalityDatasetItem(
|
323 |
+
dtype,
|
324 |
+
dataset,
|
325 |
+
(self.args.max_speech_positions, self.args.max_target_positions),
|
326 |
+
self.args.max_speech_tokens,
|
327 |
+
self.args.batch_size,
|
328 |
+
)
|
329 |
+
elif dtype == "unsup_speech":
|
330 |
+
dsitem = ModalityDatasetItem(
|
331 |
+
dtype, dataset, 1e8, self.args.max_speech_tokens, self.args.batch_size
|
332 |
+
)
|
333 |
+
else:
|
334 |
+
raise ValueError(f"{dtype} is not supported")
|
335 |
+
return dsitem
|
336 |
+
|
337 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
338 |
+
def _get_sup_src_tgt_dict(src_dict, tgt_dict, use_s2s_sup_decoder):
|
339 |
+
if use_s2s_sup_decoder:
|
340 |
+
return None, tgt_dict
|
341 |
+
# use src_dict as tgt_dict here, since we use source dictionary as target for forcealignment
|
342 |
+
return None, src_dict
|
343 |
+
|
344 |
+
is_train, dtype = self.resolve_data_type(split, self.args.use_sup_speech_ctc)
|
345 |
+
|
346 |
+
# Note we use --add-tgt-lang-token instead of data_cfg.prepend_tgt_lang_tag_no_change to set target language tag in the text dataset
|
347 |
+
# Verify add_tgt_lang_token and prepend_tgt_lang_tag_no_change are same
|
348 |
+
|
349 |
+
# Note we use --multilang-sampling-alpha instead of data_cfg.sampling_text_alpha to set text data sampling
|
350 |
+
if is_train:
|
351 |
+
msets = []
|
352 |
+
# train split, load everything into one
|
353 |
+
if self.lang_pairs != "":
|
354 |
+
text_dataset = self.load_dataset_only(
|
355 |
+
"train", self.lang_pairs, epoch=epoch, combine=combine
|
356 |
+
)
|
357 |
+
dsitem = self.create_modalitydatasetitem("text", text_dataset)
|
358 |
+
msets.append(dsitem)
|
359 |
+
if self.lang_pairs_bitext != "": # load bitext
|
360 |
+
bitext_dataset = self.load_dataset_only(
|
361 |
+
"train_bitext",
|
362 |
+
self.lang_pairs_bitext,
|
363 |
+
do_mask=False,
|
364 |
+
epoch=epoch,
|
365 |
+
combine=combine,
|
366 |
+
)
|
367 |
+
dsitem = self.create_modalitydatasetitem("bitext", bitext_dataset)
|
368 |
+
msets.append(dsitem)
|
369 |
+
if self.args.sup_speech_train_subset != "":
|
370 |
+
pre_tokenizer = self.build_tokenizer(self.data_cfg)
|
371 |
+
bpe_tokenizer = self.build_bpe(self.data_cfg)
|
372 |
+
|
373 |
+
append_eos = True
|
374 |
+
sup_speech_type = "sup_speech_ali"
|
375 |
+
if self.args.use_sup_speech_ctc:
|
376 |
+
# CTC mode
|
377 |
+
sup_speech_type = "sup_speech_ctc"
|
378 |
+
append_eos = False # CTC doesn't need eos in the target
|
379 |
+
|
380 |
+
src_dict, tgt_dict = _get_sup_src_tgt_dict(
|
381 |
+
self.src_dict, self.tgt_dict, False
|
382 |
+
)
|
383 |
+
sup_speech_dataset = SpeechToTextJointDatasetCreator.from_tsv(
|
384 |
+
self.args.sup_speech_data,
|
385 |
+
self.data_cfg,
|
386 |
+
self.args.sup_speech_train_subset,
|
387 |
+
tgt_dict=tgt_dict,
|
388 |
+
src_dict=src_dict,
|
389 |
+
pre_tokenizer=pre_tokenizer,
|
390 |
+
bpe_tokenizer=bpe_tokenizer,
|
391 |
+
src_pre_tokenizer=None,
|
392 |
+
src_bpe_tokenizer=None,
|
393 |
+
is_train_split=is_train,
|
394 |
+
epoch=epoch,
|
395 |
+
seed=self.args.seed,
|
396 |
+
append_eos=append_eos,
|
397 |
+
)
|
398 |
+
dsitem = self.create_modalitydatasetitem(
|
399 |
+
sup_speech_type, sup_speech_dataset
|
400 |
+
)
|
401 |
+
msets.append(dsitem)
|
402 |
+
|
403 |
+
if self.args.sup_speech_s2s_train_subset != "":
|
404 |
+
pre_tokenizer = self.build_tokenizer(self.data_s2s_cfg, msg="(s2s)")
|
405 |
+
bpe_tokenizer = self.build_bpe(self.data_s2s_cfg, msg="(s2s)")
|
406 |
+
|
407 |
+
# make sure self.data_cfg.prepend_tgt_lang_tag_no_change == self.args.add_tgt_lang_token
|
408 |
+
src_dict, tgt_dict = _get_sup_src_tgt_dict(
|
409 |
+
self.src_dict, self.tgt_dict, True
|
410 |
+
)
|
411 |
+
sup_speech_s2s_dataset = SpeechToTextJointDatasetCreator.from_tsv(
|
412 |
+
self.args.sup_speech_s2s_data,
|
413 |
+
self.data_s2s_cfg,
|
414 |
+
self.args.sup_speech_s2s_train_subset,
|
415 |
+
tgt_dict=tgt_dict,
|
416 |
+
src_dict=src_dict,
|
417 |
+
pre_tokenizer=pre_tokenizer,
|
418 |
+
bpe_tokenizer=bpe_tokenizer,
|
419 |
+
src_pre_tokenizer=None,
|
420 |
+
src_bpe_tokenizer=None,
|
421 |
+
is_train_split=is_train,
|
422 |
+
epoch=epoch,
|
423 |
+
seed=self.args.seed,
|
424 |
+
)
|
425 |
+
dsitem = self.create_modalitydatasetitem(
|
426 |
+
"sup_speech_s2s", sup_speech_s2s_dataset
|
427 |
+
)
|
428 |
+
msets.append(dsitem)
|
429 |
+
if self.args.unsup_speech_train_data != "":
|
430 |
+
unsup_speech_dataset = FileAudioDatasetWrapper(
|
431 |
+
self.args.unsup_speech_train_data,
|
432 |
+
self.args.sample_rate,
|
433 |
+
max_sample_size=self.args.max_sample_size,
|
434 |
+
min_sample_size=self.args.min_sample_size,
|
435 |
+
normalize=False,
|
436 |
+
)
|
437 |
+
dsitem = self.create_modalitydatasetitem(
|
438 |
+
"unsup_speech", unsup_speech_dataset
|
439 |
+
)
|
440 |
+
msets.append(dsitem)
|
441 |
+
|
442 |
+
pre_train_dataset = MultiModalityDataset(msets)
|
443 |
+
self.datasets[split] = pre_train_dataset
|
444 |
+
else: # validation split, load them for each type of data
|
445 |
+
if dtype == "text":
|
446 |
+
text_dataset = self.load_dataset_only(
|
447 |
+
split, self.lang_pairs, epoch=epoch, combine=combine
|
448 |
+
)
|
449 |
+
dsitem = self.create_modalitydatasetitem("text", text_dataset)
|
450 |
+
self.datasets[split] = MultiModalityDataset([dsitem])
|
451 |
+
elif dtype == "bitext":
|
452 |
+
bitext_dataset = self.load_dataset_only(
|
453 |
+
split,
|
454 |
+
self.lang_pairs_bitext,
|
455 |
+
do_mask=False,
|
456 |
+
epoch=epoch,
|
457 |
+
combine=combine,
|
458 |
+
)
|
459 |
+
dsitem = self.create_modalitydatasetitem("bitext", bitext_dataset)
|
460 |
+
self.datasets[split] = MultiModalityDataset([dsitem])
|
461 |
+
|
462 |
+
elif dtype in ("sup_speech_ctc", "sup_speech_ali"):
|
463 |
+
assert self.args.sup_speech_valid_subset != ""
|
464 |
+
pre_tokenizer = self.build_tokenizer(self.data_cfg)
|
465 |
+
bpe_tokenizer = self.build_bpe(self.data_cfg)
|
466 |
+
append_eos = True
|
467 |
+
if dtype == "sup_speech_ctc":
|
468 |
+
# CTC mode
|
469 |
+
append_eos = False # CTC doesn't need eos
|
470 |
+
assert self.args.use_sup_speech_ctc
|
471 |
+
|
472 |
+
datasets = []
|
473 |
+
for split_name in self.args.sup_speech_valid_subset.split(","):
|
474 |
+
src_dict, tgt_dict = _get_sup_src_tgt_dict(
|
475 |
+
self.src_dict, self.tgt_dict, False
|
476 |
+
)
|
477 |
+
datasets.append(
|
478 |
+
SpeechToTextJointDatasetCreator.from_tsv(
|
479 |
+
self.args.sup_speech_data,
|
480 |
+
self.data_cfg,
|
481 |
+
split_name,
|
482 |
+
tgt_dict=tgt_dict,
|
483 |
+
src_dict=src_dict,
|
484 |
+
pre_tokenizer=pre_tokenizer,
|
485 |
+
bpe_tokenizer=bpe_tokenizer,
|
486 |
+
src_pre_tokenizer=None,
|
487 |
+
src_bpe_tokenizer=None,
|
488 |
+
is_train_split=is_train,
|
489 |
+
epoch=epoch,
|
490 |
+
seed=self.args.seed,
|
491 |
+
append_eos=append_eos,
|
492 |
+
)
|
493 |
+
)
|
494 |
+
|
495 |
+
dset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets)
|
496 |
+
dsitem = self.create_modalitydatasetitem(dtype, dset)
|
497 |
+
self.datasets[split] = MultiModalityDataset([dsitem])
|
498 |
+
|
499 |
+
elif dtype == "sup_speech_s2s":
|
500 |
+
assert self.args.sup_speech_s2s_valid_subset != ""
|
501 |
+
pre_tokenizer = self.build_tokenizer(self.data_s2s_cfg)
|
502 |
+
bpe_tokenizer = self.build_bpe(self.data_s2s_cfg)
|
503 |
+
datasets = []
|
504 |
+
for split_name in self.args.sup_speech_s2s_valid_subset.split(","):
|
505 |
+
src_dict, tgt_dict = _get_sup_src_tgt_dict(
|
506 |
+
self.src_dict, self.tgt_dict, True
|
507 |
+
)
|
508 |
+
datasets.append(
|
509 |
+
SpeechToTextJointDatasetCreator.from_tsv(
|
510 |
+
self.args.sup_speech_s2s_data,
|
511 |
+
self.data_s2s_cfg,
|
512 |
+
split_name,
|
513 |
+
tgt_dict=tgt_dict,
|
514 |
+
src_dict=src_dict,
|
515 |
+
pre_tokenizer=pre_tokenizer,
|
516 |
+
bpe_tokenizer=bpe_tokenizer,
|
517 |
+
src_pre_tokenizer=None,
|
518 |
+
src_bpe_tokenizer=None,
|
519 |
+
is_train_split=is_train,
|
520 |
+
epoch=epoch,
|
521 |
+
seed=self.args.seed,
|
522 |
+
)
|
523 |
+
)
|
524 |
+
|
525 |
+
dset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets)
|
526 |
+
dsitem = self.create_modalitydatasetitem("sup_speech_s2s", dset)
|
527 |
+
self.datasets[split] = MultiModalityDataset([dsitem])
|
528 |
+
elif dtype == "unsup_speech":
|
529 |
+
assert self.args.unsup_speech_valid_data != ""
|
530 |
+
unsup_speech_dataset = FileAudioDatasetWrapper(
|
531 |
+
self.args.unsup_speech_valid_data,
|
532 |
+
self.args.sample_rate,
|
533 |
+
max_sample_size=self.args.max_sample_size,
|
534 |
+
min_sample_size=self.args.min_sample_size,
|
535 |
+
normalize=False,
|
536 |
+
)
|
537 |
+
dsitem = self.create_modalitydatasetitem(
|
538 |
+
"unsup_speech", unsup_speech_dataset
|
539 |
+
)
|
540 |
+
self.datasets[split] = MultiModalityDataset([dsitem])
|
541 |
+
else:
|
542 |
+
raise ValueError(f"Unsupported type {dtype}")
|
543 |
+
|
544 |
+
def get_sample_ratio(self, epoch):
|
545 |
+
sup_ratio = (
|
546 |
+
self.sup_ratio[epoch] if len(self.sup_ratio) > epoch else self.sup_ratio[-1]
|
547 |
+
)
|
548 |
+
sup_s2s_ratio = (
|
549 |
+
self.sup_s2s_ratio[epoch]
|
550 |
+
if len(self.sup_s2s_ratio) > epoch
|
551 |
+
else self.sup_s2s_ratio[-1]
|
552 |
+
)
|
553 |
+
unsup_ratio = (
|
554 |
+
self.unsup_ratio[epoch]
|
555 |
+
if len(self.unsup_ratio) > epoch
|
556 |
+
else self.unsup_ratio[-1]
|
557 |
+
)
|
558 |
+
text_ratio = (
|
559 |
+
self.text_ratio[epoch]
|
560 |
+
if len(self.text_ratio) > epoch
|
561 |
+
else self.text_ratio[-1]
|
562 |
+
)
|
563 |
+
bitext_ratio = (
|
564 |
+
self.bitext_ratio[epoch]
|
565 |
+
if len(self.bitext_ratio) > epoch
|
566 |
+
else self.bitext_ratio[-1]
|
567 |
+
)
|
568 |
+
return text_ratio, bitext_ratio, sup_ratio, sup_s2s_ratio, unsup_ratio
|
569 |
+
|
570 |
+
def get_batch_iterator(
|
571 |
+
self,
|
572 |
+
dataset,
|
573 |
+
max_tokens=None,
|
574 |
+
max_sentences=None,
|
575 |
+
max_positions=None,
|
576 |
+
ignore_invalid_inputs=False,
|
577 |
+
required_batch_size_multiple=1,
|
578 |
+
seed=1,
|
579 |
+
num_shards=1,
|
580 |
+
shard_id=0,
|
581 |
+
num_workers=0,
|
582 |
+
epoch=0,
|
583 |
+
data_buffer_size=0,
|
584 |
+
disable_iterator_cache=False,
|
585 |
+
skip_remainder_batch=False,
|
586 |
+
grouped_shuffling=False,
|
587 |
+
update_epoch_batch_itr=False,
|
588 |
+
):
|
589 |
+
|
590 |
+
assert isinstance(dataset, MultiModalityDataset)
|
591 |
+
if len(dataset.id_to_mode) == 1:
|
592 |
+
max_positions = dataset.max_positions[0]
|
593 |
+
max_tokens = dataset.max_tokens[0]
|
594 |
+
max_sentences = dataset.max_sentences[0]
|
595 |
+
return super().get_batch_iterator(
|
596 |
+
dataset,
|
597 |
+
max_tokens,
|
598 |
+
max_sentences,
|
599 |
+
max_positions,
|
600 |
+
ignore_invalid_inputs,
|
601 |
+
required_batch_size_multiple,
|
602 |
+
seed,
|
603 |
+
num_shards,
|
604 |
+
shard_id,
|
605 |
+
num_workers,
|
606 |
+
epoch,
|
607 |
+
data_buffer_size,
|
608 |
+
disable_iterator_cache,
|
609 |
+
skip_remainder_batch=skip_remainder_batch,
|
610 |
+
)
|
611 |
+
|
612 |
+
mult_ratio = []
|
613 |
+
(
|
614 |
+
text_ratio,
|
615 |
+
bitext_ratio,
|
616 |
+
sup_ratio,
|
617 |
+
sup_s2s_ratio,
|
618 |
+
unsup_ratio,
|
619 |
+
) = self.get_sample_ratio(epoch)
|
620 |
+
for mode in dataset.id_to_mode:
|
621 |
+
if mode in ("sup_speech_ctc", "sup_speech_ali"):
|
622 |
+
mult_ratio.append(sup_ratio)
|
623 |
+
elif mode == "sup_speech_s2s":
|
624 |
+
mult_ratio.append(sup_s2s_ratio)
|
625 |
+
elif mode == "text":
|
626 |
+
mult_ratio.append(text_ratio)
|
627 |
+
elif mode == "bitext":
|
628 |
+
mult_ratio.append(bitext_ratio)
|
629 |
+
elif mode == "unsup_speech":
|
630 |
+
mult_ratio.append(unsup_ratio)
|
631 |
+
|
632 |
+
# initialize the dataset with the correct starting epoch
|
633 |
+
dataset.set_epoch(epoch)
|
634 |
+
|
635 |
+
batch_samplers = dataset.get_batch_samplers(
|
636 |
+
mult_ratio, required_batch_size_multiple, seed
|
637 |
+
)
|
638 |
+
|
639 |
+
# return a reusable, sharded iterator
|
640 |
+
epoch_iter = GroupedEpochBatchIterator(
|
641 |
+
dataset=dataset,
|
642 |
+
collate_fn=dataset.collater,
|
643 |
+
batch_samplers=batch_samplers,
|
644 |
+
seed=seed,
|
645 |
+
num_shards=num_shards,
|
646 |
+
shard_id=shard_id,
|
647 |
+
num_workers=num_workers,
|
648 |
+
epoch=epoch,
|
649 |
+
mult_rate=max(self.args.update_freq) if self.args.same_data_update else 1,
|
650 |
+
buffer_size=data_buffer_size,
|
651 |
+
skip_remainder_batch=skip_remainder_batch,
|
652 |
+
)
|
653 |
+
self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
|
654 |
+
return epoch_iter
|
fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
from argparse import Namespace
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from fairseq.data import (
|
12 |
+
encoders,
|
13 |
+
Dictionary,
|
14 |
+
ResamplingDataset,
|
15 |
+
TransformEosLangPairDataset,
|
16 |
+
ConcatDataset,
|
17 |
+
)
|
18 |
+
from fairseq.data.iterators import GroupedEpochBatchIterator
|
19 |
+
from fairseq.data.audio.multi_modality_dataset import (
|
20 |
+
MultiModalityDataset,
|
21 |
+
LangPairMaskDataset,
|
22 |
+
ModalityDatasetItem,
|
23 |
+
)
|
24 |
+
from fairseq.data.audio.speech_to_text_dataset import (
|
25 |
+
SpeechToTextDataset,
|
26 |
+
SpeechToTextDatasetCreator,
|
27 |
+
)
|
28 |
+
from fairseq.data.audio.speech_to_text_joint_dataset import (
|
29 |
+
S2TJointDataConfig,
|
30 |
+
SpeechToTextJointDatasetCreator,
|
31 |
+
)
|
32 |
+
from fairseq.tasks import register_task
|
33 |
+
from fairseq.tasks.speech_to_text import SpeechToTextTask
|
34 |
+
from fairseq.tasks.translation import load_langpair_dataset
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
LANG_TAG_TEMPLATE = "<lang:{}>"
|
38 |
+
|
39 |
+
|
40 |
+
@register_task("speech_text_joint_to_text")
|
41 |
+
class SpeechTextJointToTextTask(SpeechToTextTask):
|
42 |
+
"""
|
43 |
+
Task for joint training speech and text to text.
|
44 |
+
"""
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def add_args(cls, parser):
|
48 |
+
"""Add task-specific arguments to the parser."""
|
49 |
+
super(SpeechTextJointToTextTask, cls).add_args(parser)
|
50 |
+
###
|
51 |
+
parser.add_argument(
|
52 |
+
"--parallel-text-data",
|
53 |
+
default="",
|
54 |
+
help="path to parallel text data directory",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--max-tokens-text",
|
58 |
+
type=int,
|
59 |
+
metavar="N",
|
60 |
+
help="maximum tokens for encoder text input ",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--max-positions-text",
|
64 |
+
type=int,
|
65 |
+
metavar="N",
|
66 |
+
default=400,
|
67 |
+
help="maximum tokens for per encoder text input ",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--langpairs",
|
71 |
+
default=None,
|
72 |
+
metavar="S",
|
73 |
+
help='language pairs for text training, separated with ","',
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--speech-sample-ratio",
|
77 |
+
default=1,
|
78 |
+
type=float,
|
79 |
+
metavar="N",
|
80 |
+
help="Multiple Ratio for speech dataset with transcripts ",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--text-sample-ratio",
|
84 |
+
default=1,
|
85 |
+
type=float,
|
86 |
+
metavar="N",
|
87 |
+
help="Multiple Ratio for text set ",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--update-mix-data",
|
91 |
+
action="store_true",
|
92 |
+
help="use mixed data in one update when update-freq > 1",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--load-speech-only", action="store_true", help="load speech data only",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--mask-text-ratio",
|
99 |
+
type=float,
|
100 |
+
metavar="V",
|
101 |
+
default=0.0,
|
102 |
+
help="mask V source tokens for text only mode",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--mask-text-type",
|
106 |
+
default="random",
|
107 |
+
choices=["random", "tail"],
|
108 |
+
help="mask text typed",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--noise-token",
|
112 |
+
default="",
|
113 |
+
help="noise token for masking src text tokens if mask-text-ratio > 0",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--infer-target-lang",
|
117 |
+
default="",
|
118 |
+
metavar="S",
|
119 |
+
help="target language for inference",
|
120 |
+
)
|
121 |
+
|
122 |
+
def __init__(self, args, src_dict, tgt_dict, infer_tgt_lang_id=None):
|
123 |
+
super().__init__(args, tgt_dict)
|
124 |
+
self.src_dict = src_dict
|
125 |
+
self.data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml)
|
126 |
+
assert self.tgt_dict.pad() == self.src_dict.pad()
|
127 |
+
assert self.tgt_dict.eos() == self.src_dict.eos()
|
128 |
+
self.speech_only = args.load_speech_only
|
129 |
+
self._infer_tgt_lang_id = infer_tgt_lang_id
|
130 |
+
|
131 |
+
@classmethod
|
132 |
+
def setup_task(cls, args, **kwargs):
|
133 |
+
"""Setup the task (e.g., load dictionaries)."""
|
134 |
+
data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml)
|
135 |
+
tgt_dict_path = Path(args.data) / data_cfg.vocab_filename
|
136 |
+
src_dict_path = Path(args.data) / data_cfg.src_vocab_filename
|
137 |
+
if (not os.path.isfile(src_dict_path)) or (not os.path.isfile(tgt_dict_path)):
|
138 |
+
raise FileNotFoundError("Dict not found: {}".format(args.data))
|
139 |
+
src_dict = Dictionary.load(src_dict_path.as_posix())
|
140 |
+
tgt_dict = Dictionary.load(tgt_dict_path.as_posix())
|
141 |
+
|
142 |
+
print("| src dictionary: {} types".format(len(src_dict)))
|
143 |
+
print("| tgt dictionary: {} types".format(len(tgt_dict)))
|
144 |
+
|
145 |
+
if args.parallel_text_data != "":
|
146 |
+
if not os.path.isabs(args.parallel_text_data):
|
147 |
+
args.parallel_text_data = os.path.join(
|
148 |
+
args.data, args.parallel_text_data
|
149 |
+
)
|
150 |
+
|
151 |
+
if args.langpairs is None:
|
152 |
+
raise Exception(
|
153 |
+
"Could not infer language pair, please provide it explicitly"
|
154 |
+
)
|
155 |
+
infer_tgt_lang_id = None
|
156 |
+
if args.infer_target_lang != "" and data_cfg.prepend_tgt_lang_tag_no_change:
|
157 |
+
tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format(
|
158 |
+
args.infer_target_lang
|
159 |
+
)
|
160 |
+
infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag)
|
161 |
+
assert infer_tgt_lang_id != tgt_dict.unk()
|
162 |
+
return cls(args, src_dict, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id)
|
163 |
+
|
164 |
+
def load_langpair_dataset(
|
165 |
+
self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0
|
166 |
+
):
|
167 |
+
lang_pairs = []
|
168 |
+
text_dataset = None
|
169 |
+
split = "train"
|
170 |
+
for lp in self.args.langpairs.split(","):
|
171 |
+
src, tgt = lp.split("-")
|
172 |
+
text_dataset = load_langpair_dataset(
|
173 |
+
self.args.parallel_text_data,
|
174 |
+
split,
|
175 |
+
src,
|
176 |
+
self.src_dict,
|
177 |
+
tgt,
|
178 |
+
self.tgt_dict,
|
179 |
+
combine=True,
|
180 |
+
dataset_impl=None,
|
181 |
+
upsample_primary=1,
|
182 |
+
left_pad_source=False,
|
183 |
+
left_pad_target=False,
|
184 |
+
max_source_positions=self.args.max_positions_text,
|
185 |
+
max_target_positions=self.args.max_target_positions,
|
186 |
+
load_alignments=False,
|
187 |
+
truncate_source=False,
|
188 |
+
)
|
189 |
+
if prepend_tgt_lang_tag:
|
190 |
+
# TODO
|
191 |
+
text_dataset = TransformEosLangPairDataset(
|
192 |
+
text_dataset,
|
193 |
+
src_eos=self.src_dict.eos(),
|
194 |
+
tgt_bos=self.tgt_dict.eos(), # 'prev_output_tokens' starts with eos
|
195 |
+
new_tgt_bos=self.tgt_dict.index(LANG_TAG_TEMPLATE.format(tgt)),
|
196 |
+
)
|
197 |
+
lang_pairs.append(text_dataset)
|
198 |
+
if len(lang_pairs) > 1:
|
199 |
+
if sampling_alpha != 1.0:
|
200 |
+
size_ratios = SpeechToTextDatasetCreator.get_size_ratios(
|
201 |
+
self.args.langpairs.split(","),
|
202 |
+
[len(s) for s in lang_pairs],
|
203 |
+
alpha=sampling_alpha,
|
204 |
+
)
|
205 |
+
lang_pairs = [
|
206 |
+
ResamplingDataset(d, size_ratio=r, epoch=epoch, replace=(r >= 1.0))
|
207 |
+
for d, r in zip(lang_pairs, size_ratios)
|
208 |
+
]
|
209 |
+
return ConcatDataset(lang_pairs)
|
210 |
+
return text_dataset
|
211 |
+
|
212 |
+
def inference_step(
|
213 |
+
self, generator, models, sample, prefix_tokens=None, constraints=None
|
214 |
+
):
|
215 |
+
with torch.no_grad():
|
216 |
+
return generator.generate(
|
217 |
+
models,
|
218 |
+
sample,
|
219 |
+
prefix_tokens=prefix_tokens,
|
220 |
+
constraints=constraints,
|
221 |
+
bos_token=self._infer_tgt_lang_id,
|
222 |
+
)
|
223 |
+
|
224 |
+
def build_src_tokenizer(self, args):
|
225 |
+
logger.info(f"src-pre-tokenizer: {self.data_cfg.src_pre_tokenizer}")
|
226 |
+
return encoders.build_tokenizer(Namespace(**self.data_cfg.src_pre_tokenizer))
|
227 |
+
|
228 |
+
def build_src_bpe(self, args):
|
229 |
+
logger.info(f"tokenizer: {self.data_cfg.src_bpe_tokenizer}")
|
230 |
+
return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer))
|
231 |
+
|
232 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
233 |
+
"""Load a given dataset split.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
split (str): name of the split (e.g., train, valid, test)
|
237 |
+
"""
|
238 |
+
is_train_split = split.startswith("train")
|
239 |
+
pre_tokenizer = self.build_tokenizer(self.args)
|
240 |
+
bpe_tokenizer = self.build_bpe(self.args)
|
241 |
+
src_pre_tokenizer = self.build_src_tokenizer(self.args)
|
242 |
+
src_bpe_tokenizer = self.build_src_bpe(self.args)
|
243 |
+
ast_dataset = SpeechToTextJointDatasetCreator.from_tsv(
|
244 |
+
self.args.data,
|
245 |
+
self.data_cfg,
|
246 |
+
split,
|
247 |
+
self.tgt_dict,
|
248 |
+
src_dict=None if self.speech_only else self.src_dict,
|
249 |
+
pre_tokenizer=pre_tokenizer,
|
250 |
+
bpe_tokenizer=bpe_tokenizer,
|
251 |
+
src_pre_tokenizer=src_pre_tokenizer,
|
252 |
+
src_bpe_tokenizer=src_bpe_tokenizer,
|
253 |
+
is_train_split=is_train_split,
|
254 |
+
epoch=epoch,
|
255 |
+
seed=self.args.seed,
|
256 |
+
)
|
257 |
+
noise_token_id = -1
|
258 |
+
text_dataset = None
|
259 |
+
if self.args.parallel_text_data != "" and is_train_split:
|
260 |
+
text_dataset = self.load_langpair_dataset(
|
261 |
+
self.data_cfg.prepend_tgt_lang_tag_no_change, 1.0, epoch=epoch,
|
262 |
+
)
|
263 |
+
if self.args.mask_text_ratio > 0:
|
264 |
+
# add mask
|
265 |
+
noise_token_id = (
|
266 |
+
self.src_dict.unk()
|
267 |
+
if self.args.noise_token == ""
|
268 |
+
else self.src_dict.index(self.args.noise_token)
|
269 |
+
)
|
270 |
+
text_dataset = LangPairMaskDataset(
|
271 |
+
text_dataset,
|
272 |
+
src_bos=self.src_dict.bos(),
|
273 |
+
src_eos=self.src_dict.eos(),
|
274 |
+
noise_id=noise_token_id,
|
275 |
+
mask_ratio=self.args.mask_text_ratio,
|
276 |
+
mask_type=self.args.mask_text_type,
|
277 |
+
)
|
278 |
+
|
279 |
+
if text_dataset is not None:
|
280 |
+
mdsets = [
|
281 |
+
ModalityDatasetItem(
|
282 |
+
"sup_speech",
|
283 |
+
ast_dataset,
|
284 |
+
(self.args.max_source_positions, self.args.max_target_positions),
|
285 |
+
self.args.max_tokens,
|
286 |
+
self.args.batch_size,
|
287 |
+
),
|
288 |
+
ModalityDatasetItem(
|
289 |
+
"text",
|
290 |
+
text_dataset,
|
291 |
+
(self.args.max_positions_text, self.args.max_target_positions),
|
292 |
+
self.args.max_tokens_text
|
293 |
+
if self.args.max_tokens_text is not None
|
294 |
+
else self.args.max_tokens,
|
295 |
+
self.args.batch_size,
|
296 |
+
),
|
297 |
+
]
|
298 |
+
ast_dataset = MultiModalityDataset(mdsets)
|
299 |
+
self.datasets[split] = ast_dataset
|
300 |
+
|
301 |
+
@property
|
302 |
+
def target_dictionary(self):
|
303 |
+
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
304 |
+
model."""
|
305 |
+
return self.tgt_dict
|
306 |
+
|
307 |
+
@property
|
308 |
+
def source_dictionary(self):
|
309 |
+
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
|
310 |
+
for this task)."""
|
311 |
+
return None if self.speech_only else self.src_dict
|
312 |
+
|
313 |
+
def get_batch_iterator(
|
314 |
+
self,
|
315 |
+
dataset,
|
316 |
+
max_tokens=None,
|
317 |
+
max_sentences=None,
|
318 |
+
max_positions=None,
|
319 |
+
ignore_invalid_inputs=False,
|
320 |
+
required_batch_size_multiple=1,
|
321 |
+
seed=1,
|
322 |
+
num_shards=1,
|
323 |
+
shard_id=0,
|
324 |
+
num_workers=0,
|
325 |
+
epoch=0,
|
326 |
+
data_buffer_size=0,
|
327 |
+
disable_iterator_cache=False,
|
328 |
+
skip_remainder_batch=False,
|
329 |
+
grouped_shuffling=False,
|
330 |
+
update_epoch_batch_itr=False,
|
331 |
+
):
|
332 |
+
|
333 |
+
if not isinstance(dataset, MultiModalityDataset):
|
334 |
+
return super(SpeechTextJointToTextTask, self).get_batch_iterator(
|
335 |
+
dataset,
|
336 |
+
max_tokens,
|
337 |
+
max_sentences,
|
338 |
+
max_positions,
|
339 |
+
ignore_invalid_inputs,
|
340 |
+
required_batch_size_multiple,
|
341 |
+
seed,
|
342 |
+
num_shards,
|
343 |
+
shard_id,
|
344 |
+
num_workers,
|
345 |
+
epoch,
|
346 |
+
data_buffer_size,
|
347 |
+
disable_iterator_cache,
|
348 |
+
skip_remainder_batch=skip_remainder_batch,
|
349 |
+
update_epoch_batch_itr=update_epoch_batch_itr,
|
350 |
+
)
|
351 |
+
|
352 |
+
mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio]
|
353 |
+
assert len(dataset.datasets) == 2
|
354 |
+
|
355 |
+
# initialize the dataset with the correct starting epoch
|
356 |
+
dataset.set_epoch(epoch)
|
357 |
+
|
358 |
+
batch_samplers = dataset.get_batch_samplers(
|
359 |
+
mult_ratio, required_batch_size_multiple, seed
|
360 |
+
)
|
361 |
+
|
362 |
+
# return a reusable, sharded iterator
|
363 |
+
epoch_iter = GroupedEpochBatchIterator(
|
364 |
+
dataset=dataset,
|
365 |
+
collate_fn=dataset.collater,
|
366 |
+
batch_samplers=batch_samplers,
|
367 |
+
seed=seed,
|
368 |
+
num_shards=num_shards,
|
369 |
+
shard_id=shard_id,
|
370 |
+
num_workers=num_workers,
|
371 |
+
epoch=epoch,
|
372 |
+
mult_rate=1 if self.args.update_mix_data else max(self.args.update_freq),
|
373 |
+
buffer_size=data_buffer_size,
|
374 |
+
skip_remainder_batch=skip_remainder_batch,
|
375 |
+
)
|
376 |
+
self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
|
377 |
+
return epoch_iter
|
fairseq/examples/speech_to_speech/README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Speech to speech translation (S2ST)
|
2 |
+
|
3 |
+
We provide the implementation and resources for the following work on speech-to-speech translation (S2ST):
|
4 |
+
|
5 |
+
* [Direct speech-to-speech translation with discrete units (Lee et al. 2021)](docs/direct_s2st_discrete_units.md)
|
6 |
+
* [Textless Speech-to-Speech Translation on Real Data (Lee et al. 2021)](docs/textless_s2st_real_data.md)
|
7 |
+
* [Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation](docs/enhanced_direct_s2st_discrete_units.md)
|
fairseq/examples/speech_to_speech/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import unity # noqa
|
fairseq/examples/speech_to_speech/asr_bleu/README.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ASR-BLEU evaluation toolkit
|
2 |
+
|
3 |
+
This toolkit provides a set of public ASR models used for evaluation of different speech-to-speech translation systems at FAIR. It enables easier score comparisons between different system's outputs.
|
4 |
+
|
5 |
+
The ASRGenerator wraps different CTC-based ASR models from HuggingFace and fairseq code bases. Torchaudio CTC decoder is built on top of it to decode given audio files.
|
6 |
+
|
7 |
+
Please see `asr_model_cfgs.json` for a list of languages covered currently.
|
8 |
+
|
9 |
+
The high-level pipeline is simple by design: given a lang tag, script loads the ASR model, transcribes model's predicted audio, and computes the BLEU score against provided reference translations using sacrebleu.
|
10 |
+
|
11 |
+
# Dependencies
|
12 |
+
|
13 |
+
Please see `requirements.txt`.
|
14 |
+
|
15 |
+
# Usage examples
|
16 |
+
|
17 |
+
This toolkit have been used with:
|
18 |
+
|
19 |
+
* Speechmatrix project: https://github.com/facebookresearch/fairseq/tree/ust/examples/speech_matrix.
|
20 |
+
|
21 |
+
* Hokkien speech-to-speech translation project: https://github.com/facebookresearch/fairseq/tree/ust/examples/hokkien.
|
22 |
+
|
23 |
+
# Standalone run example
|
24 |
+
|
25 |
+
High-level example, please substitute arguments per your case:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
python compute_asr_bleu.py --lang <LANG> \
|
29 |
+
--audio_dirpath <PATH_TO_AUDIO_DIR> \
|
30 |
+
--reference_path <PATH_TO_REFERENCES_FILE> \
|
31 |
+
--reference_format txt
|
32 |
+
```
|
33 |
+
|
34 |
+
For more details about arguments please see the script argparser help.
|
fairseq/examples/speech_to_speech/asr_bleu/__init__.py
ADDED
File without changes
|
fairseq/examples/speech_to_speech/asr_bleu/asr_model_cfgs.json
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"en": {
|
3 |
+
"oct22": {
|
4 |
+
"desc": "Wav2Vec 2.0 Large (LV-60) + Self Training from https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec#pre-trained-models",
|
5 |
+
"ckpt_path": "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt",
|
6 |
+
"dict_path": "https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt",
|
7 |
+
"model_type": "fairseq",
|
8 |
+
"lang": "en",
|
9 |
+
"post_process": "collapse"
|
10 |
+
}
|
11 |
+
},
|
12 |
+
"hok": {
|
13 |
+
"oct22": {
|
14 |
+
"desc": "Hokkien ASR model, for details check [TODO add paper link]",
|
15 |
+
"ckpt_path": "https://dl.fbaipublicfiles.com/ust_asr/hok/checkpoint_best.pt",
|
16 |
+
"dict_path": "https://dl.fbaipublicfiles.com/ust_asr/hok/dict.ltr.txt",
|
17 |
+
"model_type": "fairseq",
|
18 |
+
"lang": "hok",
|
19 |
+
"post_process": "none"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"es": {
|
23 |
+
"oct22": {
|
24 |
+
"model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
|
25 |
+
"model_type": "hf",
|
26 |
+
"lang": "es",
|
27 |
+
"post_process": "collapse"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"fr": {
|
31 |
+
"oct22": {
|
32 |
+
"model_path": "jonatasgrosman/wav2vec2-large-fr-voxpopuli-french",
|
33 |
+
"model_type": "hf",
|
34 |
+
"lang": "fr",
|
35 |
+
"post_process": "collapse"
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"zh": {
|
39 |
+
"oct22": {
|
40 |
+
"model_path": "ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt",
|
41 |
+
"model_type": "hf",
|
42 |
+
"lang": "zh",
|
43 |
+
"post_process": "collapse"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"tr": {
|
47 |
+
"oct22": {
|
48 |
+
"model_path": "cahya/wav2vec2-large-xlsr-turkish-artificial-cv",
|
49 |
+
"model_type": "hf",
|
50 |
+
"lang": "tr",
|
51 |
+
"post_process": "collapse"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"ar": {
|
55 |
+
"oct22": {
|
56 |
+
"model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
57 |
+
"model_type": "hf",
|
58 |
+
"lang": "ar",
|
59 |
+
"post_process": "collapse"
|
60 |
+
}
|
61 |
+
},
|
62 |
+
"vi": {
|
63 |
+
"oct22": {
|
64 |
+
"model_path": "not-tanh/wav2vec2-large-xlsr-53-vietnamese",
|
65 |
+
"model_type": "hf",
|
66 |
+
"lang": "vi",
|
67 |
+
"post_process": "collapse"
|
68 |
+
}
|
69 |
+
},
|
70 |
+
"de": {
|
71 |
+
"oct22": {
|
72 |
+
"model_path": "jonatasgrosman/wav2vec2-xls-r-1b-german",
|
73 |
+
"model_type": "hf",
|
74 |
+
"lang": "de",
|
75 |
+
"post_process": "collapse"
|
76 |
+
}
|
77 |
+
},
|
78 |
+
"pl": {
|
79 |
+
"oct22": {
|
80 |
+
"model_path": "jonatasgrosman/wav2vec2-xls-r-1b-polish",
|
81 |
+
"model_type": "hf",
|
82 |
+
"lang": "pl",
|
83 |
+
"post_process": "collapse"
|
84 |
+
}
|
85 |
+
},
|
86 |
+
"it": {
|
87 |
+
"oct22": {
|
88 |
+
"model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-italian",
|
89 |
+
"model_type": "hf",
|
90 |
+
"lang": "it",
|
91 |
+
"post_process": "collapse"
|
92 |
+
}
|
93 |
+
},
|
94 |
+
"pt": {
|
95 |
+
"oct22": {
|
96 |
+
"model_path": "jonatasgrosman/wav2vec2-xls-r-1b-portuguese",
|
97 |
+
"model_type": "hf",
|
98 |
+
"lang": "pt",
|
99 |
+
"post_process": "collapse"
|
100 |
+
}
|
101 |
+
},
|
102 |
+
"ro": {
|
103 |
+
"oct22": {
|
104 |
+
"model_path": "gigant/romanian-wav2vec2",
|
105 |
+
"model_type": "hf",
|
106 |
+
"lang": "ro",
|
107 |
+
"post_process": "collapse"
|
108 |
+
}
|
109 |
+
},
|
110 |
+
"cs": {
|
111 |
+
"oct22": {
|
112 |
+
"model_path": "comodoro/wav2vec2-xls-r-300m-cs-250",
|
113 |
+
"model_type": "hf",
|
114 |
+
"lang": "cs",
|
115 |
+
"post_process": "collapse"
|
116 |
+
}
|
117 |
+
},
|
118 |
+
"sk": {
|
119 |
+
"oct22": {
|
120 |
+
"model_path": "anuragshas/wav2vec2-xls-r-300m-sk-cv8-with-lm",
|
121 |
+
"model_type": "hf",
|
122 |
+
"lang": "sk",
|
123 |
+
"post_process": "collapse"
|
124 |
+
}
|
125 |
+
},
|
126 |
+
"sl": {
|
127 |
+
"oct22": {
|
128 |
+
"model_path": "anuragshas/wav2vec2-xls-r-300m-sl-cv8-with-lm",
|
129 |
+
"model_type": "hf",
|
130 |
+
"lang": "sl",
|
131 |
+
"post_process": "collapse"
|
132 |
+
}
|
133 |
+
},
|
134 |
+
"fi": {
|
135 |
+
"oct22": {
|
136 |
+
"model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
|
137 |
+
"model_type": "hf",
|
138 |
+
"lang": "fi",
|
139 |
+
"post_process": "collapse"
|
140 |
+
}
|
141 |
+
},
|
142 |
+
"hu": {
|
143 |
+
"oct22": {
|
144 |
+
"model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
145 |
+
"model_type": "hf",
|
146 |
+
"lang": "hu",
|
147 |
+
"post_process": "collapse"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"et": {
|
151 |
+
"oct22": {
|
152 |
+
"model_path": "RASMUS/wav2vec2-xlsr-1b-et",
|
153 |
+
"model_type": "hf",
|
154 |
+
"lang": "et",
|
155 |
+
"post_process": "collapse"
|
156 |
+
}
|
157 |
+
},
|
158 |
+
"lt": {
|
159 |
+
"oct22": {
|
160 |
+
"model_path": "sammy786/wav2vec2-xlsr-lithuanian",
|
161 |
+
"model_type": "hf",
|
162 |
+
"lang": "lt",
|
163 |
+
"post_process": "collapse"
|
164 |
+
}
|
165 |
+
},
|
166 |
+
"nl": {
|
167 |
+
"oct22": {
|
168 |
+
"model_path": "jonatasgrosman/wav2vec2-xls-r-1b-dutch",
|
169 |
+
"model_type": "hf",
|
170 |
+
"lang": "nl",
|
171 |
+
"post_process": "collapse"
|
172 |
+
}
|
173 |
+
},
|
174 |
+
"lv": {
|
175 |
+
"oct22": {
|
176 |
+
"model_path": "reach-vb/wav2vec2-large-xls-r-1B-common_voice7-lv-ft",
|
177 |
+
"model_type": "hf",
|
178 |
+
"lang": "lv",
|
179 |
+
"post_process": "collapse"
|
180 |
+
}
|
181 |
+
},
|
182 |
+
"sv": {
|
183 |
+
"oct22": {
|
184 |
+
"model_path": "marinone94/xls-r-300m-sv-robust",
|
185 |
+
"model_type": "hf",
|
186 |
+
"lang": "sv",
|
187 |
+
"post_process": "collapse"
|
188 |
+
}
|
189 |
+
},
|
190 |
+
"hr": {
|
191 |
+
"oct22": {
|
192 |
+
"model_path": "classla/wav2vec2-xls-r-parlaspeech-hr",
|
193 |
+
"model_type": "hf",
|
194 |
+
"lang": "hr",
|
195 |
+
"post_process": "collapse"
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
fairseq/examples/speech_to_speech/asr_bleu/compute_asr_bleu.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, List
|
3 |
+
import sacrebleu
|
4 |
+
import pandas as pd
|
5 |
+
from glob import glob
|
6 |
+
from pathlib import Path
|
7 |
+
from utils import retrieve_asr_config, ASRGenerator
|
8 |
+
from tqdm import tqdm
|
9 |
+
from argparse import ArgumentParser
|
10 |
+
|
11 |
+
|
12 |
+
def merge_tailo_init_final(text):
|
13 |
+
"""
|
14 |
+
Hokkien ASR hypothesis post-processing.
|
15 |
+
"""
|
16 |
+
sps = text.strip().split()
|
17 |
+
results = []
|
18 |
+
last_syllable = ""
|
19 |
+
for sp in sps:
|
20 |
+
if sp == "NULLINIT" or sp == "nullinit":
|
21 |
+
continue
|
22 |
+
last_syllable += sp
|
23 |
+
if sp[-1].isnumeric():
|
24 |
+
results.append(last_syllable)
|
25 |
+
last_syllable = ""
|
26 |
+
if last_syllable != "":
|
27 |
+
results.append(last_syllable)
|
28 |
+
return " ".join(results)
|
29 |
+
|
30 |
+
|
31 |
+
def remove_tone(text):
|
32 |
+
"""
|
33 |
+
Used for tone-less evaluation of Hokkien
|
34 |
+
"""
|
35 |
+
return " ".join([t[:-1] for t in text.split()])
|
36 |
+
|
37 |
+
|
38 |
+
def extract_audio_for_eval(audio_dirpath: str, audio_format: str):
|
39 |
+
if audio_format == "n_pred.wav":
|
40 |
+
"""
|
41 |
+
The assumption here is that 0_pred.wav corresponds to the reference at line position 0 from the reference manifest
|
42 |
+
"""
|
43 |
+
audio_list = []
|
44 |
+
audio_fp_list = glob((Path(audio_dirpath) / "*_pred.wav").as_posix())
|
45 |
+
audio_fp_list = sorted(
|
46 |
+
audio_fp_list, key=lambda x: int(os.path.basename(x).split("_")[0])
|
47 |
+
)
|
48 |
+
for i in range(len(audio_fp_list)):
|
49 |
+
try:
|
50 |
+
audio_fp = (Path(audio_dirpath) / f"{i}_pred.wav").as_posix()
|
51 |
+
assert (
|
52 |
+
audio_fp in audio_fp_list
|
53 |
+
), f"{Path(audio_fp).name} does not exist in {audio_dirpath}"
|
54 |
+
except AssertionError:
|
55 |
+
# check the audio with random speaker
|
56 |
+
audio_fp = Path(audio_dirpath) / f"{i}_spk*_pred.wav"
|
57 |
+
audio_fp = glob(
|
58 |
+
audio_fp.as_posix()
|
59 |
+
) # resolve audio filepath with random speaker
|
60 |
+
assert len(audio_fp) == 1
|
61 |
+
audio_fp = audio_fp[0]
|
62 |
+
|
63 |
+
audio_list.append(audio_fp)
|
64 |
+
else:
|
65 |
+
raise NotImplementedError
|
66 |
+
|
67 |
+
return audio_list
|
68 |
+
|
69 |
+
|
70 |
+
def extract_text_for_eval(
|
71 |
+
references_filepath: str, reference_format: str, reference_tsv_column: str = None
|
72 |
+
):
|
73 |
+
if reference_format == "txt":
|
74 |
+
reference_sentences = open(references_filepath, "r").readlines()
|
75 |
+
reference_sentences = [l.strip() for l in reference_sentences]
|
76 |
+
elif reference_format == "tsv":
|
77 |
+
tsv_df = pd.read_csv(references_filepath, sep="\t", quoting=3)
|
78 |
+
reference_sentences = tsv_df[reference_tsv_column].to_list()
|
79 |
+
reference_sentences = [l.strip() for l in reference_sentences]
|
80 |
+
else:
|
81 |
+
raise NotImplementedError
|
82 |
+
|
83 |
+
return reference_sentences
|
84 |
+
|
85 |
+
|
86 |
+
def compose_eval_data(
|
87 |
+
audio_dirpath: str,
|
88 |
+
audio_format: str,
|
89 |
+
references_filepath: str,
|
90 |
+
reference_format: str,
|
91 |
+
reference_tsv_column: str = None,
|
92 |
+
save_manifest_filepath=None,
|
93 |
+
):
|
94 |
+
"""
|
95 |
+
Speech matrix decoding pipeline produces audio with the following mask "N_pred.wav" where N is the order of the corresponding input sample
|
96 |
+
"""
|
97 |
+
|
98 |
+
reference_sentences = extract_text_for_eval(
|
99 |
+
references_filepath, reference_format, reference_tsv_column
|
100 |
+
)
|
101 |
+
predicted_audio_fp_list = extract_audio_for_eval(audio_dirpath, audio_format)
|
102 |
+
assert len(predicted_audio_fp_list) == len(reference_sentences)
|
103 |
+
|
104 |
+
audio_text_pairs = [
|
105 |
+
(audio, reference)
|
106 |
+
for audio, reference in zip(predicted_audio_fp_list, reference_sentences)
|
107 |
+
]
|
108 |
+
|
109 |
+
tsv_manifest = pd.DataFrame(audio_text_pairs, columns=["prediction", "reference"])
|
110 |
+
|
111 |
+
if save_manifest_filepath is not None:
|
112 |
+
tsv_manifest.to_csv(save_manifest_filepath, sep="\t", quoting=3)
|
113 |
+
|
114 |
+
return tsv_manifest
|
115 |
+
|
116 |
+
|
117 |
+
def load_eval_data_from_tsv(eval_data_filepath: str):
|
118 |
+
"""
|
119 |
+
We may load the result of `compose_eval_data` directly if needed
|
120 |
+
"""
|
121 |
+
eval_df = pd.from_csv(eval_data_filepath, sep="\t")
|
122 |
+
|
123 |
+
return eval_df
|
124 |
+
|
125 |
+
|
126 |
+
def run_asr_bleu(args):
|
127 |
+
|
128 |
+
asr_config = retrieve_asr_config(
|
129 |
+
args.lang, args.asr_version, json_path="./asr_model_cfgs.json"
|
130 |
+
)
|
131 |
+
asr_model = ASRGenerator(asr_config)
|
132 |
+
|
133 |
+
eval_manifest = compose_eval_data(
|
134 |
+
audio_dirpath=args.audio_dirpath,
|
135 |
+
audio_format=args.audio_format,
|
136 |
+
references_filepath=args.reference_path,
|
137 |
+
reference_format=args.reference_format,
|
138 |
+
reference_tsv_column=args.reference_tsv_column,
|
139 |
+
save_manifest_filepath=None,
|
140 |
+
)
|
141 |
+
|
142 |
+
prediction_transcripts = []
|
143 |
+
for _, eval_pair in tqdm(
|
144 |
+
eval_manifest.iterrows(),
|
145 |
+
desc="Transcribing predictions",
|
146 |
+
total=len(eval_manifest),
|
147 |
+
):
|
148 |
+
transcription = asr_model.transcribe_audiofile(eval_pair.prediction)
|
149 |
+
prediction_transcripts.append(transcription.lower())
|
150 |
+
|
151 |
+
if args.lang == "hok":
|
152 |
+
prediction_transcripts = [
|
153 |
+
merge_tailo_init_final(text) for text in prediction_transcripts
|
154 |
+
]
|
155 |
+
|
156 |
+
references = eval_manifest["reference"].tolist()
|
157 |
+
bleu_score = sacrebleu.corpus_bleu(prediction_transcripts, [references])
|
158 |
+
|
159 |
+
print(bleu_score)
|
160 |
+
|
161 |
+
return prediction_transcripts, bleu_score
|
162 |
+
|
163 |
+
|
164 |
+
def main():
|
165 |
+
parser = ArgumentParser(
|
166 |
+
description="This script computes the ASR-BLEU metric between model's generated audio and the text reference sequences."
|
167 |
+
)
|
168 |
+
|
169 |
+
parser.add_argument(
|
170 |
+
"--lang",
|
171 |
+
help="The target language used to initialize ASR model, see asr_model_cfgs.json for available languages",
|
172 |
+
type=str,
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--asr_version",
|
176 |
+
type=str,
|
177 |
+
default="oct22",
|
178 |
+
help="For future support we add and extra layer of asr versions. The current most recent version is oct22 meaning October 2022",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--audio_dirpath",
|
182 |
+
type=str,
|
183 |
+
help="Path to the directory containing the audio predictions from the translation model",
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--reference_path",
|
187 |
+
type=str,
|
188 |
+
help="Path to the file containing reference translations in the form of normalized text (to be compared to ASR predictions",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--reference_format",
|
192 |
+
choices=["txt", "tsv"],
|
193 |
+
help="Format of reference file. Txt means plain text format where each line represents single reference sequence",
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--reference_tsv_column",
|
197 |
+
default=None,
|
198 |
+
type=str,
|
199 |
+
help="If format is tsv, then specify the column name which contains reference sequence",
|
200 |
+
)
|
201 |
+
parser.add_argument(
|
202 |
+
"--audio_format",
|
203 |
+
default="n_pred.wav",
|
204 |
+
choices=["n_pred.wav"],
|
205 |
+
help="Audio format n_pred.wav corresponds to names like 94_pred.wav or 94_spk7_pred.wav where spk7 is the speaker id",
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--results_dirpath",
|
209 |
+
default=None,
|
210 |
+
type=str,
|
211 |
+
help="If specified, the resulting BLEU score will be written to this file path as txt file",
|
212 |
+
)
|
213 |
+
parser.add_argument(
|
214 |
+
"--transcripts_path",
|
215 |
+
default=None,
|
216 |
+
type=str,
|
217 |
+
help="If specified, the predicted transcripts will be written to this path as a txt file.",
|
218 |
+
)
|
219 |
+
|
220 |
+
args = parser.parse_args()
|
221 |
+
|
222 |
+
prediction_transcripts, bleu_score = run_asr_bleu(args)
|
223 |
+
result_filename = f"{args.reference_format}_{args.lang}_bleu.txt"
|
224 |
+
if args.results_dirpath is not None:
|
225 |
+
if not Path(args.results_dirpath).exists():
|
226 |
+
Path(args.results_dirpath).mkdir(parents=True)
|
227 |
+
with open(Path(args.results_dirpath) / result_filename, "w") as f:
|
228 |
+
f.write(bleu_score.format(width=2))
|
229 |
+
|
230 |
+
if args.transcripts_path is not None:
|
231 |
+
with open(args.transcripts_path, "w") as f:
|
232 |
+
for transcript in prediction_transcripts:
|
233 |
+
f.write(transcript + "\n")
|
234 |
+
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
main()
|
238 |
+
|
239 |
+
|
240 |
+
"""
|
241 |
+
Example to load Sl audio and references, compute BLEU:
|
242 |
+
|
243 |
+
export lang=fi; split=vp && python compute_asr_bleu.py --lang $lang --audio_dirpath /checkpoint/hygong/S2S/speech_matrix_release_ckpts/generated_waveform_release/en-$lang/test_$split/checkpoint.pt --audio_format n_pred.wav --reference_path /large_experiments/ust/hygong/S2S/SpeechEncoder/manifests/vp-vp/en-$lang/test_$split.$lang --reference_format txt --results_dirpath ./
|
244 |
+
"""
|
fairseq/examples/speech_to_speech/asr_bleu/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fairseq==0.12.2
|
2 |
+
pandas==1.4.3
|
3 |
+
sacrebleu==2.2.0
|
4 |
+
torch==1.12.1
|
5 |
+
torchaudio==0.12.1
|
6 |
+
tqdm==4.64.0
|
7 |
+
transformers==4.21.1
|
fairseq/examples/speech_to_speech/asr_bleu/utils.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import urllib.request
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import fairseq
|
7 |
+
import torch
|
8 |
+
from fairseq.data.data_utils import lengths_to_padding_mask
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
try:
|
12 |
+
import torchaudio
|
13 |
+
from torchaudio.models.decoder import ctc_decoder
|
14 |
+
except ImportError:
|
15 |
+
raise ImportError("Upgrade torchaudio to 0.12 to enable CTC decoding")
|
16 |
+
|
17 |
+
|
18 |
+
class DownloadProgressBar(tqdm):
|
19 |
+
"""A class to represent a download progress bar"""
|
20 |
+
|
21 |
+
def update_to(self, b=1, bsize=1, tsize=None) -> None:
|
22 |
+
"""
|
23 |
+
Update the download progress
|
24 |
+
"""
|
25 |
+
if tsize is not None:
|
26 |
+
self.total = tsize
|
27 |
+
self.update(b * bsize - self.n)
|
28 |
+
|
29 |
+
|
30 |
+
def retrieve_asr_config(lang_key: str, asr_version: str, json_path: str) -> dict:
|
31 |
+
"""
|
32 |
+
Retrieve the asr model configs
|
33 |
+
|
34 |
+
Args:
|
35 |
+
lang_key: the lanuage type as the key name
|
36 |
+
json_path: the path of the config json file
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Dict of all the configs in the json file
|
40 |
+
"""
|
41 |
+
|
42 |
+
with open(json_path, "r") as f:
|
43 |
+
asr_model_cfgs = json.load(f)
|
44 |
+
return asr_model_cfgs[lang_key][asr_version]
|
45 |
+
|
46 |
+
|
47 |
+
class ASRGenerator(object):
|
48 |
+
"""A class to represent a ASR generator"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
model_cfg: dict,
|
53 |
+
cache_dirpath: str = (Path.home() / ".cache" / "ust_asr").as_posix(),
|
54 |
+
) -> None:
|
55 |
+
"""
|
56 |
+
Construct all the necessary attributes of the ASRGenerator class
|
57 |
+
|
58 |
+
Args:
|
59 |
+
model_cfg: the dict of the asr model config
|
60 |
+
cache_dirpath: the default cache path is "Path.home()/.cache/ust_asr"
|
61 |
+
"""
|
62 |
+
|
63 |
+
self.cache_dirpath = Path(cache_dirpath) / model_cfg["lang"]
|
64 |
+
self.model_cfg = model_cfg
|
65 |
+
|
66 |
+
self.use_cuda = torch.cuda.is_available()
|
67 |
+
|
68 |
+
torchaudio.set_audio_backend("sox_io")
|
69 |
+
|
70 |
+
if self.model_cfg["model_type"] == "hf":
|
71 |
+
self.prepare_hf_model(self.model_cfg)
|
72 |
+
elif self.model_cfg["model_type"] == "fairseq":
|
73 |
+
self.prepare_fairseq_model(self.model_cfg)
|
74 |
+
else:
|
75 |
+
raise NotImplementedError(
|
76 |
+
f"Model type {self.model_cfg['model_type']} is not supported"
|
77 |
+
)
|
78 |
+
|
79 |
+
if self.model_cfg["post_process"] == "collapse":
|
80 |
+
self.post_process_fn = lambda hypo: "".join(hypo).replace(
|
81 |
+
self.sil_token, " "
|
82 |
+
)
|
83 |
+
elif self.model_cfg["post_process"] == "none":
|
84 |
+
self.post_process_fn = lambda hypo: " ".join(hypo).replace(
|
85 |
+
self.sil_token, " "
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
raise NotImplementedError
|
89 |
+
|
90 |
+
if self.use_cuda:
|
91 |
+
self.model.cuda()
|
92 |
+
self.model.eval()
|
93 |
+
|
94 |
+
self.decoder = ctc_decoder(
|
95 |
+
lexicon=None,
|
96 |
+
tokens=self.tokens,
|
97 |
+
lm=None,
|
98 |
+
nbest=1,
|
99 |
+
beam_size=1,
|
100 |
+
beam_size_token=None,
|
101 |
+
lm_weight=0.0,
|
102 |
+
word_score=0.0,
|
103 |
+
unk_score=float("-inf"),
|
104 |
+
sil_token=self.sil_token,
|
105 |
+
sil_score=0.0,
|
106 |
+
log_add=False,
|
107 |
+
blank_token=self.blank_token,
|
108 |
+
)
|
109 |
+
|
110 |
+
def prepare_hf_model(self, model_cfg: dict) -> None:
|
111 |
+
"""
|
112 |
+
Prepare the huggingface asr model
|
113 |
+
|
114 |
+
Args:
|
115 |
+
model_cfg: dict with the relevant ASR config
|
116 |
+
"""
|
117 |
+
|
118 |
+
def infer_silence_token(vocab: list):
|
119 |
+
"""
|
120 |
+
Different HF checkpoints have different notion of silence token
|
121 |
+
such as | or " " (space)
|
122 |
+
Important: when adding new HF asr model in, check what silence token it uses
|
123 |
+
"""
|
124 |
+
if "|" in vocab:
|
125 |
+
return "|"
|
126 |
+
elif " " in vocab:
|
127 |
+
return " "
|
128 |
+
else:
|
129 |
+
raise RuntimeError("Silence token is not found in the vocabulary")
|
130 |
+
|
131 |
+
try:
|
132 |
+
from transformers import (AutoFeatureExtractor, AutoTokenizer,
|
133 |
+
Wav2Vec2ForCTC, Wav2Vec2Processor)
|
134 |
+
except ImportError:
|
135 |
+
raise ImportError("Install transformers to load HF wav2vec model")
|
136 |
+
|
137 |
+
model_path = model_cfg["model_path"]
|
138 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
|
139 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
140 |
+
self.preprocessor = AutoFeatureExtractor.from_pretrained(model_path)
|
141 |
+
self.processor = Wav2Vec2Processor.from_pretrained(model_path)
|
142 |
+
|
143 |
+
# extra unk tokens are there to make some models work e.g. Finnish ASR has some vocab issue
|
144 |
+
vocab_list = [
|
145 |
+
self.tokenizer.decoder.get(i, f"{self.tokenizer.unk_token}1")
|
146 |
+
for i in range(self.tokenizer.vocab_size)
|
147 |
+
]
|
148 |
+
|
149 |
+
self.sampling_rate = self.preprocessor.sampling_rate
|
150 |
+
self.normalize_input = self.preprocessor.do_normalize
|
151 |
+
self.tokens = vocab_list
|
152 |
+
self.sil_token = infer_silence_token(vocab_list)
|
153 |
+
self.blank_token = self.tokenizer.pad_token
|
154 |
+
|
155 |
+
def prepare_fairseq_model(self, model_cfg: dict) -> None:
|
156 |
+
"""
|
157 |
+
Prepare the fairseq asr model
|
158 |
+
|
159 |
+
Args:
|
160 |
+
model_cfg: the specific model config dict must have: (1) ckpt_path, (2) dict_path
|
161 |
+
"""
|
162 |
+
|
163 |
+
def download_file(url: str, cache_dir: Path):
|
164 |
+
download_path = cache_dir / url.split("/")[-1]
|
165 |
+
if not (cache_dir / url.split("/")[-1]).exists():
|
166 |
+
with DownloadProgressBar(
|
167 |
+
unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
|
168 |
+
) as t:
|
169 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
170 |
+
urllib.request.urlretrieve(
|
171 |
+
url, filename=download_path.as_posix(), reporthook=t.update_to
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
print(f"'{url}' exists in {cache_dir}")
|
175 |
+
|
176 |
+
return download_path.as_posix()
|
177 |
+
|
178 |
+
try:
|
179 |
+
ckpt_path = model_cfg["ckpt_path"]
|
180 |
+
dict_path = model_cfg["dict_path"]
|
181 |
+
except KeyError:
|
182 |
+
raise KeyError(
|
183 |
+
"Fairseq model cfg must provide (1) ckpt_path, (2) dict_path"
|
184 |
+
)
|
185 |
+
|
186 |
+
if re.search("^https", ckpt_path):
|
187 |
+
ckpt_path = download_file(ckpt_path, self.cache_dirpath)
|
188 |
+
if re.search("^https", dict_path):
|
189 |
+
dict_path = download_file(dict_path, self.cache_dirpath)
|
190 |
+
|
191 |
+
model, saved_cfg, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
192 |
+
[ckpt_path],
|
193 |
+
arg_overrides={
|
194 |
+
"task": "audio_finetuning",
|
195 |
+
"data": self.cache_dirpath.as_posix(),
|
196 |
+
}, # data must have dict in it
|
197 |
+
)
|
198 |
+
|
199 |
+
dict_lines = open(dict_path, "r").readlines()
|
200 |
+
tokens = [l.split()[0] for l in dict_lines]
|
201 |
+
# adding default fairseq special tokens
|
202 |
+
tokens = ["<s>", "<pad>", "</s>", "<unk>"] + tokens
|
203 |
+
|
204 |
+
self.model = model[0]
|
205 |
+
self.tokens = tokens
|
206 |
+
|
207 |
+
if "|" in tokens:
|
208 |
+
self.sil_token = "|"
|
209 |
+
else:
|
210 |
+
self.sil_token = tokens[
|
211 |
+
2
|
212 |
+
] # use eos as silence token if | not presented e.g., Hok ASR model
|
213 |
+
print(f"Inferring silence token from the dict: {self.sil_token}")
|
214 |
+
self.blank_token = self.tokens[0]
|
215 |
+
|
216 |
+
self.sampling_rate = saved_cfg.task.sample_rate
|
217 |
+
self.normalize_input = saved_cfg.task.normalize
|
218 |
+
|
219 |
+
@torch.inference_mode()
|
220 |
+
def load_audiofile(self, audio_path: str) -> torch.Tensor:
|
221 |
+
"""
|
222 |
+
Load the audio files and apply resampling and normalizaion
|
223 |
+
|
224 |
+
Args:
|
225 |
+
audio_path: the audio file path
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
audio_waveform: the audio waveform as a torch.Tensor object
|
229 |
+
"""
|
230 |
+
|
231 |
+
audio_waveform, sampling_rate = torchaudio.load(audio_path)
|
232 |
+
if audio_waveform.dim == 2:
|
233 |
+
audio_waveform = audio_waveform.mean(-1)
|
234 |
+
if self.sampling_rate != sampling_rate:
|
235 |
+
audio_waveform = torchaudio.functional.resample(
|
236 |
+
audio_waveform, sampling_rate, self.sampling_rate
|
237 |
+
)
|
238 |
+
if self.normalize_input:
|
239 |
+
# following fairseq raw audio dataset
|
240 |
+
audio_waveform = torch.nn.functional.layer_norm(
|
241 |
+
audio_waveform, audio_waveform.shape
|
242 |
+
)
|
243 |
+
|
244 |
+
return audio_waveform
|
245 |
+
|
246 |
+
@torch.inference_mode()
|
247 |
+
def compute_emissions(self, audio_input: torch.Tensor) -> torch.Tensor:
|
248 |
+
"""
|
249 |
+
Compute the emissions for either fairseq or huggingface asr model
|
250 |
+
|
251 |
+
Args:
|
252 |
+
audio_path: the input audio waveform
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
emissions: the logits of the encoded prediction.
|
256 |
+
"""
|
257 |
+
|
258 |
+
if self.use_cuda:
|
259 |
+
audio_input = audio_input.to("cuda")
|
260 |
+
if isinstance(self.model, fairseq.models.wav2vec.wav2vec2_asr.Wav2VecCtc):
|
261 |
+
padding_mask = lengths_to_padding_mask(torch.tensor([audio_input.numel()]))
|
262 |
+
emissions = self.model.w2v_encoder(audio_input, padding_mask)[
|
263 |
+
"encoder_out"
|
264 |
+
].transpose(0, 1)
|
265 |
+
else:
|
266 |
+
emissions = self.model(audio_input).logits
|
267 |
+
|
268 |
+
return emissions
|
269 |
+
|
270 |
+
def decode_emissions(self, emissions: torch.Tensor) -> str:
|
271 |
+
"""
|
272 |
+
Decode the emissions and apply post process functions
|
273 |
+
|
274 |
+
Args:
|
275 |
+
emissions: the input Tensor object
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
hypo: the str as the decoded transcriptions
|
279 |
+
"""
|
280 |
+
|
281 |
+
emissions = emissions.cpu()
|
282 |
+
results = self.decoder(emissions)
|
283 |
+
|
284 |
+
# assuming the lexicon-free decoder and working with tokens
|
285 |
+
hypo = self.decoder.idxs_to_tokens(results[0][0].tokens)
|
286 |
+
hypo = self.post_process_fn(hypo)
|
287 |
+
|
288 |
+
return hypo
|
289 |
+
|
290 |
+
def transcribe_audiofile(self, audio_path: str, lower=True) -> str:
|
291 |
+
"""
|
292 |
+
Transcribe the audio into string
|
293 |
+
|
294 |
+
Args:
|
295 |
+
audio_path: the input audio waveform
|
296 |
+
lower: the case of the transcriptions with lowercase as the default
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
hypo: the transcription result
|
300 |
+
"""
|
301 |
+
|
302 |
+
asr_input = self.load_audiofile(audio_path)
|
303 |
+
emissions = self.compute_emissions(asr_input)
|
304 |
+
hypo = self.decode_emissions(emissions)
|
305 |
+
|
306 |
+
return hypo.strip().lower() if lower else hypo.strip()
|
fairseq/examples/speech_to_speech/benchmarking/README.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Benchmarking
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
The goal of this framework is to support benchmarking various speech to speech translation(S2ST) models in terms of runtime, max-memory consumption and total number of floating point operations(FLOPS). It is a generic framework and can be easily extended to support any fairseq models. To accurately benchmark the performance, core inference modules are re-implemented based on fairseq_cli/generate.py (core.py/Processing) and examples/speech_to_text/generate_waveform.py(core.py/SpeechGeneration. To ensure that the end to end models and cascaded models are compared fairly, for cascaded models we only consider the performance metrics for model inference at all stages ignoring any intermediate data and io processing consumption. We run all the benchmarking runs on CPU as it is generally used in production environment and also due to lack of good benchmarking library support for GPUs.
|
6 |
+
|
7 |
+
1. Runtime: Average time in seconds to run model inference on an example from a given dataset. We use [timeit](https://docs.python.org/3/library/timeit.html) library to measure the runtime.
|
8 |
+
2. Max memory: Maximum memory in MiB averaged over by running the model inference on all examples from the given dataset. We use [memory_profiler](https://pypi.org/project/memory-profiler/) library to gather memory footprints for a code snippet and find the maximum to get the max memory used by the code. For cascaded models, we find the max of all stages to get the overall max_memory footprint.
|
9 |
+
3. FLOPS: We compute the average number of floating point operations needed to run model inference for an example from the given dataset. We use [PAPI library](http://www.bnikolic.co.uk/blog/python/flops/2019/10/01/pytorch-count-flops.html) to benchmark the number of flops.
|
10 |
+
|
11 |
+
## CLI Commands
|
12 |
+
|
13 |
+
```{python}
|
14 |
+
CUBLAS_WORKSPACE_CONFIG=:4096:8 python examples/speech_to_speech/benchmarking/get_metrics.py ‘’ --config $config
|
15 |
+
```
|
16 |
+
|
17 |
+
|
18 |
+
## Note:
|
19 |
+
|
20 |
+
1. The npy dataset is a list of samples saved as a .npy file. Each sample is a dictionary with id, net_input.
|
21 |
+
2. The raw dataset is a list of raw audio paths similar to wav2vec2 input tsv file
|
22 |
+
|
23 |
+
```{python}
|
24 |
+
sample: {
|
25 |
+
"id": xx,
|
26 |
+
"net_input": {
|
27 |
+
"src_tokens": torch.tensor([]),
|
28 |
+
"src_lengths": torch.tensor([])
|
29 |
+
}
|
30 |
+
}
|
31 |
+
```
|
fairseq/examples/speech_to_speech/benchmarking/configs/2StageS2ST.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general:
|
2 |
+
dataset_path: $npy_dataset
|
3 |
+
cpu: True
|
4 |
+
model_type: 2StageS2ST
|
5 |
+
dataset_size: 1
|
6 |
+
|
7 |
+
stage1:
|
8 |
+
data: $data_bin_stage1
|
9 |
+
task: speech_to_text
|
10 |
+
path: $checkpoint_stage1
|
11 |
+
config_yaml: config.yaml
|
12 |
+
max_len_a: 2
|
13 |
+
max_len_b: 500
|
14 |
+
|
15 |
+
stage2:
|
16 |
+
data: $data_bin_stage2
|
17 |
+
task: text_to_speech
|
18 |
+
path: $checkpoint_stage2
|
19 |
+
config_yaml: config.yaml
|
fairseq/examples/speech_to_speech/benchmarking/configs/3StageS2ST.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general:
|
2 |
+
dataset_path: $npy_dataset
|
3 |
+
cpu: True
|
4 |
+
model_type: 3StageS2ST
|
5 |
+
max_len_a: 2
|
6 |
+
max_len_b: 500
|
7 |
+
dataset_size: 1
|
8 |
+
|
9 |
+
stage1:
|
10 |
+
data: $data_bin_stage1
|
11 |
+
task: speech_to_text
|
12 |
+
path: $checkpoint_stage1
|
13 |
+
config_yaml: config.yaml
|
14 |
+
max_len_a: 2
|
15 |
+
max_len_b: 500
|
16 |
+
|
17 |
+
stage2:
|
18 |
+
data: $data_bin_stage2
|
19 |
+
task: translation
|
20 |
+
path: $checkpoint_stage2
|
21 |
+
config_yaml: config.yaml
|
22 |
+
|
23 |
+
|
24 |
+
stage2:
|
25 |
+
data: $data_bin_stage3
|
26 |
+
task: text_to_speech
|
27 |
+
path: $checkpoint_stage3
|
28 |
+
config_yaml: config.yaml
|
fairseq/examples/speech_to_speech/benchmarking/configs/DirectS2U.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general:
|
2 |
+
dataset_path: $npy_dataset_path
|
3 |
+
cpu: True
|
4 |
+
model_type: S2UT
|
5 |
+
dataset_size: 5
|
6 |
+
dump_speech_waveforms_dir: $dump_waveforms_dir_path
|
7 |
+
|
8 |
+
stage1:
|
9 |
+
data: $data_bin
|
10 |
+
task: speech_to_speech
|
11 |
+
path: $checkpoint
|
12 |
+
config_yaml: config.yaml
|
13 |
+
max_len_b: 100000
|
14 |
+
beam: 10
|
15 |
+
target_is_code: True
|
16 |
+
max_target_positions: 3000
|
17 |
+
target_code_size: 100
|
18 |
+
|
19 |
+
stage2:
|
20 |
+
vocoder: $vocoder_path
|
21 |
+
vocoder_cfg: $vocoder_cfg_json
|
22 |
+
dur_prediction: True
|
fairseq/examples/speech_to_speech/benchmarking/configs/S2T.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general:
|
2 |
+
dataset_path: $npy_dataset
|
3 |
+
cpu: True
|
4 |
+
model_type: S2T
|
5 |
+
dataset_size: 1
|
6 |
+
|
7 |
+
stage1:
|
8 |
+
data: $data_bin
|
9 |
+
task: speech_to_text
|
10 |
+
path: $checkpoint
|
11 |
+
config_yaml: config.yaml
|
12 |
+
max_len_a: 2
|
13 |
+
max_len_b: 500
|
fairseq/examples/speech_to_speech/benchmarking/core.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timeit
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
from pypapi import events, papi_high as high
|
5 |
+
from memory_profiler import memory_usage
|
6 |
+
from torch import nn
|
7 |
+
from argparse import Namespace
|
8 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
9 |
+
from fairseq.data import data_utils as fairseq_data_utils
|
10 |
+
from fairseq import checkpoint_utils, tasks, utils
|
11 |
+
from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
|
12 |
+
from examples.hubert.simple_kmeans.dump_hubert_feature import HubertFeatureReader
|
13 |
+
from examples.hubert.simple_kmeans.dump_km_label import ApplyKmeans
|
14 |
+
from fairseq_cli.generate import get_symbols_to_strip_from_output
|
15 |
+
import soundfile as sf
|
16 |
+
import ast
|
17 |
+
import json
|
18 |
+
|
19 |
+
logging.basicConfig()
|
20 |
+
logging.root.setLevel(logging.INFO)
|
21 |
+
logging.basicConfig(level=logging.INFO)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
torch.manual_seed(1)
|
26 |
+
torch.set_deterministic(True)
|
27 |
+
|
28 |
+
|
29 |
+
class BenchmarkingBase(nn.Module):
|
30 |
+
def __init__(self):
|
31 |
+
nn.Module.__init__(self)
|
32 |
+
self.s2x_task = None
|
33 |
+
|
34 |
+
def warm_up(self, sample, repeat):
|
35 |
+
"""Warm up the model"""
|
36 |
+
for _i in range(repeat):
|
37 |
+
self.forward(sample)
|
38 |
+
logger.info(f"Model warmed up by running inference {repeat} times")
|
39 |
+
|
40 |
+
def benchmark_run_time(self, dataset, repeat):
|
41 |
+
"""Benchmark average runtime for the model by calling benchmark_run_time_single_sample function"""
|
42 |
+
logger.info("Starting run time benchmarking")
|
43 |
+
time_elapsed = 0
|
44 |
+
for i, sample in enumerate(dataset):
|
45 |
+
time_elapsed += self.benchmark_run_time_single_sample(sample, repeat=repeat)
|
46 |
+
if i % 100 == 0:
|
47 |
+
logger.info(f"Benchmarked run time for {i}/{len(dataset)} samples")
|
48 |
+
total_time_elapsed = time_elapsed / len(dataset)
|
49 |
+
return total_time_elapsed
|
50 |
+
|
51 |
+
def benchmark_run_time_single_sample(self, sample, repeat):
|
52 |
+
"""Benchmark average runtime for a single sample using timeit library. Units are seconds"""
|
53 |
+
timer = timeit.Timer(lambda: self.forward(sample))
|
54 |
+
time_elapsed = timer.timeit(repeat)
|
55 |
+
return time_elapsed / repeat
|
56 |
+
|
57 |
+
def count_flops(
|
58 |
+
self,
|
59 |
+
dataset,
|
60 |
+
repeat,
|
61 |
+
):
|
62 |
+
"""Use PYPAPI library to count average flops for model inference.
|
63 |
+
Note: It only works if the model is being run on cpu"""
|
64 |
+
logger.info("Starting flop counter")
|
65 |
+
high.start_counters([events.PAPI_DP_OPS])
|
66 |
+
for i, sample in enumerate(dataset):
|
67 |
+
for _r in range(repeat):
|
68 |
+
self.forward(sample)
|
69 |
+
if i % 100 == 0:
|
70 |
+
logger.info(f"Counted flops for {i}/{len(dataset)} samples")
|
71 |
+
flops = high.stop_counters()
|
72 |
+
flops = round(flops[0] / (repeat * len(dataset)))
|
73 |
+
return flops
|
74 |
+
|
75 |
+
def max_memory(self, dataset, repeat):
|
76 |
+
"""Compute average max memory consumed by model inference. Units are MiB"""
|
77 |
+
logger.info("Starting memory benchmarking")
|
78 |
+
total_memory = 0
|
79 |
+
for i, sample in enumerate(dataset):
|
80 |
+
for _r in range(repeat):
|
81 |
+
total_memory += max(memory_usage((self.forward, (sample,), {})))
|
82 |
+
if i % 100 == 0:
|
83 |
+
logger.info(f"Benchmarked memory for {i}/{len(dataset)} samples")
|
84 |
+
total_memory = total_memory / (repeat * len(dataset))
|
85 |
+
return total_memory
|
86 |
+
|
87 |
+
def gather_all_metrics(self, dataset, repeat):
|
88 |
+
run_time = self.benchmark_run_time(dataset, repeat)
|
89 |
+
max_memory = self.max_memory(dataset, repeat)
|
90 |
+
flops = self.count_flops(dataset, repeat)
|
91 |
+
|
92 |
+
return run_time, max_memory, flops
|
93 |
+
|
94 |
+
def dump_final_speech_output(
|
95 |
+
self, dataset, output_dir, resample_fn, sample_rate, prefix=None
|
96 |
+
):
|
97 |
+
|
98 |
+
for i, sample in enumerate(dataset):
|
99 |
+
hypo = self.forward(sample)[0]
|
100 |
+
|
101 |
+
def to_np(x):
|
102 |
+
return x.detach().cpu().numpy()
|
103 |
+
|
104 |
+
try:
|
105 |
+
wave_preds = to_np(resample_fn(hypo["waveform"]))
|
106 |
+
sf.write(
|
107 |
+
f"{output_dir}/{prefix}_{i}_pred.wav",
|
108 |
+
wave_preds,
|
109 |
+
sample_rate,
|
110 |
+
)
|
111 |
+
except Exception as e:
|
112 |
+
raise Exception(
|
113 |
+
f" Encountered {e} - Invalid waveform. Make sure the model outputs a waveform"
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
class Processing(BenchmarkingBase):
|
118 |
+
"""Class similar to fairseq_cli/generate.py. Supports ASR, MT and ST model inference"""
|
119 |
+
|
120 |
+
def __init__(self, args):
|
121 |
+
super().__init__()
|
122 |
+
self.use_cuda = not getattr(args, "cpu", False)
|
123 |
+
self.setUp(args)
|
124 |
+
self.training = False
|
125 |
+
self.s2x_task = self.task
|
126 |
+
|
127 |
+
def setUp(self, cfg):
|
128 |
+
if isinstance(cfg, Namespace):
|
129 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
130 |
+
|
131 |
+
self.task = tasks.setup_task(cfg.task)
|
132 |
+
self.tgt_dict = self.task.target_dictionary
|
133 |
+
|
134 |
+
# Load ensemble
|
135 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
136 |
+
models, _ = checkpoint_utils.load_model_ensemble(
|
137 |
+
utils.split_paths(cfg.common_eval.path),
|
138 |
+
arg_overrides={},
|
139 |
+
task=self.task,
|
140 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
141 |
+
strict=False,
|
142 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
143 |
+
)
|
144 |
+
if len(models) > 1:
|
145 |
+
raise Exception("Currently loading multiple models is not supported")
|
146 |
+
self.model = models[0]
|
147 |
+
|
148 |
+
# Optimize model for generation
|
149 |
+
if cfg.common.fp16:
|
150 |
+
self.model.half()
|
151 |
+
if self.use_cuda:
|
152 |
+
self.model.cuda()
|
153 |
+
self.model.prepare_for_inference_(cfg)
|
154 |
+
|
155 |
+
self.generator = self.task.build_generator(
|
156 |
+
[self.model],
|
157 |
+
cfg.generation,
|
158 |
+
extra_gen_cls_kwargs={},
|
159 |
+
)
|
160 |
+
# Handle tokenization and BPE
|
161 |
+
self.tokenizer = self.task.build_tokenizer(cfg.tokenizer)
|
162 |
+
self.bpe = self.task.build_bpe(cfg.bpe)
|
163 |
+
self.remove_bpe = cfg.common_eval.post_process
|
164 |
+
|
165 |
+
def encode_source(self, src):
|
166 |
+
"""Method to generate source tokens from a string"""
|
167 |
+
if self.tokenizer is not None:
|
168 |
+
src = self.tokenizer.encode(src)
|
169 |
+
if self.bpe is not None:
|
170 |
+
src = self.bpe.encode(src)
|
171 |
+
src_tokens = self.task.source_dictionary.encode_line(src).long()
|
172 |
+
src_lens = src_tokens.size(0)
|
173 |
+
return {
|
174 |
+
"net_input": {
|
175 |
+
"src_tokens": src_tokens.view(1, src_lens),
|
176 |
+
"src_lengths": torch.tensor([src_lens]),
|
177 |
+
}
|
178 |
+
}
|
179 |
+
|
180 |
+
def decode_target(self, hypos):
|
181 |
+
"""Method to decode target string from tokens"""
|
182 |
+
hypo_str = self.tgt_dict.string(
|
183 |
+
hypos[0][0]["tokens"].int().cpu(),
|
184 |
+
self.remove_bpe,
|
185 |
+
get_symbols_to_strip_from_output(self.generator),
|
186 |
+
)
|
187 |
+
if self.bpe is not None:
|
188 |
+
hypo_str = self.bpe.decode(hypo_str)
|
189 |
+
if self.tokenizer is not None:
|
190 |
+
hypo_str = self.tokenizer.decode(hypo_str)
|
191 |
+
return hypo_str
|
192 |
+
|
193 |
+
def forward(self, sample):
|
194 |
+
hypos = self.task.inference_step(
|
195 |
+
self.generator,
|
196 |
+
[self.model],
|
197 |
+
sample,
|
198 |
+
prefix_tokens=None,
|
199 |
+
constraints=None,
|
200 |
+
)
|
201 |
+
return hypos
|
202 |
+
|
203 |
+
|
204 |
+
class GenerateWaveformFromCode(BenchmarkingBase):
|
205 |
+
"""Class to support waveform generation from code. Currently, vocoder only supports single speaker"""
|
206 |
+
|
207 |
+
def __init__(self, args):
|
208 |
+
super().__init__()
|
209 |
+
with open(args.vocoder_cfg) as f:
|
210 |
+
vocoder_cfg = json.load(f)
|
211 |
+
self.dur_prediction = args.dur_prediction
|
212 |
+
self.vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg)
|
213 |
+
|
214 |
+
def format_units(self, input):
|
215 |
+
code = torch.LongTensor(list(map(int, input.strip().split()))).view(1, -1)
|
216 |
+
return {"code": code}
|
217 |
+
|
218 |
+
def generate_vocoder_input(self, dataset):
|
219 |
+
return [self.format_units(sample) for sample in dataset]
|
220 |
+
|
221 |
+
def forward(self, sample):
|
222 |
+
return [{"waveform": self.vocoder(sample, self.dur_prediction)}]
|
223 |
+
|
224 |
+
|
225 |
+
class HubertUnitExtractor(BenchmarkingBase):
|
226 |
+
def __init__(self, args):
|
227 |
+
self.feature_reader = HubertFeatureReader(
|
228 |
+
args.hubert_ckpt_path, args.hubert_layer
|
229 |
+
)
|
230 |
+
self.kmeans = ApplyKmeans(args.hubert_km_path)
|
231 |
+
|
232 |
+
def forward(self, sample):
|
233 |
+
with torch.no_grad():
|
234 |
+
feat = []
|
235 |
+
for start in range(0, sample.size(1), self.feature_reader.max_chunk):
|
236 |
+
x_chunk = sample[:, start : start + self.max_chunk]
|
237 |
+
feat_chunk, _ = self.feature_reader.model.extract_features(
|
238 |
+
source=x_chunk,
|
239 |
+
padding_mask=None,
|
240 |
+
mask=False,
|
241 |
+
output_layer=self.layer,
|
242 |
+
)
|
243 |
+
feat.append(feat_chunk)
|
244 |
+
torch.cat(feat, 1).squeeze(0)
|
245 |
+
return self.kmeans(feat).tolist()
|
246 |
+
|
247 |
+
|
248 |
+
class SpeechGeneration(BenchmarkingBase):
|
249 |
+
"""Class similar to examples/text_to_speech/generate_waveform.py.
|
250 |
+
Supports models with speech generation as end goal (TTS, Direct S2ST models etc)"""
|
251 |
+
|
252 |
+
def __init__(self, args):
|
253 |
+
super().__init__()
|
254 |
+
self.use_cuda = not getattr(args, "cpu", False)
|
255 |
+
self.setUp(args)
|
256 |
+
self.s2x_task = self.task
|
257 |
+
|
258 |
+
def setUp(self, args):
|
259 |
+
if args.task == "speech_to_speech":
|
260 |
+
args.normalize_waveform = False
|
261 |
+
self.task = tasks.setup_task(args)
|
262 |
+
self.pre_tokenizer = self.task.build_tokenizer(args)
|
263 |
+
self.bpe_tokenizer = self.task.build_bpe(args)
|
264 |
+
try:
|
265 |
+
self.src_dict = self.task.src_dict
|
266 |
+
except Exception:
|
267 |
+
self.src_dict = None
|
268 |
+
ensemble, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
269 |
+
[args.path],
|
270 |
+
arg_overrides=ast.literal_eval(args.model_overrides),
|
271 |
+
task=self.task,
|
272 |
+
strict=False,
|
273 |
+
)
|
274 |
+
self.model = ensemble[0]
|
275 |
+
if self.use_cuda:
|
276 |
+
self.model.cuda()
|
277 |
+
# criterion.cuda()
|
278 |
+
self.model.eval()
|
279 |
+
self.generator = self.task.build_generator(
|
280 |
+
[self.model],
|
281 |
+
args,
|
282 |
+
)
|
283 |
+
|
284 |
+
def processTextInput(self, text):
|
285 |
+
"""Generate source tokens from text input"""
|
286 |
+
if self.pre_tokenizer is not None:
|
287 |
+
text = self.pre_tokenizer.encode(text)
|
288 |
+
if self.bpe_tokenizer is not None:
|
289 |
+
text = self.bpe_tokenizer.encode(text)
|
290 |
+
target = self.src_dict.encode_line(
|
291 |
+
text, add_if_not_exist=False, append_eos=True
|
292 |
+
).long()
|
293 |
+
target = fairseq_data_utils.collate_tokens(
|
294 |
+
[target],
|
295 |
+
self.src_dict.pad(),
|
296 |
+
self.src_dict.eos(),
|
297 |
+
left_pad=False,
|
298 |
+
move_eos_to_beginning=False,
|
299 |
+
)
|
300 |
+
src_lengths = torch.tensor([target.size(1)], dtype=torch.long)
|
301 |
+
prev_output_tokens = None
|
302 |
+
sample = {
|
303 |
+
"net_input": {
|
304 |
+
"src_tokens": target,
|
305 |
+
"src_lengths": src_lengths,
|
306 |
+
"prev_output_tokens": prev_output_tokens,
|
307 |
+
}
|
308 |
+
}
|
309 |
+
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
|
310 |
+
return sample
|
311 |
+
|
312 |
+
def forward(self, sample):
|
313 |
+
sample["speaker"] = None
|
314 |
+
output = self.generator.generate(self.model, sample) # , has_targ=False
|
315 |
+
return output
|
316 |
+
|
317 |
+
|
318 |
+
class S2UT(BenchmarkingBase):
|
319 |
+
"""Class to support S2UT models. Also supports generating waveforms from the units predicted"""
|
320 |
+
|
321 |
+
def __init__(self, s2u_args, vocoder_args=None):
|
322 |
+
super().__init__()
|
323 |
+
self.s2u = Processing(s2u_args)
|
324 |
+
self.vocoder = None
|
325 |
+
if vocoder_args:
|
326 |
+
self.vocoder = GenerateWaveformFromCode(vocoder_args)
|
327 |
+
self.vocoder_input = None
|
328 |
+
|
329 |
+
def forward(self, sample):
|
330 |
+
s2u_hypos = self.s2u(sample)
|
331 |
+
s2u_output = self.s2u.decode_target(s2u_hypos)
|
332 |
+
if not self.vocoder:
|
333 |
+
return s2u_output
|
334 |
+
units = self.vocoder.format_units(s2u_output)
|
335 |
+
vocoder_output = self.vocoder(units)
|
336 |
+
return vocoder_output
|
337 |
+
|
338 |
+
def generate_s2u_outputs(self, dataset):
|
339 |
+
return [self.s2u.decode_target(self.s2u(sample)) for sample in dataset]
|
340 |
+
|
341 |
+
def compute_metrics(self, metric_type, dataset, repeat=None):
|
342 |
+
"""Generic function to compute metrics ignoring the io processing time"""
|
343 |
+
if self.vocoder and not self.vocoder_input:
|
344 |
+
self.s2u_output = self.generate_s2u_outputs(dataset)
|
345 |
+
self.vocoder_input = self.vocoder.generate_vocoder_input(self.s2u_output)
|
346 |
+
|
347 |
+
s2u_metrics = getattr(self.s2u, metric_type)(
|
348 |
+
dataset,
|
349 |
+
repeat,
|
350 |
+
)
|
351 |
+
vocoder_metrics = 0
|
352 |
+
if self.vocoder:
|
353 |
+
vocoder_metrics = getattr(self.vocoder, metric_type)(
|
354 |
+
self.vocoder_input,
|
355 |
+
repeat,
|
356 |
+
)
|
357 |
+
print(
|
358 |
+
f"metric_type = {metric_type} s2u_metrics = {s2u_metrics} \t vocoder_metrics = {vocoder_metrics}"
|
359 |
+
)
|
360 |
+
if metric_type == "max_memory":
|
361 |
+
return max(s2u_metrics, vocoder_metrics)
|
362 |
+
else:
|
363 |
+
return s2u_metrics + vocoder_metrics
|
364 |
+
|
365 |
+
def benchmark_run_time(self, dataset, repeat):
|
366 |
+
return self.compute_metrics("benchmark_run_time", dataset, repeat)
|
367 |
+
|
368 |
+
def count_flops(self, dataset, repeat):
|
369 |
+
return self.compute_metrics("count_flops", dataset, repeat)
|
370 |
+
|
371 |
+
def max_memory(self, dataset, repeat):
|
372 |
+
return self.compute_metrics("max_memory", dataset, repeat)
|
373 |
+
|
374 |
+
|
375 |
+
class Cascaded2StageS2ST(BenchmarkingBase):
|
376 |
+
"""ST + TTS"""
|
377 |
+
|
378 |
+
def __init__(self, s2t_args, tts_args):
|
379 |
+
super().__init__()
|
380 |
+
self.s2t = Processing(s2t_args)
|
381 |
+
self.s2x_task = self.s2t.task
|
382 |
+
self.tts = SpeechGeneration(tts_args) if tts_args else None
|
383 |
+
self.training = False
|
384 |
+
self.tts_inputs = None
|
385 |
+
|
386 |
+
def forward(self, sample):
|
387 |
+
if not self.tts:
|
388 |
+
raise Exception(
|
389 |
+
"Forward function is not callable without tts. Reinitialize the class with tts_args"
|
390 |
+
)
|
391 |
+
s2t_hypos = self.s2t(sample)
|
392 |
+
s2t_output = self.s2t.decode_target(s2t_hypos)
|
393 |
+
tts_input = self.tts.processTextInput(s2t_output)
|
394 |
+
tts_output = self.tts(tts_input)
|
395 |
+
return tts_output
|
396 |
+
|
397 |
+
def generate_s2t_outputs(self, dataset):
|
398 |
+
"""Process dataset and generate s2t outputs"""
|
399 |
+
return [self.s2t.decode_target(self.s2t(sample)) for sample in dataset]
|
400 |
+
|
401 |
+
def generate_tts_inputs(self, dataset):
|
402 |
+
"""Process dataset and generate tts inputs"""
|
403 |
+
return [self.tts.processTextInput(sample) for sample in dataset]
|
404 |
+
|
405 |
+
def compute_metrics(self, metric_type, dataset, repeat=None):
|
406 |
+
"""Generic function to compute metrics ignoring the io processing time"""
|
407 |
+
if not self.tts_inputs:
|
408 |
+
s2t_outputs = self.generate_s2t_outputs(dataset)
|
409 |
+
self.tts_inputs = self.generate_tts_inputs(s2t_outputs)
|
410 |
+
|
411 |
+
s2t_metrics = getattr(self.s2t, metric_type)(
|
412 |
+
dataset,
|
413 |
+
repeat,
|
414 |
+
)
|
415 |
+
|
416 |
+
tts_metrics = getattr(self.tts, metric_type)(
|
417 |
+
self.tts_inputs,
|
418 |
+
repeat,
|
419 |
+
)
|
420 |
+
print(
|
421 |
+
f"metric_type = {metric_type} s2t_metrics = {s2t_metrics} \t tts_metrics = {tts_metrics}"
|
422 |
+
)
|
423 |
+
if metric_type == "max_memory":
|
424 |
+
return max(s2t_metrics, tts_metrics)
|
425 |
+
else:
|
426 |
+
return s2t_metrics + tts_metrics
|
427 |
+
|
428 |
+
def benchmark_run_time(self, dataset, repeat):
|
429 |
+
return self.compute_metrics("benchmark_run_time", dataset, repeat)
|
430 |
+
|
431 |
+
def count_flops(self, dataset, repeat):
|
432 |
+
return self.compute_metrics("count_flops", dataset, repeat)
|
433 |
+
|
434 |
+
def max_memory(self, dataset, repeat):
|
435 |
+
return self.compute_metrics("max_memory", dataset, repeat)
|
436 |
+
|
437 |
+
|
438 |
+
class Cascaded3StageS2ST(Cascaded2StageS2ST):
|
439 |
+
"""ASR + MT + TTS"""
|
440 |
+
|
441 |
+
def __init__(self, s2t_args, tts_args, mt_args):
|
442 |
+
super().__init__(s2t_args, tts_args)
|
443 |
+
self.mt = Processing(mt_args)
|
444 |
+
self.mt_inputs = []
|
445 |
+
|
446 |
+
def forward(self, sample):
|
447 |
+
s2t_hypos = self.s2t(sample)
|
448 |
+
s2t_output = self.s2t.decode_target(s2t_hypos)
|
449 |
+
mt_input = self.mt.encode_source(s2t_output)
|
450 |
+
mt_hypos = self.mt(mt_input)
|
451 |
+
mt_output = self.mt.decode_target(mt_hypos)
|
452 |
+
tts_input = self.tts.processTextInput(mt_output)
|
453 |
+
tts_output = self.tts(tts_input)
|
454 |
+
return tts_output
|
455 |
+
|
456 |
+
def generate_mt_inputs(self, dataset):
|
457 |
+
"""Process dataset to generate mt model inputs"""
|
458 |
+
return [self.mt.encode_source(sample) for sample in dataset]
|
459 |
+
|
460 |
+
def generate_mt_outputs(self, dataset):
|
461 |
+
"""Process dataset to generate mt model outputs"""
|
462 |
+
return [self.mt.decode_target(self.mt(sample)) for sample in dataset]
|
463 |
+
|
464 |
+
def compute_metrics(self, metric_type, dataset, repeat=None):
|
465 |
+
"""Generic function to compute metrics ignoring the io processing time"""
|
466 |
+
if not self.tts_inputs:
|
467 |
+
s2t_outputs = self.generate_s2t_outputs(dataset)
|
468 |
+
self.mt_inputs = self.generate_mt_inputs(s2t_outputs)
|
469 |
+
mt_outputs = self.generate_mt_outputs(self.mt_inputs)
|
470 |
+
self.tts_inputs = self.generate_tts_inputs(mt_outputs)
|
471 |
+
|
472 |
+
s2t_metrics = getattr(self.s2t, metric_type)(
|
473 |
+
dataset,
|
474 |
+
repeat,
|
475 |
+
)
|
476 |
+
mt_metrics = getattr(self.mt, metric_type)(self.mt_inputs, repeat)
|
477 |
+
tts_metrics = getattr(self.tts, metric_type)(
|
478 |
+
self.tts_inputs,
|
479 |
+
repeat,
|
480 |
+
)
|
481 |
+
print(
|
482 |
+
f"metric_type = {metric_type} s2t_metrics = {s2t_metrics} \t mt_metrics = {mt_metrics} \t tts_metrics = {tts_metrics}"
|
483 |
+
)
|
484 |
+
if metric_type == "max_memory":
|
485 |
+
return max(s2t_metrics, mt_metrics, tts_metrics)
|
486 |
+
else:
|
487 |
+
return s2t_metrics + mt_metrics + tts_metrics
|
fairseq/examples/speech_to_speech/benchmarking/data_utils.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fairseq import tasks
|
2 |
+
import numpy as np
|
3 |
+
import logging
|
4 |
+
import random
|
5 |
+
from fairseq import options
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
import soundfile as sf
|
9 |
+
|
10 |
+
from fairseq.data.audio.audio_utils import (
|
11 |
+
get_waveform,
|
12 |
+
parse_path,
|
13 |
+
)
|
14 |
+
|
15 |
+
logging.basicConfig()
|
16 |
+
logging.root.setLevel(logging.INFO)
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
random.seed(1)
|
21 |
+
np.random.seed(1)
|
22 |
+
random_number_generator = np.random.RandomState(30)
|
23 |
+
|
24 |
+
|
25 |
+
def generate_random_data_sample(T, B=1, D=80):
|
26 |
+
"""Generate random data sample given the T, B, D values"""
|
27 |
+
net_input = {
|
28 |
+
"src_tokens": torch.tensor(random_number_generator.randn(B, T, D)).float(),
|
29 |
+
"src_lengths": torch.tensor([T]),
|
30 |
+
}
|
31 |
+
return {"net_input": net_input}
|
32 |
+
|
33 |
+
|
34 |
+
def generate_random_dataset(T_range_min, T_range_max, B=1, D=80, dataset_size=100):
|
35 |
+
"""Generate random dataset with T values within a given range, B, D"""
|
36 |
+
T_values = [random.randint(T_range_min, T_range_max) for i in range(dataset_size)]
|
37 |
+
dataset = []
|
38 |
+
for t in T_values:
|
39 |
+
dataset.append(generate_random_data_sample(t, B, D))
|
40 |
+
return dataset, sum(T_values) / dataset_size
|
41 |
+
|
42 |
+
|
43 |
+
def load_dataset_npy(file_name, dataset_size=None):
|
44 |
+
"""Load dataset from a .npy file."""
|
45 |
+
data = np.load(file_name, allow_pickle=True)
|
46 |
+
if dataset_size:
|
47 |
+
data = data[:dataset_size]
|
48 |
+
return data
|
49 |
+
|
50 |
+
|
51 |
+
def load_dataset_raw_to_waveforms(
|
52 |
+
file_name,
|
53 |
+
dataset_size=None,
|
54 |
+
need_waveform=True,
|
55 |
+
sample_rate=16000,
|
56 |
+
read_using_soundfile=False,
|
57 |
+
):
|
58 |
+
"""Load raw dataset from w2v tsv file. Optionally get waveforms"""
|
59 |
+
data = []
|
60 |
+
with open(file_name, "r") as fp:
|
61 |
+
lines = fp.readlines()
|
62 |
+
data = [
|
63 |
+
os.path.join(lines[0].strip(), line.strip().split("\t")[0])
|
64 |
+
for line in lines[1:]
|
65 |
+
]
|
66 |
+
|
67 |
+
if dataset_size:
|
68 |
+
data = data[:dataset_size]
|
69 |
+
|
70 |
+
if not need_waveform:
|
71 |
+
return data
|
72 |
+
|
73 |
+
features = []
|
74 |
+
if read_using_soundfile:
|
75 |
+
for _i, d in enumerate(data):
|
76 |
+
wav = sf.read(d)[0]
|
77 |
+
if wav.ndim == 2:
|
78 |
+
wav = wav.mean(-1)
|
79 |
+
features.append(torch.from_numpy(wav).float().view(1, -1))
|
80 |
+
else:
|
81 |
+
for i, d in enumerate(data):
|
82 |
+
_path, slice_ptr = parse_path(d)
|
83 |
+
if len(slice_ptr) == 0:
|
84 |
+
feat = get_waveform(
|
85 |
+
_path, always_2d=True, output_sample_rate=sample_rate
|
86 |
+
)[0]
|
87 |
+
features.append(
|
88 |
+
{
|
89 |
+
"id": i,
|
90 |
+
"net_input": {
|
91 |
+
"src_tokens": torch.tensor(feat),
|
92 |
+
"src_lengths": torch.tensor([feat.shape[1]]),
|
93 |
+
},
|
94 |
+
}
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
raise Exception("Currently unsupported data format")
|
98 |
+
return features
|
99 |
+
|
100 |
+
|
101 |
+
def load_dataset_task(
|
102 |
+
args,
|
103 |
+
batch_size=1,
|
104 |
+
limit_size=None,
|
105 |
+
ref_dataset=None,
|
106 |
+
):
|
107 |
+
"""Loads dataset based on args by creating a task"""
|
108 |
+
if not args.data or not args.subset or not args.task:
|
109 |
+
raise Exception(
|
110 |
+
"Please provide necessary arguments to load the dataset - data, subset and task"
|
111 |
+
)
|
112 |
+
task = tasks.setup_task(args)
|
113 |
+
|
114 |
+
task.load_dataset(args.subset)
|
115 |
+
if not limit_size:
|
116 |
+
limit_size = len(task.dataset(args.subset))
|
117 |
+
|
118 |
+
iter = task.get_batch_iterator(
|
119 |
+
dataset=task.dataset(args.subset), max_sentences=batch_size
|
120 |
+
).next_epoch_itr(shuffle=False)
|
121 |
+
dataset = []
|
122 |
+
for i, sample in enumerate(iter):
|
123 |
+
sample = {
|
124 |
+
"id": task.datasets[args.subset].ids[sample["id"].item()],
|
125 |
+
"net_input": {
|
126 |
+
"src_tokens": sample["net_input"]["src_tokens"],
|
127 |
+
"src_lengths": sample["net_input"]["src_lengths"],
|
128 |
+
},
|
129 |
+
}
|
130 |
+
dataset.append(sample)
|
131 |
+
if i == limit_size - 1:
|
132 |
+
break
|
133 |
+
|
134 |
+
if ref_dataset:
|
135 |
+
try:
|
136 |
+
ids = get_ids_from_dataset(ref_dataset)
|
137 |
+
except Exception as e:
|
138 |
+
raise Exception(f"{e} - Cannot extract ids from reference dataset")
|
139 |
+
|
140 |
+
filtered_dataset = []
|
141 |
+
for sample in dataset:
|
142 |
+
if (
|
143 |
+
sample["id"] in ids
|
144 |
+
or sample["id"][5:] in ids
|
145 |
+
or f"dev_{sample['id']}" in ids
|
146 |
+
):
|
147 |
+
filtered_dataset.append(sample)
|
148 |
+
dataset = filtered_dataset
|
149 |
+
|
150 |
+
max_len, min_len, avg_len = get_dataset_stats(dataset)
|
151 |
+
print(
|
152 |
+
f"{args.subset} dataset stats : num_samples={len(dataset)} max_len = {max_len} min_len = {min_len} avg_len = {avg_len}"
|
153 |
+
)
|
154 |
+
|
155 |
+
return dataset
|
156 |
+
|
157 |
+
|
158 |
+
def randomly_sample_subset(dataset, size=500):
|
159 |
+
"""Randomly sample subset from a dataset"""
|
160 |
+
random_indices = [random.randint(0, len(dataset) - 1) for i in range(size)]
|
161 |
+
return [dataset[i] for i in random_indices]
|
162 |
+
|
163 |
+
|
164 |
+
def get_short_data_subset(dataset, size=500):
|
165 |
+
"""Get a subset of desired size by sorting based on src_lengths"""
|
166 |
+
return sort_dataset(dataset)[:size]
|
167 |
+
|
168 |
+
|
169 |
+
def get_long_data_subset(dataset, size=500):
|
170 |
+
"""Get a subset of desired size by sorting based on src_lengths descending"""
|
171 |
+
return sort_dataset(dataset, reverse=True)[:size]
|
172 |
+
|
173 |
+
|
174 |
+
def sort_dataset(dataset, reverse=False):
|
175 |
+
return sorted(
|
176 |
+
dataset, key=lambda x: x["net_input"]["src_lengths"].item(), reverse=reverse
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
def save_dataset_npy(dataset, file_name):
|
181 |
+
"""Save a dataset as .npy file"""
|
182 |
+
np.save(file_name, dataset)
|
183 |
+
|
184 |
+
|
185 |
+
def get_dataset_stats(dataset):
|
186 |
+
"""Get stats about dataset based on src_lengths of samples"""
|
187 |
+
max_len = 0
|
188 |
+
min_len = 100000
|
189 |
+
avg_len = 0
|
190 |
+
for d in dataset:
|
191 |
+
max_len = max(max_len, d["net_input"]["src_lengths"].item())
|
192 |
+
min_len = min(min_len, d["net_input"]["src_lengths"].item())
|
193 |
+
avg_len += d["net_input"]["src_lengths"].item()
|
194 |
+
|
195 |
+
return max_len, min_len, avg_len / len(dataset)
|
196 |
+
|
197 |
+
|
198 |
+
def make_parser():
|
199 |
+
"""
|
200 |
+
Additional args:
|
201 |
+
1. Provide the dataset dir path using --data.
|
202 |
+
2. Loading the dataset doesn't require config, provide --config-yaml to apply additional feature transforms
|
203 |
+
"""
|
204 |
+
parser = options.get_speech_generation_parser()
|
205 |
+
parser.add_argument(
|
206 |
+
"--subset",
|
207 |
+
default=None,
|
208 |
+
type=str,
|
209 |
+
required=True,
|
210 |
+
help="Subset to use for dataset generation",
|
211 |
+
)
|
212 |
+
parser.add_argument(
|
213 |
+
"--dataset-save-dir",
|
214 |
+
default=None,
|
215 |
+
type=str,
|
216 |
+
required=False,
|
217 |
+
help="Dir path in which the datasets are to be saved",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--ref-dataset",
|
221 |
+
default=None,
|
222 |
+
type=str,
|
223 |
+
required=False,
|
224 |
+
help="If provided, the ids in the reference dataset will be used to filter the new dataset generated.",
|
225 |
+
)
|
226 |
+
parser.add_argument("--dataset-save-token", default="", type=str, required=False)
|
227 |
+
|
228 |
+
options.add_generation_args(parser)
|
229 |
+
return parser
|
230 |
+
|
231 |
+
|
232 |
+
def get_ids_from_dataset(dataset):
|
233 |
+
return {sample["id"]: 1 for sample in dataset}
|
234 |
+
|
235 |
+
|
236 |
+
def cli_main():
|
237 |
+
parser = make_parser()
|
238 |
+
args = options.parse_args_and_arch(parser)
|
239 |
+
dataset = load_dataset_task(args)
|
240 |
+
|
241 |
+
random_dataset = randomly_sample_subset(dataset)
|
242 |
+
short_dataset = get_short_data_subset(dataset)
|
243 |
+
long_dataset = get_long_data_subset(dataset)
|
244 |
+
|
245 |
+
if args.dataset_save_token:
|
246 |
+
args.dataset_save_token = f"_{args.dataset_save_token}_"
|
247 |
+
|
248 |
+
if args.dataset_save_dir:
|
249 |
+
save_dataset_npy(
|
250 |
+
random_dataset,
|
251 |
+
f"{args.dataset_save_dir}/random_dataset{args.dataset_save_token}w_ids.npy",
|
252 |
+
)
|
253 |
+
save_dataset_npy(
|
254 |
+
short_dataset,
|
255 |
+
f"{args.dataset_save_dir}/short_dataset{args.dataset_save_token}w_ids.npy",
|
256 |
+
)
|
257 |
+
save_dataset_npy(
|
258 |
+
long_dataset,
|
259 |
+
f"{args.dataset_save_dir}/long_dataset{args.dataset_save_token}w_ids.npy",
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
cli_main()
|
fairseq/examples/speech_to_speech/benchmarking/get_metrics.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
from argparse import Namespace
|
5 |
+
import yaml
|
6 |
+
from fairseq import options
|
7 |
+
from examples.speech_to_speech.benchmarking.core import (
|
8 |
+
Processing,
|
9 |
+
SpeechGeneration,
|
10 |
+
Cascaded2StageS2ST,
|
11 |
+
Cascaded3StageS2ST,
|
12 |
+
S2UT,
|
13 |
+
)
|
14 |
+
from examples.speech_to_speech.benchmarking.data_utils import (
|
15 |
+
load_dataset_npy,
|
16 |
+
load_dataset_raw_to_waveforms,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
logging.basicConfig()
|
21 |
+
logging.root.setLevel(logging.INFO)
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
torch.manual_seed(1)
|
26 |
+
torch.set_deterministic(True)
|
27 |
+
|
28 |
+
|
29 |
+
def make_parser():
|
30 |
+
"""Note: As the names indicate use s2x_args(ex:ST, ASR etc) for models with speech input,
|
31 |
+
x2s_args for models with speech output(ex:TTS) and mt_args for translation models (ex: mt, T2U etc).
|
32 |
+
For direct S2ST models, use x2s_args to provide model details.
|
33 |
+
"""
|
34 |
+
parser = options.get_speech_generation_parser()
|
35 |
+
parser.add_argument("--target-is-code", action="store_true", default=False)
|
36 |
+
parser.add_argument("--config", type=str)
|
37 |
+
parser.add_argument(
|
38 |
+
"--model-type",
|
39 |
+
default="S2U",
|
40 |
+
choices=["S2S", "TTS", "S2UT", "MT", "S2T", "2StageS2ST", "3StageS2ST"],
|
41 |
+
help="Choose one of the models. For model inference implementation, refer to core.py",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--dataset-path",
|
45 |
+
type=str,
|
46 |
+
help="""File to load dataset from. Assumes dataset is a list of samples.
|
47 |
+
Each sample is a dict of format {'net_input':{'src_tokens':torch.tenor(),'src_lengths':torch.tensor()}}""",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--dataset-type",
|
51 |
+
type=str,
|
52 |
+
default="npy",
|
53 |
+
choices=["npy", "raw"],
|
54 |
+
help="""Type of input dataset file""",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--read-using-sf",
|
58 |
+
type=str,
|
59 |
+
default=False,
|
60 |
+
help="""If sound file should be used to read the raw dataset""",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--dataset-size",
|
64 |
+
default=None,
|
65 |
+
type=int,
|
66 |
+
help="Dataset size to use for benchmarking",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--dump-speech-waveforms-dir",
|
70 |
+
default=None,
|
71 |
+
type=str,
|
72 |
+
help="Directory to dump the speech waveforms computed on the dataset.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--dump-waveform-file-prefix",
|
76 |
+
default="",
|
77 |
+
type=str,
|
78 |
+
help="File name prefix for the saved speech waveforms",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--feat-dim", default=80, type=int, help="Input feature dimension"
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--target-sr",
|
85 |
+
default=16000,
|
86 |
+
type=int,
|
87 |
+
help="Target sample rate for dumping waveforms",
|
88 |
+
)
|
89 |
+
|
90 |
+
options.add_generation_args(parser)
|
91 |
+
options.get_interactive_generation_parser(parser)
|
92 |
+
return parser
|
93 |
+
|
94 |
+
|
95 |
+
def cli_main():
|
96 |
+
parser = make_parser()
|
97 |
+
args = options.parse_args_and_arch(parser)
|
98 |
+
|
99 |
+
with open(
|
100 |
+
args.config,
|
101 |
+
"r",
|
102 |
+
) as f:
|
103 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
104 |
+
dict_args = vars(args)
|
105 |
+
dict_args.update(config["general"])
|
106 |
+
args = Namespace(**dict_args)
|
107 |
+
|
108 |
+
i = 1
|
109 |
+
stage_args = []
|
110 |
+
while i <= 3:
|
111 |
+
var = f"stage{i}"
|
112 |
+
tmp_args = copy.deepcopy(dict_args)
|
113 |
+
if var in config:
|
114 |
+
tmp_args.update(config[var])
|
115 |
+
stage_args.append(Namespace(**tmp_args))
|
116 |
+
i += 1
|
117 |
+
else:
|
118 |
+
break
|
119 |
+
|
120 |
+
if args.model_type == "S2S" or args.model_type == "TTS":
|
121 |
+
model = SpeechGeneration(stage_args[0])
|
122 |
+
elif args.model_type == "S2UT":
|
123 |
+
model = S2UT(stage_args[0], stage_args[1] if len(stage_args) > 1 else None)
|
124 |
+
elif args.model_type == "MT" or args.model_type == "S2T":
|
125 |
+
model = Processing(stage_args[0])
|
126 |
+
elif args.model_type == "2StageS2ST":
|
127 |
+
model = Cascaded2StageS2ST(stage_args[0], stage_args[1])
|
128 |
+
elif args.model_type == "3StageS2ST":
|
129 |
+
model = Cascaded3StageS2ST(stage_args[0], stage_args[2], stage_args[1])
|
130 |
+
else:
|
131 |
+
raise Exception(f"Currently unsupported model type {args.model_type}")
|
132 |
+
|
133 |
+
print(f"Evaluating on dataset - {args.dataset_path}\n")
|
134 |
+
|
135 |
+
if args.dataset_type == "npy":
|
136 |
+
dataset = load_dataset_npy(args.dataset_path, dataset_size=args.dataset_size)
|
137 |
+
elif args.dataset_type == "raw":
|
138 |
+
dataset = load_dataset_raw_to_waveforms(
|
139 |
+
args.dataset_path,
|
140 |
+
dataset_size=args.dataset_size,
|
141 |
+
read_using_soundfile=args.read_using_sf,
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
raise Exception(f"Invalid dataset type {args.dataset_type}")
|
145 |
+
|
146 |
+
model.warm_up(sample=dataset[0], repeat=2)
|
147 |
+
|
148 |
+
run_time, memory, flops = model.gather_all_metrics(dataset, repeat=1)
|
149 |
+
print(f"run_time = {run_time}sec \tmemory = {memory}MiB \tflops = {flops}")
|
150 |
+
|
151 |
+
if args.dump_speech_waveforms_dir:
|
152 |
+
model.dump_final_speech_output(
|
153 |
+
dataset,
|
154 |
+
args.dump_speech_waveforms_dir,
|
155 |
+
lambda x: x,
|
156 |
+
args.target_sr,
|
157 |
+
prefix=args.dump_waveform_file_prefix,
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
cli_main()
|
fairseq/examples/speech_to_speech/docs/data_augmentation.md
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Noise and audio augmentation techniques
|
2 |
+
|
3 |
+
The noise and data augmentation techniques were written in an effort to understand how augmenatation can affect model robustness and performance in both clean and noisy settings.
|
4 |
+
|
5 |
+
All transforms discussed in this section are subclasses of `AudioFeatureTransform`, `AudioWaveformTransform`, or `AudioDatasetTransform`. Each `Audio*Transform` has unique interaction with the data. If interested in implemented one's own transforms, it is highly advisable to review the differences (see [Adding your own transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#adding-your-own-transforms)). If only applying the in-built transforms, then one only needs to be mindful that the correct kind of transform is listed in the config (see [Using transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#using-transforms)). These transforms can be applied to instances of `SpeechToTextDataset`.
|
6 |
+
|
7 |
+
### Contents
|
8 |
+
[In-built transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#in-built-transforms)
|
9 |
+
|
10 |
+
[Benchmark studies](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#benchmark-studies)
|
11 |
+
|
12 |
+
[Using transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#using-transforms)
|
13 |
+
|
14 |
+
[Adding your own transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#adding-your-own-transforms)
|
15 |
+
|
16 |
+
|
17 |
+
## In-built transforms
|
18 |
+
### 1. Utterance concatenation
|
19 |
+
Utterance concatenation is a data augmenation technique introduced as ConcatAug in [Translatotron 2: High-quality direct speech-to-speech translation
|
20 |
+
with voice preservation](https://arxiv.org/pdf/2107.08661.pdf).
|
21 |
+
With some parameterized probability, samples are concatenated with one other randomly chosen sample from the whole dataset. In the positive (concatenation) case, accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]+source[j]` and `target=target[i]+target[j]`. In the negative (skip concatenation) case, accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]` and `target=target[i]` as usual.
|
22 |
+
|
23 |
+
**Usage**: `concataugment` is an `AudioDatasetTransform` and has three configurable hyperparameters:
|
24 |
+
- `rate`: probability that any single access will result in the positive (concatenation) case. Defaults to 0.25.
|
25 |
+
- `max_tokens`: maximum number of tokens allowed for concatenated source sequences. This parameter is meant to limit the length of concatenated samples to avoid out-of-memory errors. Defaults to 300.
|
26 |
+
- `attempts`: maximum number of invalid concatenation attempts before defaulting to the negative (skip concatenation) case. This parameter aims to limit excessive time spent trying to find candidate samples that are short enough to concatenate with. Defaults to 5.
|
27 |
+
|
28 |
+
Please be wary of OOMs while using this augmentation technique; we used smaller batch sizes as a workaround to avoid OOMs. Batch size is determined by update frequency, batch size hyperparameter, and the number of GPU, so you may want to alter these to this end.
|
29 |
+
|
30 |
+
### 2. Noise augmentation suite
|
31 |
+
|
32 |
+
The four noise augmentation methods in this suite adhere to the following principle: with some parameterized probability, samples are overlayed with a noise track. The content of the noise track is specific to the method. Signal-to-noise ratio with which the noise track is overlayed is determined by choosing a value from a random uniform distribution with parameterized endpoints. The first three methods are based off data augmentation methods suggested in Section 3.3 of [X-Vectors: Robust DNN Embeddings for Speaker Recognition](https://danielpovey.com/files/2018_icassp_xvectors.pdf).
|
33 |
+
|
34 |
+
#### 2.1. Music augmentation
|
35 |
+
For music augmentation, the noise track consists of one file uniformly randomly selected from a corpus of music files. The music file is cut to size, including being repeated to fill the original sample length if necessary.
|
36 |
+
|
37 |
+
**Usage**: `musicaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
|
38 |
+
- `samples_path`: path where background music files are saved as audios (.wav files). No default.
|
39 |
+
- `rate`: probability that any single access will result in the positive (background music) case. Defaults to 0.25.
|
40 |
+
- `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
41 |
+
- `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
|
42 |
+
|
43 |
+
#### 2.2. Babble augmentation
|
44 |
+
For babble augmentation, the noise track consists of multiple audios uniformly randomly selected from a corpus of speech files. The number of speech audios in the background track is chosen randomly with equal probability between 3 and 7 audios.
|
45 |
+
|
46 |
+
**Usage**: `babbleaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
|
47 |
+
- `samples_path`: path where background speech files are saved as audios (.wav files). No default.
|
48 |
+
- `rate`: probability that any single access will result in the positive (background speech) case. Defaults to 0.25.
|
49 |
+
- `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
50 |
+
- `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
|
51 |
+
|
52 |
+
#### 2.3. Sporadic noise augmentation
|
53 |
+
For sporadic noise augmentation, the noise track is mostly silent except for intermittent short clips of noise which are added at roughly a parameterized frequency. These clips are randomly chosen and cut from a corpus of noise files to lengths according to a parameterized Gaussian distribution.
|
54 |
+
|
55 |
+
**Usage**: `sporadicnoiseaugment` is an `AudioWaveformTransform` and has seven configurable hyperparameters:
|
56 |
+
- `samples_path`: path where background noise files are saved as audios (.wav files). No default.
|
57 |
+
- `rate`: probability that any single access will result in the positive (add a sporadic noise track) case. Defaults to 0.25.
|
58 |
+
- `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
59 |
+
- `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
|
60 |
+
- `noise_rate`: rate in noises per second at which noise clip will be added to the original sample
|
61 |
+
- `noise_len_mean`: mean of Gaussian normal distribution from which length of noise clip is chosen
|
62 |
+
- `noise_len_std`: standard deviation of Gaussian normal distribution from which length of noise clip is chosen
|
63 |
+
|
64 |
+
#### 2.4. Background noise augmentation
|
65 |
+
For background noise augmentation, the noise track is a single track uniformly randomly selected from a corpus of noise files. The noise file is cut to size, including being repeated to fill the original sample length if necessary.
|
66 |
+
|
67 |
+
**Usage**: `backgroundnoiseaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
|
68 |
+
- `samples_path`: path where background noise files are saved as audios (.wav files). No default.
|
69 |
+
- `rate`: probability that any single access will result in the positive (background noise) case. Defaults to 0.25.
|
70 |
+
- `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
71 |
+
- `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
|
72 |
+
|
73 |
+
### 3. Mixed babble and background noise augmentation with recognizable source speaker
|
74 |
+
|
75 |
+
This augmentation technique is based on Algorithm 1 in [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) and is similar to the noise augmentation suite techniques in that it has a background noise track. The noise track consists of either (1) another audio sample from the batch or (2) a background noise track. A key difference is the length of the noise track is chosen from a uniform random distribution between 0 and half of the original sample length.
|
76 |
+
|
77 |
+
**Usage**: `noisyoverlapaugment` is an `AudioDatasetTransform` and has seven configurable hyperparameters:
|
78 |
+
- `noises_path`: path where background noise files are saved as audios (.wav files). No default.
|
79 |
+
- `rate`: probability that any single access will result in the positive (background noise) case. Defaults to 0.25.
|
80 |
+
- `mixing_noise_rate`: probability that in a positive (background noise) case, the noise track will consist of background noise (rather than babble from the batch). Defaults to 0.1.
|
81 |
+
- `noise_snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to -5.
|
82 |
+
- `noise_snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
83 |
+
- `utterance_snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add **another audio from the batch** to the original source. Defaults to -5.
|
84 |
+
- `utterance_snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add **another audio from the batch** to the original source. Defaults to 5.
|
85 |
+
|
86 |
+
## Benchmark studies
|
87 |
+
### Evaluation on clean data
|
88 |
+
Augmentation in training data|Hyperparameters|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
89 |
+
---|---|---|---|---|---
|
90 |
+
None||3.954|24.984|23.962|24.448
|
91 |
+
ConcatAugment|rate = 0.25, max_tokens = 3000, attempts = 5|3.940|25.322|26.124|26.19
|
92 |
+
BabbleAugment|rate = 0.25, MUSAN speech, snr_min = (-5), snr_max = 5|3.957|24.226|23.186|22.368|
|
93 |
+
BackgroundNoiseAugment|rate = 0.1, MUSAN noises, snr_min = (-10), snr_max = 10|3.955|24.745|23.513|23.819
|
94 |
+
MusicAugment|rate = 0.25, MUSAN music, snr_min = 0, snr_max = 20|3.954|25.096|24.301|23.341|
|
95 |
+
SporadicNoiseAugment|rate = 0.1, noise_rate = 0.25, MUSAN noises, snr_min = 10, snr_max = 35|3.954|24.924|23.951|23.484|
|
96 |
+
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|as above, except limited rates to sum to 0.25: music (0.074), background (0.029), babble (0.074), sporadic (0.029)|3.953|24.874|23.675|24.249|
|
97 |
+
NoisyOverlapAugment|rate = 0.25, mixing_noise_rate = 0.5, MUSAN noises, utterance_snr_min = (-10), utterance_snr_max = 0, noise_snr_min = (-5), noise_snr_max = 20|3.954|24.949|24.015|23.768|
|
98 |
+
|
99 |
+
### Evaluation on data with music noise added at SNR = (-5) - 5
|
100 |
+
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
101 |
+
---|---|---|---|---
|
102 |
+
None|3.954|15.785|21.105|16.944
|
103 |
+
ConcatAugment|3.940|17.186|23.255|18.24
|
104 |
+
BabbleAugment|3.957|19.158|22.064|17.116
|
105 |
+
BackgroundNoiseAugment|3.955|17.777|22.0|17.535|
|
106 |
+
MusicAugment|3.954|20.345|23.126|19.433|
|
107 |
+
SporadicNoiseAugment|3.954|15.927|21.382|14.736|
|
108 |
+
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|19.724|22.659|17.852|
|
109 |
+
NoisyOverlapAugment|3.954|17.49|22.142|17.207|
|
110 |
+
|
111 |
+
### Evaluation on data with babble noise added at SNR = (-5) - 5
|
112 |
+
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
113 |
+
---|---|---|---|---
|
114 |
+
None|3.954|4.092|13.514|5.13
|
115 |
+
ConcatAugment|3.940|5.493|15.835|6.893
|
116 |
+
BabbleAugment|3.957|16.12|21.097|13.996
|
117 |
+
BackgroundNoiseAugment|3.955|4.691|15.784|5.982
|
118 |
+
MusicAugment|3.954|8.06|17.764|9.008
|
119 |
+
SporadicNoiseAugment|3.954|4.009|13.935|4.814
|
120 |
+
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|14.692|20.882|14.45
|
121 |
+
NoisyOverlapAugment|3.954|4.032|16.434|7.284
|
122 |
+
|
123 |
+
### Evaluation on data with sporadic noise added at SNR = (-5) - 5
|
124 |
+
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
125 |
+
---|---|---|---|---
|
126 |
+
None|3.954|23.778|23.745|22.748
|
127 |
+
ConcatAugment|3.940|24.239|25.907|25.723
|
128 |
+
BabbleAugment|3.957|23.42|23.048|21.076
|
129 |
+
BackgroundNoiseAugment|3.955|23.998|23.467|22.494
|
130 |
+
MusicAugment|3.954|24.142|24.181|19.143
|
131 |
+
SporadicNoiseAugment|3.954|23.97|23.894|22.61
|
132 |
+
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|24.118|23.59|23.717
|
133 |
+
NoisyOverlapAugment|3.954|24.265|24.103|23.167
|
134 |
+
|
135 |
+
### Evaluation on data with background noise added at SNR = (-5) - 5
|
136 |
+
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
137 |
+
---|---|---|---|---
|
138 |
+
None|3.954|20.201|22.525|19.66
|
139 |
+
ConcatAugment|3.940|20.904|24.706|21.353
|
140 |
+
BabbleAugment|3.957|20.687|22.374|18.907
|
141 |
+
BackgroundNoiseAugment|3.955|21.574|22.998|20.043
|
142 |
+
MusicAugment|3.954|21.65|23.529|19.87
|
143 |
+
SporadicNoiseAugment|3.954|20.578|22.577|19.096
|
144 |
+
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|21.811|23.144|20.986
|
145 |
+
NoisyOverlapAugment|3.954|21.312|23.153|20.302
|
146 |
+
|
147 |
+
### Evaluation on data with all four types of noises added at SNR = (-5) - 5, each applied with prob 0.5
|
148 |
+
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
149 |
+
---|---|---|---|---
|
150 |
+
None|3.954|10.895|19.319|12.748
|
151 |
+
ConcatAugment|3.940|13.517|21.658|15.428
|
152 |
+
BabbleAugment|3.957|18.09|21.384|16.018
|
153 |
+
BackgroundNoiseAugment|3.955|12.837|20.719|13.933
|
154 |
+
MusicAugment|3.954|16.589|21.823|15.927
|
155 |
+
SporadicNoiseAugment|3.954|11.238|19.91|13.31
|
156 |
+
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|18.636|21.935|17.845
|
157 |
+
NoisyOverlapAugment|3.954|12.829|20.856|15.048
|
158 |
+
|
159 |
+
### Evaluation on data with noisy overlap augment
|
160 |
+
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
161 |
+
---|---|---|---|---
|
162 |
+
None|3.954|21.245|22.24|20.994
|
163 |
+
ConcatAugment|3.940|21.611|24.247|23.068
|
164 |
+
BabbleAugment|3.957|21.867|21.987|20.099|
|
165 |
+
BackgroundNoiseAugment|3.955|21.533|21.806|19.717|
|
166 |
+
MusicAugment|3.954|21.823|22.643|20.847|
|
167 |
+
SporadicNoiseAugment|3.954|21.373|22.381|20.672|
|
168 |
+
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|22.206|22.414|21.375|
|
169 |
+
NoisyOverlapAugment|3.954|23.371|23.396|22.627|
|
170 |
+
|
171 |
+
## Using transforms
|
172 |
+
Transforms are configurable.
|
173 |
+
|
174 |
+
1. Please pay careful attention to the type of transform you are applying.
|
175 |
+
- `concataugment` and `noisyoverlapaugment` are instances of `AudioDatasetTransform` and should be listed in the config under `dataset_transforms`.
|
176 |
+
- `musicaugment`, `babbleaugment`, `sporadicnoiseaugment`, and `backgroundnoiseaugment` are instances of `AudioWaveformTransform` and should be listed under `waveform_transforms`.
|
177 |
+
- Instances of `AudioFeatureTransform` should be listed under `feature_transforms`.
|
178 |
+
2. Feel free to apply these augmentations in different contexts, e.g., you may use a `_train` or `_eval` flag to specify when the transform will be applied. If the dataset at hand contains `train` in its name, those transforms under the `_train` flag will be applied; else, the remaining transforms will be applied.
|
179 |
+
|
180 |
+
For example, you would add this to your config to apply the musicaugment transform to a training dataset:
|
181 |
+
```yaml
|
182 |
+
musicaugment:
|
183 |
+
samples_path: ${MUSIC_PATH}
|
184 |
+
snr_min: 10
|
185 |
+
snr_max: 15
|
186 |
+
rate: 0.25
|
187 |
+
waveform_transforms:
|
188 |
+
_train:
|
189 |
+
- musicaugment
|
190 |
+
```
|
191 |
+
or add this to apply the concataugment transform:
|
192 |
+
```yaml
|
193 |
+
concataugment:
|
194 |
+
rate: 0.25
|
195 |
+
max_tokens: 3000
|
196 |
+
attempts: 5
|
197 |
+
dataset_transforms:
|
198 |
+
_train:
|
199 |
+
- concataugment
|
200 |
+
```
|
201 |
+
You may also want to add multiple of one type of transform; here, we add multiple `AudioWaveformTransform`s:
|
202 |
+
```yaml
|
203 |
+
musicaugment:
|
204 |
+
samples_path: ${MUSIC_PATH}
|
205 |
+
snr_min: 5
|
206 |
+
snr_max: 20
|
207 |
+
rate: 0.25
|
208 |
+
backgroundnoiseaugment:
|
209 |
+
samples_path: ${NOISES_PATH}
|
210 |
+
snr_min: 10
|
211 |
+
snr_max: 20
|
212 |
+
rate: 0.1
|
213 |
+
sporadicnoiseaugment:
|
214 |
+
samples_path: ${NOISES_PATH}
|
215 |
+
snr_min: 5
|
216 |
+
snr_max: 15
|
217 |
+
rate: 0.1
|
218 |
+
noise_rate: 0.25
|
219 |
+
waveform_transforms:
|
220 |
+
_train:
|
221 |
+
- musicaugment
|
222 |
+
- backgroundnoiseaugment
|
223 |
+
- sporadicnoiseaugment
|
224 |
+
```
|
225 |
+
|
226 |
+
## Adding your own transforms
|
227 |
+
Note: We store transform implementations in `fairseq/data/audio/*_transforms` directories. You may refer to these as examples while implementing your own transform.
|
228 |
+
|
229 |
+
### Step 1. Picking the right class for your transform
|
230 |
+
The integration into SpeechToTextDataset is quite different for each kind of transform, so it is important to understand which one is best suited to your purposes.
|
231 |
+
|
232 |
+
**Feature transforms**
|
233 |
+
`AudioFeatureTransform` is a base class which allows **some transform to be applied to audio spectrograms** in the data loading step. One thing to note is that the source data is either saved as `np.ndarrays` or as audio files, and is to be returned either as features (spectrogram) or waveform. If and only if the data is to be returned as a spectrogram, then `AudioFeatureTransform`s will be applied.
|
234 |
+
|
235 |
+
**Waveform transforms**
|
236 |
+
`AudioWaveformTransform` is a base class which allows some **transform to be applied to waveforms** in the data loading step. As mentioned above, there are two source and return types to data loading for this dataset. If and only if the data is saved in audio file format, then `AudioWaveformTransform`s will be applied, whichever return type is used.
|
237 |
+
|
238 |
+
**Dataset transforms**
|
239 |
+
`AudioDatasetTransform` is a base class for transforms **based on more than one item in a dataset**, ex. concatenation of two random samples in a dataset. Rather than being applied in a consistent way, i.e., to all features or to all waveforms, the integration of a dataset transform is entirely specific. Adding a dataset transform requires actually editing the `fairseq/data/audio/speech_to_text_dataset.py` file.
|
240 |
+
|
241 |
+
### Step 2. Setting up your transform (generic to all types of transforms)
|
242 |
+
Now that you know which kind of transform you would like to use, we are ready to implement it. This step is generic for all transform types, i.e., `TRANSFORM_TYPE` may be any of `feature`, `waveform`, or `dataset`. We will show how to build utterance concatenation (an `AudioDatasetTransform`) as an example.
|
243 |
+
|
244 |
+
Import the base class and registration function for your transform.
|
245 |
+
```python
|
246 |
+
from fairseq.data.audio.dataset_transforms import (
|
247 |
+
AudioDatasetTransform,
|
248 |
+
register_audio_dataset_transform
|
249 |
+
)
|
250 |
+
```
|
251 |
+
|
252 |
+
Define the class and register the transform. The name passed into the registration function is how your transform should be named in the config.
|
253 |
+
```python
|
254 |
+
@register_audio_dataset_transform("concataugment")
|
255 |
+
class ConcatAugment(AudioDatasetTransform):
|
256 |
+
```
|
257 |
+
|
258 |
+
We are now ready to add the basic important functions to our new class. In this example, `_DEFAULTS` refers to a dictionary with the default hyperparameter values that we defined. `from_config_dict` is called to instantiate the transform given hyperparameters from the config.
|
259 |
+
```python
|
260 |
+
@classmethod
|
261 |
+
def from_config_dict(cls, config=None):
|
262 |
+
_config = {} if config is None else config
|
263 |
+
return ConcatAugment(
|
264 |
+
_config.get("rate", _DEFAULTS["rate"]),
|
265 |
+
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
|
266 |
+
_config.get("attempts", _DEFAULTS["attempts"]),
|
267 |
+
)
|
268 |
+
```
|
269 |
+
We edit the instantiation function `__init__` to track hyperparameters and do any setup work.
|
270 |
+
```python
|
271 |
+
def __init__(
|
272 |
+
self,
|
273 |
+
rate=_DEFAULTS["rate"],
|
274 |
+
max_tokens=_DEFAULTS["max_tokens"],
|
275 |
+
attempts=_DEFAULTS["attempts"],
|
276 |
+
):
|
277 |
+
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
|
278 |
+
```
|
279 |
+
Lastly `__repr__` gives how the transform will be reported in an output log.
|
280 |
+
```python
|
281 |
+
def __repr__(self):
|
282 |
+
return (
|
283 |
+
self.__class__.__name__
|
284 |
+
+ "("
|
285 |
+
+ ", ".join(
|
286 |
+
[
|
287 |
+
f"rate={self.rate}",
|
288 |
+
f"max_tokens={self.max_tokens}",
|
289 |
+
f"attempts={self.attempts}",
|
290 |
+
]
|
291 |
+
)
|
292 |
+
+ ")"
|
293 |
+
)
|
294 |
+
```
|
295 |
+
|
296 |
+
### Step 3. Adding the transform logic
|
297 |
+
At this point, we are ready to implement the actual transform logic. The flow from here is different for each of the three transforms, so follow the path that is relevant to you.
|
298 |
+
### ...for feature transforms
|
299 |
+
The final step is implementing the `__call__` function, which applies the transform logic and **returns** the spectrogram with transform applied. This supports and should take exactly **two arguments**:
|
300 |
+
- `self`
|
301 |
+
- `x` (np.ndarray): the spectrogram for one source sample. (This is a positional argument, so you can use another parameter name like `spectrogram` instead of `x`.)
|
302 |
+
|
303 |
+
For example, this is the `__call__` function for GlobalCMVN (cepstral mean and variance normalization).
|
304 |
+
```python
|
305 |
+
def __call__(self, x):
|
306 |
+
x = np.subtract(x, self.mean)
|
307 |
+
x = np.divide(x, self.std)
|
308 |
+
return x
|
309 |
+
|
310 |
+
```
|
311 |
+
### ...for waveform transforms
|
312 |
+
The final step is implementing the `__call__` function, which applies the transform logic. This supports and should take exactly **three arguments**:
|
313 |
+
- `self`
|
314 |
+
- `source` (numpy.ndarray or torch.Tensor): source audio 2d waveform (channels x length)
|
315 |
+
- `sample_rate` (optional, defaults to None): sample rate of `source`
|
316 |
+
|
317 |
+
`__call__` **returns**:
|
318 |
+
- transformed audio waveform
|
319 |
+
- sample rate of transformed audio waveform
|
320 |
+
|
321 |
+
For example, this is the `__call__` function for augmentations in the Noise Augmentation Suite.
|
322 |
+
```python
|
323 |
+
def __call__(self, source, sample_rate=None):
|
324 |
+
if np.random.random() > self.rate:
|
325 |
+
return source
|
326 |
+
|
327 |
+
noise = self._get_noise(
|
328 |
+
source.shape, always_2d=True, use_sample_rate=sample_rate
|
329 |
+
)
|
330 |
+
return self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)), sample_rate
|
331 |
+
```
|
332 |
+
|
333 |
+
### ...for dataset transforms
|
334 |
+
Dataset transforms are extremely flexible, and implementation involves directly integrating them into `fairseq/data/audio/speech_to_text_dataset.py` in transform-specific ways.
|
335 |
+
There are two basic components: (1) check whether or not this transform is part of this dataset instance using `self.dataset_transforms.has_transform(TRANSFORM_CLS)`, and (2) if so, get the transform using `self.dataset_transforms.get_transform(TRANSFORM_CLS)` & apply it.
|
336 |
+
Due to the case-by-case specificity, it is easier to demonstrate this by examples.
|
337 |
+
|
338 |
+
#### Example: NoisyOverlapAugment
|
339 |
+
This transform requires access to multiple items within the same batch at once.
|
340 |
+
|
341 |
+
**Logic**: We still use the transform classes to keep away the transform logic. For example, `__call__` of `NoisyOverlapAugment` class takes a list of source tokens for items in a mini-batch, applies noise/utterance as dictated by the transform, and returns the list of transformed source tokens for items in the mini-batch.
|
342 |
+
|
343 |
+
```python
|
344 |
+
def __call__(self, sources):
|
345 |
+
for i, source in enumerate(sources):
|
346 |
+
if np.random.random() > self.rate:
|
347 |
+
continue
|
348 |
+
|
349 |
+
pri = source.numpy()
|
350 |
+
|
351 |
+
# ... some transform code omitted
|
352 |
+
|
353 |
+
pri[s_source : s_source + l] = np.add(
|
354 |
+
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
|
355 |
+
)
|
356 |
+
sources[i] = torch.from_numpy(pri).float()
|
357 |
+
|
358 |
+
return sources
|
359 |
+
```
|
360 |
+
|
361 |
+
**Integration**: The `collater` function for `SpeechToTextDataset` is responsible for preparing a mini-batch for training, so we integrate NOAug through adding a few lines to the top of this function:
|
362 |
+
```python
|
363 |
+
def collater(
|
364 |
+
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
|
365 |
+
) -> Dict:
|
366 |
+
if len(samples) == 0:
|
367 |
+
return {}
|
368 |
+
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
369 |
+
|
370 |
+
sources = [x.source for x in samples]
|
371 |
+
|
372 |
+
# NOAUG INTEGRATION BLOCK
|
373 |
+
# (1) Check whether or not this transform is part of this dataset instance
|
374 |
+
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
|
375 |
+
# (2) If so, get & apply the transform
|
376 |
+
if has_NOAug and self.cfg.use_audio_input:
|
377 |
+
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
|
378 |
+
sources = NOAug(sources)
|
379 |
+
|
380 |
+
frames = _collate_frames(sources, self.cfg.use_audio_input)
|
381 |
+
# sort samples by descending number of frames
|
382 |
+
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
|
383 |
+
n_frames, order = n_frames.sort(descending=True)
|
384 |
+
indices = indices.index_select(0, order)
|
385 |
+
frames = frames.index_select(0, order)
|
386 |
+
|
387 |
+
# ... rest of function
|
388 |
+
```
|
389 |
+
|
390 |
+
#### Example: ConcatAugment
|
391 |
+
This transform requires access to another item within the dataset at once.
|
392 |
+
|
393 |
+
**Logic**: We abstract the logic for picking indices to concatenate by adding a `find_indices` function to the `ConcatAugment` class, which takes one index in the dataset and finds a compatible second index to concatenate source and target tokens.
|
394 |
+
```python
|
395 |
+
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
|
396 |
+
# skip conditions: application rate, max_tokens limit exceeded
|
397 |
+
if np.random.random() > self.rate:
|
398 |
+
return [index]
|
399 |
+
if self.max_tokens and n_frames[index] > self.max_tokens:
|
400 |
+
return [index]
|
401 |
+
|
402 |
+
# pick second sample to concatenate
|
403 |
+
for _ in range(self.attempts):
|
404 |
+
index2 = np.random.randint(0, n_samples)
|
405 |
+
if index2 != index and (
|
406 |
+
not self.max_tokens
|
407 |
+
or n_frames[index] + n_frames[index2] < self.max_tokens
|
408 |
+
):
|
409 |
+
return [index, index2]
|
410 |
+
|
411 |
+
return [index]
|
412 |
+
```
|
413 |
+
|
414 |
+
**Integration**: `SpeechToTextDataset` uses a custom `__getitem__(self, index)` function (called in the background when you write `dataset[i]`). We edited this function (as well as `_get_source_audio` and `get_tokenized_tgt_text`) to achieve the desired transform effect where accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]+source[j]` and `target=target[i]+target[j]`.
|
415 |
+
```python
|
416 |
+
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
|
417 |
+
|
418 |
+
# CONCATAUGMENT INTEGRATION BLOCK
|
419 |
+
# (1) Check whether or not this transform is part of this dataset instance
|
420 |
+
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
|
421 |
+
# (2) If so, get & apply the transform
|
422 |
+
if has_concat:
|
423 |
+
concat = self.dataset_transforms.get_transform(ConcatAugment)
|
424 |
+
indices = concat.find_indices(index, self.n_frames, self.n_samples)
|
425 |
+
|
426 |
+
source = self._get_source_audio(indices if has_concat else index)
|
427 |
+
source = self.pack_frames(source)
|
428 |
+
|
429 |
+
target = None
|
430 |
+
if self.tgt_texts is not None:
|
431 |
+
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
|
432 |
+
target = self.tgt_dict.encode_line(
|
433 |
+
|
434 |
+
# ... rest of function
|
435 |
+
```
|
fairseq/examples/speech_to_speech/docs/direct_s2st_discrete_units.md
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Direct speech-to-speech translation with discrete units
|
2 |
+
|
3 |
+
We provide the implementation for speech-to-unit translation (S2UT) proposed in "[Direct speech-to-speech translation with discrete units (Lee et al. 2021)](https://arxiv.org/abs/2107.05604)" and also the transformer-based implementation of the speech-to-spectrogram translation (S2SPECT, or transformer-based [Translatotron](https://arxiv.org/abs/1904.06037)) baseline in the paper.
|
4 |
+
|
5 |
+
## Pretrained Models
|
6 |
+
|
7 |
+
### Unit-based HiFi-GAN Vocoder
|
8 |
+
Unit config | Unit size | Vocoder dataset | Model
|
9 |
+
|---|---|---|---
|
10 |
+
[HuBERT Base, Librispeech](https://github.com/fairinternal/fairseq-py/tree/main/examples/hubert), layer 6 | 100 | [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/hubert_base_100_lj/g_00500000), [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/hubert_base_100_lj/config.json)
|
11 |
+
|
12 |
+
|
13 |
+
## Data preparation
|
14 |
+
### Target speech
|
15 |
+
0. (optional) To prepare S2S data from a speech-to-text translation (ST) dataset, see [fairseq-S^2](https://github.com/pytorch/fairseq/tree/main/examples/speech_synthesis) for pre-trained TTS models and instructions on how to train and decode TTS models.
|
16 |
+
1. Prepare two folders, `$SRC_AUDIO` and `$TGT_AUDIO`, with `${SPLIT}/${SAMPLE_ID}.wav` for source and target speech under each folder, separately. Note that for S2UT experiments, target audio sampling rate should be in 16,000 Hz, and for S2SPECT experiments, target audio sampling rate is recommended to be in 22,050 Hz.
|
17 |
+
2. To prepare target discrete units for S2UT model training, see [Generative Spoken Language Modeling (speech2unit)](https://github.com/pytorch/fairseq/tree/main/examples/textless_nlp/gslm/speech2unit) for pre-trained k-means models, checkpoints, and instructions on how to decode units from speech. Set the output target unit files (`--out_quantized_file_path`) as `${TGT_AUDIO}/${SPLIT}.txt`. In [Lee et al. 2021](https://arxiv.org/abs/2107.05604), we use 100 units from the sixth layer (`--layer 6`) of the HuBERT Base model.
|
18 |
+
|
19 |
+
### Formatting data
|
20 |
+
**Speech-to-speech data**
|
21 |
+
|
22 |
+
_S2UT_
|
23 |
+
* Set `--reduce-unit` for training S2UT _reduced_ model
|
24 |
+
* Pre-trained vocoder and config (`$VOCODER_CKPT`, `$VOCODER_CFG`) can be downloaded from the **Pretrained Models** section. They are not required if `--eval-inference` is not going to be set during model training.
|
25 |
+
```
|
26 |
+
# $SPLIT1, $SPLIT2, etc. are split names such as train, dev, test, etc.
|
27 |
+
|
28 |
+
python examples/speech_to_speech/preprocessing/prep_s2ut_data.py \
|
29 |
+
--source-dir $SRC_AUDIO --target-dir $TGT_AUDIO --data-split $SPLIT1 $SPLIT2 \
|
30 |
+
--output-root $DATA_ROOT --reduce-unit \
|
31 |
+
--vocoder-checkpoint $VOCODER_CKPT --vocoder-cfg $VOCODER_CFG
|
32 |
+
```
|
33 |
+
|
34 |
+
_S2SPECT_
|
35 |
+
```
|
36 |
+
# $SPLIT1, $SPLIT2, etc. are split names such as train, dev, test, etc.
|
37 |
+
|
38 |
+
python examples/speech_to_speech/preprocessing/prep_s2spect_data.py \
|
39 |
+
--source-dir $SRC_AUDIO --target-dir $TGT_AUDIO --data-split $SPLIT1 $SPLIT2 \
|
40 |
+
--output-root $DATA_ROOT
|
41 |
+
```
|
42 |
+
|
43 |
+
**Multitask data**
|
44 |
+
* For each multitask `$TASK_NAME`, prepare `${DATA_ROOT}/${TASK_NAME}/${SPLIT}.tsv` files for each split following the format below: (Two tab separated columns. The sample_ids should match with the sample_ids for the speech-to-speech data in `${DATA_ROOT}/${SPLIT}.tsv`.)
|
45 |
+
```
|
46 |
+
id tgt_text
|
47 |
+
sample_id_0 token1 token2 token3 ...
|
48 |
+
sample_id_1 token1 token2 token3 ...
|
49 |
+
...
|
50 |
+
```
|
51 |
+
* For each multitask `$TASK_NAME`, prepare `${DATA_ROOT}/${TASK_NAME}/dict.txt`, a dictionary in fairseq format with all tokens for the targets for `$TASK_NAME`.
|
52 |
+
* Create `config_multitask.yaml`. Below is an example of the config used for S2UT _reduced_ with Fisher experiments including two encoder multitasks (`source_letter`, `target_letter`) and one decoder CTC task (`decoder_target_ctc`).
|
53 |
+
```
|
54 |
+
source_letter: # $TASK_NAME
|
55 |
+
decoder_type: transformer
|
56 |
+
dict: ${DATA_ROOT}/source_letter/dict.txt
|
57 |
+
data: ${DATA_ROOT}/source_letter
|
58 |
+
encoder_layer: 6
|
59 |
+
loss_weight: 8.0
|
60 |
+
target_letter:
|
61 |
+
decoder_type: transformer
|
62 |
+
dict: ${DATA_ROOT}/target_letter/dict.txt
|
63 |
+
data: ${DATA_ROOT}/target_letter
|
64 |
+
encoder_layer: 8
|
65 |
+
loss_weight: 8.0
|
66 |
+
decoder_target_ctc:
|
67 |
+
decoder_type: ctc
|
68 |
+
dict: ${DATA_ROOT}/decoder_target_ctc/dict.txt
|
69 |
+
data: ${DATA_ROOT}/decoder_target_ctc
|
70 |
+
decoder_layer: 3
|
71 |
+
loss_weight: 1.6
|
72 |
+
```
|
73 |
+
|
74 |
+
|
75 |
+
## Training
|
76 |
+
|
77 |
+
**Speech-to-unit translation (S2UT)**
|
78 |
+
|
79 |
+
Here's an example for training Fisher S2UT models with 100 discrete units as target:
|
80 |
+
```
|
81 |
+
fairseq-train $DATA_ROOT \
|
82 |
+
--config-yaml config.yaml --multitask-config-yaml config_multitask.yaml \
|
83 |
+
--task speech_to_speech --target-is-code --target-code-size 100 --vocoder code_hifigan \
|
84 |
+
--criterion speech_to_unit --label-smoothing 0.2 \
|
85 |
+
--arch s2ut_transformer_fisher --share-decoder-input-output-embed \
|
86 |
+
--dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \
|
87 |
+
--train-subset train --valid-subset dev \
|
88 |
+
--save-dir ${MODEL_DIR} \
|
89 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-init-lr 1e-7 --warmup-updates 10000 \
|
90 |
+
--optimizer adam --adam-betas "(0.9,0.98)" --clip-norm 10.0 \
|
91 |
+
--max-update 400000 --max-tokens 20000 --max-target-positions 3000 --update-freq 4 \
|
92 |
+
--seed 1 --fp16 --num-workers 8
|
93 |
+
```
|
94 |
+
* Adjust `--update-freq` accordingly for different #GPUs. In the above we set `--update-freq 4` to simulate training with 4 GPUs.
|
95 |
+
* Set `--n-frames-per-step 5` to train an S2UT _stacked_ system with reduction ratio r=5. (Use `$DATA_ROOT` prepared without `--reduce-unit`.)
|
96 |
+
* (optional) one can turn on tracking MCD loss during training for checkpoint selection by setting `--eval-inference --eval-args '{"beam": 1, "max_len_a": 1}' --best-checkpoint-metric mcd_loss`. It is recommended to sample a smaller subset as the validation set as MCD loss computation is time-consuming.
|
97 |
+
|
98 |
+
**Speech-to-spectrogram translation (S2SPECT)**
|
99 |
+
|
100 |
+
Here's an example for training Fisher S2SPECT models with reduction ratio r=5:
|
101 |
+
```
|
102 |
+
fairseq-train $DATA_ROOT \
|
103 |
+
--config-yaml config.yaml --multitask-config-yaml config_multitask.yaml \
|
104 |
+
--task speech_to_speech --n-frames-per-step 5 \
|
105 |
+
--criterion speech_to_spectrogram \
|
106 |
+
--arch s2spect_transformer_fisher --decoder-normalize-before \
|
107 |
+
--dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \
|
108 |
+
--train-subset train --valid-subset dev \
|
109 |
+
--save-dir ${MODEL_DIR} \
|
110 |
+
--eval-inference --best-checkpoint-metric mcd_loss \
|
111 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-init-lr 1e-7 --warmup-updates 10000 \
|
112 |
+
--optimizer adam --adam-betas "(0.9,0.98)" --clip-norm 10.0 --weight-decay 1e-6 \
|
113 |
+
--max-update 400000 --max-tokens 80000 --max-tokens-valid 30000 --required-batch-size-multiple 1 \
|
114 |
+
--max-target-positions 3000 --update-freq 16 \
|
115 |
+
--seed 1 --fp16 --num-workers 8
|
116 |
+
```
|
117 |
+
* Adjust `--update-freq` accordingly for different #GPUs. In the above we set `--update-freq 16` to simulate training with 16 GPUs.
|
118 |
+
* We recommend turning on MCD loss during training for the best checkpoint selection.
|
119 |
+
|
120 |
+
**Unit-based HiFi-GAN vocoder**
|
121 |
+
|
122 |
+
The vocoder is trained with the [speech-resynthesis repo](https://github.com/facebookresearch/speech-resynthesis). See [here](https://github.com/facebookresearch/speech-resynthesis/tree/main/examples/speech_to_speech_translation) for instructions on how to train the unit-based HiFi-GAN vocoder with duration prediction. The same vocoder can support waveform generation for both _reduced_ unit sequences (with `--dur-prediction` set during inference) and original unit sequences.
|
123 |
+
|
124 |
+
## Inference
|
125 |
+
|
126 |
+
**Speech-to-unit translation (S2UT)**
|
127 |
+
|
128 |
+
1. Follow the same inference process as in [fairseq-S2T](https://github.com/pytorch/fairseq/tree/main/examples/speech_to_text) to generate unit sequences (`${RESULTS_PATH}/generate-${GEN_SUBSET}.txt`).
|
129 |
+
```
|
130 |
+
fairseq-generate $DATA_ROOT \
|
131 |
+
--config-yaml config.yaml --multitask-config-yaml config_multitask.yaml \
|
132 |
+
--task speech_to_speech --target-is-code --target-code-size 100 --vocoder code_hifigan \
|
133 |
+
--path $MODEL_DIR/checkpoint_best.pt --gen-subset $GEN_SUBSET \
|
134 |
+
--max-tokens 50000 \
|
135 |
+
--beam 10 --max-len-a 1 \
|
136 |
+
--results-path ${RESULTS_PATH}
|
137 |
+
```
|
138 |
+
* Set `--beam 1 --n-frames-per-step $r` for decoding with S2UT _stacked_ models.
|
139 |
+
|
140 |
+
2. Convert unit sequences to waveform.
|
141 |
+
```
|
142 |
+
grep "^D\-" ${RESULTS_PATH}/generate-${GEN_SUBSET}.txt | \
|
143 |
+
sed 's/^D-//ig' | sort -nk1 | cut -f3 \
|
144 |
+
> ${RESULTS_PATH}/generate-${GEN_SUBSET}.unit
|
145 |
+
|
146 |
+
python examples/speech_to_speech/generate_waveform_from_code.py \
|
147 |
+
--in-code-file ${RESULTS_PATH}/generate-${GEN_SUBSET}.unit \
|
148 |
+
--vocoder $VOCODER_CKPT --vocoder-cfg $VOCODER_CFG \
|
149 |
+
--results-path ${RESULTS_PATH} --dur-prediction
|
150 |
+
```
|
151 |
+
* Set `--dur-prediction` for generating audio for S2UT _reduced_ models.
|
152 |
+
|
153 |
+
|
154 |
+
**Speech-to-spectrogram translation (S2SPECT)**
|
155 |
+
|
156 |
+
Follow the same inference process as in [fairseq-S^2](https://github.com/pytorch/fairseq/tree/main/examples/speech_synthesis) to generate waveform.
|
157 |
+
|
158 |
+
```
|
159 |
+
# assume using a default Griffin-Lim vocoder
|
160 |
+
|
161 |
+
python examples/speech_synthesis/generate_waveform.py $DATA_ROOT \
|
162 |
+
--config-yaml config.yaml --multitask-config-yaml config_multitask.yaml \
|
163 |
+
--task speech_to_speech --n-frames-per-step 5 \
|
164 |
+
--path $MODEL_DIR/checkpoint_best.pt --gen-subset $GEN_SUBSET \
|
165 |
+
--max-tokens 50000 \
|
166 |
+
--results-path ${RESULTS_PATH} --dump-waveforms --output-sample-rate 16000
|
167 |
+
```
|
168 |
+
|
169 |
+
In addition to using the default Griffin-Lim vocoder, one can also finetune a HiFi-GAN vocoder for the S2SPECT model by following the instructions in the [HiFi-GAN repo](https://github.com/jik876/hifi-gan).
|
170 |
+
|
171 |
+
**Multitask decoding**
|
172 |
+
|
173 |
+
Coming soon.
|
174 |
+
|
175 |
+
## Evaluation
|
176 |
+
|
177 |
+
To evaluate speech translation output, we first apply ASR on the speech output and then compute BLEU score betweent the ASR decoded text and the references using sacreBLEU.
|
178 |
+
|
179 |
+
**En**
|
180 |
+
* ASR: We use the "[Wav2Vec 2.0 Large (LV-60) + Self Training / 960 hours / Libri-Light + Librispeech](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt)" En ASR model open-sourced by the [wav2vec](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec) project. See [instructions](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#evaluating-a-ctc-model) on how to run inference with a wav2vec-based ASR model. The model is also available on [Hugging Face](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
|
181 |
+
* Text normalization: We use the text cleaner at [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron) for pre-processing reference English text for ASR BLEU evaluation.
|
fairseq/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Speech to speech translation (S2ST)
|
2 |
+
|
3 |
+
We provide the implementation for speech-to-unit translation (S2UT) proposed in [Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation (Popuri et al. 2022)](https://arxiv.org/abs/2204.02967) and the various pretrained models used.
|
4 |
+
|
5 |
+
## Pretrained Models
|
6 |
+
|
7 |
+
### Unit extraction
|
8 |
+
|
9 |
+
We used the multilingual HuBERT model open sourced in [Textless S2ST with Real Data](textless_s2st_real_data.md)
|
10 |
+
|
11 |
+
### Wav2vec 2.0
|
12 |
+
|
13 |
+
Language | Block type | Model size | Dataset | Model |
|
14 |
+
--- | --- | --- | --- | --- |
|
15 |
+
Es | Transformer | BASE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_B.pt) |
|
16 |
+
Es | Transformer | LARGE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_L.pt) |
|
17 |
+
Es | Conformer | LARGE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/conformer_L.pt) |
|
18 |
+
En | Transformer | BASE | Librilight| [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/transformer_B.pt) |
|
19 |
+
En | Conformer | LARGE | Librilight | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/conformer_L.pt) |
|
20 |
+
|
21 |
+
### Unit mBART
|
22 |
+
|
23 |
+
Unit size | Dataset | Unit config | Model |
|
24 |
+
--- | --- | --- | --- |
|
25 |
+
1000 | [Voxpopuli](https://aclanthology.org/2021.acl-long.80) En, Es unlabelled speech | [mbart_large](https://github.com/pytorch/fairseq/blob/f591cc94caa85098ccf125a4782f91125b6a086d/fairseq/models/bart/model.py#L368) |[ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/unit_mBART/checkpoint.pt) |
|
26 |
+
|
27 |
+
## Data preparation
|
28 |
+
|
29 |
+
1. To prepare data for S2UT finetuning, follow the steps from [Direct S2ST with Discrete Units](./direct_s2st_discrete_units.md) and format the data in the _S2UT_ format. Note that we use 1000 units from the eleventh layer (`--layer 11`) of the multilingual hubert model linked above instead
|
30 |
+
2. Run
|
31 |
+
|
32 |
+
```
|
33 |
+
var="id\taudio\tn_frames\ttgt_text\ttgt_n_frames"
|
34 |
+
sed -i "1s/.*/$var/" ${SPLIT}.tsv
|
35 |
+
```
|
36 |
+
|
37 |
+
## Training
|
38 |
+
|
39 |
+
**Speech-to-unit translation (S2UT)**
|
40 |
+
|
41 |
+
Here's an example for finetuning S2UT models with 1000 discrete units as target. You can download the sample [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/config.yaml) file and [vocabulary](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/dict.txt) for Es-En from here:
|
42 |
+
|
43 |
+
```
|
44 |
+
fairseq-train $DATA_ROOT \
|
45 |
+
--config-yaml config.yaml \
|
46 |
+
--task speech_to_text --arch xm_transformer\
|
47 |
+
--criterion l --label-smoothing 0.2 \
|
48 |
+
--share-decoder-input-output-embed --adaptor-n-layers 1 --normalize\
|
49 |
+
--dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \
|
50 |
+
--train-subset train --valid-subset dev \
|
51 |
+
--load-pretrained-decoder-from ${unit_mBART} --w2v-path ${wav2vec2.0} \
|
52 |
+
--mask-prob 0.3 --mask-channel-length 32 --mask-channel-prob 0.25\
|
53 |
+
--save-dir ${MODEL_DIR} --checkpoint-activations --encoder-proj \
|
54 |
+
--lr 0.0005 --dropout 0.1 --attention-dropout 0.1 --lr-scheduler inverse_sqrt\
|
55 |
+
--warmup-init-lr 1e-7 --warmup-updates 10000 \
|
56 |
+
--optimizer adam --adam-betas "(0.9,0.98)" --clip-norm 10.0 \
|
57 |
+
--max-update 20000 --max-tokens 4000 --max-tokens-valid 4000 --max-source-positions 4000 \
|
58 |
+
--max-target-positions 4000 --update-freq 120 \
|
59 |
+
--seed 1 --fp16 --num-workers 1
|
60 |
+
```
|
61 |
+
|
62 |
+
* Adjust `--update-freq` accordingly for different #GPUs. In the above we set `--update-freq 15` to simulate training with 120 GPUs.
|
63 |
+
* In the above setting we finetune the model end to end, corresponding to the full setup in the paper.
|
64 |
+
* To apply LNA-E partial finetuning, add `--finetune-w2v-params layer_norm,self_attn`
|
65 |
+
* For LNA-D partial finetuning add `--finetune-decoder-params encoder_attn,layer_norm,self_attn`. To optionally freeze the encoder by k updates, use `--freeze-finetune-updates ${K}`
|
66 |
+
* For LNA-E,D partial finetuning add both the above options.
|
67 |
+
|
68 |
+
**Unit-based HiFi-GAN vocoder**
|
69 |
+
|
70 |
+
We apply the open-sourced unit-based HiFi-GAN vocoders to convert the predicted unit sequences to waveform. They are open sourced in [Textless S2ST with Real Data](textless_s2st_real_data.md)
|
71 |
+
|
72 |
+
## Inference
|
73 |
+
|
74 |
+
**Speech-to-unit translation (S2UT)**
|
75 |
+
|
76 |
+
1. Follow the same inference process as in [fairseq-S2T](https://github.com/pytorch/fairseq/tree/main/examples/speech_to_text) to generate unit sequences (`${RESULTS_PATH}/generate-${GEN_SUBSET}.txt`).
|
77 |
+
|
78 |
+
```
|
79 |
+
fairseq-generate $DATA_ROOT \
|
80 |
+
--config-yaml config.yaml \
|
81 |
+
--task speech_to_text \
|
82 |
+
--path $MODEL_DIR/checkpoint_best.pt --gen-subset $GEN_SUBSET \
|
83 |
+
--max-tokens 10000 --max-source-positions 10000 --max-target-positions 10000\
|
84 |
+
--beam 10 --max-len-a 1 --max-len-b 200 \
|
85 |
+
--results-path ${RESULTS_PATH}
|
86 |
+
```
|
87 |
+
|
88 |
+
2. Convert unit sequences to waveform.
|
89 |
+
|
90 |
+
```
|
91 |
+
grep "^D\-" ${RESULTS_PATH}/generate-${GEN_SUBSET}.txt | \
|
92 |
+
sed 's/^D-//ig' | sort -nk1 | cut -f3 \
|
93 |
+
> ${RESULTS_PATH}/generate-${GEN_SUBSET}.unit
|
94 |
+
|
95 |
+
python examples/speech_to_speech/generate_waveform_from_code.py \
|
96 |
+
--in-code-file ${RESULTS_PATH}/generate-${GEN_SUBSET}.unit \
|
97 |
+
--vocoder $VOCODER_CKPT --vocoder-cfg $VOCODER_CFG \
|
98 |
+
--results-path ${RESULTS_PATH} --dur-prediction
|
99 |
+
```
|
100 |
+
|
101 |
+
## Evaluation
|
102 |
+
|
103 |
+
To evaluate speech translation output, we first apply ASR on the speech output and then compute BLEU score betweent the ASR decoded text and the references using sacreBLEU.
|
104 |
+
|
105 |
+
* Text normalization: We use the text cleaner at [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron) for pre-processing reference English text for ASR BLEU evaluation. The text cleaner used for Spanish text normalization will be updated here shortly.
|
106 |
+
* En ASR: We use the "[Wav2Vec 2.0 Large (LV-60) + Self Training / 960 hours / Libri-Light + Librispeech](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt)" En ASR model open-sourced by the [wav2vec](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec) project. The model is also available on [Hugging Face](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
|
107 |
+
* Es ASR: We use the [Wav2Vec2-Large-XLSR-53-Spanish](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) finetuned on spanish Common Voice Es ASR model open-sourced by Jonatasgrosman(<https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish>) on [Hugging Face](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish).
|
108 |
+
* See [instructions](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#evaluating-a-ctc-model) on how to run inference with a wav2vec-based ASR model.
|
109 |
+
|
110 |
+
|
111 |
+
## Finetuned Model Checkpoints
|
112 |
+
|
113 |
+
ID | En - Es | Es - En |
|
114 |
+
| --- | --- | --- |
|
115 |
+
**S2UT systems without pre-training**
|
116 |
+
S2UT with multitask | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//S2UT_w_multitask.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//S2UT_w_multitask.pt) |
|
117 |
+
**S2UT systems with model pre-training**
|
118 |
+
w2v2-L | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_only.pt ) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_only.pt) |
|
119 |
+
w2v2-L + mBART (LNA-E) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_mbart_LNE.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_mbart_LNE.pt) |
|
120 |
+
w2v2-L + mBART (LNA-D) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_mbart_LND.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_mbart_LND.pt) |
|
121 |
+
w2v2-L + mBART (LNA-E,D) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_mbart_LNED.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_mbart_LNED.pt) |
|
122 |
+
**S2UT systems with model pre-training and data augmentation**
|
123 |
+
w2v2-L + mBART (LNA-D) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_mbart_LND_w_ASR.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_mbart_LND_w_ASR.pt) |
|
124 |
+
|
125 |
+
Note: Some of the tasks use speech_to_text_sharded task which is yet to be open sourced. So make sure to override the task to speech_to_text to use those models.
|
fairseq/examples/speech_to_speech/docs/textless_s2st_real_data.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Textless Speech-to-Speech Translation (S2ST) on Real Data
|
2 |
+
|
3 |
+
We provide instructions and pre-trained models for the work "[Textless Speech-to-Speech Translation on Real Data (Lee et al. 2021)](https://arxiv.org/abs/2112.08352)".
|
4 |
+
|
5 |
+
## Pre-trained Models
|
6 |
+
|
7 |
+
### HuBERT
|
8 |
+
Model | Pretraining Data | Model | Quantizer
|
9 |
+
|---|---|---|---
|
10 |
+
mHuBERT Base | [VoxPopuli](https://github.com/facebookresearch/voxpopuli) En, Es, Fr speech from the 100k subset | [download](https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3.pt) | [L11 km1000](https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3_L11_km1000.bin)
|
11 |
+
|
12 |
+
|
13 |
+
### Unit-based HiFi-GAN vocoder
|
14 |
+
Unit config | Unit size | Vocoder language | Dataset | Model
|
15 |
+
|---|---|---|---|---
|
16 |
+
mHuBERT, layer 11 | 1000 | En | [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000), [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json)
|
17 |
+
mHuBERT, layer 11 | 1000 | Es | [CSS10](https://github.com/Kyubyong/css10) | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_es_css10/g_00500000), [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_es_css10/config.json)
|
18 |
+
mHuBERT, layer 11 | 1000 | Fr | [CSS10](https://github.com/Kyubyong/css10) | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_fr_css10/g_00500000), [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_fr_css10/config.json)
|
19 |
+
|
20 |
+
|
21 |
+
### Speech normalizer
|
22 |
+
Language | Training data | Target unit config | Model
|
23 |
+
|---|---|---|---
|
24 |
+
En | 10 mins | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/en/en_10min.tar.gz)
|
25 |
+
En | 1 hr | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/en/en_1h.tar.gz)
|
26 |
+
En | 10 hrs | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/en/en_10h.tar.gz)
|
27 |
+
Es | 10 mins | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/es/es_10min.tar.gz)
|
28 |
+
Es | 1 hr | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/es/es_1h.tar.gz)
|
29 |
+
Es | 10 hrs | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/es/es_10h.tar.gz)
|
30 |
+
Fr | 10 mins | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/fr/fr_10min.tar.gz)
|
31 |
+
Fr | 1 hr | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/fr/fr_1h.tar.gz)
|
32 |
+
Fr | 10 hrs | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/fr/fr_10h.tar.gz)
|
33 |
+
|
34 |
+
* Refer to the paper for the details of the training data.
|
35 |
+
|
36 |
+
## Inference with Pre-trained Models
|
37 |
+
|
38 |
+
### Speech normalizer
|
39 |
+
1. Download the pre-trained models, including the dictionary, to `DATA_DIR`.
|
40 |
+
2. Format the audio data.
|
41 |
+
```bash
|
42 |
+
# AUDIO_EXT: audio extension, e.g. wav, flac, etc.
|
43 |
+
# Assume all audio files are at ${AUDIO_DIR}/*.${AUDIO_EXT}
|
44 |
+
|
45 |
+
python examples/speech_to_speech/preprocessing/prep_sn_data.py \
|
46 |
+
--audio-dir ${AUDIO_DIR} --ext ${AUIDO_EXT} \
|
47 |
+
--data-name ${GEN_SUBSET} --output-dir ${DATA_DIR} \
|
48 |
+
--for-inference
|
49 |
+
```
|
50 |
+
|
51 |
+
3. Run the speech normalizer and post-process the output.
|
52 |
+
```bash
|
53 |
+
mkdir -p ${RESULTS_PATH}
|
54 |
+
|
55 |
+
python examples/speech_recognition/new/infer.py \
|
56 |
+
--config-dir examples/hubert/config/decode/ \
|
57 |
+
--config-name infer_viterbi \
|
58 |
+
task.data=${DATA_DIR} \
|
59 |
+
task.normalize=false \
|
60 |
+
common_eval.results_path=${RESULTS_PATH}/log \
|
61 |
+
common_eval.path=${DATA_DIR}/checkpoint_best.pt \
|
62 |
+
dataset.gen_subset=${GEN_SUBSET} \
|
63 |
+
'+task.labels=["unit"]' \
|
64 |
+
+decoding.results_path=${RESULTS_PATH} \
|
65 |
+
common_eval.post_process=none \
|
66 |
+
+dataset.batch_size=1 \
|
67 |
+
common_eval.quiet=True
|
68 |
+
|
69 |
+
# Post-process and generate output at ${RESULTS_PATH}/${GEN_SUBSET}.txt
|
70 |
+
python examples/speech_to_speech/preprocessing/prep_sn_output_data.py \
|
71 |
+
--in-unit ${RESULTS_PATH}/hypo.units \
|
72 |
+
--in-audio ${DATA_DIR}/${GEN_SUBSET}.tsv \
|
73 |
+
--output-root ${RESULTS_PATH}
|
74 |
+
```
|
75 |
+
|
76 |
+
|
77 |
+
### Unit-to-waveform conversion with unit vocoder
|
78 |
+
The pre-trained vocoders can support generating audio for both full unit sequences and reduced unit sequences (i.e. duplicating consecutive units removed). Set `--dur-prediction` for generating audio with reduced unit sequences.
|
79 |
+
```bash
|
80 |
+
# IN_CODE_FILE contains one unit sequence per line. Units are separated by space.
|
81 |
+
|
82 |
+
python examples/speech_to_speech/generate_waveform_from_code.py \
|
83 |
+
--in-code-file ${IN_CODE_FILE} \
|
84 |
+
--vocoder ${VOCODER_CKPT} --vocoder-cfg ${VOCODER_CFG} \
|
85 |
+
--results-path ${RESULTS_PATH} --dur-prediction
|
86 |
+
```
|
87 |
+
|
88 |
+
## Training new models
|
89 |
+
To be updated.
|
fairseq/examples/speech_to_speech/generate_waveform_from_code.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
import random
|
11 |
+
import soundfile as sf
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from fairseq import utils
|
17 |
+
from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
|
18 |
+
|
19 |
+
|
20 |
+
logging.basicConfig()
|
21 |
+
logging.root.setLevel(logging.INFO)
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def dump_result(args, sample_id, pred_wav, suffix=""):
|
27 |
+
sf.write(
|
28 |
+
f"{args.results_path}/{sample_id}{suffix}_pred.wav",
|
29 |
+
pred_wav.detach().cpu().numpy(),
|
30 |
+
16000,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def load_code(in_file):
|
35 |
+
with open(in_file) as f:
|
36 |
+
out = [list(map(int, line.strip().split())) for line in f]
|
37 |
+
return out
|
38 |
+
|
39 |
+
|
40 |
+
def main(args):
|
41 |
+
logger.info(args)
|
42 |
+
|
43 |
+
use_cuda = torch.cuda.is_available() and not args.cpu
|
44 |
+
|
45 |
+
with open(args.vocoder_cfg) as f:
|
46 |
+
vocoder_cfg = json.load(f)
|
47 |
+
vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg)
|
48 |
+
if use_cuda:
|
49 |
+
vocoder = vocoder.cuda()
|
50 |
+
|
51 |
+
multispkr = vocoder.model.multispkr
|
52 |
+
if multispkr:
|
53 |
+
logger.info("multi-speaker vocoder")
|
54 |
+
num_speakers = vocoder_cfg.get(
|
55 |
+
"num_speakers", 200
|
56 |
+
) # following the default in codehifigan to set to 200
|
57 |
+
assert (
|
58 |
+
args.speaker_id < num_speakers
|
59 |
+
), f"invalid --speaker-id ({args.speaker_id}) with total #speakers = {num_speakers}"
|
60 |
+
|
61 |
+
data = load_code(args.in_code_file)
|
62 |
+
Path(args.results_path).mkdir(exist_ok=True, parents=True)
|
63 |
+
for i, d in tqdm(enumerate(data), total=len(data)):
|
64 |
+
x = {
|
65 |
+
"code": torch.LongTensor(d).view(1, -1),
|
66 |
+
}
|
67 |
+
suffix = ""
|
68 |
+
if multispkr:
|
69 |
+
spk = (
|
70 |
+
random.randint(0, num_speakers - 1)
|
71 |
+
if args.speaker_id == -1
|
72 |
+
else args.speaker_id
|
73 |
+
)
|
74 |
+
suffix = f"_spk{spk}"
|
75 |
+
x["spkr"] = torch.LongTensor([spk]).view(1, 1)
|
76 |
+
|
77 |
+
x = utils.move_to_cuda(x) if use_cuda else x
|
78 |
+
wav = vocoder(x, args.dur_prediction)
|
79 |
+
dump_result(args, i, wav, suffix=suffix)
|
80 |
+
|
81 |
+
|
82 |
+
def cli_main():
|
83 |
+
parser = argparse.ArgumentParser()
|
84 |
+
parser.add_argument(
|
85 |
+
"--in-code-file", type=str, required=True, help="one unit sequence per line"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--vocoder", type=str, required=True, help="path to the CodeHiFiGAN vocoder"
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--vocoder-cfg",
|
92 |
+
type=str,
|
93 |
+
required=True,
|
94 |
+
help="path to the CodeHiFiGAN vocoder config",
|
95 |
+
)
|
96 |
+
parser.add_argument("--results-path", type=str, required=True)
|
97 |
+
parser.add_argument(
|
98 |
+
"--dur-prediction",
|
99 |
+
action="store_true",
|
100 |
+
help="enable duration prediction (for reduced/unique code sequences)",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--speaker-id",
|
104 |
+
type=int,
|
105 |
+
default=-1,
|
106 |
+
help="Speaker id (for vocoder that supports multispeaker). Set to -1 to randomly sample speakers.",
|
107 |
+
)
|
108 |
+
parser.add_argument("--cpu", action="store_true", help="run on CPU")
|
109 |
+
|
110 |
+
args = parser.parse_args()
|
111 |
+
|
112 |
+
main(args)
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
cli_main()
|
fairseq/examples/speech_to_speech/preprocessing/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
fairseq/examples/speech_to_speech/preprocessing/data_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
from examples.speech_to_text.data_utils import S2TDataConfigWriter
|
10 |
+
|
11 |
+
|
12 |
+
def gen_config_yaml(
|
13 |
+
manifest_root: Path,
|
14 |
+
yaml_filename: str = "config.yaml",
|
15 |
+
specaugment_policy: Optional[str] = "lb",
|
16 |
+
feature_transform: Optional[List[str]] = None,
|
17 |
+
input_channels: Optional[int] = 1,
|
18 |
+
input_feat_per_channel: Optional[int] = 80,
|
19 |
+
audio_root: str = "",
|
20 |
+
vocoder_type: Optional[str] = None,
|
21 |
+
vocoder_checkpoint: Optional[str] = None,
|
22 |
+
vocoder_cfg: Optional[str] = None,
|
23 |
+
extra=None,
|
24 |
+
):
|
25 |
+
manifest_root = manifest_root.absolute()
|
26 |
+
writer = S2TDataConfigWriter(manifest_root / yaml_filename)
|
27 |
+
|
28 |
+
if input_channels is not None:
|
29 |
+
writer.set_input_channels(input_channels)
|
30 |
+
if input_feat_per_channel is not None:
|
31 |
+
writer.set_input_feat_per_channel(input_feat_per_channel)
|
32 |
+
specaugment_setters = {
|
33 |
+
"lb": writer.set_specaugment_lb_policy,
|
34 |
+
"ld": writer.set_specaugment_ld_policy,
|
35 |
+
"sm": writer.set_specaugment_sm_policy,
|
36 |
+
"ss": writer.set_specaugment_ss_policy,
|
37 |
+
}
|
38 |
+
specaugment_setter = specaugment_setters.get(specaugment_policy, None)
|
39 |
+
if specaugment_setter is not None:
|
40 |
+
specaugment_setter()
|
41 |
+
|
42 |
+
if feature_transform is None:
|
43 |
+
feature_transform = []
|
44 |
+
else:
|
45 |
+
writer.set_feature_transforms("*", feature_transform)
|
46 |
+
|
47 |
+
if specaugment_policy is not None:
|
48 |
+
writer.set_feature_transforms("_train", feature_transform + ["specaugment"])
|
49 |
+
|
50 |
+
if len(audio_root) > 0:
|
51 |
+
writer.set_audio_root(audio_root)
|
52 |
+
|
53 |
+
if (
|
54 |
+
vocoder_type is not None
|
55 |
+
and vocoder_checkpoint is not None
|
56 |
+
and vocoder_cfg is not None
|
57 |
+
):
|
58 |
+
writer.set_extra(
|
59 |
+
{
|
60 |
+
"vocoder": {
|
61 |
+
"type": vocoder_type,
|
62 |
+
"config": vocoder_cfg,
|
63 |
+
"checkpoint": vocoder_checkpoint,
|
64 |
+
}
|
65 |
+
}
|
66 |
+
)
|
67 |
+
|
68 |
+
if extra is not None:
|
69 |
+
writer.set_extra(extra)
|
70 |
+
writer.flush()
|
71 |
+
|
72 |
+
|
73 |
+
def load_units(in_file):
|
74 |
+
out = {}
|
75 |
+
with open(in_file) as f:
|
76 |
+
for line in f:
|
77 |
+
sample_id, units = line.strip().split("|", 1)
|
78 |
+
out[sample_id] = units.split()
|
79 |
+
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
def process_units(units, reduce=False):
|
84 |
+
if not reduce:
|
85 |
+
return units
|
86 |
+
|
87 |
+
out = [u for i, u in enumerate(units) if i == 0 or u != units[i - 1]]
|
88 |
+
return out
|
fairseq/examples/speech_to_speech/preprocessing/prep_s2spect_data.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
from pathlib import Path
|
11 |
+
import shutil
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
import soundfile as sf
|
15 |
+
from tqdm import tqdm
|
16 |
+
import pandas as pd
|
17 |
+
|
18 |
+
from examples.speech_synthesis.data_utils import extract_logmel_spectrogram
|
19 |
+
from examples.speech_to_speech.preprocessing.data_utils import gen_config_yaml
|
20 |
+
from examples.speech_to_text.data_utils import create_zip, get_zip_manifest, save_df_to_tsv
|
21 |
+
from fairseq.data.audio.audio_utils import convert_waveform
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
MANIFEST_COLUMNS = ["id", "src_audio", "src_n_frames", "tgt_audio", "tgt_n_frames"]
|
27 |
+
|
28 |
+
|
29 |
+
def prepare_target_data(args, tgt_audios):
|
30 |
+
feature_name = "logmelspec80"
|
31 |
+
zip_path = args.output_root / f"{feature_name}.zip"
|
32 |
+
if zip_path.exists():
|
33 |
+
print(f"{zip_path} exists.")
|
34 |
+
return zip_path
|
35 |
+
|
36 |
+
feature_root = args.output_root / feature_name
|
37 |
+
feature_root.mkdir(exist_ok=True)
|
38 |
+
|
39 |
+
print("Extracting Mel spectrogram features...")
|
40 |
+
for tgt_audio in tqdm(tgt_audios):
|
41 |
+
sample_id = tgt_audio.stem
|
42 |
+
waveform, sample_rate = torchaudio.load(tgt_audio.as_posix())
|
43 |
+
waveform, sample_rate = convert_waveform(
|
44 |
+
waveform, sample_rate, normalize_volume=args.normalize_volume,
|
45 |
+
to_sample_rate=args.sample_rate
|
46 |
+
)
|
47 |
+
extract_logmel_spectrogram(
|
48 |
+
waveform, sample_rate, feature_root / f"{sample_id}.npy",
|
49 |
+
win_length=args.win_length, hop_length=args.hop_length,
|
50 |
+
n_fft=args.n_fft, n_mels=args.n_mels, f_min=args.f_min,
|
51 |
+
f_max=args.f_max
|
52 |
+
)
|
53 |
+
print("ZIPing features...")
|
54 |
+
create_zip(feature_root, zip_path)
|
55 |
+
shutil.rmtree(feature_root)
|
56 |
+
|
57 |
+
return zip_path
|
58 |
+
|
59 |
+
|
60 |
+
def process(args):
|
61 |
+
os.makedirs(args.output_root, exist_ok=True)
|
62 |
+
|
63 |
+
manifest = {}
|
64 |
+
tgt_audios = []
|
65 |
+
for split in args.data_split:
|
66 |
+
print(f"Processing {split}...")
|
67 |
+
|
68 |
+
manifest[split] = {c: [] for c in MANIFEST_COLUMNS}
|
69 |
+
missing_tgt_audios = []
|
70 |
+
src_audios = list(args.source_dir.glob(f"{split}/*.wav"))
|
71 |
+
for src_audio in tqdm(src_audios):
|
72 |
+
sample_id = src_audio.stem
|
73 |
+
|
74 |
+
tgt_audio = args.target_dir / split / f"{sample_id}.wav"
|
75 |
+
if not tgt_audio.is_file():
|
76 |
+
missing_tgt_audios.append(sample_id)
|
77 |
+
continue
|
78 |
+
|
79 |
+
tgt_audios.append(tgt_audio)
|
80 |
+
|
81 |
+
src_n_frames = sf.info(src_audio.as_posix()).frames
|
82 |
+
manifest[split]["id"].append(sample_id)
|
83 |
+
manifest[split]["src_audio"].append(src_audio.as_posix())
|
84 |
+
manifest[split]["src_n_frames"].append(
|
85 |
+
src_n_frames // 160
|
86 |
+
) # estimation of 10-ms frame for 16kHz audio
|
87 |
+
|
88 |
+
print(f"Processed {len(manifest[split]['id'])} samples")
|
89 |
+
if len(missing_tgt_audios) > 0:
|
90 |
+
print(
|
91 |
+
f"{len(missing_tgt_audios)} with missing target data (first 3 examples: {', '.join(missing_tgt_audios[:3])})"
|
92 |
+
)
|
93 |
+
|
94 |
+
# Extract features and pack features into ZIP
|
95 |
+
zip_path = prepare_target_data(args, tgt_audios)
|
96 |
+
|
97 |
+
print("Fetching ZIP manifest...")
|
98 |
+
tgt_audio_paths, tgt_audio_lengths = get_zip_manifest(zip_path)
|
99 |
+
|
100 |
+
print("Generating manifest...")
|
101 |
+
for split in args.data_split:
|
102 |
+
print(f"Processing {split}...")
|
103 |
+
|
104 |
+
for sample_id in tqdm(manifest[split]["id"]):
|
105 |
+
manifest[split]["tgt_audio"].append(tgt_audio_paths[sample_id])
|
106 |
+
manifest[split]["tgt_n_frames"].append(tgt_audio_lengths[sample_id])
|
107 |
+
|
108 |
+
out_manifest = args.output_root / f"{split}.tsv"
|
109 |
+
print(f"Writing manifest to {out_manifest}...")
|
110 |
+
save_df_to_tsv(pd.DataFrame.from_dict(manifest[split]), out_manifest)
|
111 |
+
|
112 |
+
# Generate config YAML
|
113 |
+
win_len_t = args.win_length / args.sample_rate
|
114 |
+
hop_len_t = args.hop_length / args.sample_rate
|
115 |
+
extra = {
|
116 |
+
"features": {
|
117 |
+
"type": "spectrogram+melscale+log",
|
118 |
+
"sample_rate": args.sample_rate,
|
119 |
+
"eps": 1e-5, "n_mels": args.n_mels, "n_fft": args.n_fft,
|
120 |
+
"window_fn": "hann", "win_length": args.win_length,
|
121 |
+
"hop_length": args.hop_length,
|
122 |
+
"win_len_t": win_len_t, "hop_len_t": hop_len_t,
|
123 |
+
"f_min": args.f_min, "f_max": args.f_max,
|
124 |
+
"n_stft": args.n_fft // 2 + 1
|
125 |
+
}
|
126 |
+
}
|
127 |
+
gen_config_yaml(
|
128 |
+
args.output_root,
|
129 |
+
audio_root=args.output_root.as_posix(),
|
130 |
+
specaugment_policy="lb",
|
131 |
+
feature_transform=["utterance_cmvn", "delta_deltas"],
|
132 |
+
extra=extra,
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
def main():
|
137 |
+
parser = argparse.ArgumentParser()
|
138 |
+
parser.add_argument(
|
139 |
+
"--source-dir", required=True, type=Path, help="source audio directory"
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--target-dir", required=True, type=Path, help="target audio directory"
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--data-split",
|
146 |
+
default=["train", "valid", "test"],
|
147 |
+
nargs="+",
|
148 |
+
help="data split names",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--output-root", required=True, type=Path, help="output directory"
|
152 |
+
)
|
153 |
+
# target feature related
|
154 |
+
parser.add_argument("--win-length", type=int, default=1024)
|
155 |
+
parser.add_argument("--hop-length", type=int, default=256)
|
156 |
+
parser.add_argument("--n-fft", type=int, default=1024)
|
157 |
+
parser.add_argument("--n-mels", type=int, default=80)
|
158 |
+
parser.add_argument("--f-min", type=int, default=20)
|
159 |
+
parser.add_argument("--f-max", type=int, default=8000)
|
160 |
+
parser.add_argument("--sample-rate", type=int, default=22050)
|
161 |
+
parser.add_argument("--normalize-volume", "-n", action="store_true")
|
162 |
+
|
163 |
+
args = parser.parse_args()
|
164 |
+
|
165 |
+
process(args)
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
main()
|
fairseq/examples/speech_to_speech/preprocessing/prep_s2ut_data.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import soundfile as sf
|
12 |
+
from tqdm import tqdm
|
13 |
+
import pandas as pd
|
14 |
+
|
15 |
+
from examples.speech_to_speech.preprocessing.data_utils import (
|
16 |
+
gen_config_yaml,
|
17 |
+
load_units,
|
18 |
+
process_units,
|
19 |
+
)
|
20 |
+
from examples.speech_to_text.data_utils import save_df_to_tsv
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
MANIFEST_COLUMNS = ["id", "src_audio", "src_n_frames", "tgt_audio", "tgt_n_frames"]
|
25 |
+
|
26 |
+
|
27 |
+
def process(args):
|
28 |
+
args.output_root.mkdir(exist_ok=True)
|
29 |
+
|
30 |
+
print("Generating manifest...")
|
31 |
+
for split in args.data_split:
|
32 |
+
print(f"Processing {split}")
|
33 |
+
|
34 |
+
# load target units
|
35 |
+
target_unit_data = load_units(args.target_dir / f"{split}.txt")
|
36 |
+
|
37 |
+
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
38 |
+
missing_tgt_audios = []
|
39 |
+
src_audios = list(args.source_dir.glob(f"{split}/*.wav"))
|
40 |
+
for src_audio in tqdm(src_audios):
|
41 |
+
sample_id = src_audio.stem
|
42 |
+
|
43 |
+
if sample_id not in target_unit_data:
|
44 |
+
missing_tgt_audios.append(sample_id)
|
45 |
+
continue
|
46 |
+
|
47 |
+
src_n_frames = sf.info(src_audio.as_posix()).frames
|
48 |
+
manifest["id"].append(sample_id)
|
49 |
+
manifest["src_audio"].append(src_audio.as_posix())
|
50 |
+
manifest["src_n_frames"].append(
|
51 |
+
src_n_frames // 160
|
52 |
+
) # estimation of 10-ms frame for 16kHz audio
|
53 |
+
|
54 |
+
target_units = process_units(target_unit_data[sample_id], args.reduce_unit)
|
55 |
+
manifest["tgt_audio"].append(" ".join(target_units))
|
56 |
+
manifest["tgt_n_frames"].append(len(target_units))
|
57 |
+
|
58 |
+
print(f"Processed {len(manifest['id'])} samples")
|
59 |
+
if len(missing_tgt_audios) > 0:
|
60 |
+
print(
|
61 |
+
f"{len(missing_tgt_audios)} with missing target data (first 3 examples: {', '.join(missing_tgt_audios[:3])})"
|
62 |
+
)
|
63 |
+
|
64 |
+
out_manifest = args.output_root / f"{split}.tsv"
|
65 |
+
print(f"Writing manifest to {out_manifest}...")
|
66 |
+
save_df_to_tsv(pd.DataFrame.from_dict(manifest), out_manifest)
|
67 |
+
|
68 |
+
# Generate config YAML
|
69 |
+
gen_config_yaml(
|
70 |
+
args.output_root,
|
71 |
+
specaugment_policy="lb",
|
72 |
+
feature_transform=["utterance_cmvn"],
|
73 |
+
vocoder_type="code_hifigan",
|
74 |
+
vocoder_checkpoint=args.vocoder_checkpoint,
|
75 |
+
vocoder_cfg=args.vocoder_cfg,
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
def main():
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument(
|
82 |
+
"--source-dir", required=True, type=Path, help="source audio directory"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--target-dir", required=True, type=Path, help="target audio directory"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--data-split",
|
89 |
+
default=["train", "valid", "test"],
|
90 |
+
nargs="+",
|
91 |
+
help="data split names",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--output-root", required=True, type=Path, help="output directory"
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--reduce-unit",
|
98 |
+
action="store_true",
|
99 |
+
help="reduce a target unit sequence to a unique unit sequence, i.e. '1 1 1 2 2' -> '1 2'",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--vocoder-checkpoint", default=None, type=str, help="vocoder checkpoint"
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--vocoder-cfg", default=None, type=str, help="vocoder config file"
|
106 |
+
)
|
107 |
+
|
108 |
+
args = parser.parse_args()
|
109 |
+
|
110 |
+
process(args)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
main()
|
fairseq/examples/speech_to_speech/preprocessing/prep_sn_data.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Adapted from examples/wav2vec/wav2vec_manifest.py
|
8 |
+
"""
|
9 |
+
Data preparation for the speech normalizer
|
10 |
+
"""
|
11 |
+
|
12 |
+
import argparse
|
13 |
+
import glob
|
14 |
+
import os
|
15 |
+
|
16 |
+
import soundfile
|
17 |
+
|
18 |
+
from examples.speech_to_speech.preprocessing.data_utils import load_units, process_units
|
19 |
+
|
20 |
+
|
21 |
+
def process(args):
|
22 |
+
assert (
|
23 |
+
args.for_inference or args.target_unit is not None
|
24 |
+
), "missing --target-unit or --for-inference"
|
25 |
+
|
26 |
+
if not os.path.exists(args.output_dir):
|
27 |
+
os.makedirs(args.output_dir)
|
28 |
+
|
29 |
+
dir_path = os.path.realpath(args.audio_dir)
|
30 |
+
search_path = os.path.join(dir_path, "**/*." + args.ext)
|
31 |
+
|
32 |
+
if args.target_unit:
|
33 |
+
unit_data = load_units(args.target_unit)
|
34 |
+
|
35 |
+
with open(os.path.join(args.output_dir, f"{args.data_name}.tsv"), "w") as o_t, open(
|
36 |
+
os.path.join(args.output_dir, f"{args.data_name}.unit"), "w"
|
37 |
+
) as o_u:
|
38 |
+
print(dir_path, file=o_t)
|
39 |
+
for fname in glob.iglob(search_path, recursive=True):
|
40 |
+
file_path = os.path.realpath(fname)
|
41 |
+
frames = soundfile.info(fname).frames
|
42 |
+
print(
|
43 |
+
"{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=o_t
|
44 |
+
)
|
45 |
+
|
46 |
+
if args.for_inference:
|
47 |
+
print("0", file=o_u)
|
48 |
+
else:
|
49 |
+
sample_id = os.path.basename(file_path)[: -len(args.ext) - 1]
|
50 |
+
assert (
|
51 |
+
sample_id in unit_data
|
52 |
+
), f'{fname} does not have unit data in {args.target_unit}. Expecting sample_id "{sample_id}".'
|
53 |
+
target_units = process_units(unit_data[sample_id], reduce=True)
|
54 |
+
print(" ".join(target_units), file=o_u)
|
55 |
+
|
56 |
+
|
57 |
+
def main():
|
58 |
+
parser = argparse.ArgumentParser()
|
59 |
+
parser.add_argument("--audio-dir", required=True, type=str, help="audio directory")
|
60 |
+
parser.add_argument("--ext", default="flac", type=str, help="audio extension")
|
61 |
+
parser.add_argument(
|
62 |
+
"--data-name",
|
63 |
+
required=True,
|
64 |
+
type=str,
|
65 |
+
help="dataset name",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--output-dir", required=True, type=str, help="output directory"
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--for-inference",
|
72 |
+
action="store_true",
|
73 |
+
help="set this if preparing data for running inference with a speech normalizer",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--target-unit",
|
77 |
+
default=None,
|
78 |
+
type=str,
|
79 |
+
help="a file containing unit sequences in the format: sample_id|u1 u2 ...",
|
80 |
+
)
|
81 |
+
|
82 |
+
args = parser.parse_args()
|
83 |
+
|
84 |
+
process(args)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
main()
|
fairseq/examples/speech_to_speech/preprocessing/prep_sn_output_data.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
def process(args):
|
14 |
+
args.output_root.mkdir(exist_ok=True)
|
15 |
+
|
16 |
+
# load units
|
17 |
+
units = {}
|
18 |
+
with open(args.in_unit) as f:
|
19 |
+
for line in f:
|
20 |
+
unit_seq, utt_id = line.strip().rsplit(" ", 1)
|
21 |
+
utt_id = int(utt_id[6:-1]) # remove "(None-"
|
22 |
+
units[utt_id] = unit_seq
|
23 |
+
|
24 |
+
with open(args.in_audio) as f, open(
|
25 |
+
args.output_root / f"{args.in_audio.stem}.txt", "w"
|
26 |
+
) as o:
|
27 |
+
f.readline()
|
28 |
+
for i, line in enumerate(tqdm(f.readlines())):
|
29 |
+
audio, _ = line.strip().split("\t", 1)
|
30 |
+
sample_id = Path(audio).stem
|
31 |
+
o.write(f"{sample_id}|{units[i]}\n")
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument(
|
37 |
+
"--in-unit",
|
38 |
+
required=True,
|
39 |
+
type=Path,
|
40 |
+
help="unit file (output from the speech normalizer)",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--in-audio",
|
44 |
+
required=True,
|
45 |
+
type=Path,
|
46 |
+
help="tsv file (input to the normalizer)",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--output-root", required=True, type=Path, help="output directory"
|
50 |
+
)
|
51 |
+
|
52 |
+
args = parser.parse_args()
|
53 |
+
|
54 |
+
process(args)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
main()
|
fairseq/examples/speech_to_speech/unity/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import sequence_generator # noqa
|
7 |
+
from . import sequence_generator_multi_decoder # noqa
|
fairseq/examples/speech_to_speech/unity/sequence_generator.py
ADDED
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import sys
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
|
13 |
+
from fairseq.sequence_generator import EnsembleModel as EnsembleModelBase
|
14 |
+
from fairseq.sequence_generator import SequenceGenerator as SequenceGeneratorBase
|
15 |
+
|
16 |
+
|
17 |
+
class SequenceGenerator(SequenceGeneratorBase):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
models,
|
21 |
+
tgt_dict,
|
22 |
+
beam_size=1,
|
23 |
+
max_len_a=0,
|
24 |
+
max_len_b=200,
|
25 |
+
max_len=0,
|
26 |
+
min_len=1,
|
27 |
+
normalize_scores=True,
|
28 |
+
len_penalty=1.0,
|
29 |
+
unk_penalty=0.0,
|
30 |
+
temperature=1.0,
|
31 |
+
match_source_len=False,
|
32 |
+
no_repeat_ngram_size=0,
|
33 |
+
search_strategy=None,
|
34 |
+
eos=None,
|
35 |
+
symbols_to_strip_from_output=None,
|
36 |
+
lm_model=None,
|
37 |
+
lm_weight=1.0,
|
38 |
+
tokens_to_suppress=(),
|
39 |
+
):
|
40 |
+
"""Generates translations of a given source sentence.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
44 |
+
currently support fairseq.models.TransformerModel for scripting
|
45 |
+
beam_size (int, optional): beam width (default: 1)
|
46 |
+
max_len_a/b (int, optional): generate sequences of maximum length
|
47 |
+
ax + b, where x is the source length
|
48 |
+
max_len (int, optional): the maximum length of the generated output
|
49 |
+
(not including end-of-sentence)
|
50 |
+
min_len (int, optional): the minimum length of the generated output
|
51 |
+
(not including end-of-sentence)
|
52 |
+
normalize_scores (bool, optional): normalize scores by the length
|
53 |
+
of the output (default: True)
|
54 |
+
len_penalty (float, optional): length penalty, where <1.0 favors
|
55 |
+
shorter, >1.0 favors longer sentences (default: 1.0)
|
56 |
+
unk_penalty (float, optional): unknown word penalty, where <0
|
57 |
+
produces more unks, >0 produces fewer (default: 0.0)
|
58 |
+
temperature (float, optional): temperature, where values
|
59 |
+
>1.0 produce more uniform samples and values <1.0 produce
|
60 |
+
sharper samples (default: 1.0)
|
61 |
+
match_source_len (bool, optional): outputs should match the source
|
62 |
+
length (default: False)
|
63 |
+
"""
|
64 |
+
super().__init__(
|
65 |
+
models=models,
|
66 |
+
tgt_dict=tgt_dict,
|
67 |
+
beam_size=beam_size,
|
68 |
+
max_len_a=max_len_a,
|
69 |
+
max_len_b=max_len_b,
|
70 |
+
max_len=max_len,
|
71 |
+
min_len=min_len,
|
72 |
+
normalize_scores=normalize_scores,
|
73 |
+
len_penalty=len_penalty,
|
74 |
+
unk_penalty=unk_penalty,
|
75 |
+
temperature=temperature,
|
76 |
+
match_source_len=match_source_len,
|
77 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
78 |
+
search_strategy=search_strategy,
|
79 |
+
eos=eos,
|
80 |
+
symbols_to_strip_from_output=symbols_to_strip_from_output,
|
81 |
+
lm_model=lm_model,
|
82 |
+
lm_weight=lm_weight,
|
83 |
+
tokens_to_suppress=tokens_to_suppress,
|
84 |
+
)
|
85 |
+
|
86 |
+
if isinstance(models, EnsembleModel):
|
87 |
+
self.model = models
|
88 |
+
else:
|
89 |
+
self.model = EnsembleModel(models)
|
90 |
+
|
91 |
+
self.model.set_decoder_beam_size(self.beam_size)
|
92 |
+
self.model.eval()
|
93 |
+
|
94 |
+
def _generate(
|
95 |
+
self,
|
96 |
+
sample: Dict[str, Dict[str, Tensor]],
|
97 |
+
prefix_tokens: Optional[Tensor] = None,
|
98 |
+
constraints: Optional[Tensor] = None,
|
99 |
+
bos_token: Optional[int] = None,
|
100 |
+
):
|
101 |
+
net_input = sample["net_input"]
|
102 |
+
|
103 |
+
if "src_tokens" in net_input:
|
104 |
+
src_tokens = net_input["src_tokens"]
|
105 |
+
# length of the source text being the character length except EndOfSentence and pad
|
106 |
+
# if src_lengths exists in net_input (speech_to_text dataset case), then use it
|
107 |
+
if "src_lengths" in net_input:
|
108 |
+
src_lengths = net_input["src_lengths"]
|
109 |
+
else:
|
110 |
+
src_lengths = (
|
111 |
+
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad))
|
112 |
+
.long()
|
113 |
+
.sum(dim=1)
|
114 |
+
)
|
115 |
+
elif "source" in net_input:
|
116 |
+
src_tokens = net_input["source"]
|
117 |
+
src_lengths = (
|
118 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
119 |
+
if net_input["padding_mask"] is not None
|
120 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
121 |
+
)
|
122 |
+
elif "features" in net_input:
|
123 |
+
src_tokens = net_input["features"]
|
124 |
+
src_lengths = (
|
125 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
126 |
+
if net_input["padding_mask"] is not None
|
127 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
raise Exception(
|
131 |
+
"expected src_tokens or source in net input. input keys: "
|
132 |
+
+ str(net_input.keys())
|
133 |
+
)
|
134 |
+
|
135 |
+
if constraints is not None and not self.search.supports_constraints:
|
136 |
+
raise NotImplementedError(
|
137 |
+
"Target-side constraints were provided, but search method doesn't support them"
|
138 |
+
)
|
139 |
+
|
140 |
+
# Initialize constraints, when active
|
141 |
+
self.search.init_constraints(constraints, self.beam_size)
|
142 |
+
|
143 |
+
# compute the encoder output for each beam
|
144 |
+
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
|
145 |
+
encoder_outs = self.model.forward_encoder(net_input)
|
146 |
+
|
147 |
+
finalized = self.generate_decoder(
|
148 |
+
encoder_outs,
|
149 |
+
src_tokens,
|
150 |
+
src_lengths,
|
151 |
+
sample,
|
152 |
+
prefix_tokens,
|
153 |
+
constraints,
|
154 |
+
bos_token,
|
155 |
+
)
|
156 |
+
return finalized
|
157 |
+
|
158 |
+
def generate_decoder(
|
159 |
+
self,
|
160 |
+
encoder_outs,
|
161 |
+
src_tokens,
|
162 |
+
src_lengths,
|
163 |
+
sample: Dict[str, Dict[str, Tensor]],
|
164 |
+
prefix_tokens: Optional[Tensor] = None,
|
165 |
+
constraints: Optional[Tensor] = None,
|
166 |
+
bos_token: Optional[int] = None,
|
167 |
+
aux_task_name="",
|
168 |
+
encoder_outs_aug: Optional[
|
169 |
+
Tensor
|
170 |
+
] = None, # an additional/augmented encoder_outs
|
171 |
+
):
|
172 |
+
incremental_states = torch.jit.annotate(
|
173 |
+
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
174 |
+
[
|
175 |
+
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
176 |
+
for i in range(self.model.models_size)
|
177 |
+
],
|
178 |
+
)
|
179 |
+
|
180 |
+
# bsz: total number of sentences in beam
|
181 |
+
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
|
182 |
+
bsz, src_len = src_tokens.size()[:2]
|
183 |
+
beam_size = self.beam_size
|
184 |
+
|
185 |
+
decoder_name = f"{aux_task_name}_decoder" if aux_task_name else "decoder"
|
186 |
+
|
187 |
+
max_len: int = -1
|
188 |
+
if self.match_source_len:
|
189 |
+
max_len = src_lengths.max().item()
|
190 |
+
else:
|
191 |
+
max_len = min(
|
192 |
+
int(self.max_len_a * src_len + self.max_len_b),
|
193 |
+
self.max_len - 1,
|
194 |
+
)
|
195 |
+
assert (
|
196 |
+
self.min_len <= max_len
|
197 |
+
), "min_len cannot be larger than max_len, please adjust these!"
|
198 |
+
|
199 |
+
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
|
200 |
+
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
|
201 |
+
new_order = new_order.to(src_tokens.device).long()
|
202 |
+
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
|
203 |
+
# ensure encoder_outs is a List.
|
204 |
+
assert encoder_outs is not None
|
205 |
+
if encoder_outs_aug is not None:
|
206 |
+
encoder_outs_aug = self.model.reorder_encoder_out(
|
207 |
+
encoder_outs_aug, new_order
|
208 |
+
)
|
209 |
+
|
210 |
+
# initialize buffers
|
211 |
+
scores = (
|
212 |
+
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
|
213 |
+
) # +1 for eos; pad is never chosen for scoring
|
214 |
+
tokens = (
|
215 |
+
torch.zeros(bsz * beam_size, max_len + 2)
|
216 |
+
.to(src_tokens)
|
217 |
+
.long()
|
218 |
+
.fill_(self.pad)
|
219 |
+
) # +2 for eos and pad
|
220 |
+
tokens[:, 0] = self.eos if bos_token is None else bos_token
|
221 |
+
attn: Optional[Tensor] = None
|
222 |
+
|
223 |
+
# A list that indicates candidates that should be ignored.
|
224 |
+
# For example, suppose we're sampling and have already finalized 2/5
|
225 |
+
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
|
226 |
+
# so that we only finalize the remaining 3 samples.
|
227 |
+
cands_to_ignore = (
|
228 |
+
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
|
229 |
+
) # forward and backward-compatible False mask
|
230 |
+
|
231 |
+
# list of completed sentences
|
232 |
+
finalized = torch.jit.annotate(
|
233 |
+
List[List[Dict[str, Tensor]]],
|
234 |
+
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
|
235 |
+
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
|
236 |
+
|
237 |
+
# a boolean array indicating if the sentence at the index is finished or not
|
238 |
+
finished = [False for i in range(bsz)]
|
239 |
+
num_remaining_sent = bsz # number of sentences remaining
|
240 |
+
|
241 |
+
# number of candidate hypos per step
|
242 |
+
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
|
243 |
+
|
244 |
+
# offset arrays for converting between different indexing schemes
|
245 |
+
bbsz_offsets = (
|
246 |
+
(torch.arange(0, bsz) * beam_size)
|
247 |
+
.unsqueeze(1)
|
248 |
+
.type_as(tokens)
|
249 |
+
.to(src_tokens.device)
|
250 |
+
)
|
251 |
+
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
|
252 |
+
|
253 |
+
reorder_state: Optional[Tensor] = None
|
254 |
+
batch_idxs: Optional[Tensor] = None
|
255 |
+
|
256 |
+
original_batch_idxs: Optional[Tensor] = None
|
257 |
+
if "id" in sample and isinstance(sample["id"], Tensor):
|
258 |
+
original_batch_idxs = sample["id"]
|
259 |
+
else:
|
260 |
+
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
|
261 |
+
|
262 |
+
for step in range(max_len + 1): # one extra step for EOS marker
|
263 |
+
# reorder decoder internal states based on the prev choice of beams
|
264 |
+
if reorder_state is not None:
|
265 |
+
if batch_idxs is not None:
|
266 |
+
# update beam indices to take into account removed sentences
|
267 |
+
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
|
268 |
+
batch_idxs
|
269 |
+
)
|
270 |
+
reorder_state.view(-1, beam_size).add_(
|
271 |
+
corr.unsqueeze(-1) * beam_size
|
272 |
+
)
|
273 |
+
original_batch_idxs = original_batch_idxs[batch_idxs]
|
274 |
+
self.model.reorder_incremental_state(
|
275 |
+
incremental_states, reorder_state, decoder_name
|
276 |
+
)
|
277 |
+
encoder_outs = self.model.reorder_encoder_out(
|
278 |
+
encoder_outs, reorder_state
|
279 |
+
)
|
280 |
+
if encoder_outs_aug is not None:
|
281 |
+
encoder_outs_aug = self.model.reorder_encoder_out(
|
282 |
+
encoder_outs_aug, reorder_state
|
283 |
+
)
|
284 |
+
with torch.autograd.profiler.record_function(
|
285 |
+
"EnsembleModel: forward_decoder"
|
286 |
+
):
|
287 |
+
lprobs, avg_attn_scores = self.model.forward_decoder(
|
288 |
+
tokens[:, : step + 1],
|
289 |
+
encoder_outs,
|
290 |
+
incremental_states,
|
291 |
+
self.temperature,
|
292 |
+
decoder_name=decoder_name,
|
293 |
+
encoder_outs_aug=encoder_outs_aug,
|
294 |
+
)
|
295 |
+
|
296 |
+
if self.lm_model is not None and not aux_task_name:
|
297 |
+
lm_out = self.lm_model(tokens[:, : step + 1])
|
298 |
+
probs = self.lm_model.get_normalized_probs(
|
299 |
+
lm_out, log_probs=True, sample=None
|
300 |
+
)
|
301 |
+
probs = probs[:, -1, :] * self.lm_weight
|
302 |
+
lprobs += probs
|
303 |
+
|
304 |
+
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
|
305 |
+
|
306 |
+
lprobs[:, self.pad] = -math.inf # never select pad
|
307 |
+
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
|
308 |
+
|
309 |
+
# handle max length constraint
|
310 |
+
if step >= max_len:
|
311 |
+
lprobs[:, : self.eos] = -math.inf
|
312 |
+
lprobs[:, self.eos + 1 :] = -math.inf
|
313 |
+
|
314 |
+
# handle prefix tokens (possibly with different lengths)
|
315 |
+
if (
|
316 |
+
prefix_tokens is not None
|
317 |
+
and step < prefix_tokens.size(1)
|
318 |
+
and step < max_len
|
319 |
+
):
|
320 |
+
lprobs, tokens, scores = self._prefix_tokens(
|
321 |
+
step, lprobs, scores, tokens, prefix_tokens, beam_size
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
if step < self.min_len:
|
325 |
+
# minimum length constraint (does not apply if using prefix_tokens)
|
326 |
+
lprobs[:, self.eos] = -math.inf
|
327 |
+
|
328 |
+
if self.token_indices_to_suppress is not None:
|
329 |
+
lprobs[:, self.token_indices_to_suppress] = -math.inf
|
330 |
+
|
331 |
+
# Record attention scores, only support avg_attn_scores is a Tensor
|
332 |
+
if avg_attn_scores is not None:
|
333 |
+
if attn is None:
|
334 |
+
attn = torch.empty(
|
335 |
+
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
|
336 |
+
).to(scores)
|
337 |
+
attn[:, :, step + 1].copy_(avg_attn_scores)
|
338 |
+
|
339 |
+
scores = scores.type_as(lprobs)
|
340 |
+
eos_bbsz_idx = torch.empty(0).to(
|
341 |
+
tokens
|
342 |
+
) # indices of hypothesis ending with eos (finished sentences)
|
343 |
+
eos_scores = torch.empty(0).to(
|
344 |
+
scores
|
345 |
+
) # scores of hypothesis ending with eos (finished sentences)
|
346 |
+
|
347 |
+
if self.should_set_src_lengths:
|
348 |
+
self.search.set_src_lengths(src_lengths)
|
349 |
+
|
350 |
+
if self.repeat_ngram_blocker is not None:
|
351 |
+
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
|
352 |
+
|
353 |
+
# Shape: (batch, cand_size)
|
354 |
+
cand_scores, cand_indices, cand_beams = self.search.step(
|
355 |
+
step,
|
356 |
+
lprobs.view(bsz, -1, self.vocab_size),
|
357 |
+
scores.view(bsz, beam_size, -1)[:, :, :step],
|
358 |
+
tokens[:, : step + 1],
|
359 |
+
original_batch_idxs,
|
360 |
+
)
|
361 |
+
|
362 |
+
# cand_bbsz_idx contains beam indices for the top candidate
|
363 |
+
# hypotheses, with a range of values: [0, bsz*beam_size),
|
364 |
+
# and dimensions: [bsz, cand_size]
|
365 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
366 |
+
|
367 |
+
# finalize hypotheses that end in eos
|
368 |
+
# Shape of eos_mask: (batch size, beam size)
|
369 |
+
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
|
370 |
+
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
|
371 |
+
|
372 |
+
# only consider eos when it's among the top beam_size indices
|
373 |
+
# Now we know what beam item(s) to finish
|
374 |
+
# Shape: 1d list of absolute-numbered
|
375 |
+
eos_bbsz_idx = torch.masked_select(
|
376 |
+
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
|
377 |
+
)
|
378 |
+
|
379 |
+
finalized_sents: List[int] = []
|
380 |
+
if eos_bbsz_idx.numel() > 0:
|
381 |
+
eos_scores = torch.masked_select(
|
382 |
+
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
383 |
+
)
|
384 |
+
|
385 |
+
finalized_sents = self.finalize_hypos(
|
386 |
+
step,
|
387 |
+
eos_bbsz_idx,
|
388 |
+
eos_scores,
|
389 |
+
tokens,
|
390 |
+
scores,
|
391 |
+
finalized,
|
392 |
+
finished,
|
393 |
+
beam_size,
|
394 |
+
attn,
|
395 |
+
src_lengths,
|
396 |
+
max_len,
|
397 |
+
)
|
398 |
+
num_remaining_sent -= len(finalized_sents)
|
399 |
+
|
400 |
+
assert num_remaining_sent >= 0
|
401 |
+
if num_remaining_sent == 0:
|
402 |
+
break
|
403 |
+
if self.search.stop_on_max_len and step >= max_len:
|
404 |
+
break
|
405 |
+
assert step < max_len, f"{step} < {max_len}"
|
406 |
+
|
407 |
+
# Remove finalized sentences (ones for which {beam_size}
|
408 |
+
# finished hypotheses have been generated) from the batch.
|
409 |
+
if len(finalized_sents) > 0:
|
410 |
+
new_bsz = bsz - len(finalized_sents)
|
411 |
+
|
412 |
+
# construct batch_idxs which holds indices of batches to keep for the next pass
|
413 |
+
batch_mask = torch.ones(
|
414 |
+
bsz, dtype=torch.bool, device=cand_indices.device
|
415 |
+
)
|
416 |
+
batch_mask[finalized_sents] = False
|
417 |
+
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
|
418 |
+
batch_idxs = torch.arange(
|
419 |
+
bsz, device=cand_indices.device
|
420 |
+
).masked_select(batch_mask)
|
421 |
+
|
422 |
+
# Choose the subset of the hypothesized constraints that will continue
|
423 |
+
self.search.prune_sentences(batch_idxs)
|
424 |
+
|
425 |
+
eos_mask = eos_mask[batch_idxs]
|
426 |
+
cand_beams = cand_beams[batch_idxs]
|
427 |
+
bbsz_offsets.resize_(new_bsz, 1)
|
428 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
429 |
+
cand_scores = cand_scores[batch_idxs]
|
430 |
+
cand_indices = cand_indices[batch_idxs]
|
431 |
+
|
432 |
+
if prefix_tokens is not None:
|
433 |
+
prefix_tokens = prefix_tokens[batch_idxs]
|
434 |
+
src_lengths = src_lengths[batch_idxs]
|
435 |
+
cands_to_ignore = cands_to_ignore[batch_idxs]
|
436 |
+
|
437 |
+
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
438 |
+
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
439 |
+
if attn is not None:
|
440 |
+
attn = attn.view(bsz, -1)[batch_idxs].view(
|
441 |
+
new_bsz * beam_size, attn.size(1), -1
|
442 |
+
)
|
443 |
+
bsz = new_bsz
|
444 |
+
else:
|
445 |
+
batch_idxs = None
|
446 |
+
|
447 |
+
# Set active_mask so that values > cand_size indicate eos hypos
|
448 |
+
# and values < cand_size indicate candidate active hypos.
|
449 |
+
# After, the min values per row are the top candidate active hypos
|
450 |
+
|
451 |
+
# Rewrite the operator since the element wise or is not supported in torchscript.
|
452 |
+
|
453 |
+
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
|
454 |
+
active_mask = torch.add(
|
455 |
+
eos_mask.type_as(cand_offsets) * cand_size,
|
456 |
+
cand_offsets[: eos_mask.size(1)],
|
457 |
+
)
|
458 |
+
|
459 |
+
# get the top beam_size active hypotheses, which are just
|
460 |
+
# the hypos with the smallest values in active_mask.
|
461 |
+
# {active_hypos} indicates which {beam_size} hypotheses
|
462 |
+
# from the list of {2 * beam_size} candidates were
|
463 |
+
# selected. Shapes: (batch size, beam size)
|
464 |
+
new_cands_to_ignore, active_hypos = torch.topk(
|
465 |
+
active_mask, k=beam_size, dim=1, largest=False
|
466 |
+
)
|
467 |
+
|
468 |
+
# update cands_to_ignore to ignore any finalized hypos.
|
469 |
+
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
|
470 |
+
# Make sure there is at least one active item for each sentence in the batch.
|
471 |
+
assert (~cands_to_ignore).any(dim=1).all()
|
472 |
+
|
473 |
+
# update cands_to_ignore to ignore any finalized hypos
|
474 |
+
|
475 |
+
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
|
476 |
+
# can be selected more than once).
|
477 |
+
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
|
478 |
+
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
|
479 |
+
|
480 |
+
active_bbsz_idx = active_bbsz_idx.view(-1)
|
481 |
+
active_scores = active_scores.view(-1)
|
482 |
+
|
483 |
+
# copy tokens and scores for active hypotheses
|
484 |
+
|
485 |
+
# Set the tokens for each beam (can select the same row more than once)
|
486 |
+
tokens[:, : step + 1] = torch.index_select(
|
487 |
+
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
|
488 |
+
)
|
489 |
+
# Select the next token for each of them
|
490 |
+
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
|
491 |
+
cand_indices, dim=1, index=active_hypos
|
492 |
+
)
|
493 |
+
if step > 0:
|
494 |
+
scores[:, :step] = torch.index_select(
|
495 |
+
scores[:, :step], dim=0, index=active_bbsz_idx
|
496 |
+
)
|
497 |
+
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
|
498 |
+
cand_scores, dim=1, index=active_hypos
|
499 |
+
)
|
500 |
+
|
501 |
+
# Update constraints based on which candidates were selected for the next beam
|
502 |
+
self.search.update_constraints(active_hypos)
|
503 |
+
|
504 |
+
# copy attention for active hypotheses
|
505 |
+
if attn is not None:
|
506 |
+
attn[:, :, : step + 2] = torch.index_select(
|
507 |
+
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
|
508 |
+
)
|
509 |
+
|
510 |
+
# reorder incremental state in decoder
|
511 |
+
reorder_state = active_bbsz_idx
|
512 |
+
|
513 |
+
# sort by score descending
|
514 |
+
for sent in range(len(finalized)):
|
515 |
+
scores = torch.tensor(
|
516 |
+
[float(elem["score"].item()) for elem in finalized[sent]]
|
517 |
+
)
|
518 |
+
_, sorted_scores_indices = torch.sort(scores, descending=True)
|
519 |
+
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
|
520 |
+
finalized[sent] = torch.jit.annotate(
|
521 |
+
List[Dict[str, Tensor]], finalized[sent]
|
522 |
+
)
|
523 |
+
return finalized
|
524 |
+
|
525 |
+
|
526 |
+
class EnsembleModel(EnsembleModelBase):
|
527 |
+
"""A wrapper around an ensemble of models."""
|
528 |
+
|
529 |
+
def __init__(self, models):
|
530 |
+
super().__init__(models)
|
531 |
+
|
532 |
+
@torch.jit.export
|
533 |
+
def forward_decoder(
|
534 |
+
self,
|
535 |
+
tokens,
|
536 |
+
encoder_outs: List[Dict[str, List[Tensor]]],
|
537 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
538 |
+
temperature: float = 1.0,
|
539 |
+
decoder_name="decoder",
|
540 |
+
encoder_outs_aug: List[Dict[str, List[Tensor]]] = None,
|
541 |
+
):
|
542 |
+
log_probs = []
|
543 |
+
avg_attn: Optional[Tensor] = None
|
544 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None
|
545 |
+
encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None
|
546 |
+
for i, model in enumerate(self.models):
|
547 |
+
if self.has_encoder():
|
548 |
+
encoder_out = encoder_outs[i]
|
549 |
+
if encoder_outs_aug is not None:
|
550 |
+
encoder_out_aug = encoder_outs_aug[i]
|
551 |
+
# decode each model
|
552 |
+
if self.has_incremental_states():
|
553 |
+
if encoder_out_aug is not None:
|
554 |
+
decoder_out = getattr(model, decoder_name).forward(
|
555 |
+
tokens,
|
556 |
+
encoder_out=encoder_out,
|
557 |
+
encoder_out_aug=encoder_out_aug,
|
558 |
+
incremental_state=incremental_states[i],
|
559 |
+
)
|
560 |
+
else:
|
561 |
+
decoder_out = getattr(model, decoder_name).forward(
|
562 |
+
tokens,
|
563 |
+
encoder_out=encoder_out,
|
564 |
+
incremental_state=incremental_states[i],
|
565 |
+
)
|
566 |
+
else:
|
567 |
+
if hasattr(model, decoder_name):
|
568 |
+
decoder_out = getattr(model, decoder_name).forward(
|
569 |
+
tokens, encoder_out=encoder_out
|
570 |
+
)
|
571 |
+
else:
|
572 |
+
decoder_out = model.forward(tokens)
|
573 |
+
|
574 |
+
attn: Optional[Tensor] = None
|
575 |
+
decoder_len = len(decoder_out)
|
576 |
+
if decoder_len > 1 and decoder_out[1] is not None:
|
577 |
+
if isinstance(decoder_out[1], Tensor):
|
578 |
+
attn = decoder_out[1]
|
579 |
+
else:
|
580 |
+
attn_holder = decoder_out[1]["attn"]
|
581 |
+
if isinstance(attn_holder, Tensor):
|
582 |
+
attn = attn_holder
|
583 |
+
elif attn_holder is not None:
|
584 |
+
attn = attn_holder[0]
|
585 |
+
if attn is not None:
|
586 |
+
attn = attn[:, -1, :]
|
587 |
+
|
588 |
+
decoder_out_tuple = (
|
589 |
+
decoder_out[0][:, -1:, :].div_(temperature),
|
590 |
+
None if decoder_len <= 1 else decoder_out[1],
|
591 |
+
)
|
592 |
+
probs = getattr(model, decoder_name).get_normalized_probs(
|
593 |
+
decoder_out_tuple, log_probs=True, sample=None
|
594 |
+
)
|
595 |
+
probs = probs[:, -1, :]
|
596 |
+
if self.models_size == 1:
|
597 |
+
return probs, attn
|
598 |
+
|
599 |
+
log_probs.append(probs)
|
600 |
+
if attn is not None:
|
601 |
+
if avg_attn is None:
|
602 |
+
avg_attn = attn
|
603 |
+
else:
|
604 |
+
avg_attn.add_(attn)
|
605 |
+
|
606 |
+
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
|
607 |
+
self.models_size
|
608 |
+
)
|
609 |
+
|
610 |
+
if avg_attn is not None:
|
611 |
+
avg_attn.div_(self.models_size)
|
612 |
+
return avg_probs, avg_attn
|
613 |
+
|
614 |
+
@torch.jit.export
|
615 |
+
def reorder_incremental_state(
|
616 |
+
self,
|
617 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
618 |
+
new_order,
|
619 |
+
decoder_name="decoder",
|
620 |
+
):
|
621 |
+
if not self.has_incremental_states():
|
622 |
+
return
|
623 |
+
for i, model in enumerate(self.models):
|
624 |
+
getattr(model, decoder_name).reorder_incremental_state_scripting(
|
625 |
+
incremental_states[i], new_order
|
626 |
+
)
|
fairseq/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Dict, List, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
from fairseq import search
|
13 |
+
|
14 |
+
|
15 |
+
class MultiDecoderSequenceGenerator(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
models,
|
19 |
+
tgt_dict,
|
20 |
+
tgt_dict_mt,
|
21 |
+
beam_size=1,
|
22 |
+
beam_size_mt=1,
|
23 |
+
max_len_a=0,
|
24 |
+
max_len_b=200,
|
25 |
+
max_len_a_mt=0,
|
26 |
+
max_len_b_mt=200,
|
27 |
+
max_len=0,
|
28 |
+
min_len=1,
|
29 |
+
normalize_scores=True,
|
30 |
+
len_penalty=1.0,
|
31 |
+
len_penalty_mt=1.0,
|
32 |
+
unk_penalty=0.0,
|
33 |
+
temperature=1.0,
|
34 |
+
match_source_len=False,
|
35 |
+
no_repeat_ngram_size=0,
|
36 |
+
eos=None,
|
37 |
+
eos_mt=None,
|
38 |
+
symbols_to_strip_from_output=None,
|
39 |
+
lm_model=None,
|
40 |
+
lm_weight=1.0,
|
41 |
+
):
|
42 |
+
"""Generates translations of a given source sentence.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
46 |
+
currently support fairseq.models.TransformerModel for scripting
|
47 |
+
beam_size (int, optional): beam width (default: 1)
|
48 |
+
max_len_a/b (int, optional): generate sequences of maximum length
|
49 |
+
ax + b, where x is the source length for the second pass
|
50 |
+
max_len_a_mt/b_mt (int, optional): generate sequences of maximum length
|
51 |
+
ax + b, where x is the source length for the first pass
|
52 |
+
max_len (int, optional): the maximum length of the generated output
|
53 |
+
(not including end-of-sentence)
|
54 |
+
min_len (int, optional): the minimum length of the generated output
|
55 |
+
(not including end-of-sentence)
|
56 |
+
normalize_scores (bool, optional): normalize scores by the length
|
57 |
+
of the output (default: True)
|
58 |
+
len_penalty (float, optional): length penalty in the second pass, where <1.0 favors
|
59 |
+
shorter, >1.0 favors longer sentences (default: 1.0)
|
60 |
+
len_penalty (float, optional): length penalty in the first pass, where <1.0 favors
|
61 |
+
shorter, >1.0 favors longer sentences (default: 1.0)
|
62 |
+
unk_penalty (float, optional): unknown word penalty, where <0
|
63 |
+
produces more unks, >0 produces fewer (default: 0.0)
|
64 |
+
temperature (float, optional): temperature, where values
|
65 |
+
>1.0 produce more uniform samples and values <1.0 produce
|
66 |
+
sharper samples (default: 1.0)
|
67 |
+
match_source_len (bool, optional): outputs should match the source
|
68 |
+
length (default: False)
|
69 |
+
"""
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator
|
73 |
+
|
74 |
+
self.generator = SequenceGenerator(
|
75 |
+
models,
|
76 |
+
tgt_dict,
|
77 |
+
beam_size=beam_size,
|
78 |
+
max_len_a=max_len_a,
|
79 |
+
max_len_b=max_len_b,
|
80 |
+
max_len=max_len,
|
81 |
+
min_len=min_len,
|
82 |
+
normalize_scores=normalize_scores,
|
83 |
+
len_penalty=len_penalty,
|
84 |
+
unk_penalty=unk_penalty,
|
85 |
+
temperature=temperature,
|
86 |
+
match_source_len=match_source_len,
|
87 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
88 |
+
search_strategy=search.BeamSearch(tgt_dict),
|
89 |
+
eos=eos,
|
90 |
+
symbols_to_strip_from_output=symbols_to_strip_from_output,
|
91 |
+
lm_model=lm_model,
|
92 |
+
lm_weight=lm_weight,
|
93 |
+
)
|
94 |
+
self.eos = self.generator.eos
|
95 |
+
|
96 |
+
self.generator_mt = SequenceGenerator(
|
97 |
+
models,
|
98 |
+
tgt_dict_mt,
|
99 |
+
beam_size=beam_size_mt,
|
100 |
+
max_len_a=max_len_a_mt,
|
101 |
+
max_len_b=max_len_b_mt,
|
102 |
+
max_len=max_len,
|
103 |
+
min_len=min_len,
|
104 |
+
normalize_scores=normalize_scores,
|
105 |
+
len_penalty=len_penalty_mt,
|
106 |
+
unk_penalty=unk_penalty,
|
107 |
+
temperature=temperature,
|
108 |
+
match_source_len=match_source_len,
|
109 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
110 |
+
search_strategy=search.BeamSearch(tgt_dict_mt),
|
111 |
+
eos=eos_mt,
|
112 |
+
symbols_to_strip_from_output=symbols_to_strip_from_output,
|
113 |
+
)
|
114 |
+
|
115 |
+
@torch.no_grad()
|
116 |
+
def generate(
|
117 |
+
self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
|
118 |
+
) -> List[List[Dict[str, Tensor]]]:
|
119 |
+
"""Generate translations. Match the api of other fairseq generators.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
123 |
+
sample (dict): batch
|
124 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
125 |
+
with these tokens
|
126 |
+
constraints (torch.LongTensor, optional): force decoder to include
|
127 |
+
the list of constraints
|
128 |
+
bos_token (int, optional): beginning of sentence token
|
129 |
+
(default: self.eos)
|
130 |
+
"""
|
131 |
+
return self._generate(sample, **kwargs)
|
132 |
+
|
133 |
+
def _generate(
|
134 |
+
self,
|
135 |
+
sample: Dict[str, Dict[str, Tensor]],
|
136 |
+
prefix_tokens: Optional[Tensor] = None,
|
137 |
+
constraints: Optional[Tensor] = None,
|
138 |
+
bos_token: Optional[int] = None,
|
139 |
+
):
|
140 |
+
net_input = sample["net_input"]
|
141 |
+
|
142 |
+
if "src_tokens" in net_input:
|
143 |
+
src_tokens = net_input["src_tokens"]
|
144 |
+
# length of the source text being the character length except EndOfSentence and pad
|
145 |
+
# if src_lengths exists in net_input (speech_to_text dataset case), then use it
|
146 |
+
if "src_lengths" in net_input:
|
147 |
+
src_lengths = net_input["src_lengths"]
|
148 |
+
else:
|
149 |
+
src_lengths = (
|
150 |
+
(
|
151 |
+
src_tokens.ne(self.generator.eos)
|
152 |
+
& src_tokens.ne(self.generator.pad)
|
153 |
+
)
|
154 |
+
.long()
|
155 |
+
.sum(dim=1)
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
raise Exception(
|
159 |
+
"expected src_tokens or source in net input. input keys: "
|
160 |
+
+ str(net_input.keys())
|
161 |
+
)
|
162 |
+
|
163 |
+
if constraints is not None and not self.generator.search.supports_constraints:
|
164 |
+
raise NotImplementedError(
|
165 |
+
"Target-side constraints were provided, but search method doesn't support them"
|
166 |
+
)
|
167 |
+
|
168 |
+
# Initialize constraints, when active
|
169 |
+
self.generator.search.init_constraints(constraints, self.generator.beam_size)
|
170 |
+
self.generator_mt.search.init_constraints(
|
171 |
+
constraints, self.generator_mt.beam_size
|
172 |
+
)
|
173 |
+
|
174 |
+
# compute the encoder output for each beam
|
175 |
+
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
|
176 |
+
encoder_outs = self.generator.model.forward_encoder(net_input)
|
177 |
+
|
178 |
+
single_model = self.generator.model.single_model
|
179 |
+
mt_decoder = getattr(single_model, f"{single_model.mt_task_name}_decoder")
|
180 |
+
|
181 |
+
# 1. MT decoder
|
182 |
+
finalized_mt = self.generator_mt.generate_decoder(
|
183 |
+
encoder_outs,
|
184 |
+
src_tokens,
|
185 |
+
src_lengths,
|
186 |
+
sample,
|
187 |
+
prefix_tokens,
|
188 |
+
constraints,
|
189 |
+
bos_token,
|
190 |
+
aux_task_name=single_model.mt_task_name,
|
191 |
+
)
|
192 |
+
|
193 |
+
# extract decoder output corresponding to the best hypothesis
|
194 |
+
max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt])
|
195 |
+
prev_output_tokens_mt = (
|
196 |
+
src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len)
|
197 |
+
.fill_(mt_decoder.padding_idx)
|
198 |
+
.int()
|
199 |
+
) # B x T
|
200 |
+
for i, hypo in enumerate(finalized_mt):
|
201 |
+
i_beam = 0
|
202 |
+
tmp = hypo[i_beam]["tokens"].int() # hyp + eos
|
203 |
+
prev_output_tokens_mt[i, 0] = self.generator_mt.eos
|
204 |
+
if tmp[-1] == self.generator_mt.eos:
|
205 |
+
tmp = tmp[:-1]
|
206 |
+
prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp
|
207 |
+
|
208 |
+
text = "".join([self.generator_mt.tgt_dict[c] for c in tmp])
|
209 |
+
text = text.replace("_", " ")
|
210 |
+
text = text.replace("▁", " ")
|
211 |
+
text = text.replace("<unk>", " ")
|
212 |
+
text = text.replace("<s>", "")
|
213 |
+
text = text.replace("</s>", "")
|
214 |
+
if len(text) > 0 and text[0] == " ":
|
215 |
+
text = text[1:]
|
216 |
+
sample_id = sample["id"].tolist()[i]
|
217 |
+
print("{} (None-{})".format(text, sample_id))
|
218 |
+
|
219 |
+
x = mt_decoder(
|
220 |
+
prev_output_tokens_mt,
|
221 |
+
encoder_out=encoder_outs[0],
|
222 |
+
features_only=True,
|
223 |
+
)[0].transpose(0, 1)
|
224 |
+
|
225 |
+
if getattr(single_model, "proj", None) is not None:
|
226 |
+
x = single_model.proj(x)
|
227 |
+
|
228 |
+
mt_decoder_padding_mask = None
|
229 |
+
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
|
230 |
+
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
|
231 |
+
|
232 |
+
# 2. T2U encoder
|
233 |
+
if getattr(single_model, "synthesizer_encoder", None) is not None:
|
234 |
+
t2u_encoder_out = single_model.synthesizer_encoder(
|
235 |
+
x,
|
236 |
+
mt_decoder_padding_mask,
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
t2u_encoder_out = {
|
240 |
+
"encoder_out": [x], # T x B x C
|
241 |
+
"encoder_padding_mask": [mt_decoder_padding_mask]
|
242 |
+
if mt_decoder_padding_mask is not None
|
243 |
+
else [], # B x T
|
244 |
+
"encoder_embedding": [],
|
245 |
+
"encoder_states": [],
|
246 |
+
"src_tokens": [],
|
247 |
+
"src_lengths": [],
|
248 |
+
}
|
249 |
+
|
250 |
+
if getattr(single_model, "t2u_augmented_cross_attn", False):
|
251 |
+
encoder_outs_aug = [t2u_encoder_out]
|
252 |
+
else:
|
253 |
+
encoder_outs = [t2u_encoder_out]
|
254 |
+
encoder_outs_aug = None
|
255 |
+
|
256 |
+
# 3. T2U decoder
|
257 |
+
finalized = self.generator.generate_decoder(
|
258 |
+
encoder_outs,
|
259 |
+
src_tokens,
|
260 |
+
src_lengths,
|
261 |
+
sample,
|
262 |
+
prefix_tokens,
|
263 |
+
constraints,
|
264 |
+
bos_token,
|
265 |
+
encoder_outs_aug=encoder_outs_aug,
|
266 |
+
)
|
267 |
+
return finalized
|
fairseq/examples/speech_to_text/README.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Speech-to-Text (S2T) Modeling
|
2 |
+
|
3 |
+
[https://www.aclweb.org/anthology/2020.aacl-demo.6](https://www.aclweb.org/anthology/2020.aacl-demo.6.pdf)
|
4 |
+
|
5 |
+
Speech recognition (ASR) and speech-to-text translation (ST) with fairseq.
|
6 |
+
|
7 |
+
## Data Preparation
|
8 |
+
S2T modeling data consists of source speech features, target text and other optional information
|
9 |
+
(source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files
|
10 |
+
to store these information. Each data field is represented by a column in the TSV file.
|
11 |
+
|
12 |
+
Unlike text token embeddings, speech features (e.g. log mel-scale filter banks) are usually fixed
|
13 |
+
during model training and can be pre-computed. The manifest file contains the path to
|
14 |
+
either the feature file in NumPy format or the WAV/FLAC audio file. For the latter,
|
15 |
+
features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed
|
16 |
+
into uncompressed ZIP files (then accessed via byte offset and length) to improve I/O performance.
|
17 |
+
|
18 |
+
Fairseq S2T also employs a YAML file for data related configurations: tokenizer type and dictionary path
|
19 |
+
for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment,
|
20 |
+
temperature-based resampling, etc.
|
21 |
+
|
22 |
+
## Model Training
|
23 |
+
Fairseq S2T uses the unified `fairseq-train` interface for model training. It requires arguments `--task speech_to_text`,
|
24 |
+
`--arch <model architecture in fairseq.models.speech_to_text.*>` and `--config-yaml <config YAML filename>`.
|
25 |
+
|
26 |
+
## Inference & Evaluation
|
27 |
+
Fairseq S2T uses the unified `fairseq-generate`/`fairseq-interactive` interface for inference and evaluation. It
|
28 |
+
requires arguments `--task speech_to_text` and `--config-yaml <config YAML filename>`. The interactive console takes
|
29 |
+
audio paths (one per line) as inputs.
|
30 |
+
|
31 |
+
|
32 |
+
## Examples
|
33 |
+
- [Speech Recognition (ASR) on LibriSpeech](docs/librispeech_example.md)
|
34 |
+
|
35 |
+
- [Speech-to-Text Translation (ST) on MuST-C](docs/mustc_example.md)
|
36 |
+
|
37 |
+
- [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md)
|
38 |
+
|
39 |
+
- [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md)
|
40 |
+
- [Simultaneous Speech-to-Text Translation (SimulST) on MuST-C](docs/simulst_mustc_example.md)
|
41 |
+
|
42 |
+
## Updates
|
43 |
+
- 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples:
|
44 |
+
[ASR (LibriSpeech)](docs/librispeech_example.md#interactive-decoding)
|
45 |
+
and [ST (CoVoST 2)](docs/covost_example.md#interactive-decoding).
|
46 |
+
- 01/08/2021: Several fixes for S2T Transformer model, inference-time de-tokenization, scorer configuration and data
|
47 |
+
preparation scripts. We also add pre-trained models to the examples and revise the instructions.
|
48 |
+
Breaking changes: the data preparation scripts now extract filterbank features without CMVN. CMVN is instead applied
|
49 |
+
on-the-fly (defined in the config YAML).
|
50 |
+
|
51 |
+
## What's Next
|
52 |
+
- We are migrating the old fairseq [ASR example](../speech_recognition) into this S2T framework and
|
53 |
+
merging the features from both sides.
|
54 |
+
- The following papers also base their experiments on fairseq S2T. We are adding more examples for replication.
|
55 |
+
- [Improving Cross-Lingual Transfer Learning for End-to-End Speech Recognition with Speech Translation (Wang et al., 2020)](https://arxiv.org/abs/2006.05474)
|
56 |
+
- [Self-Supervised Representations Improve End-to-End Speech Translation (Wu et al., 2020)](https://arxiv.org/abs/2006.12124)
|
57 |
+
- [Self-Training for End-to-End Speech Translation (Pino et al., 2020)](https://arxiv.org/abs/2006.02490)
|
58 |
+
- [CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus (Wang et al., 2020)](https://arxiv.org/abs/2002.01320)
|
59 |
+
- [Harnessing Indirect Training Data for End-to-End Automatic Speech Translation: Tricks of the Trade (Pino et al., 2019)](https://arxiv.org/abs/1909.06515)
|
60 |
+
|
61 |
+
## Citation
|
62 |
+
Please cite as:
|
63 |
+
```
|
64 |
+
@inproceedings{wang2020fairseqs2t,
|
65 |
+
title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
|
66 |
+
author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
|
67 |
+
booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
|
68 |
+
year = {2020},
|
69 |
+
}
|
70 |
+
|
71 |
+
@inproceedings{ott2019fairseq,
|
72 |
+
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
|
73 |
+
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
|
74 |
+
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
|
75 |
+
year = {2019},
|
76 |
+
}
|
77 |
+
```
|
fairseq/examples/speech_to_text/data_utils.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import csv
|
7 |
+
from pathlib import Path
|
8 |
+
import zipfile
|
9 |
+
from functools import reduce
|
10 |
+
from multiprocessing import cpu_count
|
11 |
+
from typing import Any, Dict, List, Optional, Union
|
12 |
+
import io
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
import sentencepiece as sp
|
17 |
+
from fairseq.data.audio.audio_utils import (
|
18 |
+
convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank, is_npy_data,
|
19 |
+
is_sf_audio_data
|
20 |
+
)
|
21 |
+
import torch
|
22 |
+
import soundfile as sf
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
|
26 |
+
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
|
27 |
+
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
|
28 |
+
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
|
29 |
+
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
|
30 |
+
|
31 |
+
|
32 |
+
def gen_vocab(
|
33 |
+
input_path: Path, output_path_prefix: Path, model_type="bpe",
|
34 |
+
vocab_size=1000, special_symbols: Optional[List[str]] = None
|
35 |
+
):
|
36 |
+
# Train SentencePiece Model
|
37 |
+
arguments = [
|
38 |
+
f"--input={input_path.as_posix()}",
|
39 |
+
f"--model_prefix={output_path_prefix.as_posix()}",
|
40 |
+
f"--model_type={model_type}",
|
41 |
+
f"--vocab_size={vocab_size}",
|
42 |
+
"--character_coverage=1.0",
|
43 |
+
f"--num_threads={cpu_count()}",
|
44 |
+
f"--unk_id={UNK_TOKEN_ID}",
|
45 |
+
f"--bos_id={BOS_TOKEN_ID}",
|
46 |
+
f"--eos_id={EOS_TOKEN_ID}",
|
47 |
+
f"--pad_id={PAD_TOKEN_ID}",
|
48 |
+
]
|
49 |
+
if special_symbols is not None:
|
50 |
+
_special_symbols = ",".join(special_symbols)
|
51 |
+
arguments.append(f"--user_defined_symbols={_special_symbols}")
|
52 |
+
sp.SentencePieceTrainer.Train(" ".join(arguments))
|
53 |
+
# Export fairseq dictionary
|
54 |
+
spm = sp.SentencePieceProcessor()
|
55 |
+
spm.Load(output_path_prefix.as_posix() + ".model")
|
56 |
+
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
|
57 |
+
assert (
|
58 |
+
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
|
59 |
+
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
|
60 |
+
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
|
61 |
+
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
|
62 |
+
)
|
63 |
+
vocab = {
|
64 |
+
i: s
|
65 |
+
for i, s in vocab.items()
|
66 |
+
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
|
67 |
+
}
|
68 |
+
with open(output_path_prefix.as_posix() + ".txt", "w") as f_out:
|
69 |
+
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
|
70 |
+
f_out.write(f"{s} 1\n")
|
71 |
+
|
72 |
+
|
73 |
+
def extract_fbank_features(
|
74 |
+
waveform: torch.FloatTensor,
|
75 |
+
sample_rate: int,
|
76 |
+
output_path: Optional[Path] = None,
|
77 |
+
n_mel_bins: int = 80,
|
78 |
+
overwrite: bool = False,
|
79 |
+
):
|
80 |
+
if output_path is not None and output_path.is_file() and not overwrite:
|
81 |
+
return
|
82 |
+
|
83 |
+
_waveform, _ = convert_waveform(waveform, sample_rate, to_mono=True)
|
84 |
+
# Kaldi compliance: 16-bit signed integers
|
85 |
+
_waveform = _waveform * (2 ** 15)
|
86 |
+
_waveform = _waveform.numpy()
|
87 |
+
|
88 |
+
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
|
89 |
+
if features is None:
|
90 |
+
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
|
91 |
+
if features is None:
|
92 |
+
raise ImportError(
|
93 |
+
"Please install pyKaldi or torchaudio to enable fbank feature extraction"
|
94 |
+
)
|
95 |
+
|
96 |
+
if output_path is not None:
|
97 |
+
np.save(output_path.as_posix(), features)
|
98 |
+
return features
|
99 |
+
|
100 |
+
|
101 |
+
def create_zip(data_root: Path, zip_path: Path):
|
102 |
+
paths = list(data_root.glob("*.npy"))
|
103 |
+
paths.extend(data_root.glob("*.flac"))
|
104 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
|
105 |
+
for path in tqdm(paths):
|
106 |
+
f.write(path, arcname=path.name)
|
107 |
+
|
108 |
+
|
109 |
+
def get_zip_manifest(
|
110 |
+
zip_path: Path, zip_root: Optional[Path] = None, is_audio=False
|
111 |
+
):
|
112 |
+
_zip_path = Path.joinpath(zip_root or Path(""), zip_path)
|
113 |
+
with zipfile.ZipFile(_zip_path, mode="r") as f:
|
114 |
+
info = f.infolist()
|
115 |
+
paths, lengths = {}, {}
|
116 |
+
for i in tqdm(info):
|
117 |
+
utt_id = Path(i.filename).stem
|
118 |
+
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
|
119 |
+
paths[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}"
|
120 |
+
with open(_zip_path, "rb") as f:
|
121 |
+
f.seek(offset)
|
122 |
+
byte_data = f.read(file_size)
|
123 |
+
assert len(byte_data) > 1
|
124 |
+
if is_audio:
|
125 |
+
assert is_sf_audio_data(byte_data), i
|
126 |
+
else:
|
127 |
+
assert is_npy_data(byte_data), i
|
128 |
+
byte_data_fp = io.BytesIO(byte_data)
|
129 |
+
if is_audio:
|
130 |
+
lengths[utt_id] = sf.info(byte_data_fp).frames
|
131 |
+
else:
|
132 |
+
lengths[utt_id] = np.load(byte_data_fp).shape[0]
|
133 |
+
return paths, lengths
|
134 |
+
|
135 |
+
|
136 |
+
def gen_config_yaml(
|
137 |
+
manifest_root: Path,
|
138 |
+
spm_filename: Optional[str] = None,
|
139 |
+
vocab_name: Optional[str] = None,
|
140 |
+
yaml_filename: str = "config.yaml",
|
141 |
+
specaugment_policy: Optional[str] = "lb",
|
142 |
+
prepend_tgt_lang_tag: bool = False,
|
143 |
+
sampling_alpha: Optional[float] = None,
|
144 |
+
input_channels: Optional[int] = 1,
|
145 |
+
input_feat_per_channel: Optional[int] = 80,
|
146 |
+
audio_root: str = "",
|
147 |
+
cmvn_type: str = "utterance",
|
148 |
+
gcmvn_path: Optional[Path] = None,
|
149 |
+
extra=None
|
150 |
+
):
|
151 |
+
manifest_root = manifest_root.absolute()
|
152 |
+
writer = S2TDataConfigWriter(manifest_root / yaml_filename)
|
153 |
+
assert spm_filename is not None or vocab_name is not None
|
154 |
+
vocab_name = spm_filename.replace(".model", ".txt") if vocab_name is None \
|
155 |
+
else vocab_name
|
156 |
+
writer.set_vocab_filename(vocab_name)
|
157 |
+
if input_channels is not None:
|
158 |
+
writer.set_input_channels(input_channels)
|
159 |
+
if input_feat_per_channel is not None:
|
160 |
+
writer.set_input_feat_per_channel(input_feat_per_channel)
|
161 |
+
specaugment_setters = {
|
162 |
+
"lb": writer.set_specaugment_lb_policy,
|
163 |
+
"ld": writer.set_specaugment_ld_policy,
|
164 |
+
"sm": writer.set_specaugment_sm_policy,
|
165 |
+
"ss": writer.set_specaugment_ss_policy,
|
166 |
+
}
|
167 |
+
specaugment_setter = specaugment_setters.get(specaugment_policy, None)
|
168 |
+
if specaugment_setter is not None:
|
169 |
+
specaugment_setter()
|
170 |
+
if spm_filename is not None:
|
171 |
+
writer.set_bpe_tokenizer(
|
172 |
+
{
|
173 |
+
"bpe": "sentencepiece",
|
174 |
+
"sentencepiece_model": (manifest_root / spm_filename).as_posix(),
|
175 |
+
}
|
176 |
+
)
|
177 |
+
if prepend_tgt_lang_tag:
|
178 |
+
writer.set_prepend_tgt_lang_tag(True)
|
179 |
+
if sampling_alpha is not None:
|
180 |
+
writer.set_sampling_alpha(sampling_alpha)
|
181 |
+
|
182 |
+
if cmvn_type not in ["global", "utterance"]:
|
183 |
+
raise NotImplementedError
|
184 |
+
|
185 |
+
if specaugment_policy is not None:
|
186 |
+
writer.set_feature_transforms(
|
187 |
+
"_train", [f"{cmvn_type}_cmvn", "specaugment"]
|
188 |
+
)
|
189 |
+
writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"])
|
190 |
+
|
191 |
+
if cmvn_type == "global":
|
192 |
+
if gcmvn_path is None:
|
193 |
+
raise ValueError("Please provide path of global cmvn file.")
|
194 |
+
else:
|
195 |
+
writer.set_global_cmvn(gcmvn_path.as_posix())
|
196 |
+
|
197 |
+
if len(audio_root) > 0:
|
198 |
+
writer.set_audio_root(audio_root)
|
199 |
+
|
200 |
+
if extra is not None:
|
201 |
+
writer.set_extra(extra)
|
202 |
+
writer.flush()
|
203 |
+
|
204 |
+
|
205 |
+
def load_df_from_tsv(path: Union[str, Path]) -> pd.DataFrame:
|
206 |
+
_path = path if isinstance(path, str) else path.as_posix()
|
207 |
+
return pd.read_csv(
|
208 |
+
_path,
|
209 |
+
sep="\t",
|
210 |
+
header=0,
|
211 |
+
encoding="utf-8",
|
212 |
+
escapechar="\\",
|
213 |
+
quoting=csv.QUOTE_NONE,
|
214 |
+
na_filter=False,
|
215 |
+
)
|
216 |
+
|
217 |
+
|
218 |
+
def save_df_to_tsv(dataframe, path: Union[str, Path]):
|
219 |
+
_path = path if isinstance(path, str) else path.as_posix()
|
220 |
+
dataframe.to_csv(
|
221 |
+
_path,
|
222 |
+
sep="\t",
|
223 |
+
header=True,
|
224 |
+
index=False,
|
225 |
+
encoding="utf-8",
|
226 |
+
escapechar="\\",
|
227 |
+
quoting=csv.QUOTE_NONE,
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
def load_tsv_to_dicts(path: Union[str, Path]) -> List[dict]:
|
232 |
+
with open(path, "r") as f:
|
233 |
+
reader = csv.DictReader(
|
234 |
+
f,
|
235 |
+
delimiter="\t",
|
236 |
+
quotechar=None,
|
237 |
+
doublequote=False,
|
238 |
+
lineterminator="\n",
|
239 |
+
quoting=csv.QUOTE_NONE,
|
240 |
+
)
|
241 |
+
rows = [dict(e) for e in reader]
|
242 |
+
return rows
|
243 |
+
|
244 |
+
|
245 |
+
def filter_manifest_df(
|
246 |
+
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
|
247 |
+
):
|
248 |
+
filters = {
|
249 |
+
"no speech": df["audio"] == "",
|
250 |
+
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
|
251 |
+
"empty sentence": df["tgt_text"] == "",
|
252 |
+
}
|
253 |
+
if is_train_split:
|
254 |
+
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
|
255 |
+
if extra_filters is not None:
|
256 |
+
filters.update(extra_filters)
|
257 |
+
invalid = reduce(lambda x, y: x | y, filters.values())
|
258 |
+
valid = ~invalid
|
259 |
+
print(
|
260 |
+
"| "
|
261 |
+
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
|
262 |
+
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
|
263 |
+
)
|
264 |
+
return df[valid]
|
265 |
+
|
266 |
+
|
267 |
+
def cal_gcmvn_stats(features_list):
|
268 |
+
features = np.concatenate(features_list)
|
269 |
+
square_sums = (features ** 2).sum(axis=0)
|
270 |
+
mean = features.mean(axis=0)
|
271 |
+
features = np.subtract(features, mean)
|
272 |
+
var = square_sums / features.shape[0] - mean ** 2
|
273 |
+
std = np.sqrt(np.maximum(var, 1e-8))
|
274 |
+
return {"mean": mean.astype("float32"), "std": std.astype("float32")}
|
275 |
+
|
276 |
+
|
277 |
+
class S2TDataConfigWriter(object):
|
278 |
+
DEFAULT_VOCAB_FILENAME = "dict.txt"
|
279 |
+
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
|
280 |
+
DEFAULT_INPUT_CHANNELS = 1
|
281 |
+
|
282 |
+
def __init__(self, yaml_path: Path):
|
283 |
+
try:
|
284 |
+
import yaml
|
285 |
+
except ImportError:
|
286 |
+
print("Please install PyYAML for S2T data config YAML files")
|
287 |
+
self.yaml = yaml
|
288 |
+
self.yaml_path = yaml_path
|
289 |
+
self.config = {}
|
290 |
+
|
291 |
+
def flush(self):
|
292 |
+
with open(self.yaml_path, "w") as f:
|
293 |
+
self.yaml.dump(self.config, f)
|
294 |
+
|
295 |
+
def set_audio_root(self, audio_root=""):
|
296 |
+
self.config["audio_root"] = audio_root
|
297 |
+
|
298 |
+
def set_vocab_filename(self, vocab_filename: str = "dict.txt"):
|
299 |
+
self.config["vocab_filename"] = vocab_filename
|
300 |
+
|
301 |
+
def set_specaugment(
|
302 |
+
self,
|
303 |
+
time_wrap_w: int,
|
304 |
+
freq_mask_n: int,
|
305 |
+
freq_mask_f: int,
|
306 |
+
time_mask_n: int,
|
307 |
+
time_mask_t: int,
|
308 |
+
time_mask_p: float,
|
309 |
+
):
|
310 |
+
self.config["specaugment"] = {
|
311 |
+
"time_wrap_W": time_wrap_w,
|
312 |
+
"freq_mask_N": freq_mask_n,
|
313 |
+
"freq_mask_F": freq_mask_f,
|
314 |
+
"time_mask_N": time_mask_n,
|
315 |
+
"time_mask_T": time_mask_t,
|
316 |
+
"time_mask_p": time_mask_p,
|
317 |
+
}
|
318 |
+
|
319 |
+
def set_specaugment_lb_policy(self):
|
320 |
+
self.set_specaugment(
|
321 |
+
time_wrap_w=0,
|
322 |
+
freq_mask_n=1,
|
323 |
+
freq_mask_f=27,
|
324 |
+
time_mask_n=1,
|
325 |
+
time_mask_t=100,
|
326 |
+
time_mask_p=1.0,
|
327 |
+
)
|
328 |
+
|
329 |
+
def set_specaugment_ld_policy(self):
|
330 |
+
self.set_specaugment(
|
331 |
+
time_wrap_w=0,
|
332 |
+
freq_mask_n=2,
|
333 |
+
freq_mask_f=27,
|
334 |
+
time_mask_n=2,
|
335 |
+
time_mask_t=100,
|
336 |
+
time_mask_p=1.0,
|
337 |
+
)
|
338 |
+
|
339 |
+
def set_specaugment_sm_policy(self):
|
340 |
+
self.set_specaugment(
|
341 |
+
time_wrap_w=0,
|
342 |
+
freq_mask_n=2,
|
343 |
+
freq_mask_f=15,
|
344 |
+
time_mask_n=2,
|
345 |
+
time_mask_t=70,
|
346 |
+
time_mask_p=0.2,
|
347 |
+
)
|
348 |
+
|
349 |
+
def set_specaugment_ss_policy(self):
|
350 |
+
self.set_specaugment(
|
351 |
+
time_wrap_w=0,
|
352 |
+
freq_mask_n=2,
|
353 |
+
freq_mask_f=27,
|
354 |
+
time_mask_n=2,
|
355 |
+
time_mask_t=70,
|
356 |
+
time_mask_p=0.2,
|
357 |
+
)
|
358 |
+
|
359 |
+
def set_input_channels(self, input_channels: int = 1):
|
360 |
+
self.config["input_channels"] = input_channels
|
361 |
+
|
362 |
+
def set_input_feat_per_channel(self, input_feat_per_channel: int = 80):
|
363 |
+
self.config["input_feat_per_channel"] = input_feat_per_channel
|
364 |
+
|
365 |
+
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
|
366 |
+
self.config["bpe_tokenizer"] = bpe_tokenizer
|
367 |
+
|
368 |
+
def set_global_cmvn(self, stats_npz_path: str):
|
369 |
+
self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path}
|
370 |
+
|
371 |
+
def set_feature_transforms(self, split: str, transforms: List[str]):
|
372 |
+
if "transforms" not in self.config:
|
373 |
+
self.config["transforms"] = {}
|
374 |
+
self.config["transforms"][split] = transforms
|
375 |
+
|
376 |
+
def set_prepend_tgt_lang_tag(self, flag: bool = True):
|
377 |
+
self.config["prepend_tgt_lang_tag"] = flag
|
378 |
+
|
379 |
+
def set_sampling_alpha(self, sampling_alpha: float = 1.0):
|
380 |
+
self.config["sampling_alpha"] = sampling_alpha
|
381 |
+
|
382 |
+
def set_extra(self, data):
|
383 |
+
self.config.update(data)
|
fairseq/examples/speech_to_text/docs/covost_example.md
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[Back]](..)
|
2 |
+
|
3 |
+
# S2T Example: ST on CoVoST
|
4 |
+
|
5 |
+
We replicate the experiments in
|
6 |
+
[CoVoST 2 and Massively Multilingual Speech-to-Text Translation (Wang et al., 2020)](https://arxiv.org/abs/2007.10310).
|
7 |
+
|
8 |
+
## Data Preparation
|
9 |
+
|
10 |
+
[Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path
|
11 |
+
`${COVOST_ROOT}/${SOURCE_LANG_ID}`, then preprocess it with
|
12 |
+
|
13 |
+
```bash
|
14 |
+
# additional Python packages for S2T data processing/model training
|
15 |
+
pip install pandas torchaudio sentencepiece
|
16 |
+
|
17 |
+
# En ASR
|
18 |
+
python examples/speech_to_text/prep_covost_data.py \
|
19 |
+
--data-root ${COVOST_ROOT} --vocab-type char --src-lang en
|
20 |
+
# ST
|
21 |
+
python examples/speech_to_text/prep_covost_data.py \
|
22 |
+
--data-root ${COVOST_ROOT} --vocab-type char \
|
23 |
+
--src-lang fr --tgt-lang en
|
24 |
+
```
|
25 |
+
|
26 |
+
The generated files (manifest, features, vocabulary and data configuration) will be added to
|
27 |
+
`${COVOST_ROOT}/${SOURCE_LANG_ID}`.
|
28 |
+
|
29 |
+
Download our vocabulary files if you want to use our pre-trained models:
|
30 |
+
|
31 |
+
- ASR: [En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_vocab_char.zip)
|
32 |
+
- ST: [Fr-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_vocab_char.zip), [De-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_vocab_char.zip), [Es-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_vocab_char.zip), [Ca-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_vocab_char.zip), [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip), [En-Ca](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_vocab_char.zip), [En-Fa](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_vocab_char.zip), [En-Et](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_vocab_char.zip)
|
33 |
+
|
34 |
+
## ASR
|
35 |
+
|
36 |
+
#### Training
|
37 |
+
|
38 |
+
We train an En ASR model for encoder pre-training some of the ST models.
|
39 |
+
|
40 |
+
```bash
|
41 |
+
fairseq-train ${COVOST_ROOT}/en \
|
42 |
+
--config-yaml config_asr_en.yaml --train-subset train_asr_en --valid-subset dev_asr_en \
|
43 |
+
--save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 50000 --max-update 60000 \
|
44 |
+
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
45 |
+
--report-accuracy --arch s2t_transformer_s --dropout 0.15 --optimizer adam --lr 2e-3 \
|
46 |
+
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
|
47 |
+
--attn-type None --pos-enc-type ${POS_ENC_TYPE}
|
48 |
+
```
|
49 |
+
|
50 |
+
where `ASR_SAVE_DIR` is the checkpoint root path and `POS_ENC_TYPE` refers to positional encoding to be used in the conformer encoder.
|
51 |
+
Set it to `abs`, `rope` or `rel_pos` to use the absolute positional encoding, rotary positional encoding or relative positional encoding in the conformer layer respectively.
|
52 |
+
Transformer encoder only supports absolute positional encoding and by default, the transformer encoder will be used.
|
53 |
+
To switch to conformer, set `--attn-type espnet` and `--POS_ENC_TYPE`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
|
54 |
+
|
55 |
+
#### Inference & Evaluation
|
56 |
+
|
57 |
+
```bash
|
58 |
+
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
|
59 |
+
python scripts/average_checkpoints.py \
|
60 |
+
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
|
61 |
+
--output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
|
62 |
+
fairseq-generate ${COVOST_ROOT}/en \
|
63 |
+
--config-yaml config_asr_en.yaml --gen-subset test_asr_en --task speech_to_text \
|
64 |
+
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
|
65 |
+
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
|
66 |
+
```
|
67 |
+
|
68 |
+
#### Results
|
69 |
+
|
70 |
+
| --arch | --pos-enc-type | Params | En | Model |
|
71 |
+
|---|---|---|---|---|
|
72 |
+
| s2t_transformer_s | - | 31M | 25.6 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_transformer_s.pt) |
|
73 |
+
| s2t_conformer | rel_pos | 42.9M | 23.18| [Download](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_asr/rel_pos_asr_checkpoint_best.pt) |
|
74 |
+
| s2t_conformer | rope | 42.1M | 23.8| [Download](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_asr/rope_pos_asr_checkpoint_best.pt) |
|
75 |
+
| s2t_conformer | abs | 42.1M | 23.8| [Download](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_asr/abs_asr_checkpoint_best.pt) |
|
76 |
+
|
77 |
+
## ST
|
78 |
+
|
79 |
+
#### Training
|
80 |
+
|
81 |
+
Fr-En as example:
|
82 |
+
|
83 |
+
```bash
|
84 |
+
fairseq-train ${COVOST_ROOT}/fr \
|
85 |
+
--config-yaml config_st_fr_en.yaml --train-subset train_st_fr_en --valid-subset dev_st_fr_en \
|
86 |
+
--save-dir ${ST_SAVE_DIR} --num-workers 4 --max-update 30000 --max-tokens 40000 \ # --max-tokens 50000 for en-*
|
87 |
+
--task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
|
88 |
+
--arch s2t_transformer_s --encoder-freezing-updates 1000 --optimizer adam --lr 2e-3 \
|
89 |
+
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
|
90 |
+
--attn-type None --pos-enc-type ${POS_ENC_TYPE} \
|
91 |
+
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
|
92 |
+
```
|
93 |
+
|
94 |
+
where `ST_SAVE_DIR` is the checkpoint root path and `POS_ENC_TYPE` refers to positional encoding to be used in the conformer encoder.
|
95 |
+
Set it to `abs`, `rope` or `rel_pos` to use the absolute positional encoding, rotary positional encoding or relative positional encoding in the conformer layer respectively.
|
96 |
+
Transformer encoder only supports absolute positional encoding and by default, the transformer encoder will be used.
|
97 |
+
To switch to conformer, set `--attn-type espnet` and `--POS_ENC_TYPE`. Optionally load the pre-trained En ASR encoder for faster training and better
|
98 |
+
performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
|
99 |
+
You may want to update it accordingly when using more than 1 GPU.
|
100 |
+
|
101 |
+
#### Inference & Evaluation
|
102 |
+
|
103 |
+
Average the last 10 checkpoints and evaluate on test split:
|
104 |
+
|
105 |
+
```bash
|
106 |
+
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
|
107 |
+
python scripts/average_checkpoints.py \
|
108 |
+
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \
|
109 |
+
--output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
|
110 |
+
fairseq-generate ${COVOST_ROOT}/fr \
|
111 |
+
--config-yaml config_st_fr_en.yaml --gen-subset test_st_fr_en --task speech_to_text \
|
112 |
+
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
|
113 |
+
--max-tokens 50000 --beam 5 --scoring sacrebleu
|
114 |
+
```
|
115 |
+
|
116 |
+
## Interactive Decoding
|
117 |
+
|
118 |
+
Launch the interactive console via
|
119 |
+
|
120 |
+
```bash
|
121 |
+
fairseq-interactive ${COVOST_ROOT}/fr --config-yaml config_st_fr_en.yaml \
|
122 |
+
--task speech_to_text --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \
|
123 |
+
--max-tokens 50000 --beam 5
|
124 |
+
```
|
125 |
+
|
126 |
+
Type in WAV/FLAC/OGG audio paths (one per line) after the prompt.
|
127 |
+
|
128 |
+
#### Results
|
129 |
+
|
130 |
+
| --arch | --pos-enc-type | Params | ASR PT | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | Model |
|
131 |
+
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
132 |
+
| s2t_transformer | - | 31M | Yes | [27.2](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_transformer_s.pt) | [23.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_transformer_s.pt) | [19.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_transformer_s.pt) | [16.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_transformer_s.pt) | [21.6](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_transformer_s.pt) | [12.9](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_transformer_s.pt) | [12.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_transformer_s.pt) | (<-Download) |
|
133 |
+
| s2t_conformer | rel_pos | 42.9M | No | [28.32](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [18.21](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [25.98](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [21.13](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [20.37](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [25.89](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [15.59](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [14.49](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | (<-Download) |
|
134 |
+
| s2t_conformer | rel_pos | 42.9M | Yes| [27.15](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [18.22](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [25.14](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [21.68](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [20.35](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [25.92](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [15.76](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [16.52](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | (<-Download) |
|
135 |
+
| s2t_conformer | rope | 42.1M | No | [27.61](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/rope_from_scratch_avg_last_10_checkpoint.pt) | [17.6](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/rope_from_scratch_avg_last_10_checkpoint.pt) | [24.91](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/rope_from_scratch_avg_last_10_checkpoint.pt) | [20.78](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/rope_from_scratch_avg_last_10_checkpoint.pt) | [19.7](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rope_from_scratch_avg_last_10_checkpoint.pt) | [25.13](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/rope_from_scratch_avg_last_10_checkpoint.pt) | [15.22](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/rope_from_scratch_avg_last_10_checkpoint.pt) | [15.87](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/rope_from_scratch_avg_last_10_checkpoint.pt) | (<-Download) |
|
136 |
+
| s2t_conformer | rope | 42.1M | Yes | [26.99](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/rope_asr_pt_avg_last_10_checkpoint.pt) | [17.71](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/rope_asr_pt_avg_last_10_checkpoint.pt) | [24.24](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/rope_asr_pt_avg_last_10_checkpoint.pt) | [21.24](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/rope_asr_pt_avg_last_10_checkpoint.pt) | [19.9](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rope_asr_pt_avg_last_10_checkpoint.pt) | [25.25](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/rope_asr_pt_avg_last_10_checkpoint.pt) | [15.58](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/rope_asr_pt_avg_last_10_checkpoint.pt) | [15.97](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/rope_asr_pt_avg_last_10_checkpoint.pt) | (<-Download) |
|
137 |
+
| s2t_conformer | abs | 42.1M | No | [27.45](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/abs_from_scratch_avg_last_10_checkpoint.pt) | [17.25](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/abs_from_scratch_avg_last_10_checkpoint.pt) | [25.01](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/abs_from_scratch_avg_last_10_checkpoint.pt) | [20.26](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/abs_from_scratch_avg_last_10_checkpoint.pt) | [19.86](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/abs_from_scratch_avg_last_10_checkpoint.pt) | [25.25](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/abs_from_scratch_avg_last_10_checkpoint.pt) | [15.46](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/abs_from_scratch_avg_last_10_checkpoint.pt) | [15.81](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/abs_from_scratch_avg_last_10_checkpoint.pt) | (<-Download) |
|
138 |
+
| s2t_conforme | abs | 42.1M | Yes| [26.52](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/abs_asr_pt_avg_last_10_checkpoint.pt) | [17.37](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/abs_asr_pt_avg_last_10_checkpoint.pt) | [25.40](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/abs_asr_pt_avg_last_10_checkpoint.pt) | [20.45](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/abs_asr_pt_avg_last_10_checkpoint.pt) | [19.57](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/abs_asr_pt_avg_last_10_checkpoint.pt) | [25.40](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/abs_asr_pt_avg_last_10_checkpoint.pt) | [15.17](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/abs_asr_pt_avg_last_10_checkpoint.pt) | [15.83](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/abs_asr_pt_avg_last_10_checkpoint.pt) | (<-Download) |
|
139 |
+
|
140 |
+
[[Back]](..)
|