PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
211c22d
·
verified ·
1 Parent(s): a1d9110

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py +101 -0
  2. fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py +224 -0
  3. fairseq/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py +318 -0
  4. fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md +118 -0
  5. fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md +76 -0
  6. fairseq/examples/speech_text_joint_to_text/docs/pre-training.md +192 -0
  7. fairseq/examples/speech_text_joint_to_text/models/__init__.py +8 -0
  8. fairseq/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py +698 -0
  9. fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py +1093 -0
  10. fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py +526 -0
  11. fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py +584 -0
  12. fairseq/examples/speech_text_joint_to_text/scripts/convert_model.py +71 -0
  13. fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py +191 -0
  14. fairseq/examples/speech_text_joint_to_text/tasks/__init__.py +8 -0
  15. fairseq/examples/speech_text_joint_to_text/tasks/pair_denoising.py +447 -0
  16. fairseq/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py +654 -0
  17. fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py +377 -0
  18. fairseq/examples/speech_to_speech/README.md +7 -0
  19. fairseq/examples/speech_to_speech/__init__.py +6 -0
  20. fairseq/examples/speech_to_speech/asr_bleu/README.md +34 -0
  21. fairseq/examples/speech_to_speech/asr_bleu/__init__.py +0 -0
  22. fairseq/examples/speech_to_speech/asr_bleu/asr_model_cfgs.json +198 -0
  23. fairseq/examples/speech_to_speech/asr_bleu/compute_asr_bleu.py +244 -0
  24. fairseq/examples/speech_to_speech/asr_bleu/requirements.txt +7 -0
  25. fairseq/examples/speech_to_speech/asr_bleu/utils.py +306 -0
  26. fairseq/examples/speech_to_speech/benchmarking/README.md +31 -0
  27. fairseq/examples/speech_to_speech/benchmarking/configs/2StageS2ST.yaml +19 -0
  28. fairseq/examples/speech_to_speech/benchmarking/configs/3StageS2ST.yaml +28 -0
  29. fairseq/examples/speech_to_speech/benchmarking/configs/DirectS2U.yaml +22 -0
  30. fairseq/examples/speech_to_speech/benchmarking/configs/S2T.yaml +13 -0
  31. fairseq/examples/speech_to_speech/benchmarking/core.py +487 -0
  32. fairseq/examples/speech_to_speech/benchmarking/data_utils.py +264 -0
  33. fairseq/examples/speech_to_speech/benchmarking/get_metrics.py +162 -0
  34. fairseq/examples/speech_to_speech/docs/data_augmentation.md +435 -0
  35. fairseq/examples/speech_to_speech/docs/direct_s2st_discrete_units.md +181 -0
  36. fairseq/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md +125 -0
  37. fairseq/examples/speech_to_speech/docs/textless_s2st_real_data.md +89 -0
  38. fairseq/examples/speech_to_speech/generate_waveform_from_code.py +116 -0
  39. fairseq/examples/speech_to_speech/preprocessing/__init__.py +4 -0
  40. fairseq/examples/speech_to_speech/preprocessing/data_utils.py +88 -0
  41. fairseq/examples/speech_to_speech/preprocessing/prep_s2spect_data.py +169 -0
  42. fairseq/examples/speech_to_speech/preprocessing/prep_s2ut_data.py +114 -0
  43. fairseq/examples/speech_to_speech/preprocessing/prep_sn_data.py +88 -0
  44. fairseq/examples/speech_to_speech/preprocessing/prep_sn_output_data.py +58 -0
  45. fairseq/examples/speech_to_speech/unity/__init__.py +7 -0
  46. fairseq/examples/speech_to_speech/unity/sequence_generator.py +626 -0
  47. fairseq/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py +267 -0
  48. fairseq/examples/speech_to_text/README.md +77 -0
  49. fairseq/examples/speech_to_text/data_utils.py +383 -0
  50. fairseq/examples/speech_to_text/docs/covost_example.md +140 -0
fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+
7
+ from fairseq import utils
8
+ from fairseq.criterions import register_criterion
9
+ from fairseq.criterions.label_smoothed_cross_entropy import (
10
+ LabelSmoothedCrossEntropyCriterion,
11
+ LabelSmoothedCrossEntropyCriterionConfig,
12
+ label_smoothed_nll_loss,
13
+ )
14
+
15
+
16
+ @register_criterion(
17
+ "speech_text_pretrain_cross_entropy",
18
+ dataclass=LabelSmoothedCrossEntropyCriterionConfig,
19
+ )
20
+ class SpeechTextPreTrainCrossEntCriterion(LabelSmoothedCrossEntropyCriterion):
21
+ def __init__(self, task, sentence_avg, label_smoothing, report_accuracy=False):
22
+ super().__init__(
23
+ task, sentence_avg, label_smoothing, report_accuracy=report_accuracy
24
+ )
25
+
26
+ def forward(self, model, sample, reduce=True):
27
+ net_output = model(**sample["net_input"])
28
+ loss, nll_loss, nsentences, ntokens, n_correct = self.compute_loss(
29
+ model, net_output, sample, reduce=reduce
30
+ )
31
+ sample_size = nsentences if self.sentence_avg else ntokens
32
+ logging_output = {
33
+ "loss": loss.data,
34
+ "nll_loss": nll_loss.data,
35
+ "ntokens": ntokens,
36
+ "nsentences": nsentences,
37
+ "sample_size": sample_size,
38
+ }
39
+ if self.report_accuracy:
40
+ logging_output["n_correct"] = utils.item(n_correct)
41
+ logging_output["total"] = utils.item(ntokens)
42
+ return loss, sample_size, logging_output
43
+
44
+ def get_lprobs_and_target(self, model, net_output, sample):
45
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
46
+ target = model.get_targets(sample, net_output)
47
+ assert self.ignore_prefix_size == 0
48
+ if self.ignore_prefix_size > 0:
49
+ if getattr(lprobs, "batch_first", False):
50
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
51
+ target = target[:, self.ignore_prefix_size :].contiguous()
52
+ else:
53
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
54
+ target = target[self.ignore_prefix_size :, :].contiguous()
55
+ return lprobs, target
56
+
57
+ def compute_loss(self, model, net_output, sample, reduce=True):
58
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
59
+ n_correct = 0
60
+ if isinstance(target, dict):
61
+ t_lprobs = target["target_logprobs"]
62
+
63
+ if not lprobs.batch_first:
64
+ lprobs = lprobs.transpose(0, 1)
65
+ t_lprobs = t_lprobs.transpose(0, 1)
66
+ nsentences, seq_len = lprobs.size()[:2]
67
+ ntokens = nsentences * seq_len
68
+ t_probs = t_lprobs.exp()
69
+ mask_indices = (
70
+ net_output[1]["mask_indices"][0]
71
+ if len(net_output[1]["mask_indices"]) > 0
72
+ else None
73
+ )
74
+
75
+ # mask_indices is True for those masking frames
76
+ if mask_indices is not None: # B X T
77
+ t_probs = t_probs.masked_fill(mask_indices.eq(False).unsqueeze(-1), 0)
78
+ ntokens = mask_indices.int().sum()
79
+ t_probs = t_probs.detach()
80
+ t_lprobs = t_lprobs.detach()
81
+ loss = (
82
+ -(t_probs * (lprobs - t_lprobs)).sum()
83
+ if reduce
84
+ else -(t_probs * (lprobs - t_lprobs)).sum(-1, keepdim=True)
85
+ )
86
+ nll_loss = loss
87
+ else:
88
+ nsentences = target.size(0)
89
+ mask = target.ne(self.padding_idx)
90
+ loss, nll_loss = label_smoothed_nll_loss(
91
+ lprobs.view(-1, lprobs.size(-1)),
92
+ target.view(-1),
93
+ self.eps,
94
+ ignore_index=self.padding_idx,
95
+ reduce=reduce,
96
+ )
97
+ n_correct = torch.sum(
98
+ lprobs.argmax(-1).masked_select(mask).eq(target.masked_select(mask))
99
+ )
100
+ ntokens = torch.sum(mask)
101
+ return loss, nll_loss, nsentences, ntokens, n_correct
fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from fairseq import utils
10
+ from fairseq.criterions import FairseqCriterion, register_criterion
11
+ from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
12
+ from fairseq.logging import metrics
13
+
14
+
15
+ @register_criterion("guided_label_smoothed_cross_entropy_with_accuracy")
16
+ class GuidedCrossEntAccCriterion(FairseqCriterion):
17
+ def __init__(
18
+ self,
19
+ task,
20
+ sentence_avg,
21
+ guide_alpha,
22
+ text_input_cost_ratio,
23
+ label_smoothing,
24
+ disable_text_guide_update_num=0,
25
+ attentive_cost_regularization=0,
26
+ ):
27
+ """
28
+ guide_alpha: alpha to inteplate nll and kd loss
29
+ text_input_cost_ratio: loss ratio for text only input data
30
+ label_smoothing: label smoothing ratio
31
+ disable_text_guide_update_num: only use nll loss for the first N updates
32
+ attentive_cost_regularization: ratio fo attentive cost
33
+ """
34
+ super().__init__(task)
35
+ self.alpha = guide_alpha
36
+ self.attn_beta = attentive_cost_regularization
37
+ self.sentence_avg = sentence_avg
38
+ self.eps = label_smoothing
39
+ self.text_input_cost_ratio = text_input_cost_ratio
40
+ self.disable_update_num = disable_text_guide_update_num
41
+ assert self.alpha >= 0 and self.alpha <= 1.0
42
+
43
+ @staticmethod
44
+ def add_args(parser):
45
+ """Add criterion-specific arguments to the parser."""
46
+ # fmt: off
47
+ parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
48
+ help='epsilon for label smoothing, 0 means no label smoothing')
49
+ # fmt: off
50
+ parser.add_argument('--guide-alpha', default=0., type=float, metavar='D',
51
+ help='alpha to merge kd cost from text to speech input with ce loss')
52
+ # fmt: off
53
+ parser.add_argument('--disable-text-guide-update-num', default=0, type=int, metavar='D',
54
+ help='disable guided target from text for the first N updates.')
55
+ parser.add_argument("--attentive-cost-regularization", default=0.0, type=float, metavar='D',
56
+ help="use encoder attentive loss regularization with cost ratio D")
57
+ parser.add_argument("--attentive-cost-without-normalize", action='store_true',
58
+ help="Don't do normalization during attentive cost computation")
59
+
60
+ def forward(self, model, sample, reduce=True):
61
+ reduction = 'sum' if reduce else 'none'
62
+ net_input = sample["net_input"]
63
+ net_output = model(**net_input)
64
+ attn_cost = None
65
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
66
+ is_dual_input = True if net_input['src_tokens'] is not None and net_input.get('src_txt_tokens') is not None else False
67
+ target = model.get_targets(sample, net_output)
68
+ src_token_num = 0
69
+ if is_dual_input:
70
+ # lprobs_spch from speech encoder and lprobs_text from text encoder
71
+ lprobs_spch, lprobs_text = torch.chunk(lprobs, 2)
72
+ lprobs_spch.batch_first = lprobs.batch_first
73
+ lprobs_text.batch_first = lprobs.batch_first
74
+
75
+ speech_loss, speech_nll_loss, speech_correct, speech_total = \
76
+ self.guide_loss_and_acc(model, lprobs_spch, lprobs_text, target, reduce=(reduction == 'sum'))
77
+ text_loss, text_nll_loss, text_correct, text_total = self.compute_loss_and_acc(model, lprobs_text, target, reduction=reduction)
78
+ loss = (speech_loss + text_loss)
79
+ nll_loss = (speech_nll_loss + text_nll_loss)
80
+ correct = speech_correct + text_correct
81
+ total = speech_total + text_total
82
+
83
+ attn_cost = net_output[1].get('attn_cost')
84
+ if attn_cost is not None:
85
+ # attn_cost is batch_first and padding tokens have been masked already
86
+ src_token_num = attn_cost.ne(0).sum()
87
+ attn_cost = attn_cost.sum()
88
+ loss = loss + attn_cost * self.attn_beta
89
+ else:
90
+ attn_cost = 0
91
+ else:
92
+ loss, nll_loss, correct, total = self.compute_loss_and_acc(model, lprobs, target, reduction=reduction)
93
+ if sample["net_input"]['src_tokens'] is None: # text input only
94
+ loss = loss * self.text_input_cost_ratio
95
+ speech_loss = None
96
+ speech_nll_loss = None
97
+
98
+ sample_size, logging_output = self.get_logging_output(
99
+ sample, loss, nll_loss, correct, total, src_token_num, speech_loss, speech_nll_loss, attn_cost, is_dual_input
100
+ )
101
+ return loss, sample_size, logging_output
102
+
103
+ def compute_loss_and_acc(self, model, lprobs, target, reduction='sum'):
104
+ if not lprobs.batch_first:
105
+ lprobs = lprobs.transpose(0, 1)
106
+ lprobs = lprobs.view(-1, lprobs.size(-1)) # -> (B x T) x C
107
+ target = target.view(-1)
108
+ loss, nll_loss = label_smoothed_nll_loss(
109
+ lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=(reduction == 'sum'),
110
+ )
111
+
112
+ mask = target.ne(self.padding_idx)
113
+ correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
114
+ total = torch.sum(mask)
115
+ return loss, nll_loss, correct, total
116
+
117
+ def guide_loss_and_acc(self, model, lprobs, lprobs_teacher, target, reduce=True):
118
+ """ lprobs_teacher is used as guide for lprobs """
119
+ if self.alpha == 0.0 or model.num_updates < self.disable_update_num:
120
+ return self.compute_loss_and_acc(model, lprobs, target, reduction=('sum' if reduce else 'none'))
121
+ if not lprobs.batch_first:
122
+ lprobs = lprobs.transpose(0, 1)
123
+ lprobs_teacher = lprobs_teacher.transpose(0, 1)
124
+
125
+ lprobs = lprobs.view(-1, lprobs.size(-1)).float() # -> (B x T) x C
126
+ lprobs_teacher = lprobs_teacher.view(-1, lprobs_teacher.size(-1)).float() # -> (B x T) x C
127
+ target = target.view(-1)
128
+ loss = F.nll_loss(lprobs, target, ignore_index=self.padding_idx, reduction='sum' if reduce else 'none')
129
+ nll_loss = loss
130
+ probs_teacher = lprobs_teacher.exp().masked_fill_(target.unsqueeze(-1).eq(self.padding_idx), 0)
131
+ probs_teacher = probs_teacher.detach()
132
+ guide_loss = -(probs_teacher*lprobs).sum() if reduce else -(probs_teacher*lprobs).sum(-1, keepdim=True)
133
+ loss = self.alpha*guide_loss + (1.0 - self.alpha)*loss
134
+
135
+ mask = target.ne(self.padding_idx)
136
+ correct = torch.sum(lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
137
+ total = torch.sum(mask)
138
+ return loss, nll_loss, correct, total
139
+
140
+ def get_logging_output(
141
+ self,
142
+ sample,
143
+ loss,
144
+ nll_loss,
145
+ correct,
146
+ total,
147
+ src_token_num=0,
148
+ speech_loss=None,
149
+ speech_nll_loss=None,
150
+ attn_cost=None,
151
+ is_dual_input=False,
152
+ ):
153
+
154
+ sample_size = (
155
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
156
+ )
157
+ mul_size = 2 if is_dual_input else 1
158
+
159
+ logging_output = {
160
+ "loss": utils.item(loss.data), # * sample['ntokens'],
161
+ "nll_loss": utils.item(nll_loss.data), # * sample['ntokens'],
162
+ "ntokens": sample["ntokens"]*mul_size,
163
+ "nsentences": sample["target"].size(0)*mul_size,
164
+ "sample_size": sample_size*mul_size,
165
+ "correct": utils.item(correct.data),
166
+ "total": utils.item(total.data),
167
+ "src_token_num": utils.item(src_token_num.data) if src_token_num > 0 else 0,
168
+ "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
169
+ }
170
+
171
+ if speech_loss is not None:
172
+ logging_output["speech_loss"] = utils.item(speech_loss.data)
173
+ logging_output["speech_nll_loss"] = utils.item(speech_nll_loss.data)
174
+ logging_output["sample_size_speech_cost"] = sample_size
175
+ logging_output["speech_attn_loss"] = attn_cost
176
+
177
+ return sample_size*mul_size, logging_output
178
+
179
+ @staticmethod
180
+ def aggregate_logging_outputs(logging_outputs):
181
+ """Aggregate logging outputs from data parallel training."""
182
+ correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
183
+ total_sum = sum(log.get("total", 0) for log in logging_outputs)
184
+ src_token_sum = sum(log.get("src_token_num", 0) for log in logging_outputs)
185
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
186
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
187
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
188
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
189
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
190
+ nframes = sum(log.get("nframes", 0) for log in logging_outputs)
191
+ speech_loss_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
192
+ speech_nll_loss_sum = sum(log.get("speech_nll_loss", 0) for log in logging_outputs)
193
+ speech_attn_loss_sum = sum(log.get("speech_attn_loss", 0) for log in logging_outputs)
194
+ sample_size_speech = sum(log.get("sample_size_speech_cost", 0) for log in logging_outputs)
195
+
196
+ agg_output = {
197
+ "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
198
+ "nll_loss": nll_loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
199
+ # if args.sentence_avg, then sample_size is nsentences, and loss
200
+ # is per-sentence loss; else sample_size is ntokens, and the loss
201
+ # becomes per-output token loss
202
+ "speech_loss": speech_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
203
+ "speech_nll_loss": speech_nll_loss_sum / sample_size_speech / math.log(2) if sample_size_speech > 0 else 0.0,
204
+ "speech_attn_loss": speech_attn_loss_sum / src_token_sum / math.log(2) if src_token_sum > 0 else 0.0,
205
+ "ntokens": ntokens,
206
+ "nsentences": nsentences,
207
+ "nframes": nframes,
208
+ "sample_size": sample_size,
209
+ "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
210
+ "correct": correct_sum,
211
+ "total": total_sum,
212
+ "src_token_num": src_token_sum,
213
+ # total is the number of validate tokens
214
+ }
215
+ return agg_output
216
+
217
+ @classmethod
218
+ def reduce_metrics(cls, logging_outputs):
219
+ """Aggregate logging outputs from data parallel training."""
220
+ agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
221
+ for k, v in agg_logging_outputs.items():
222
+ if k in {'nsentences', 'ntokens', 'sample_size'}:
223
+ continue
224
+ metrics.log_scalar(k, v, round=3)
fairseq/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ import math
8
+ import re
9
+
10
+ import torch
11
+
12
+ from fairseq.data import data_utils
13
+ from fairseq.data.language_pair_dataset import LanguagePairDataset
14
+
15
+
16
+ # Part of the code is modified from DenoisingDataset
17
+ # compared with DenoisingDataset, no permute_sentences or documents (rotate_ratio, permute_sentence_ratio)
18
+ class LanguagePairDenoisingDataset(LanguagePairDataset):
19
+ def __init__(
20
+ self,
21
+ src,
22
+ src_sizes,
23
+ src_dict,
24
+ tgt,
25
+ tgt_sizes,
26
+ tgt_dict,
27
+ mask_idx,
28
+ mask_whole_words,
29
+ seed,
30
+ args,
31
+ left_pad_source=True,
32
+ left_pad_target=False,
33
+ shuffle=True,
34
+ input_feeding=True,
35
+ remove_eos_from_source=False,
36
+ append_eos_to_target=False,
37
+ align_dataset=None,
38
+ constraints=None,
39
+ append_bos=False,
40
+ eos=None,
41
+ num_buckets=0,
42
+ src_lang_id=None,
43
+ tgt_lang_id=None,
44
+ pad_to_multiple=1,
45
+ ):
46
+ super().__init__(
47
+ src,
48
+ src_sizes,
49
+ src_dict,
50
+ tgt,
51
+ tgt_sizes,
52
+ tgt_dict,
53
+ left_pad_source,
54
+ left_pad_target,
55
+ shuffle,
56
+ input_feeding,
57
+ remove_eos_from_source,
58
+ append_eos_to_target,
59
+ align_dataset,
60
+ constraints,
61
+ append_bos,
62
+ eos,
63
+ num_buckets,
64
+ src_lang_id,
65
+ tgt_lang_id,
66
+ pad_to_multiple,
67
+ )
68
+
69
+ self.mask_idx = mask_idx
70
+ self.mask_whole_word = mask_whole_words
71
+ self.mask_ratio = args.mask
72
+ self.random_ratio = args.mask_random
73
+ self.insert_ratio = args.insert
74
+
75
+ self.replace_length = args.replace_length
76
+
77
+ if self.replace_length not in [-1, 0, 1]:
78
+ raise ValueError(f"invalid arg: replace_length={self.replace_length}")
79
+ if args.mask_length not in ["subword", "word", "span-poisson"]:
80
+ raise ValueError(f"invalid arg: mask-length={args.mask_length}")
81
+ if args.mask_length == "subword" and args.replace_length not in [0, 1]:
82
+ raise ValueError("if using subwords, use replace-length=1 or 0")
83
+
84
+ self.mask_span_distribution = None
85
+ if args.mask_length == "span-poisson":
86
+ # Text infilling: "A number of text spans are sampled, with span lengths drawn from a Poisson distribution (λ = 3). Each span is replaced with a single [MASK] token. 0-length spans correspond to the insertion of [MASK] tokens."
87
+ _lambda = args.poisson_lambda
88
+
89
+ lambda_to_the_k = 1
90
+ e_to_the_minus_lambda = math.exp(-_lambda)
91
+ k_factorial = 1
92
+ ps = []
93
+ for k in range(0, 128):
94
+ ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
95
+ lambda_to_the_k *= _lambda
96
+ k_factorial *= k + 1
97
+ if ps[-1] < 0.0000001:
98
+ break
99
+ ps = torch.FloatTensor(ps)
100
+ self.mask_span_distribution = torch.distributions.Categorical(ps)
101
+
102
+ self.epoch = 0
103
+ self.seed = seed
104
+
105
+ def _is_phoneme(x):
106
+ if re.search("<lang:", x) or x in (
107
+ "<mask>",
108
+ "<sil>",
109
+ "<pad>",
110
+ "<s>",
111
+ "</s>",
112
+ "<unk>",
113
+ ):
114
+ return False
115
+ return True
116
+
117
+ self.voc_valid_ids = torch.LongTensor(
118
+ [i for i, x in enumerate(self.src_dict.symbols) if _is_phoneme(x)]
119
+ )
120
+ self.voc_valid_size = self.voc_valid_ids.size(0)
121
+
122
+ @property
123
+ def can_reuse_epoch_itr_across_epochs(self):
124
+ return False
125
+
126
+ def set_epoch(self, epoch, **unused):
127
+ self.epoch = epoch
128
+
129
+ def __getitem__(self, index):
130
+ tgt_item = self.tgt[index] if self.tgt is not None else None
131
+ src_item = copy.deepcopy(self.src[index])
132
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
133
+ source = src_item
134
+ assert source[-1] == self.eos
135
+ if self.mask_ratio > 0:
136
+ source = self.add_whole_word_mask(source, self.mask_ratio)
137
+
138
+ if self.insert_ratio > 0:
139
+ source = self.add_insertion_noise(source, self.insert_ratio)
140
+ src_item = source
141
+
142
+ if self.append_eos_to_target:
143
+ eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
144
+ if self.tgt and self.tgt[index][-1] != eos:
145
+ tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
146
+
147
+ if self.append_bos:
148
+ bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
149
+ if self.tgt and self.tgt[index][0] != bos:
150
+ tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
151
+
152
+ bos = self.src_dict.bos()
153
+ if src_item[0] != bos:
154
+ src_item = torch.cat([torch.LongTensor([bos]), src_item])
155
+
156
+ if self.remove_eos_from_source:
157
+ eos = self.src_dict.eos()
158
+ if src_item[-1] == eos:
159
+ src_item = src_item[:-1]
160
+
161
+ example = {
162
+ "id": index,
163
+ "source": src_item,
164
+ "target": tgt_item,
165
+ }
166
+ if self.align_dataset is not None:
167
+ example["alignment"] = self.align_dataset[index]
168
+ if self.constraints is not None:
169
+ example["constraints"] = self.constraints[index]
170
+ if self.src_lang_id is not None:
171
+ example["src_lang_id"] = self.src_lang_id
172
+ if self.tgt_lang_id is not None:
173
+ example["tgt_lang_id"] = self.tgt_lang_id
174
+ return example
175
+
176
+ # following functions are borrowed from denoising_dataset
177
+ def word_starts(self, source):
178
+ if self.mask_whole_word is not None:
179
+ is_word_start = self.mask_whole_word.gather(0, source)
180
+ else:
181
+ is_word_start = torch.ones(source.size())
182
+ is_word_start[0] = 0
183
+ is_word_start[-1] = 0
184
+ return is_word_start
185
+
186
+ def add_whole_word_mask(self, source, p):
187
+ is_word_start = self.word_starts(source)
188
+ num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
189
+ num_inserts = 0
190
+ if num_to_mask == 0:
191
+ return source
192
+
193
+ if self.mask_span_distribution is not None:
194
+ lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
195
+
196
+ # Make sure we have enough to mask
197
+ cum_length = torch.cumsum(lengths, 0)
198
+ while cum_length[-1] < num_to_mask:
199
+ lengths = torch.cat(
200
+ [
201
+ lengths,
202
+ self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
203
+ ],
204
+ dim=0,
205
+ )
206
+ cum_length = torch.cumsum(lengths, 0)
207
+
208
+ # Trim to masking budget
209
+ i = 0
210
+ while cum_length[i] < num_to_mask:
211
+ i += 1
212
+ lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
213
+ num_to_mask = i + 1
214
+ lengths = lengths[:num_to_mask]
215
+
216
+ # Handle 0-length mask (inserts) separately
217
+ lengths = lengths[lengths > 0]
218
+ num_inserts = num_to_mask - lengths.size(0)
219
+ num_to_mask -= num_inserts
220
+ if num_to_mask == 0:
221
+ return self.add_insertion_noise(source, num_inserts / source.size(0))
222
+
223
+ assert (lengths > 0).all()
224
+ else:
225
+ lengths = torch.ones((num_to_mask,)).long()
226
+ assert is_word_start[-1] == 0
227
+ word_starts = is_word_start.nonzero(as_tuple=False)
228
+ indices = word_starts[
229
+ torch.randperm(word_starts.size(0))[:num_to_mask]
230
+ ].squeeze(1)
231
+ mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
232
+
233
+ source_length = source.size(0)
234
+ assert source_length - 1 not in indices
235
+ to_keep = torch.ones(source_length, dtype=torch.bool)
236
+ is_word_start[
237
+ -1
238
+ ] = 255 # acts as a long length, so spans don't go over the end of doc
239
+ if self.replace_length == 0:
240
+ to_keep[indices] = 0
241
+ else:
242
+ # keep index, but replace it with [MASK]
243
+ source[indices] = self.mask_idx
244
+ source[indices[mask_random]] = self.voc_valid_ids[
245
+ torch.randint(0, self.voc_valid_size - 1, size=(mask_random.sum(),))
246
+ ]
247
+
248
+ if self.mask_span_distribution is not None:
249
+ assert len(lengths.size()) == 1
250
+ assert lengths.size() == indices.size()
251
+ lengths -= 1
252
+ while indices.size(0) > 0:
253
+ assert lengths.size() == indices.size()
254
+ lengths -= is_word_start[indices + 1].long()
255
+ uncompleted = lengths >= 0
256
+ indices = indices[uncompleted] + 1
257
+ mask_random = mask_random[uncompleted]
258
+ lengths = lengths[uncompleted]
259
+ if self.replace_length != -1:
260
+ # delete token
261
+ to_keep[indices] = 0
262
+ else:
263
+ # keep index, but replace it with [MASK]
264
+ source[indices] = self.mask_idx
265
+ source[indices[mask_random]] = self.voc_valid_ids[
266
+ torch.randint(
267
+ 0, self.voc_valid_size - 1, size=(mask_random.sum(),)
268
+ )
269
+ ]
270
+ else:
271
+ # A bit faster when all lengths are 1
272
+ while indices.size(0) > 0:
273
+ uncompleted = is_word_start[indices + 1] == 0
274
+ indices = indices[uncompleted] + 1
275
+ mask_random = mask_random[uncompleted]
276
+ if self.replace_length != -1:
277
+ # delete token
278
+ to_keep[indices] = 0
279
+ else:
280
+ # keep index, but replace it with [MASK]
281
+ source[indices] = self.mask_idx
282
+ source[indices[mask_random]] = self.voc_valid_ids[
283
+ torch.randint(
284
+ 0, self.voc_valid_size - 1, size=(mask_random.sum(),)
285
+ )
286
+ ]
287
+
288
+ assert source_length - 1 not in indices
289
+
290
+ source = source[to_keep]
291
+
292
+ if num_inserts > 0:
293
+ source = self.add_insertion_noise(source, num_inserts / source.size(0))
294
+
295
+ return source
296
+
297
+ def add_insertion_noise(self, tokens, p):
298
+ if p == 0.0:
299
+ return tokens
300
+
301
+ num_tokens = len(tokens)
302
+ n = int(math.ceil(num_tokens * p))
303
+
304
+ noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
305
+ noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
306
+ noise_mask[noise_indices] = 1
307
+ result = torch.LongTensor(n + len(tokens)).fill_(-1)
308
+
309
+ num_random = int(math.ceil(n * self.random_ratio))
310
+ result[noise_indices[num_random:]] = self.mask_idx
311
+ result[noise_indices[:num_random]] = self.voc_valid_ids[
312
+ torch.randint(0, self.voc_valid_size - 1, size=(num_random,))
313
+ ]
314
+
315
+ result[~noise_mask] = tokens
316
+
317
+ assert (result >= 0).all()
318
+ return result
fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[Back]](..)
2
+
3
+ # Joint Speech Text Training for the MuST-C English to German Speech Translation task
4
+
5
+ Joint Training Baseline: it is based on paper ["A general multi-task learning framework to leverage text data for speech to text tasks"](https://arxiv.org/pdf/2010.11338.pdf)
6
+
7
+ Enhanced Joint Training: the joint training is enhanced with pre-trained models, cross attentive regularization and online knowledge distillation based on paper ["Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task"](https://research.fb.com/publications/improving-speech-translation-by-understanding-and-learning-from-the-auxiliary-text-translation-task)
8
+
9
+ ## Prepare Data
10
+ #### Download files
11
+ - Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/spm.model)
12
+ - Dictionary [dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/dict.txt)
13
+ - config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/config.yaml)
14
+ #### Prepare MuST-C data set
15
+ - Please follow the data preparation in the [S2T example](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mustc_example.md)
16
+ - Convert source text under the "src_text" column in the tsv file into phoneme representation.
17
+ ```bash
18
+ python examples/speech_text_joint_to_text/scripts/g2p_encode.py \
19
+ --lower-case --do-filter --use-word-start --no-punc \
20
+ --reserve-word examples/speech_text_joint_to_text/configs/mustc_noise.list \
21
+ --data-path ${must_c_en_de_src_text} \
22
+ --out-path ${must_c_en_de_src_text_pho}
23
+ ```
24
+ - Replace the source text under the "src_text" column in the tsv file with the corresponding phoneme reprentation generated in the step above.
25
+ Below is the snapshot for the MuST-C en-de dev tsv
26
+ ```
27
+ id audio n_frames tgt_text src_text speaker
28
+ ted_767_0 en-de/flac.zip:10071514743:48445 56160 Heute spreche ich zu Ihnen über Energie und Klima. ▁AY1 M ▁G OW1 IH0 NG ▁T UW1 ▁T AO1 K ▁T AH0 D EY1 ▁AH0 B AW1 T ▁EH1 N ER0 JH IY0 ▁AH0 N D ▁K L AY1 M AH0 T spk.767_
29
+ ted_767_1 en-de/flac.zip:1214217978:205678 226080 Und das überrascht vielleicht etwas, weil sich meine Vollzeitbeschäftigung bei der Stiftung hauptsächlich um Impfstoffe und Saatgut dreht, um die Dinge, die wir erfinden und liefern müssen um den ärmsten 2 Milliarden ein besseres Leben zu ermöglichen. ▁AH0 N D ▁DH AE1 T ▁M AY1 T ▁S IY1 M ▁AH0 ▁B IH1 T ▁S ER0 P R AY1 Z IH0 NG ▁B IH0 K AO1 Z ▁M AY1 ▁F UH1 L ▁T AY1 M ▁W ER1 K ▁AE1 T ▁DH AH0 ▁F AW0 N D EY1 SH AH0 N ▁IH1 Z ▁M OW1 S T L IY0 ▁AH0 B AW1 T ▁V AE2 K S IY1 N Z ▁AH0 N D ▁S IY1 D Z ▁AH0 B AW1 T ▁DH AH0 ▁TH IH1 NG Z ▁DH AE1 T ▁W IY1 ▁N IY1 D ▁T UW1 ▁IH0 N V EH1 N T ▁AH0 N D ▁D IH0 L IH1 V ER0 ▁T UW1 ▁HH EH1 L P ▁DH AH0 ▁P UH1 R IH0 S T ▁T UW1 ▁B IH1 L Y AH0 N ▁L AY1 V ▁B EH1 T ER0 ▁L IH1 V Z spk.767_
30
+ ```
31
+ - Prepare phoneme dictionary and save to $MANIFEST_ROOT as [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/src_dict.txt)
32
+ #### Prepare WMT text data
33
+ - [Download wmt data](https://github.com/pytorch/fairseq/blob/main/examples/translation/prepare-wmt14en2de.sh)
34
+ - Convert source text (English) into phoneme representation as above
35
+ - Generate binary parallel files with "fairseq-preprocess" from fairseq for training and validation. The source input is English phoneme representation and the target input is German sentencepiece token . The output is saved under $parallel_text_data
36
+
37
+ ## Training
38
+ The model is trained with 8 v100 GPUs.
39
+
40
+ #### Download pretrained models
41
+ - [pretrain_encoder](https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_transformer_m.pt)
42
+ - [pretrain_nmt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_mt.pt)
43
+
44
+ #### Training scripts
45
+ - Jointly trained model from scratch
46
+ ```bash
47
+ python train.py ${MANIFEST_ROOT} \
48
+ --save-dir ${save_dir} \
49
+ --num-workers 8 \
50
+ --task speech_text_joint_to_text \
51
+ --arch dualinputs2ttransformer_s \
52
+ --user-dir examples/speech_text_joint_to_text \
53
+ --max-epoch 100 --update-mix-data \
54
+ --optimizer adam --lr-scheduler inverse_sqrt \
55
+ --lr 0.001 --update-freq 4 --clip-norm 10.0 \
56
+ --criterion guided_label_smoothed_cross_entropy_with_accuracy \
57
+ --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
58
+ --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
59
+ --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
60
+ --dropout 0.1 --warmup-updates 20000 \
61
+ --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
62
+ --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
63
+ --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
64
+ --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
65
+ --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
66
+ --keep-last-epochs 10
67
+ ```
68
+ - Jointly trained model with good initialization, cross attentive loss and online knowledge distillation
69
+ ```bash
70
+ python train.py ${MANIFEST_ROOT} \
71
+ --save-dir ${save_dir} \
72
+ --num-workers 8 \
73
+ --task speech_text_joint_to_text \
74
+ --arch dualinputs2ttransformer_m \
75
+ --user-dir examples/speech_text_joint_to_text \
76
+ --max-epoch 100 --update-mix-data \
77
+ --optimizer adam --lr-scheduler inverse_sqrt \
78
+ --lr 0.002 --update-freq 4 --clip-norm 10.0 \
79
+ --criterion guided_label_smoothed_cross_entropy_with_accuracy \
80
+ --guide-alpha 0.8 --disable-text-guide-update-num 5000 \
81
+ --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
82
+ --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
83
+ --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
84
+ --dropout 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \
85
+ --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
86
+ --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
87
+ --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
88
+ --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
89
+ --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
90
+ --load-pretrain-speech-encoder ${pretrain_encoder} \
91
+ --load-pretrain-decoder ${pretrain_nmt} \
92
+ --load-pretrain-text-encoder-last ${pretrain_nmt} \
93
+ --keep-last-epochs 10
94
+ ```
95
+
96
+ ## Evaluation
97
+ ```bash
98
+ python ./fairseq_cli/generate.py \
99
+ ${MANIFEST_ROOT} \
100
+ --task speech_text_joint_to_text \
101
+ --max-tokens 25000 \
102
+ --nbest 1 \
103
+ --results-path ${infer_results} \
104
+ --batch-size 512 \
105
+ --path ${model} \
106
+ --gen-subset tst-COMMON_st \
107
+ --config-yaml config.yaml \
108
+ --scoring sacrebleu \
109
+ --beam 5 --lenpen 1.0 \
110
+ --user-dir examples/speech_text_joint_to_text \
111
+ --load-speech-only
112
+ ```
113
+
114
+ ## Results (Joint training with initialization + CAR + online KD)
115
+ |Direction|En-De | En-Es | En-Fr |
116
+ |---|---|---|---|
117
+ |BLEU|27.4| 31.2 | 37.6 |
118
+ |checkpoint | [link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/checkpoint_ave_10.pt) |[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_es/checkpoint_ave_10.pt)|[link](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_fr/checkpoint_ave_10.pt)|
fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[Back]](..)
2
+
3
+ # Joint Speech Text Training for the 2021 IWSLT multilingual speech translation
4
+
5
+ This directory contains the code from paper ["FST: the FAIR Speech Translation System for the IWSLT21 Multilingual Shared Task"](https://arxiv.org/pdf/2107.06959.pdf).
6
+
7
+ ## Prepare Data
8
+ #### Download files
9
+ - Sentence piece model [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/spm.model)
10
+ - Dictionary [tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/dict.txt)
11
+ - Config [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/config.yaml)
12
+
13
+ #### Prepare
14
+ - Please follow the data preparation in [speech-to-text](https://github.com/pytorch/fairseq/blob/main/examples/speech_to_text/docs/mtedx_example.md) with option "--use-audio-input" for raw audio tsv files.
15
+ - Prepare tsv files with phoneme based source text (under column 'src_text') as [MuST-C](ende-mustc.md) example.
16
+
17
+
18
+ ## Training
19
+
20
+ #### Download pretrained models
21
+ - [Pretrained mbart model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/mbart.pt)
22
+ - [Pretrained w2v model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/xlsr_53_56k.pt)
23
+
24
+
25
+ #### Training scripts
26
+
27
+ ```bash
28
+ python train.py ${MANIFEST_ROOT} \
29
+ --save-dir ${save_dir} \
30
+ --user-dir examples/speech_text_joint_to_text \
31
+ --train-subset train_es_en_tedx,train_es_es_tedx,train_fr_en_tedx,train_fr_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_en_tedx,train_pt_pt_tedx \
32
+ --valid-subset valid_es_en_tedx,valid_es_es_tedx,valid_es_fr_tedx,valid_es_it_tedx,valid_es_pt_tedx,valid_fr_en_tedx,valid_fr_es_tedx,valid_fr_fr_tedx,valid_fr_pt_tedx,valid_it_en_tedx,valid_it_es_tedx,valid_it_it_tedx,valid_pt_en_tedx,valid_pt_es_tedx,valid_pt_pt_tedx \
33
+ --config-yaml config.yaml --ddp-backend no_c10d \
34
+ --num-workers 2 --task speech_text_joint_to_text \
35
+ --criterion guided_label_smoothed_cross_entropy_with_accuracy \
36
+ --label-smoothing 0.3 --guide-alpha 0.8 \
37
+ --disable-text-guide-update-num 5000 --arch dualinputxmtransformer_base \
38
+ --max-tokens 500000 --max-sentences 3 --max-tokens-valid 800000 \
39
+ --max-source-positions 800000 --enc-grad-mult 2.0 \
40
+ --attentive-cost-regularization 0.02 --optimizer adam \
41
+ --clip-norm 1.0 --log-format simple --log-interval 200 \
42
+ --keep-last-epochs 5 --seed 1 \
43
+ --w2v-path ${w2v_path} \
44
+ --load-pretrained-mbart-from ${mbart_path} \
45
+ --max-update 1000000 --update-freq 4 \
46
+ --skip-invalid-size-inputs-valid-test \
47
+ --skip-encoder-projection --save-interval 1 \
48
+ --attention-dropout 0.3 --mbart-dropout 0.3 \
49
+ --finetune-w2v-params all --finetune-mbart-decoder-params all \
50
+ --finetune-mbart-encoder-params all --stack-w2v-mbart-encoder \
51
+ --drop-w2v-layers 12 --normalize \
52
+ --lr 5e-05 --lr-scheduler inverse_sqrt --warmup-updates 5000
53
+ ```
54
+
55
+ ## Evaluation
56
+ ```bash
57
+ python ./fairseq_cli/generate.py
58
+ ${MANIFEST_ROOT} \
59
+ --task speech_text_joint_to_text \
60
+ --user-dir ./examples/speech_text_joint_to_text \
61
+ --load-speech-only --gen-subset test_es_en_tedx \
62
+ --path ${model} \
63
+ --max-source-positions 800000 \
64
+ --skip-invalid-size-inputs-valid-test \
65
+ --config-yaml config.yaml \
66
+ --infer-target-lang en \
67
+ --max-tokens 800000 \
68
+ --beam 5 \
69
+ --results-path ${RESULTS_DIR} \
70
+ --scoring sacrebleu
71
+ ```
72
+ The trained model can be downloaded [here](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/iwslt/iwslt_data/checkpoint17.pt)
73
+
74
+ |direction|es_en|fr_en|pt_en|it_en|fr_es|pt_es|it_es|es_es|fr_fr|pt_pt|it_it|
75
+ |---|---|---|---|---|---|---|---|---|---|---|---|
76
+ |BLEU|31.62|36.93|35.07|27.12|38.87|35.57|34.13|74.59|74.64|70.84|69.76|
fairseq/examples/speech_text_joint_to_text/docs/pre-training.md ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[Back]](..)
2
+
3
+ # Unified Speech-Text Pre-training for Speech Translation and Recognition
4
+
5
+ This directory contains the pre-training recipes from paper ["Unified Speech-Text Pre-training for Speech Translation and Recognition"](https://arxiv.org/abs/2204.05409).
6
+
7
+ ## Librispeech ASR Pre-training
8
+ ### Prepare Data
9
+ #### Download files
10
+ #### Prepare pre-training data
11
+ - Text to text task (T2T): prepare the binary data following the similar steps in [EN_DE Joint training](./ende-mustc.md). The source data is presented as phomeme token sequence and the target data is coded as subword tokens via SentencePiece. The text data is downloaded from [openslr](https://www.openslr.org/12)
12
+ - Self-supervised speech learning task (SSL): The data is prepared as [wav2vec 2.0](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec/README.md)
13
+ - Speech to phoneme classification task (S2P): The tsv file contains 5 fields: "id", "audio", "n_frames", "tgt_text", and "align". The tgt_text field is corresponding to the phoneme based representation of the speech data. "align" field contains the alignment information. The phoneme level forced alignment for the labelled speech data (i.e. Librispeech) can be obtained via [kaldi](http://kaldi-asr.org) or [MFA](https://montrealcorpustools.github.io/Montreal-Forced-Aligner/). The segmentation information is normalized to 0$\sim$1 for the whole utterance. The snapshot of the tsv file is below:
14
+ ```
15
+ id audio n_frames tgt_text align
16
+ 116-288045-0000 /librispeech/dev-other/116/288045/116-288045-0000.flac 170400 <sil> ▁AE1 Z AY1 ▁AH0 P R OW1 CH T ▁DH AH1 ▁S IH1 T IY0 <sil> AY1 ▁HH ER1 D ▁B EH1 L Z ▁R IH1 NG IH0 NG <sil> ▁AE1 N D AH0 ▁L IH1 T AH0 L ▁L EY1 T ER0 AY1 ▁F AW1 N D ▁DH AH0 ▁S T R IY1 T S ▁AH0 S T IH1 R ▁W IH0 TH ▁TH R AO1 NG Z ▁AH0 V ▁W EH1 L ▁D R EH1 S T ▁P IY1 P AH0 L ▁IH1 N ▁F AE1 M L IY0 ▁G R UW1 P S <sil> ▁W EH1 N D IH0 NG ▁DH EH1 R ▁W EY1 <sil> ▁HH IH1 DH ER0 ▁AH0 N D ▁TH IH1 DH ER0 <sil> 0.047977 0.056444 0.064911 0.075259 0.081844 0.089370 0.095014 0.104421 0.109125 0.111947 0.115710 0.120414 0.134525 0.141110 0.143932 0.174036 0.176858 0.190028 0.199436 0.207902 0.218250 0.224835 0.231421 0.242709 0.251176 0.257761 0.263405 0.268109 0.270931 0.290687 0.342427 0.349953 0.353716 0.356538 0.360301 0.363123 0.365945 0.368768 0.371590 0.376294 0.384760 0.394167 0.401693 0.409219 0.419567 0.430856 0.441204 0.444026 0.446849 0.449671 0.456256 0.463782 0.471308 0.477893 0.486359 0.491063 0.494826 0.501411 0.512700 0.517404 0.520226 0.534337 0.540922 0.545626 0.550329 0.559737 0.568203 0.583255 0.592662 0.600188 0.603951 0.611477 0.619003 0.624647 0.634055 0.639699 0.646284 0.653810 0.659454 0.664158 0.670743 0.682032 0.687676 0.692380 0.708373 0.713076 0.719661 0.729069 0.740357 0.744120 0.748824 0.752587 0.761994 0.770461 0.781750 0.790216 0.805268 0.808090 0.823142 0.832549 0.836312 0.840075 0.843838 0.851364 0.854186 0.857008 0.862653 0.878645 0.898401 0.901223 0.906867 0.913452 0.920038 0.926623 0.934149 0.939793 0.942615 0.945437 0.952023 0.957667 0.977422 1.000000
17
+
18
+ ```
19
+ - Speech to text task (S2T): The data preparation follow the steps in [EN_DE Joint training](./ende-mustc.md).
20
+
21
+ #### Prepare fine-tuning data:
22
+ We re-use the data from T2T and S2T tasks in the fine-tuning stage.
23
+
24
+ ### Model Build
25
+ #### Pre-training
26
+ ```
27
+ python train.py $T2T_DATA \
28
+ --save-dir $SAVE_PRE_PATH --user-dir examples/speech_text_joint_to_text --task speech_text_joint_denoising \
29
+ --criterion speech_text_pretrain_cross_entropy --optimizer adam --weight-decay 0.01 --config-yaml config_s2p.yaml --config-s2s-yaml config.yaml --ddp-backend no_c10d \
30
+ --lang-pairs pho-wrd --num-workers 4 --log-interval 500 --save-interval-updates 5000 --keep-interval-updates 1 --no-emb-update-unsup --report-accuracy --lr 0.001 --end-learning-rate 1e-06 \
31
+ --lr-scheduler polynomial_decay --warmup-updates 10000 --total-num-update 800000 --update-freq 6 --validate-interval-updates 10000 --train-subset train \
32
+ --valid-subset valid,valid_sup_speech,valid_sup_speech_s2s,valid_unsup_speech --dataset-impl mmap \
33
+ --sup-speech-data $S2P_DATA_PATH --sup-speech-train-subset train_960.ali --sup-speech-valid-subset dev-clean.ali --sup-speech-s2s-data $S2T_DATA_PATH \
34
+ --sup-speech-s2s-train-subset train --sup-speech-s2s-valid-subset dev-clean --unsup-speech-train-data $SSL_DATA_PATH/train.tsv --unsup-speech-valid-data $SSL_DATA_PATH/valid.tsv \
35
+ --batch-size 200 --batch-size-valid 150 --max-source-positions 1024 --max-target-positions 1024 --max-text-tokens 3072 --max-speech-positions 600000 \
36
+ --max-sample-size 750000 --min-sample-size 64000 --max-speech-tokens 750000 --max-tokens-valid 750000 --skip-invalid-size-inputs-valid-test \
37
+ --unsupervised-speech-sample-ratio 3.0 --supervised-speech-sample-ratio 5 --supervised-speech-s2s-sample-ratio 5 --text-sample-ratio 1.0 --mask 0.3 --mask-random 0.1 \
38
+ --mask-length span-poisson --speech-sup-mask-prob 0.3 --speech-unsup-mask-prob 0.7 --use-mask-whole-words --arch speech_text_pretrain_bart_base_stack \
39
+ --no-scale-feature --activation-fn gelu --speech-extractor-mode default --stacked-encoder all --encoder-normalize-before --decoder-normalize-before \
40
+ --encoder-learned-pos --decoder-learned-pos --dropout 0.1 --load-pretrained-mbart-encoder-from $BART --load-pretrained-mbart-decoder-from $BART
41
+ ```
42
+ The current implementation also supports model pre-training without the forced alignment supervised data. In this case, CTC is used to optimize the S2P task. We need to do following changes for the setting:
43
+ 1. options to be added
44
+ ```
45
+ --use-sup-speech-ctc --criterion speech_text_pretrain_compound
46
+ ```
47
+ 2. options to be deleted
48
+ ```
49
+ --same-data-update --criterion speech_text_pretrain_cross_entropy
50
+ ```
51
+ However, we find the CTC based pre-training is still worse than the forced alignment based setting. It could be partially due to the inferior pre-training setting that we re-use the forced alignment based pre-training setting for the CTC based pre-training.
52
+
53
+ #### Fine-tuning
54
+ ```
55
+ python train.py $S2T_DATA_PATH \
56
+ --save-dir $SAVE_FT_PATH --num-workers 8 --task speech_text_joint_to_text --arch dualinputs2twavtransformer_base_stack \
57
+ --user-dir examples/speech_text_joint_to_text --max-update 100000 --optimizer adam --lr-scheduler inverse_sqrt --lr 0.0003 --update-freq 3 --clip-norm 10.0 \
58
+ --criterion guided_label_smoothed_cross_entropy_with_accuracy --guide-alpha 0.8 --label-smoothing 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \
59
+ --enc-grad-mult 2.0 --max-tokens 800000 --max-source-positions 800000 --max-tokens-text 10000 --max-positions-text 1024 --max-target-positions 1024 --no-scale-feature \
60
+ --activation-fn gelu --load-pretrained-speech-text-encoder $SAVE_PRE_PATH/checkpoint_last.pt --load-pretrained-speech-text-decoder $SAVE_PRE_PATH/checkpoint_last.pt \
61
+ --encoder-normalize-before --decoder-normalize-before --speech-extractor-mode default --speech-mask-channel-length 64 --speech-mask-channel-prob 0.5 \
62
+ --speech-mask-length 10 --speech-mask-prob 0.65 --text-sample-ratio 0.25 --mask-text-ratio 0.3 --mask-text-type random --parallel-text-data text_bin \
63
+ --text-input-cost-ratio 0.5 --langpairs pho-wrd --update-mix-data --log-format json --max-tokens-valid 800000 --ddp-backend no_c10d --log-interval 500 \
64
+ --config-yaml config.yaml --skip-invalid-size-inputs-valid-test --keep-last-epochs 50 --layernorm-embedding --encoder-learned-pos --decoder-learned-pos
65
+ ```
66
+
67
+ ### Evaluation
68
+ The last 10 epoch models from fine-tuning is conducted model average to get $FINAL_MODEL
69
+ ```
70
+ python ./fairseq_cli/generate.py \
71
+ $S2T_DATA_PATH \
72
+ --task speech_text_joint_to_text \
73
+ --max-tokens 800000 \
74
+ --max-source-positions 800000 \
75
+ --nbest 1 \
76
+ --results-path $RESULTS_LOG \
77
+ --batch-size 512 \
78
+ --path $FINAL_MODEL \
79
+ --gen-subset $SUBSET \
80
+ --config-yaml config.yaml \
81
+ --scoring wer \
82
+ --beam 10 --lenpen 1.0 examples/speech_text_joint_to_text \
83
+ --user-dir examples/speech_text_joint_to_text --load-speech-only \
84
+ --model-overrides {'load_pretrained_speech_text_decoder':'','load_pretrained_speech_text_encoder':''}
85
+ ```
86
+
87
+ ### Results and models
88
+ | | dev-clean | dev-other | test-clean | test-other |
89
+ |---|---|---|---|---|
90
+ | WER| 2.0 | 4.4 | 2.1 |4.6 |
91
+
92
+ **Model Links**:
93
+ - [config_s2p.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/config_s2p.yaml): Config for S2P
94
+ - [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/spm.model): Sentence Piece model
95
+ - [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/src_dict.txt): Source Phoneme Dictionary
96
+ - [tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/tgt_dict.txt): Target Sentence Piece Dictionary
97
+ - [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/config.yaml): Config for S2T
98
+ - [BART](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/bart.pt): trained from Librispeech text data
99
+ - [Joint Pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/checkpoint6.pt): model pre-trained with 960 hours Librispeech data (S2P, S2T) Librispeech text training data (T2T) and Librilight data (SSL)
100
+ - [Fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/checkpoint_ave_10.pt): the pre-trained model is fined one 960 hours Librispeech speech and text data. (S2T + T2T)
101
+
102
+ ## MuST-C
103
+ ### Prepare Data
104
+ Compared with the ASR Librispeech ASR recipe, the differences are below:
105
+ - Replace the speech data with corresponding MuST-C data
106
+ - Parallel text data from WMT is replaced the Librispeech text data
107
+
108
+ ### Model Build
109
+ #### Pre-training
110
+ EN-DE is used as an example
111
+ ```
112
+ python train.py $TXT_DATA \
113
+ --save-dir $SAVE_PRE_PATH --user-dir examples/speech_text_joint_to_text --task speech_text_joint_denoising --criterion speech_text_pretrain_cross_entropy --optimizer adam --weight-decay 0.01 \
114
+ --config-yaml config_s2p.yaml --config-s2s-yaml config.yaml --ddp-backend no_c10d --lang-pairs-bitext en-fr --num-workers 4 --log-interval 500 --save-interval-updates 5000 --keep-interval-updates 1 \
115
+ --no-emb-update-unsup --use-decoder-output-proj --report-accuracy --lr 0.001 --end-learning-rate 1e-06 --lr-scheduler polynomial_decay --warmup-updates 10000 --total-num-update 800000 \
116
+ --update-freq 8 --validate-interval-updates 10000 --train-subset train --valid-subset valid_sup_speech,valid_sup_speech_s2s,valid_unsup_speech --dataset-impl mmap \
117
+ --sup-speech-data $S2P_DATA_PATH --sup-speech-train-subset train --sup-speech-valid-subset dev --sup-speech-s2s-data $S2T_DATA_PATH --sup-speech-s2s-train-subset train \
118
+ --sup-speech-s2s-valid-subset dev --unsup-speech-train-data $SSL_DATA_PATH/train.tsv --unsup-speech-valid-data $SSL_DATA_PATH/valid.tsv --batch-size 200 --batch-size-valid 100 \
119
+ --max-source-positions 1024 --max-target-positions 1024 --max-text-tokens 2048 --max-speech-positions 600000 --max-sample-size 600000 --min-sample-size 64000 \
120
+ --max-speech-tokens 600000 --max-tokens-valid 600000 --skip-invalid-size-inputs-valid-test --unsupervised-speech-sample-ratio 1.2 --supervised-speech-sample-ratio 10 \
121
+ --supervised-speech-s2s-sample-ratio 10 --bitext-sample-ratio 0.5 --mask 0.3 --mask-random 0.1 --mask-length span-poisson --speech-sup-mask-prob 0.3 \
122
+ --speech-unsup-mask-prob 0.7 --use-mask-whole-words --arch speech_text_pretrain_bart_base_stack --no-scale-feature --activation-fn gelu --speech-extractor-mode default \
123
+ --stacked-encoder s2s --encoder-normalize-before --decoder-normalize-before --encoder-learned-pos --decoder-learned-pos --dropout 0.1 \
124
+ --load-pretrained-mbart-encoder-from $EN_FR_NMT --load-pretrained-mbart-decoder-from $EN_FR_NMT
125
+ ```
126
+ #### Fine-tuning
127
+ ```
128
+ python train.py $S2T_DATA_PATH \
129
+ --save-dir $SAVE_FT_PATH --num-workers 8 --task speech_text_joint_to_text --arch dualinputs2twavtransformer_base_stack --user-dir examples/speech_text_joint_to_text \
130
+ --max-epoch 25 --update-mix-data --optimizer adam --lr-scheduler inverse_sqrt --lr 0.0003 --update-freq 4 --clip-norm 10.0 --warmup-updates 20000 \
131
+ --criterion guided_label_smoothed_cross_entropy_with_accuracy --guide-alpha 0.8 --attentive-cost-regularization 0.02 --enc-grad-mult 2.0 --label-smoothing 0.1 \
132
+ --max-tokens 800000 --max-source-positions 800000 --max-tokens-text 10000 --max-positions-text 1024 --load-pretrained-speech-text-encoder $SAVE_PRE_PATH/checkpoint_last.pt \
133
+ --load-pretrained-speech-text-decoder $SAVE_PRE_PATH/checkpoint_last.pt --speech-mask-channel-length 64 --speech-mask-channel-prob 0.5 --speech-mask-length 10 \
134
+ --speech-mask-prob 0.65 --text-sample-ratio 0.05 --mask-text-ratio 0.3 --mask-text-type random --parallel-text-data data-bin-wt --text-input-cost-ratio 0.5 \
135
+ --langpairs en-fr --log-format json --max-tokens-valid 800000 --ddp-backend no_c10d --log-interval 100 --config-yaml config.yaml --skip-invalid-size-inputs-valid-test \
136
+ --noise-token '▁NOISE' --keep-last-epochs 40 --layernorm-embedding --encoder-learned-pos --decoder-learned-pos --activation-fn gelu \
137
+ --speech-extractor-mode default --max-target-positions 1024 --encoder-normalize-before --decoder-normalize-before
138
+ ```
139
+
140
+ ### Evaluation
141
+ The last 10 epoch models from fine-tuning is conducted model average to get $FINAL_MODEL
142
+ ```
143
+ python fairseq_cli/generate.py \
144
+ $S2T_DATA_PATH \
145
+ --task speech_text_joint_to_text \
146
+ --nbest 1 \
147
+ --max-tokens 800000 \
148
+ --max-source-positions 800000 \
149
+ --results-path $RESULTS_LOG \
150
+ --batch-size 512 \
151
+ --path $FINAL_MODEL \
152
+ --gen-subset $SUBSET \
153
+ --config-yaml config.yaml \
154
+ --scoring sacrebleu \
155
+ --beam 10 --lenpen 1.0 examples/speech_text_joint_to_text \
156
+ --user-dir examples/speech_text_joint_to_text --load-speech-only \
157
+ --model-overrides {'load_pretrained_speech_text_decoder':'','load_pretrained_speech_text_encoder':''}
158
+ ```
159
+
160
+
161
+ ### Results and models
162
+ | | en-fr | en-es | en-de |
163
+ |---|---|---|---|
164
+ | BLEU| 39.7 | 33.2 |29.2 |
165
+
166
+
167
+ **Model Links**:
168
+ 1. DE
169
+ - [de config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/config.yaml)
170
+ - [de src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/src_dict.txt)
171
+ - [de tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/tgt_dict.txt)
172
+ - [de spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/spm.model)
173
+ - [de pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/nmt.pt)
174
+ - [de pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/checkpoint_pretraing.pt)
175
+ - [de fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/checkpoint_finetune_ave10.pt)
176
+ 2. ES
177
+ - [es config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/config.yaml)
178
+ - [es src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/src_dict.txt)
179
+ - [es tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/tgt_dict.txt)
180
+ - [es spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/spm.model)
181
+ - [es pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/nmt.pt)
182
+ - [es pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/checkpoint_pretraing.pt)
183
+ - [es fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/checkpoint_finetune_ave10.pt)
184
+ 3. FR
185
+ - [fr config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/config.yaml)
186
+ - [fr src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/src_dict.txt)
187
+ - [fr tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/tgt_dict.txt)
188
+ - [fr spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/spm.model)
189
+ - [fr pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/nmt.pt)
190
+ - [fr pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/checkpoint_pretraing.pt)
191
+ - [fr fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/checkpoint_finetune_ave10.pt)
192
+ 4. [config_s2p.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/config_s2p.yaml)
fairseq/examples/speech_text_joint_to_text/models/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
fairseq/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import logging
4
+ from collections import OrderedDict, namedtuple
5
+ from typing import Dict, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch import Tensor
11
+
12
+ from fairseq import checkpoint_utils, utils
13
+ from fairseq.file_io import PathManager
14
+ from fairseq.models import (
15
+ FairseqDecoder,
16
+ FairseqEncoderDecoderModel,
17
+ register_model,
18
+ register_model_architecture,
19
+ )
20
+ from fairseq.models.speech_to_text import (
21
+ MultiInputDecoder,
22
+ MultiModalityEncoder,
23
+ SpeechWavTransformerEncoder,
24
+ StackedSpeechWavTransformerEncoder,
25
+ )
26
+ from fairseq.models.transformer import (
27
+ TransformerDecoder,
28
+ TransformerEncoder,
29
+ TransformerModel,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SpeechTextPreTrainEncoder(MultiModalityEncoder):
36
+ def __init__(
37
+ self,
38
+ dictionary,
39
+ sup_speech_encoder,
40
+ sup_s2s_speech_encoder,
41
+ unsup_speech_encoder,
42
+ text_encoder,
43
+ ):
44
+ super().__init__(dictionary)
45
+ self.sup_speech_encoder = sup_speech_encoder
46
+ self.sup_s2s_speech_encoder = sup_s2s_speech_encoder
47
+ self.unsup_speech_encoder = unsup_speech_encoder
48
+ self.text_encoder = text_encoder
49
+
50
+ @classmethod
51
+ def update_transformer_encoder_cfg(cls, args, update_dict):
52
+ cfg = dict(args._get_kwargs())
53
+ for fkey in update_dict.keys():
54
+ cfg[fkey] = update_dict[fkey]
55
+ cfg.pop("_name", None) # remove keys start with _
56
+ model_args = namedtuple("args", cfg.keys())(*cfg.values())
57
+ return model_args
58
+
59
+ @classmethod
60
+ def build_text_encoder(cls, args, src_dictionary):
61
+ enc_emb = nn.Embedding(
62
+ len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad()
63
+ )
64
+ model_args = cls.update_transformer_encoder_cfg(
65
+ args, {"encoder_layers": args.text_encoder_layers}
66
+ )
67
+ text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
68
+ return text_encoder
69
+
70
+ @classmethod
71
+ def build_speech_encoder(cls, args):
72
+ model_args = cls.update_transformer_encoder_cfg(
73
+ args,
74
+ {
75
+ "encoder_layers": args.speech_encoder_layers,
76
+ "speech_mask_prob": args.speech_sup_mask_prob,
77
+ },
78
+ )
79
+ speech_encoder = SpeechWavTransformerEncoder(model_args)
80
+ return speech_encoder
81
+
82
+ @classmethod
83
+ def share_layers(cls, src_layers, tgt_layers): # share layer but not dropout
84
+ # share parameters in src_layers with tgt_layers
85
+ assert len(src_layers) == len(tgt_layers)
86
+ for i, ly in enumerate(src_layers):
87
+ tly = tgt_layers[i]
88
+ tly.self_attn = ly.self_attn
89
+ tly.self_attn_layer_norm = ly.self_attn_layer_norm
90
+ tly.activation_fn = ly.activation_fn
91
+ tly.normalize_before = ly.normalize_before
92
+ tly.fc1 = ly.fc1
93
+ tly.fc2 = ly.fc2
94
+ tly.final_layer_norm = ly.final_layer_norm
95
+ if hasattr(tly, "encoder_attn"):
96
+ tly.encoder_attn = ly.encoder_attn
97
+ tly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm
98
+ return tgt_layers
99
+
100
+ @classmethod
101
+ def build_unsup_speech_encoder(cls, args, sup_speech_encoder):
102
+ model_args = cls.update_transformer_encoder_cfg(
103
+ args,
104
+ {
105
+ "encoder_layers": args.speech_encoder_layers,
106
+ "speech_mask_prob": args.speech_unsup_mask_prob,
107
+ "encoder_layerdrop": 0.0,
108
+ "decoder_layerdrop": 0.0,
109
+ "dropout": args.speech_unsup_dropout,
110
+ "activation_dropout": args.speech_unsup_dropout,
111
+ "attention_dropout": 0.0,
112
+ "dropout_features": args.speech_unsup_feature_dropout,
113
+ "dropout_input": args.speech_unsup_feature_dropout,
114
+ },
115
+ )
116
+
117
+ unsup_speech_encoder = SpeechWavTransformerEncoder(model_args, alway_mask=True)
118
+ unsup_speech_encoder.layer_norm = sup_speech_encoder.layer_norm
119
+ unsup_speech_encoder.layers = cls.share_layers(
120
+ sup_speech_encoder.layers, unsup_speech_encoder.layers
121
+ )
122
+ unsup_speech_encoder.mask_emb = sup_speech_encoder.mask_emb
123
+ unsup_speech_encoder.embed_positions = sup_speech_encoder.embed_positions
124
+ unsup_speech_encoder.feat_layer_norm = sup_speech_encoder.feat_layer_norm
125
+ unsup_speech_encoder.feat_proj = sup_speech_encoder.feat_proj
126
+ unsup_speech_encoder.subsample = sup_speech_encoder.subsample
127
+ return unsup_speech_encoder
128
+
129
+ @classmethod
130
+ def build_encoder(cls, args, dictionary):
131
+ text_encoder = cls.build_text_encoder(args, dictionary)
132
+ if getattr(args, "load_pretrained_mbart_encoder_from", None):
133
+ text_encoder = checkpoint_utils.load_pretrained_component_from_model(
134
+ component=text_encoder,
135
+ checkpoint=args.load_pretrained_mbart_encoder_from,
136
+ )
137
+ speech_encoder = cls.build_speech_encoder(args)
138
+ if getattr(args, "load_pretrained_feature_extractor_from", None):
139
+
140
+ def load_feature_extractor(component, checkpoint):
141
+ if not PathManager.exists(checkpoint):
142
+ raise IOError("Model file not found: {}".format(checkpoint))
143
+ state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
144
+ component_state_dict = OrderedDict()
145
+
146
+ component_prefix = "feature_extractor"
147
+ for key in state["model"].keys():
148
+ if key.startswith(component_prefix):
149
+ component_subkey = key[len(component_prefix) + 1 :]
150
+ component_state_dict[component_subkey] = state["model"][key]
151
+ component.load_state_dict(component_state_dict, strict=True)
152
+ return component
153
+
154
+ speech_encoder.subsample = load_feature_extractor(
155
+ speech_encoder.subsample, args.load_pretrained_feature_extractor_from
156
+ )
157
+ speech_s2s_encoder = speech_encoder
158
+ unsup_speech_encoder = cls.build_unsup_speech_encoder(args, speech_encoder)
159
+ if getattr(args, "stacked_encoder", "none") != "none":
160
+ if args.encoder_shared_text_layers_from_begin > 0:
161
+ raise ValueError(
162
+ "We can not stack encoders and share encoders at the same time!"
163
+ )
164
+ speech_s2s_encoder = StackedSpeechWavTransformerEncoder(
165
+ speech_encoder, text_encoder.layers, text_encoder.layer_norm
166
+ )
167
+ if args.stacked_encoder == "all":
168
+ speech_encoder = speech_s2s_encoder
169
+ unsup_speech_encoder = StackedSpeechWavTransformerEncoder(
170
+ unsup_speech_encoder, text_encoder.layers, text_encoder.layer_norm
171
+ )
172
+ else:
173
+ cls.share_speech_text_encoder(
174
+ speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin
175
+ )
176
+ return SpeechTextPreTrainEncoder(
177
+ dictionary,
178
+ speech_encoder,
179
+ speech_s2s_encoder,
180
+ unsup_speech_encoder,
181
+ text_encoder,
182
+ )
183
+
184
+ @classmethod
185
+ def share_speech_text_encoder(
186
+ cls, speech_encoder, text_encoder, shared_layers_from_begin
187
+ ):
188
+ if shared_layers_from_begin > 0:
189
+ num_text_encoder_layers = len(text_encoder.layers)
190
+ assert len(speech_encoder.layers) >= shared_layers_from_begin
191
+ assert num_text_encoder_layers >= shared_layers_from_begin
192
+ assert len(speech_encoder.layers) >= num_text_encoder_layers
193
+ for i, ly in enumerate(
194
+ speech_encoder.layers[
195
+ -num_text_encoder_layers : -num_text_encoder_layers
196
+ + shared_layers_from_begin
197
+ ]
198
+ ):
199
+ assert isinstance(text_encoder.layers[i], type(ly))
200
+ text_encoder.layers[i] = ly
201
+
202
+ def select_encoder(self, mode, **kwargs):
203
+ if mode in ("speech", "sup_speech_ctc", "sup_speech_ali", "sup_speech_s2s"):
204
+ kwargs["features_only"] = True
205
+ if mode == "sup_speech_s2s":
206
+ return self.sup_s2s_speech_encoder, kwargs
207
+ return self.sup_speech_encoder, kwargs
208
+ elif mode == "unsup_speech":
209
+ kwargs["features_only"] = False
210
+ return self.unsup_speech_encoder, kwargs
211
+ elif mode in ("text", "bitext"):
212
+ return self.text_encoder, kwargs
213
+ else:
214
+ raise NotImplementedError(f"{mode} is not supported")
215
+ return None, kwargs
216
+
217
+ def forward(self, src_tokens, src_lengths=None, mode="", alignment=None, **kwargs):
218
+ return super().forward(src_tokens, src_lengths, mode, **kwargs)
219
+
220
+
221
+ # SpeechDummyDecoder works as an extension of encoder, so we could fit encoder only training into seq2seq training
222
+ class SpeechDummyDecoder(FairseqDecoder):
223
+ def __init__(
224
+ self,
225
+ dictionary,
226
+ output_embedding,
227
+ no_emb_update_unsup=False,
228
+ use_output_proj=False,
229
+ ):
230
+ super().__init__(dictionary)
231
+ self.output_embedding = output_embedding
232
+ num_embedding, num_dim = self.output_embedding.weight.size()
233
+ self.out_proj = (
234
+ None if use_output_proj is False else nn.Linear(num_dim, num_dim)
235
+ )
236
+ self.no_emb_update_unsup = no_emb_update_unsup
237
+
238
+ def extend_alignment(self, alignment, src_lengths, prev_output_tokens):
239
+ # alignment: B X N
240
+ # src_lengths: B X T
241
+ # prev_output_tokens: B X (N + 1)
242
+ tgt_tokens = prev_output_tokens[
243
+ :, 1:
244
+ ] # remove the leading start of sentence token
245
+ ext_alignment = (
246
+ torch.ones(len(src_lengths), src_lengths.max(), device=src_lengths.device)
247
+ .long()
248
+ .fill_(self.dictionary.pad())
249
+ )
250
+ for bs in range(src_lengths.size(0)):
251
+ tgt_length = tgt_tokens[bs].ne(self.dictionary.pad()).sum().item()
252
+ assert tgt_length == sum(alignment[bs].ne(1)) + 1
253
+ src_st = 0
254
+ for i in range(tgt_length):
255
+ tok = tgt_tokens[bs][i]
256
+ src_ed = (alignment[bs][i] * src_lengths[bs]).int().item()
257
+ ext_alignment[bs][src_st:src_ed].fill_(tok)
258
+ src_st = src_ed
259
+ return ext_alignment
260
+
261
+ def forward(
262
+ self,
263
+ prev_output_tokens,
264
+ encoder_out,
265
+ incremental_state=None,
266
+ mode="speech",
267
+ alignment=None,
268
+ **kwargs,
269
+ ):
270
+ """
271
+ Args:
272
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
273
+ `(batch, tgt_len)`, for teacher forcing
274
+ encoder_out (optional): output from the encoder, used for
275
+ encoder-side attention
276
+ incremental_state (dict): dictionary used for storing state during
277
+ :ref:`Incremental decoding`
278
+ features_only (bool, optional): only return features without
279
+ applying output layer (default: False).
280
+ full_context_alignment (bool, optional): don't apply
281
+ auto-regressive mask to self-attention (default: False).
282
+
283
+ Returns:
284
+ sup_speech_ctc:
285
+ dictionary{"logits": logits, "padding_mask": padding_mask}
286
+ sup_speech_ali and unsup_speech:
287
+ tuple:
288
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
289
+ - a dictionary with any model-specific outputs
290
+ """
291
+ emb_weight = self.output_embedding.weight
292
+ if (
293
+ mode == "unsup_speech" and self.no_emb_update_unsup
294
+ ): # no gradient for embedding here
295
+ emb_weight = emb_weight.detach()
296
+ enc_out = (
297
+ encoder_out["encoder_out"][0]
298
+ if self.out_proj is None
299
+ else self.out_proj(encoder_out["encoder_out"][0])
300
+ )
301
+ logits = F.linear(enc_out, emb_weight, None).transpose(0, 1) # B X T X C
302
+ others = None
303
+ if mode in (
304
+ "speech",
305
+ "sup_speech_ctc",
306
+ ): # speech data with label, do forcealignment
307
+ if len(encoder_out["encoder_padding_mask"]) > 0:
308
+ padding_mask = encoder_out["encoder_padding_mask"][0]
309
+ logits = logits.masked_fill(padding_mask, float("-inf"))
310
+ else:
311
+ seq_len, bsz = encoder_out["encoder_out"][0].size()[:2]
312
+ padding_mask = torch.zeros(
313
+ bsz, seq_len, device=encoder_out["encoder_out"][0].device
314
+ ).bool()
315
+ return {"x": logits, "padding_mask": padding_mask}
316
+ elif mode == "sup_speech_ali":
317
+ src_lengths = None
318
+ if len(encoder_out["encoder_padding_mask"]) > 0:
319
+ src_lengths = (1 - encoder_out["encoder_padding_mask"][0].long()).sum(
320
+ -1
321
+ )
322
+ else:
323
+ seq_len, bsz = encoder_out["encoder_out"][0].size()[:2]
324
+ src_lengths = (
325
+ torch.ones(bsz, device=encoder_out["encoder_out"][0].device).long()
326
+ * seq_len
327
+ )
328
+ assert alignment is not None
329
+ alignment = self.extend_alignment(
330
+ alignment, src_lengths, prev_output_tokens
331
+ )
332
+ others = {"pseudo_target_tokens": alignment}
333
+ elif mode == "unsup_speech":
334
+ enc_out_ori = (
335
+ encoder_out["encoder_unmasked_out"][0]
336
+ if self.out_proj is None
337
+ else self.out_proj(encoder_out["encoder_unmasked_out"][0])
338
+ )
339
+ logits_ori = F.linear(enc_out_ori, emb_weight, None).transpose(0, 1)
340
+ if len(encoder_out["encoder_padding_mask"]) > 0:
341
+ encoder_padding_mask = encoder_out["encoder_padding_mask"][0]
342
+ logits_ori = logits_ori.masked_fill(encoder_padding_mask, float("-inf"))
343
+ pseudo_labels = utils.log_softmax(logits_ori, dim=-1)
344
+ others = {
345
+ "pseudo_target_logprobs": pseudo_labels,
346
+ "padding_mask": encoder_out["encoder_padding_mask"], # B X T
347
+ "mask_indices": encoder_out[
348
+ "mask_indices"
349
+ ], # True for masked frames B X T
350
+ }
351
+ return logits, others
352
+
353
+ def get_normalized_probs(
354
+ self,
355
+ net_output: Dict[str, Tensor],
356
+ log_probs: bool,
357
+ sample: Optional[Dict[str, Tensor]] = None,
358
+ ):
359
+ return self.get_normalized_probs_scriptable(
360
+ (net_output["x"], None), log_probs, sample
361
+ )
362
+
363
+
364
+ class SpeechTextPreTrainDecoder(MultiInputDecoder):
365
+ def __init__(self, dictionary, speech_decoder, text_decoder):
366
+ super().__init__(dictionary)
367
+ self.speech_decoder = speech_decoder
368
+ self.text_decoder = text_decoder
369
+
370
+ def select_decoder(self, mode, **kwargs):
371
+ if mode == "unsup_speech":
372
+ kwargs["mode"] = mode
373
+ return self.speech_decoder, kwargs
374
+ if mode in ("text", "bitext"):
375
+ return self.text_decoder, kwargs
376
+ if mode in ("speech", "sup_speech_ctc", "sup_speech_ali"):
377
+ kwargs["mode"] = mode
378
+ return self.speech_decoder, kwargs
379
+ if mode in ("speech", "sup_speech_s2s"):
380
+ if "alignment" in kwargs:
381
+ del kwargs["alignment"]
382
+ return self.text_decoder, kwargs
383
+
384
+ raise NotImplementedError(f"{mode} is not supported")
385
+ return None, kwargs
386
+
387
+ def get_normalized_probs(
388
+ self,
389
+ net_output,
390
+ log_probs,
391
+ sample=None,
392
+ ):
393
+ """Get normalized probabilities (or log probs) from a net's output."""
394
+ if isinstance(net_output, dict):
395
+ return self.speech_decoder.get_normalized_probs(
396
+ net_output, log_probs, sample
397
+ )
398
+ return self.text_decoder.get_normalized_probs(net_output, log_probs, sample)
399
+
400
+ @classmethod
401
+ def build_text_decoder(cls, args, tgt_dictionary, dec_emb_share=None):
402
+ dec_emb = (
403
+ nn.Embedding(
404
+ len(tgt_dictionary), args.decoder_embed_dim, tgt_dictionary.pad()
405
+ )
406
+ if dec_emb_share is None
407
+ else dec_emb_share
408
+ )
409
+ text_decoder = TransformerDecoder(args, tgt_dictionary, dec_emb)
410
+ return text_decoder
411
+
412
+ @classmethod
413
+ def build_dummy_speech_decoder(cls, args, dictionary, dec_emb_share=None):
414
+ dec_emb = (
415
+ nn.Embedding(len(dictionary), args.decoder_embed_dim, dictionary.pad())
416
+ if dec_emb_share is None
417
+ else dec_emb_share
418
+ )
419
+ speech_decoder = SpeechDummyDecoder(
420
+ dictionary,
421
+ dec_emb,
422
+ no_emb_update_unsup=getattr(args, "no_emb_update_unsup", False),
423
+ use_output_proj=getattr(args, "use_decoder_output_proj", False),
424
+ )
425
+ return speech_decoder
426
+
427
+ @classmethod
428
+ def build_decoder(
429
+ cls, args, text_dictionary, speech_dictionary, speech_output_embedding
430
+ ):
431
+ text_decoder = cls.build_text_decoder(args, text_dictionary)
432
+ speech_decoder = cls.build_dummy_speech_decoder(
433
+ args, speech_dictionary, speech_output_embedding
434
+ )
435
+ if getattr(args, "load_pretrained_mbart_decoder_from", None):
436
+ text_decoder = checkpoint_utils.load_pretrained_component_from_model(
437
+ component=text_decoder,
438
+ checkpoint=args.load_pretrained_mbart_decoder_from,
439
+ )
440
+ return SpeechTextPreTrainDecoder(text_dictionary, speech_decoder, text_decoder)
441
+
442
+
443
+ @register_model("speech_text_pretrain_bart")
444
+ class SpeechTextPreTrainModel(FairseqEncoderDecoderModel):
445
+ def __init__(self, encoder, decoder):
446
+ super().__init__(encoder, decoder)
447
+ self.num_updates = 0
448
+
449
+ def forward(
450
+ self, src_tokens, src_lengths, prev_output_tokens, src_lang_ids=None, **kwargs
451
+ ):
452
+ if src_lang_ids is not None:
453
+ encoder_out = self.encoder(
454
+ src_tokens, src_lengths=src_lengths, src_lang_ids=src_lang_ids, **kwargs
455
+ )
456
+ else:
457
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
458
+ decoder_out = self.decoder(
459
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
460
+ )
461
+ return decoder_out
462
+
463
+ def max_positions(self):
464
+ return None # it is provided in task
465
+
466
+ def get_targets(self, sample, net_output):
467
+ mode = sample["net_input"]["mode"]
468
+ if mode == "unsup_speech":
469
+ return {"target_logprobs": net_output[1]["pseudo_target_logprobs"]}
470
+ if mode == "sup_speech_ali":
471
+ return net_output[1]["pseudo_target_tokens"]
472
+ return sample["target"]
473
+
474
+ def get_normalized_probs(
475
+ self,
476
+ net_output,
477
+ log_probs,
478
+ sample=None,
479
+ ):
480
+ # net_output['encoder_out'] is a (B, T, D) tensor
481
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
482
+ lprobs.batch_first = True
483
+ return lprobs
484
+
485
+ @staticmethod
486
+ def add_args(parser):
487
+ TransformerModel.add_args(parser)
488
+ SpeechWavTransformerEncoder.add_args(parser)
489
+ parser.add_argument(
490
+ "--speech-sup-mask-prob",
491
+ type=float,
492
+ help="probability of replacing a token with mask (sup-speech)",
493
+ )
494
+ parser.add_argument(
495
+ "--speech-unsup-mask-prob",
496
+ type=float,
497
+ help="probability of replacing a token with mask (unsup-speech)",
498
+ )
499
+ parser.add_argument(
500
+ "--load-pretrained-mbart-encoder-from",
501
+ type=str,
502
+ metavar="STR",
503
+ help="model to take text encoder weights from (for initialization)",
504
+ )
505
+
506
+ parser.add_argument(
507
+ "--load-pretrained-mbart-decoder-from",
508
+ type=str,
509
+ metavar="STR",
510
+ help="model to take text decoder weights from (for initialization)",
511
+ )
512
+
513
+ parser.add_argument(
514
+ "--load-pretrained-feature-extractor-from",
515
+ type=str,
516
+ metavar="STR",
517
+ help="model to take feature extractor weights from (for initialization)",
518
+ )
519
+
520
+ parser.add_argument(
521
+ "--speech-unsup-dropout",
522
+ type=float,
523
+ default=0,
524
+ help="dropout for unsupervised speech encoder",
525
+ )
526
+
527
+ parser.add_argument(
528
+ "--speech-unsup-feature-dropout",
529
+ type=float,
530
+ default=0,
531
+ help="dropout for unsupervised speech feature encoder",
532
+ )
533
+
534
+ parser.add_argument(
535
+ "--encoder-shared-text-layers-from-begin",
536
+ type=int,
537
+ help="number of text encoder layers shared with speech encoder (from first layer)",
538
+ )
539
+
540
+ parser.add_argument(
541
+ "--stacked-encoder",
542
+ default="none",
543
+ choices=["none", "s2s", "all"],
544
+ help="stack speech and text encoders",
545
+ )
546
+
547
+ parser.add_argument("--use-decoder-output-proj", action="store_true")
548
+
549
+ @classmethod
550
+ def build_model(cls, args, task):
551
+ encoder = SpeechTextPreTrainEncoder.build_encoder(args, task.src_dict)
552
+ decoder = SpeechTextPreTrainDecoder.build_decoder(
553
+ args, task.tgt_dict, task.src_dict, encoder.text_encoder.embed_tokens
554
+ )
555
+ model = SpeechTextPreTrainModel(encoder, decoder)
556
+ return model
557
+
558
+ def upgrade_state_dict(self, state_dict):
559
+ """Upgrade old state dicts to work with newer code."""
560
+ if "decoder.speech_decoder.output_projection.weight" in state_dict:
561
+ del state_dict["decoder.speech_decoder.output_projection.weight"]
562
+ self.upgrade_state_dict_named(state_dict, "")
563
+
564
+
565
+ @register_model_architecture(
566
+ "speech_text_pretrain_bart", "speech_text_pretrain_bart_base"
567
+ )
568
+ def speech_text_pretrain_bart_base(args):
569
+ # speech masking
570
+ args.dropout_input = getattr(args, "dropout_input", 0)
571
+ args.dropout_features = getattr(args, "dropout_features", 0)
572
+ args.speech_mask_length = getattr(args, "speech_mask_length", 10)
573
+ args.speech_mask_prob = getattr(args, "speech_mask_prob", 0.65)
574
+ args.speech_sup_mask_prob = getattr(args, "speech_sup_mask_prob", 0.3)
575
+ args.speech_unsup_mask_prob = getattr(
576
+ args, "speech_unsup_mask_prob", args.speech_mask_prob
577
+ )
578
+ args.speech_mask_selection = getattr(args, "speech_mask_selection", "static")
579
+ args.speech_mask_other = getattr(args, "speech_mask_other", 0)
580
+ args.speech_mask_min_space = getattr(args, "speech_mask_min_space", 1)
581
+ args.speech_no_mask_overlap = getattr(args, "speech_no_mask_overlap", False)
582
+
583
+ args.speech_mask_channel_length = getattr(args, "speech_mask_channel_length", 10)
584
+ args.speech_mask_channel_prob = getattr(args, "speech_mask_channel_prob", 0.0)
585
+ args.speech_mask_channel_selection = getattr(
586
+ args, "speech_mask_channel_selection", "static"
587
+ )
588
+ args.speech_mask_channel_other = getattr(args, "speech_mask_channel_other", 0)
589
+ args.speech_mask_channel_min_space = getattr(
590
+ args, "speech_mask_channel_min_space", 1
591
+ )
592
+ args.speech_no_mask_channel_overlap = getattr(
593
+ args, "speech_no_mask_channel_overlap", False
594
+ )
595
+ args.no_scale_feature = getattr(args, "", False)
596
+ args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) # 0.1
597
+
598
+ # Transformer
599
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
600
+ args.encoder_ffn_embed_dim = getattr(
601
+ args, "encoder_ffn_embed_dim", args.encoder_embed_dim * 4
602
+ )
603
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
604
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
605
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
606
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
607
+ args.speech_conv_bias = getattr(args, "speech_conv_bias", False)
608
+
609
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
610
+ args.decoder_ffn_embed_dim = getattr(
611
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
612
+ )
613
+ args.decoder_attention_heads = getattr(
614
+ args, "decoder_attention_heads", args.encoder_attention_heads
615
+ )
616
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
617
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
618
+ args.dropout = getattr(args, "dropout", 0.1)
619
+ args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
620
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
621
+ args.activation_fn = getattr(args, "activation_fn", "relu") # gelu?
622
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
623
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
624
+
625
+ args.speech_unsup_dropout = getattr(args, "speech_unsup_dropout", 0)
626
+ args.speech_unsup_feature_dropout = getattr(args, "speech_unsup_feature_dropout", 0)
627
+
628
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
629
+ args.share_decoder_input_output_embed = getattr(
630
+ args, "share_decoder_input_output_embed", False
631
+ )
632
+ args.no_token_positional_embeddings = getattr(
633
+ args, "no_token_positional_embeddings", False
634
+ )
635
+ args.adaptive_input = getattr(args, "adaptive_input", False)
636
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
637
+ args.decoder_output_dim = getattr(
638
+ args, "decoder_output_dim", args.decoder_embed_dim
639
+ )
640
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
641
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
642
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
643
+
644
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
645
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
646
+ args.encoder_shared_text_layers_from_begin = getattr(
647
+ args, "encoder_shared_text_layers_from_begin", 6
648
+ )
649
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
650
+
651
+ args.no_emb_update_unsup = getattr(args, "no_emb_update_unsup", False)
652
+
653
+
654
+ @register_model_architecture(
655
+ "speech_text_pretrain_bart", "speech_text_pretrain_bart_base_stack"
656
+ )
657
+ def speech_text_pretrain_bart_base_stack(args):
658
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
659
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
660
+ args.encoder_shared_text_layers_from_begin = getattr(
661
+ args, "encoder_shared_text_layers_from_begin", 0
662
+ )
663
+ args.stacked_encoder = getattr(args, "stacked_encoder", "all")
664
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
665
+ speech_text_pretrain_bart_base(args)
666
+
667
+
668
+ @register_model_architecture(
669
+ "speech_text_pretrain_bart", "speech_text_pretrain_bart_large"
670
+ )
671
+ def speech_text_pretrain_bart_large(args):
672
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
673
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
674
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 24)
675
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
676
+ args.encoder_shared_text_layers_from_begin = getattr(
677
+ args, "encoder_shared_text_layers_from_begin", 12
678
+ )
679
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
680
+ args.dropout = getattr(args, "dropout", 0.3)
681
+ speech_text_pretrain_bart_base(args)
682
+
683
+
684
+ @register_model_architecture(
685
+ "speech_text_pretrain_bart", "speech_text_pretrain_bart_large_stack"
686
+ )
687
+ def speech_text_pretrain_bart_large_stack(args):
688
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
689
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
690
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
691
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
692
+ args.encoder_shared_text_layers_from_begin = getattr(
693
+ args, "encoder_shared_text_layers_from_begin", 0
694
+ )
695
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
696
+ args.stacked_encoder = getattr(args, "stacked_encoder", "s2s")
697
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
698
+ speech_text_pretrain_bart_base(args)
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from collections import namedtuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from fairseq import checkpoint_utils
12
+ from fairseq import utils
13
+ from fairseq.models import (
14
+ FairseqEncoder,
15
+ FairseqDecoder,
16
+ FairseqEncoderDecoderModel,
17
+ register_model,
18
+ register_model_architecture,
19
+ )
20
+ from fairseq.models.fairseq_encoder import EncoderOut
21
+ from fairseq.models.speech_to_text import (
22
+ TransformerDecoder,
23
+ S2TTransformerEncoder,
24
+ )
25
+ from fairseq.models.transformer import TransformerEncoder
26
+ from fairseq.modules import (
27
+ TransformerEncoderLayer,
28
+ GradMultiply,
29
+ LayerNorm,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SpeechEoSEncoder(FairseqEncoder):
36
+ def __init__(self, encoder, eos_num, feat_dim, adapter_type="None", adapter_dim=0):
37
+ super().__init__(None)
38
+ self.encoder = encoder
39
+ self.eos_num = eos_num # downsampling rate for speech input feature
40
+ self.eos_emb = (
41
+ nn.Parameter(torch.zeros(1, feat_dim), requires_grad=True)
42
+ if eos_num > 0
43
+ else None
44
+ )
45
+ self.adapter = self.add_adapter(adapter_type, adapter_dim)
46
+
47
+ def add_adapter(self, adapter_type, adapter_dim):
48
+ def _make_identity(linear, eps=1e-5):
49
+ assert isinstance(linear, nn.Linear)
50
+ linear.weight.data.mul_(eps)
51
+ linear.weight.data.fill_diagonal_(1.0)
52
+ if linear.bias is not None:
53
+ linear.bias.data.mul_(eps)
54
+
55
+ adapter = None
56
+ if adapter_type == "Linear":
57
+ assert adapter_dim > 0
58
+ adapter = nn.Sequential(
59
+ nn.Linear(adapter_dim, adapter_dim), LayerNorm(adapter_dim)
60
+ )
61
+ # initialize the adapter as identity matrix first
62
+ _make_identity(adapter[0])
63
+
64
+ elif adapter_type == "MLP":
65
+ assert adapter_dim > 0
66
+ # assume the model is pre-norm model
67
+ adapter = nn.Sequential(
68
+ nn.Linear(adapter_dim, 2 * adapter_dim),
69
+ nn.ReLU(),
70
+ nn.Linear(2 * adapter_dim, adapter_dim),
71
+ LayerNorm(adapter_dim),
72
+ )
73
+ _make_identity(adapter[0])
74
+ _make_identity(adapter[2])
75
+ return adapter
76
+
77
+ def add_eos(self, src_tokens, src_lengths):
78
+ bsz, max_seq_len, fdim = src_tokens.size()
79
+ if self.eos_num > 0:
80
+ src_token_eos = torch.zeros(
81
+ [bsz, max_seq_len + self.eos_num, fdim],
82
+ dtype=src_tokens.dtype,
83
+ device=src_tokens.device,
84
+ )
85
+ src_token_eos[:, :max_seq_len] = src_tokens
86
+ for bi in range(bsz):
87
+ src_token_eos[bi][
88
+ src_lengths[bi] : src_lengths[bi] + self.eos_num
89
+ ] = self.eos_emb.expand(self.eos_num, fdim)
90
+ src_lengths = src_lengths + self.eos_num
91
+ src_tokens = src_token_eos
92
+ return src_tokens, src_lengths
93
+
94
+ def apply_adapter(self, enc_out):
95
+ if self.adapter is None:
96
+ return enc_out
97
+ rst = self.adapter(enc_out.encoder_out)
98
+ if enc_out.encoder_padding_mask is not None:
99
+ rst.masked_fill_(
100
+ enc_out.encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0
101
+ )
102
+ return EncoderOut(
103
+ encoder_out=rst,
104
+ encoder_padding_mask=enc_out.encoder_padding_mask,
105
+ encoder_embedding=enc_out.encoder_embedding,
106
+ encoder_states=enc_out.encoder_states,
107
+ src_tokens=enc_out.src_tokens,
108
+ src_lengths=enc_out.src_lengths,
109
+ )
110
+
111
+ def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs):
112
+ """
113
+ src_tokens: padded tensor (B, T, C * feat)
114
+ src_lengths: tensor of original lengths of input utterances (B,)
115
+ """
116
+ src_tokens, src_lengths = self.add_eos(src_tokens, src_lengths)
117
+ enc_out = self.encoder(src_tokens, src_lengths, return_all_hiddens)
118
+ enc_out = self.apply_adapter(enc_out)
119
+ return enc_out
120
+
121
+ def reorder_encoder_out(self, encoder_out, new_order):
122
+ return self.encoder.reorder_encoder_out(encoder_out, new_order)
123
+
124
+
125
+ class DualInputEncoder(FairseqEncoder):
126
+ def __init__(
127
+ self,
128
+ args,
129
+ spch_encoder,
130
+ text_encoder,
131
+ dictionary,
132
+ cross_attentive_loss_before_last_layer=-1,
133
+ ):
134
+ super().__init__(dictionary)
135
+
136
+ self.spch_encoder = spch_encoder
137
+ self.text_encoder = text_encoder
138
+ self.enc_grad_mult = args.enc_grad_mult
139
+ self.cross_attentive_loss_before_last_layer = (
140
+ cross_attentive_loss_before_last_layer
141
+ )
142
+ self.use_cross_attentive_loss = (
143
+ False if cross_attentive_loss_before_last_layer <= -1 else True
144
+ )
145
+ self.enc2_along_grad_mult = args.enc2_along_grad_mult
146
+
147
+ @classmethod
148
+ def set_shared_layer(cls, share_level, src_layer, tgt_layer):
149
+ """
150
+ share parameters from tgt_layer to src_layer
151
+ share_level:
152
+ 0: share everything
153
+ 1: share everything but different model
154
+ 2: share weight but not bias, layernorm
155
+ """
156
+ if share_level == 0:
157
+ return tgt_layer
158
+ if isinstance(src_layer, nn.Linear):
159
+ return tgt_layer
160
+ if isinstance(src_layer, TransformerEncoderLayer):
161
+ assert src_layer.embed_dim == tgt_layer.embed_dim
162
+ assert src_layer.normalize_before == tgt_layer.normalize_before
163
+ if share_level == 1:
164
+ src_layer.fc1 = tgt_layer.fc1
165
+ src_layer.fc2 = tgt_layer.fc2
166
+ src_layer.self_attn = tgt_layer.self_attn
167
+ src_layer.final_layer_norm = tgt_layer.final_layer_norm
168
+ src_layer.self_attn_layer_norm = tgt_layer.self_attn_layer_norm
169
+ src_layer.layernorm_embedding = tgt_layer.layernorm_embedding
170
+ else:
171
+ src_layer.fc1.weight = tgt_layer.fc1.weight
172
+ src_layer.fc2.weight = tgt_layer.fc2.weight
173
+ src_layer.self_attn.k_proj.weight = tgt_layer.self_attn.k_proj.weight
174
+ src_layer.self_attn.v_proj.weight = tgt_layer.self_attn.v_proj.weight
175
+ src_layer.self_attn.q_proj.weight = tgt_layer.self_attn.q_proj.weight
176
+ src_layer.self_attn.out_proj.weight = (
177
+ tgt_layer.self_attn.out_proj.weight
178
+ )
179
+ else:
180
+ if share_level == 1:
181
+ return tgt_layer
182
+ return src_layer
183
+
184
+ @classmethod
185
+ def build_spch_encoder(cls, args):
186
+ cfg = {
187
+ "input_feat_per_channel": args.input_feat_per_channel,
188
+ "input_channels": args.input_channels,
189
+ "conv_kernel_sizes": args.conv_kernel_sizes,
190
+ "conv_channels": args.conv_channels,
191
+ "encoder_embed_dim": args.encoder_embed_dim,
192
+ "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
193
+ "encoder_layers": args.speech_encoder_layers,
194
+ "encoder_layerdrop": args.encoder_layerdrop,
195
+ "encoder_attention_heads": args.encoder_attention_heads,
196
+ "max_source_positions": args.max_source_positions,
197
+ "dropout": args.dropout,
198
+ "encoder_normalize_before": args.encoder_normalize_before,
199
+ "activation_dropout": args.activation_dropout,
200
+ "attention_dropout": args.attention_dropout,
201
+ "activation_fn": args.activation_fn,
202
+ "layernorm_embedding": args.layernorm_embedding,
203
+ "no_token_positional_embeddings": args.no_token_positional_embeddings,
204
+ "no_scale_embedding": args.no_scale_embedding,
205
+ "quant_noise_pq": args.quant_noise_pq,
206
+ "encoder_freezing_updates": 0,
207
+ }
208
+ model_args = namedtuple("args", cfg.keys())(*cfg.values())
209
+ spch_encoder = S2TTransformerEncoder(model_args)
210
+ if args.add_speech_eos:
211
+ spch_encoder = SpeechEoSEncoder(
212
+ spch_encoder,
213
+ 2 * len(args.conv_kernel_sizes.split(",")),
214
+ args.input_feat_per_channel,
215
+ adapter_type=getattr(args, "speech_encoder_adapter_type", "None"),
216
+ adapter_dim=args.encoder_embed_dim,
217
+ )
218
+ return spch_encoder
219
+
220
+ @classmethod
221
+ def build_text_encoder(cls, args, src_dictionary, spch_encoder):
222
+ if args.encoder_shared_layers > 0:
223
+ mx_shared_layers = (
224
+ args.speech_encoder_layers
225
+ if args.speech_encoder_layers < args.text_encoder_layers
226
+ else args.text_encoder_layers
227
+ )
228
+ args.encoder_shared_layers = (
229
+ args.encoder_shared_layers
230
+ if args.encoder_shared_layers <= mx_shared_layers
231
+ else mx_shared_layers
232
+ )
233
+ cfg = {
234
+ "encoder_embed_dim": args.encoder_text_embed_dim,
235
+ "encoder_ffn_embed_dim": args.encoder_ffn_embed_dim,
236
+ "encoder_layers": args.text_encoder_layers,
237
+ "encoder_layerdrop": args.encoder_layerdrop,
238
+ "encoder_attention_heads": args.encoder_attention_heads,
239
+ "encoder_learned_pos": args.encoder_learned_pos,
240
+ "max_source_positions": args.max_source_positions,
241
+ "dropout": args.dropout,
242
+ "encoder_normalize_before": args.encoder_normalize_before,
243
+ "activation_dropout": args.activation_dropout,
244
+ "attention_dropout": args.attention_dropout,
245
+ "activation_fn": args.activation_fn,
246
+ "adaptive_input": args.adaptive_input,
247
+ "no_token_positional_embeddings": args.no_token_positional_embeddings,
248
+ "no_scale_embedding": args.no_scale_embedding,
249
+ "quant_noise_pq": args.quant_noise_pq,
250
+ }
251
+ model_args = namedtuple("args", cfg.keys())(*cfg.values())
252
+ enc_emb = nn.Embedding(
253
+ len(src_dictionary), model_args.encoder_embed_dim, src_dictionary.pad()
254
+ )
255
+ text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
256
+ if args.add_speech_eos:
257
+ spch_encoder = spch_encoder.encoder
258
+ if args.encoder_shared_layers > 0:
259
+ text_encoder.layer_norm = cls.set_shared_layer(
260
+ args.encoder_shared_layer_level,
261
+ text_encoder.layer_norm,
262
+ spch_encoder.layer_norm,
263
+ )
264
+ for i, ly in enumerate(
265
+ spch_encoder.transformer_layers[-args.encoder_shared_layers :]
266
+ ):
267
+ ly_id = i + args.text_encoder_layers - args.encoder_shared_layers
268
+ if not isinstance(text_encoder.layers[ly_id], type(ly)):
269
+ if text_encoder.layers[ly_id]._get_name() not in ('TransformerEncoderLayerBase', 'TransformerEncoderLayer'):
270
+ raise ValueError("The shared layers are expected from the same class")
271
+ text_encoder.layers[ly_id] = cls.set_shared_layer(
272
+ args.encoder_shared_layer_level,
273
+ text_encoder.layers[ly_id],
274
+ ly,
275
+ )
276
+ return text_encoder
277
+
278
+ def mult_rst_grad(self, rst, ratio):
279
+ assert isinstance(rst, dict) # instead of EncoderOut
280
+ assert len(rst["encoder_out"]) == 1
281
+ rst["encoder_out"][0] = GradMultiply.apply(rst["encoder_out"][0], ratio)
282
+ return rst
283
+
284
+ def process_attentive_loss_states(self, rst, interstates):
285
+ assert isinstance(rst, dict) # instead of EncoderOut
286
+ rst["encoder_states"] = interstates
287
+ return rst
288
+
289
+ def forward(
290
+ self,
291
+ src_tokens,
292
+ src_lengths=None,
293
+ src_txt_tokens=None,
294
+ src_txt_lengths=None,
295
+ **kwargs
296
+ ):
297
+ """
298
+ Args:
299
+ src_tokens: padded tensor (B, T, C * feat)
300
+ src_lengths: tensor of original lengths of input utterances (speech) (B,)
301
+ src_txt_tokens: padded tensor (B, T)
302
+ src_txt_lengths: tensor of original lengths of input utterances (text) (B,)
303
+ """
304
+ # src_tokens only: inference
305
+ # src_tokens, src_lengths: speech only training
306
+ # src_txt_tokens, src_txt_lengths: text only training
307
+ # all valid: speech + text training
308
+
309
+ if src_tokens is None and src_txt_tokens is None:
310
+ raise ValueError(
311
+ "src_tokens and src_txt_tokens cannot be None at the same time"
312
+ )
313
+ ret1 = None
314
+ ret2 = None
315
+ return_all_hiddens = False
316
+ if src_tokens is not None:
317
+ if (
318
+ self.use_cross_attentive_loss and src_txt_tokens is not None
319
+ ): # remove self.training so we can get attn score during validation step
320
+ return_all_hiddens = True
321
+ ret1 = self.spch_encoder(
322
+ src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
323
+ )
324
+
325
+ if self.use_cross_attentive_loss and src_txt_tokens is not None:
326
+ assert self.cross_attentive_loss_before_last_layer < len(
327
+ ret1["encoder_states"]
328
+ )
329
+ ret1 = self.process_attentive_loss_states(
330
+ ret1,
331
+ ret1["encoder_states"][
332
+ -self.cross_attentive_loss_before_last_layer - 1
333
+ ],
334
+ )
335
+
336
+ if src_txt_tokens is not None:
337
+ ret2 = self.text_encoder(
338
+ src_txt_tokens, src_txt_lengths, return_all_hiddens=return_all_hiddens
339
+ )
340
+ if return_all_hiddens:
341
+ if self.cross_attentive_loss_before_last_layer == len(
342
+ self.text_encoder.layers
343
+ ):
344
+ text_embedding, _ = self.text_encoder.forward_embedding(
345
+ src_txt_tokens
346
+ )
347
+ text_embedding = text_embedding.transpose(0, 1)
348
+ ret2 = self.process_attentive_loss_states(ret2, text_embedding)
349
+ else:
350
+ assert self.cross_attentive_loss_before_last_layer < len(
351
+ self.text_encoder.layers
352
+ )
353
+ ret2 = self.process_attentive_loss_states(
354
+ ret2,
355
+ ret2["encoder_states"][
356
+ -self.cross_attentive_loss_before_last_layer - 1
357
+ ],
358
+ )
359
+
360
+ def merge_output(rst1, rst2):
361
+ if rst1 is None:
362
+ if not (self.enc2_along_grad_mult == 1.0 or self.training):
363
+ rst2 = self.mult_rst_grad(rst2, self.enc2_along_grad_mult)
364
+ return rst2
365
+ if rst2 is None:
366
+ return rst1
367
+ if self.enc_grad_mult != 1.0 and self.training:
368
+ rst1 = self.mult_rst_grad(rst1, self.enc_grad_mult)
369
+ rst2 = self.mult_rst_grad(rst2, self.enc_grad_mult)
370
+ rst = (rst1, rst2)
371
+ return rst
372
+
373
+ return merge_output(ret1, ret2)
374
+
375
+ def reorder_encoder_out(self, encoder_out, new_order):
376
+ assert self.training is False # used for inference only
377
+ return self.spch_encoder.reorder_encoder_out(encoder_out, new_order)
378
+
379
+
380
+ # TransformerMultiInputDecoder: take one or two encoder inputs
381
+ class TransformerMultiInputDecoder(FairseqDecoder):
382
+ def __init__(
383
+ self,
384
+ dictionary,
385
+ spch_decoder,
386
+ text_decoder,
387
+ compute_cross_attentive_loss=False,
388
+ cross_attentive_loss_with_norm=True,
389
+ cross_attentive_loss_reverse=False,
390
+ ):
391
+
392
+ super().__init__(dictionary)
393
+ self.spch_decoder = spch_decoder
394
+ self.text_decoder = text_decoder
395
+ self.compute_cross_attentive_loss = compute_cross_attentive_loss
396
+ self.cross_attentive_loss_with_norm = cross_attentive_loss_with_norm
397
+ self.cross_attentive_loss_reverse = cross_attentive_loss_reverse
398
+
399
+ @classmethod
400
+ def share_spchdecoder(cls, task_args, text_decoder, spch_decoder):
401
+ if task_args.decoder_shared_layer_level == 0:
402
+ return text_decoder
403
+ assert text_decoder.embed_tokens == spch_decoder.embed_tokens
404
+ spch_decoder.project_in_dim = text_decoder.project_in_dim
405
+ spch_decoder.embed_positions = text_decoder.embed_positions
406
+ spch_decoder.layernorm_embedding = text_decoder.layernorm_embedding
407
+ spch_decoder.project_out_dim = text_decoder.project_out_dim
408
+ spch_decoder.adaptive_softmax = text_decoder.adaptive_softmax
409
+ if task_args.decoder_shared_layer_level == 1:
410
+ spch_decoder.output_projection = text_decoder.output_projection
411
+ spch_decoder.layer_norm = text_decoder.layer_norm
412
+ else: # 2
413
+ spch_decoder.output_projection.weight = (
414
+ text_decoder.output_projection.weight
415
+ )
416
+ for i, ly in enumerate(text_decoder.layers):
417
+ sly = spch_decoder.layers[i]
418
+ sly.self_attn = ly.self_attn
419
+ sly.self_attn_layer_norm = ly.self_attn_layer_norm
420
+ # sly.encoder_attn = ly.encoder_attn
421
+ if (
422
+ task_args.decoder_shared_layer_level == 1
423
+ ): # share everything, but under different models
424
+ sly.encoder_attn = ly.encoder_attn
425
+ sly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm
426
+ sly.fc1 = ly.fc1
427
+ sly.fc2 = ly.fc2
428
+ sly.final_layer_norm = ly.final_layer_norm
429
+ else: # task_args.decoder_shared_layer_level == 2: #separated encoder_attn_layer_norm and bias
430
+ sly.encoder_attn.k_proj.weight = ly.encoder_attn.k_proj.weight
431
+ sly.encoder_attn.v_proj.weight = ly.encoder_attn.v_proj.weight
432
+ sly.encoder_attn.q_proj.weight = ly.encoder_attn.q_proj.weight
433
+ sly.encoder_attn.out_proj.weight = ly.encoder_attn.out_proj.weight
434
+ sly.fc1.weight = ly.fc1.weight
435
+ sly.fc2.weight = ly.fc2.weight
436
+
437
+ return spch_decoder
438
+
439
+ def cross_attentive_loss(
440
+ self, teacher_states, student_states, teacher_masking, student_masking, eps=1e-6
441
+ ):
442
+ x = teacher_states.transpose(0, 1) # from T X B X D to B X T X D
443
+ y = student_states.transpose(0, 1)
444
+ if self.cross_attentive_loss_with_norm:
445
+ x = x / (x.norm(dim=2, keepdim=True) + eps)
446
+ y = y / (y.norm(dim=2, keepdim=True) + eps)
447
+ dim = x.size(-1)
448
+ # lengths: batch X seqLen
449
+ sim_scores_xy = torch.bmm(x, y.transpose(1, 2)) # batch X lenx X leny ]
450
+ if y.dtype == torch.float16:
451
+ sim_scores_xy = sim_scores_xy.float()
452
+ y = y.float()
453
+ x = x.float()
454
+ if teacher_masking != []:
455
+ assert len(teacher_masking) == 1
456
+ sim_scores_xy = sim_scores_xy.masked_fill(
457
+ teacher_masking[0].unsqueeze(-1), float("-inf")
458
+ )
459
+ if student_masking != []:
460
+ sim_scores_xy = sim_scores_xy.masked_fill(
461
+ student_masking[0].unsqueeze(1), float("-inf")
462
+ )
463
+ # do masking
464
+ y_weights = utils.softmax(sim_scores_xy, dim=-1)
465
+ if teacher_masking != []:
466
+ y_weights = y_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0)
467
+ x_reconstruct_from_y = torch.bmm(y_weights, y)
468
+
469
+ sim_scores_xx = torch.bmm(x, x.transpose(1, 2)) # batch X lenx X lenx ]
470
+ x_weights = utils.softmax(sim_scores_xx, dim=-1)
471
+ if teacher_masking != []:
472
+ x_weights = x_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0)
473
+
474
+ # no gradient for teacher state
475
+ x_reconstruct_from_x = torch.bmm(x_weights, x).detach()
476
+ cost = (x_reconstruct_from_x - x_reconstruct_from_y).norm(dim=2)
477
+ if teacher_masking != []:
478
+ cost = cost.masked_fill(teacher_masking[0], 0)
479
+
480
+ if not self.cross_attentive_loss_with_norm:
481
+ cost = cost / dim
482
+ return cost
483
+
484
+ def forward(
485
+ self,
486
+ prev_output_tokens,
487
+ encoder_out,
488
+ incremental_state=None,
489
+ has_txt_input=False,
490
+ **kwargs
491
+ ):
492
+ """
493
+ Args:
494
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
495
+ `(batch, tgt_len)`, for input feeding/teacher forcing. If there are
496
+ two or more input during training, they will share the same prev_output_tokens
497
+ encoder_out (tuple[Tensor]): output from the encoder, used for
498
+ encoder-side attention. It will be tuple if there are more inputs, but a tensor
499
+ if only one input
500
+ incremental_state ([dict]): dictionary used for storing state during
501
+ :ref:`Incremental decoding`. It is only valid for inference, only from single
502
+ input
503
+ Returns:
504
+ tuple:
505
+ - the last decoder layer's output of shape `(batch, tgt_len,
506
+ vocab)`. If there are N inputs, batch will be N bigger than a single input
507
+ - the last decoder layer's attention weights of shape `(batch,
508
+ tgt_len, src_len)`
509
+ """
510
+ assert not isinstance(encoder_out, EncoderOut)
511
+ if isinstance(encoder_out, tuple): # training with mulitple input
512
+ rst = []
513
+ assert len(encoder_out) == 2
514
+ for i, eo in enumerate(encoder_out):
515
+ assert incremental_state is None
516
+ if i == 0:
517
+ rst.append(
518
+ self.spch_decoder(prev_output_tokens, eo, incremental_state)
519
+ )
520
+ else:
521
+ rst.append(
522
+ self.text_decoder(prev_output_tokens, eo, incremental_state)
523
+ )
524
+ dec_out = torch.cat([r[0] for r in rst], dim=0)
525
+ attn_cost = None
526
+ if self.compute_cross_attentive_loss:
527
+ assert isinstance(encoder_out[0], dict)
528
+ if self.cross_attentive_loss_reverse:
529
+ attn_cost = self.cross_attentive_loss(
530
+ teacher_states=encoder_out[1]["encoder_states"], # text_states
531
+ student_states=encoder_out[0]["encoder_states"], # spch_states
532
+ teacher_masking=encoder_out[1]["encoder_padding_mask"],
533
+ student_masking=encoder_out[0]["encoder_padding_mask"],
534
+ )
535
+ else:
536
+ attn_cost = self.cross_attentive_loss(
537
+ teacher_states=encoder_out[0]["encoder_states"], # spch_states
538
+ student_states=encoder_out[1]["encoder_states"], # text_states
539
+ teacher_masking=encoder_out[0]["encoder_padding_mask"],
540
+ student_masking=encoder_out[1]["encoder_padding_mask"],
541
+ )
542
+
543
+ return (dec_out, {"attn_cost": attn_cost})
544
+ else: # inference or training with one input
545
+ if has_txt_input:
546
+ return self.text_decoder(
547
+ prev_output_tokens, encoder_out, incremental_state
548
+ )
549
+ return self.spch_decoder(prev_output_tokens, encoder_out, incremental_state)
550
+
551
+
552
+ # Note:
553
+ # dual input transformer:
554
+ # encoder: S2TTransformerEncoder for speech + TransformerEncoder for text
555
+ # decoder: TransformerDecoder for text
556
+ @register_model("dual_input_s2t_transformer")
557
+ class DualInputS2TTransformerModel(FairseqEncoderDecoderModel):
558
+ def __init__(self, encoder, decoder):
559
+ super().__init__(encoder, decoder)
560
+ self.num_updates = 0
561
+
562
+ def max_positions(self):
563
+ return None # it is provided in task
564
+
565
+ @staticmethod
566
+ def add_args(parser):
567
+ """Add model-specific arguments to the parser."""
568
+ # encoder 1: S2TTransformerEncoder for speech
569
+ parser.add_argument(
570
+ "--conv-kernel-sizes",
571
+ type=str,
572
+ metavar="N",
573
+ help="kernel sizes of Conv1d subsampling layers",
574
+ )
575
+ parser.add_argument(
576
+ "--conv-channels",
577
+ type=int,
578
+ metavar="N",
579
+ help="# of channels in Conv1d subsampling layers",
580
+ )
581
+ parser.add_argument(
582
+ "--enc-output-dim",
583
+ type=int,
584
+ metavar="N",
585
+ help="""
586
+ encoder output dimension, can be None. If specified, projecting the
587
+ transformer output to the specified dimension""",
588
+ )
589
+ # standard Transformer
590
+ parser.add_argument(
591
+ "--activation-fn",
592
+ type=str,
593
+ default="relu",
594
+ choices=utils.get_available_activation_fns(),
595
+ help="activation function to use",
596
+ )
597
+ parser.add_argument(
598
+ "--dropout", type=float, metavar="D", help="dropout probability"
599
+ )
600
+ parser.add_argument(
601
+ "--attention-dropout",
602
+ type=float,
603
+ metavar="D",
604
+ help="dropout probability for attention weights",
605
+ )
606
+ parser.add_argument(
607
+ "--activation-dropout",
608
+ "--relu-dropout",
609
+ type=float,
610
+ metavar="D",
611
+ help="dropout probability after activation in FFN.",
612
+ )
613
+ parser.add_argument(
614
+ "--encoder-embed-dim",
615
+ type=int,
616
+ metavar="N",
617
+ help="encoder embedding dimension",
618
+ )
619
+ parser.add_argument(
620
+ "--encoder-text-embed-dim",
621
+ type=int,
622
+ metavar="N",
623
+ help="encoder text embedding dimension",
624
+ )
625
+ parser.add_argument(
626
+ "--encoder-ffn-embed-dim",
627
+ type=int,
628
+ metavar="N",
629
+ help="encoder embedding dimension for FFN",
630
+ )
631
+ parser.add_argument(
632
+ "--encoder-attention-heads",
633
+ type=int,
634
+ metavar="N",
635
+ help="num encoder attention heads",
636
+ )
637
+ parser.add_argument(
638
+ "--decoder-embed-dim",
639
+ type=int,
640
+ metavar="N",
641
+ help="decoder embedding dimension",
642
+ )
643
+ parser.add_argument(
644
+ "--decoder-ffn-embed-dim",
645
+ type=int,
646
+ metavar="N",
647
+ help="decoder embedding dimension for FFN",
648
+ )
649
+ parser.add_argument(
650
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
651
+ )
652
+ parser.add_argument(
653
+ "--decoder-attention-heads",
654
+ type=int,
655
+ metavar="N",
656
+ help="num decoder attention heads",
657
+ )
658
+ parser.add_argument(
659
+ "--layernorm-embedding",
660
+ action="store_true",
661
+ help="add layernorm to embedding",
662
+ )
663
+ parser.add_argument(
664
+ "--no-scale-embedding",
665
+ action="store_true",
666
+ help="if True, dont scale embeddings",
667
+ )
668
+ # non-standard transformer parameters
669
+ parser.add_argument(
670
+ "--speech-encoder-layers",
671
+ type=int,
672
+ metavar="N",
673
+ help="num speech encoder layers",
674
+ )
675
+ parser.add_argument(
676
+ "--text-encoder-layers",
677
+ type=int,
678
+ metavar="N",
679
+ help="num text encoder layers",
680
+ )
681
+ parser.add_argument(
682
+ "--encoder-shared-layers",
683
+ type=int,
684
+ metavar="N",
685
+ help="num shared encoder layers",
686
+ )
687
+ parser.add_argument(
688
+ "--encoder-shared-layer-level",
689
+ type=int,
690
+ metavar="N",
691
+ default=0,
692
+ choices=[0, 1, 2],
693
+ help="share layer level 0: all share 1: all share with separate model 2: share weight but not bias and layernorm",
694
+ )
695
+
696
+ parser.add_argument(
697
+ "--decoder-shared-layer-level",
698
+ default=0,
699
+ choices=[0, 1, 2],
700
+ type=int,
701
+ metavar="N",
702
+ help="0: share everything; 1: share everything with different model 2: no share layer_norm and bias",
703
+ )
704
+ ###
705
+ parser.add_argument(
706
+ "--text-input-cost-ratio",
707
+ type=float,
708
+ default=1.0,
709
+ metavar="V",
710
+ help="text input cost ratio relative to speech input cost",
711
+ )
712
+ parser.add_argument(
713
+ "--init-scale",
714
+ type=float,
715
+ default=1.0,
716
+ metavar="V",
717
+ help="scale the initial weight by given factor",
718
+ )
719
+ parser.add_argument(
720
+ "--enc-grad-mult",
721
+ type=float,
722
+ metavar="V",
723
+ default=1.0,
724
+ help="multiply enc1 and enc2 gradient by V",
725
+ )
726
+ parser.add_argument(
727
+ "--enc2-along-grad-mult",
728
+ type=float,
729
+ metavar="V",
730
+ default=1.0,
731
+ help="multiply enc2 gradient by V if only enc2 is used",
732
+ )
733
+ parser.add_argument(
734
+ "--load-pretrain-encoder",
735
+ type=str,
736
+ default="",
737
+ metavar="EXPR",
738
+ help=""" path to the pretrained encoder """,
739
+ )
740
+ parser.add_argument(
741
+ "--load-pretrain-speech-encoder",
742
+ type=str,
743
+ default="",
744
+ metavar="EXPR",
745
+ help=""" path to the pretrained speech encoder """,
746
+ )
747
+ parser.add_argument(
748
+ "--load-pretrain-text-encoder",
749
+ type=str,
750
+ default="",
751
+ metavar="EXPR",
752
+ help=""" path to the pretrained text encoder """,
753
+ )
754
+ parser.add_argument(
755
+ "--load-pretrain-text-encoder-last",
756
+ type=str,
757
+ default="",
758
+ metavar="EXPR",
759
+ help=""" path to the pretrained text encoder """,
760
+ )
761
+ parser.add_argument(
762
+ "--load-pretrain-decoder",
763
+ type=str,
764
+ metavar="EXPR",
765
+ default="",
766
+ help=""" path to the pretrained encoder """,
767
+ )
768
+ parser.add_argument(
769
+ "--add-speech-eos",
770
+ action="store_true",
771
+ help="add eos token at the end of input feature",
772
+ )
773
+ parser.add_argument(
774
+ "--speech-encoder-adapter-type",
775
+ type=str,
776
+ metavar="EXPR",
777
+ default="None",
778
+ choices=["None", "Linear", "MLP"],
779
+ help="add speech encoder adapter",
780
+ )
781
+
782
+ @classmethod
783
+ def build_encoder(cls, args, task):
784
+ spch_encoder = DualInputEncoder.build_spch_encoder(args)
785
+ text_encoder = DualInputEncoder.build_text_encoder(
786
+ args, task.src_dict, spch_encoder
787
+ )
788
+ cross_attentive_loss_before_last_layer = (
789
+ 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
790
+ )
791
+ encoder = DualInputEncoder(
792
+ args,
793
+ spch_encoder,
794
+ text_encoder,
795
+ task.src_dict,
796
+ cross_attentive_loss_before_last_layer,
797
+ )
798
+ if args.init_scale != 1.0:
799
+ with torch.no_grad():
800
+ for param in encoder.parameters():
801
+ param.data.mul_(args.init_scale)
802
+ if args.load_pretrain_text_encoder != "":
803
+ checkpoint_utils.load_pretrained_component_from_model(
804
+ text_encoder, args.load_pretrain_text_encoder
805
+ )
806
+ if args.load_pretrain_speech_encoder != "":
807
+ if hasattr(spch_encoder, "encoder"):
808
+ checkpoint_utils.load_pretrained_component_from_model(
809
+ spch_encoder.encoder, args.load_pretrain_speech_encoder
810
+ )
811
+ else:
812
+ checkpoint_utils.load_pretrained_component_from_model(
813
+ spch_encoder, args.load_pretrain_speech_encoder
814
+ )
815
+ if (
816
+ args.load_pretrain_text_encoder_last != ""
817
+ ): # if share encoder, speech encoder parameters will be used.
818
+ # It provides a chance to use pre-trained mt encoder instead
819
+ checkpoint_utils.load_pretrained_component_from_model(
820
+ text_encoder, args.load_pretrain_text_encoder_last
821
+ )
822
+
823
+ if args.load_pretrain_encoder != "":
824
+ checkpoint_utils.load_pretrained_component_from_model(
825
+ encoder, args.load_pretrain_encoder
826
+ )
827
+ return encoder
828
+
829
+ @classmethod
830
+ def build_decoder(cls, args, task):
831
+ dec_cfg = {
832
+ "decoder_layerdrop": args.decoder_layerdrop,
833
+ "share_decoder_input_output_embed": args.share_decoder_input_output_embed,
834
+ "decoder_embed_dim": args.decoder_embed_dim,
835
+ "max_target_positions": args.max_target_positions,
836
+ "dropout": args.dropout,
837
+ "encoder_learned_pos": args.encoder_learned_pos,
838
+ "decoder_learned_pos": args.decoder_learned_pos,
839
+ "layernorm_embedding": args.layernorm_embedding,
840
+ "decoder_normalize_before": args.decoder_normalize_before,
841
+ "activation_dropout": args.activation_dropout,
842
+ "attention_dropout": args.attention_dropout,
843
+ "decoder_ffn_embed_dim": args.decoder_ffn_embed_dim,
844
+ "decoder_layers": args.decoder_layers,
845
+ "decoder_attention_heads": args.decoder_attention_heads,
846
+ "decoder_output_dim": args.decoder_embed_dim,
847
+ "no_scale_embedding": args.no_scale_embedding,
848
+ "adaptive_input": args.adaptive_input,
849
+ "quant_noise_pq": args.quant_noise_pq,
850
+ "adaptive_softmax_cutoff": args.adaptive_softmax_cutoff,
851
+ "tie_adaptive_weights": args.tie_adaptive_weights,
852
+ "no_token_positional_embeddings": args.no_token_positional_embeddings,
853
+ "encoder": {"embed_dim":args.encoder_embed_dim}
854
+ }
855
+ dec_cfg = namedtuple("args", dec_cfg.keys())(*dec_cfg.values())
856
+ dec_emb = nn.Embedding(
857
+ len(task.target_dictionary),
858
+ args.decoder_embed_dim,
859
+ task.target_dictionary.pad(),
860
+ )
861
+ compute_cross_attentive_loss = (
862
+ True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
863
+ )
864
+ cross_attentive_loss_without_norm = getattr(
865
+ args, "attentive_cost_without_normalize", False
866
+ )
867
+ cross_attentive_loss_reverse = (
868
+ False # getattr(args, "attentive_cost_reverse", False)
869
+ )
870
+
871
+ text_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb)
872
+ spch_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb)
873
+ spch_decoder = TransformerMultiInputDecoder.share_spchdecoder(
874
+ args, text_decoder, spch_decoder
875
+ )
876
+ decoder = TransformerMultiInputDecoder(
877
+ dictionary=task.target_dictionary,
878
+ spch_decoder=spch_decoder,
879
+ text_decoder=text_decoder,
880
+ compute_cross_attentive_loss=compute_cross_attentive_loss,
881
+ cross_attentive_loss_with_norm=True
882
+ if not cross_attentive_loss_without_norm
883
+ else False,
884
+ cross_attentive_loss_reverse=cross_attentive_loss_reverse,
885
+ )
886
+ if args.init_scale != 1.0:
887
+ with torch.no_grad():
888
+ for param in decoder.parameters():
889
+ param.data.mul_(args.init_scale)
890
+ if args.load_pretrain_decoder != "":
891
+ try:
892
+ checkpoint_utils.load_pretrained_component_from_model(
893
+ decoder, args.load_pretrain_decoder
894
+ )
895
+ except RuntimeError:
896
+ checkpoint_utils.load_pretrained_component_from_model(
897
+ decoder.text_decoder, args.load_pretrain_decoder
898
+ )
899
+ if args.decoder_shared_layer_level > 0:
900
+ checkpoint_utils.load_pretrained_component_from_model(
901
+ decoder.spch_decoder, args.load_pretrain_decoder
902
+ )
903
+
904
+ return decoder
905
+
906
+ @classmethod
907
+ def build_model(cls, args, task):
908
+ """Build a new model instance."""
909
+ # make sure that all args are properly defaulted
910
+ # (in case there are any new ones)
911
+ dualinputs2ttransformer_base(args)
912
+
913
+ encoder = cls.build_encoder(args, task)
914
+ decoder = cls.build_decoder(args, task)
915
+ return cls(encoder, decoder)
916
+
917
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
918
+ # net_output['encoder_out'] is a (B, T, D) tensor
919
+ lprobs = super().get_normalized_probs(net_output, log_probs, sample)
920
+ lprobs.batch_first = True
921
+ return lprobs
922
+
923
+ def set_num_updates(self, num_updates):
924
+ """Set the number of parameters updates."""
925
+ super().set_num_updates(num_updates)
926
+ self.num_updates = num_updates
927
+
928
+ def forward(
929
+ self,
930
+ src_tokens,
931
+ src_lengths,
932
+ prev_output_tokens,
933
+ use_encoder_outputs=False,
934
+ src_txt_tokens=None,
935
+ src_txt_lengths=None,
936
+ mode="sup_speech",
937
+ **kwargs
938
+ ):
939
+ """
940
+ Run the forward pass for an encoder-decoder model.
941
+
942
+ First feed a batch of source tokens through the encoder. Then, feed the
943
+ encoder output and previous decoder outputs (i.e., teacher forcing) to
944
+ the decoder to produce the next outputs::
945
+
946
+ encoder_out = self.encoder(src_tokens, src_lengths)
947
+ return self.decoder(prev_output_tokens, encoder_out)
948
+
949
+ Args:
950
+ src_tokens (LongTensor): tokens in the source language of shape
951
+ `(batch, src_len)`
952
+ src_lengths (LongTensor): source sentence lengths of shape `(batch)`
953
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
954
+ `(batch, tgt_len)`, for teacher forcing
955
+ mode = 'sup_speech' or 'text'
956
+
957
+ Returns:
958
+ tuple:
959
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
960
+ - a dictionary with any model-specific outputs
961
+ """
962
+ if mode == "text":
963
+ assert src_txt_tokens is None
964
+ src_txt_tokens = src_tokens
965
+ src_txt_lengths = src_lengths
966
+ src_tokens = None
967
+ src_lengths = None
968
+ encoder_out = self.encoder(
969
+ src_tokens,
970
+ src_lengths=src_lengths,
971
+ src_txt_tokens=src_txt_tokens,
972
+ src_txt_lengths=src_txt_lengths,
973
+ **kwargs
974
+ )
975
+ has_txt_input = True if src_txt_tokens is not None else False
976
+ decoder_out = self.decoder(
977
+ prev_output_tokens,
978
+ encoder_out=encoder_out,
979
+ has_txt_input=has_txt_input,
980
+ **kwargs
981
+ )
982
+ if use_encoder_outputs:
983
+ return decoder_out, encoder_out
984
+ return decoder_out
985
+
986
+
987
+ @register_model_architecture(
988
+ "dual_input_s2t_transformer", "dualinputs2ttransformer_base"
989
+ )
990
+ def dualinputs2ttransformer_base(args):
991
+ args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0)
992
+ # Convolutional subsampler
993
+ args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
994
+ args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
995
+ args.conv_channels = getattr(args, "conv_channels", 1024)
996
+ # Transformer
997
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
998
+ args.encoder_text_embed_dim = getattr(
999
+ args, "encoder_text_embed_dim", args.encoder_embed_dim
1000
+ )
1001
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
1002
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
1003
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
1004
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
1005
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
1006
+
1007
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
1008
+ args.decoder_ffn_embed_dim = getattr(
1009
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
1010
+ )
1011
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
1012
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
1013
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
1014
+ args.dropout = getattr(args, "dropout", 0.1)
1015
+ args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
1016
+ args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
1017
+ args.activation_fn = getattr(args, "activation_fn", "relu")
1018
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
1019
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
1020
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
1021
+ args.share_decoder_input_output_embed = getattr(
1022
+ args, "share_decoder_input_output_embed", False
1023
+ )
1024
+ args.no_token_positional_embeddings = getattr(
1025
+ args, "no_token_positional_embeddings", False
1026
+ )
1027
+ args.adaptive_input = getattr(args, "adaptive_input", False)
1028
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
1029
+ args.decoder_output_dim = getattr(
1030
+ args, "decoder_output_dim", args.decoder_embed_dim
1031
+ )
1032
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
1033
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
1034
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
1035
+
1036
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10)
1037
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
1038
+ args.encoder_shared_layers = getattr(args, "encoder_shared_layers", 0)
1039
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1040
+
1041
+ args.add_speech_eos = getattr(args, "add_speech_eos", False)
1042
+
1043
+
1044
+ @register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_s")
1045
+ def dualinputs2ttransformer_s(args):
1046
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
1047
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4)
1048
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
1049
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
1050
+ args.dropout = getattr(args, "dropout", 0.1)
1051
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 7)
1052
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 7)
1053
+ args.decoder_layers = getattr(args, "decoder_layers", 7)
1054
+ dualinputs2ttransformer_base(args)
1055
+
1056
+
1057
+ @register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_m")
1058
+ def dualinputs2ttransformer_m(args):
1059
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
1060
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
1061
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
1062
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
1063
+ args.dropout = getattr(args, "dropout", 0.15)
1064
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10)
1065
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
1066
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1067
+ dualinputs2ttransformer_base(args)
1068
+
1069
+
1070
+ @register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_b")
1071
+ def dualinputs2ttransformer_b(args):
1072
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
1073
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
1074
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
1075
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
1076
+ args.dropout = getattr(args, "dropout", 0.15)
1077
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
1078
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
1079
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1080
+ dualinputs2ttransformer_base(args)
1081
+
1082
+
1083
+ @register_model_architecture("dual_input_s2t_transformer", "dualinputs2ttransformer_l")
1084
+ def dualinputs2ttransformer_l(args):
1085
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
1086
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4)
1087
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
1088
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
1089
+ args.dropout = getattr(args, "dropout", 0.2)
1090
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
1091
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
1092
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1093
+ dualinputs2ttransformer_base(args)
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from collections import OrderedDict, namedtuple
8
+
9
+ import torch.nn as nn
10
+
11
+ from fairseq import checkpoint_utils, utils
12
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
13
+ from fairseq.file_io import PathManager
14
+ from fairseq.models import register_model, register_model_architecture
15
+ from fairseq.models.speech_to_text import (
16
+ SpeechWavTransformerEncoder,
17
+ StackedSpeechWavTransformerEncoder,
18
+ TransformerDecoder,
19
+ )
20
+ from fairseq.models.transformer import TransformerEncoder
21
+
22
+ from .s2t_dualinputtransformer import (
23
+ DualInputEncoder,
24
+ DualInputS2TTransformerModel,
25
+ TransformerMultiInputDecoder,
26
+ )
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @register_model("dual_input_wav_transformer")
32
+ class DualInputWavTransformerModel(DualInputS2TTransformerModel):
33
+ def __init__(self, encoder, decoder):
34
+ super().__init__(encoder, decoder)
35
+
36
+ @staticmethod
37
+ def add_args(parser):
38
+ def add_transformer_args(parser):
39
+ # We can't use TransformerModel.add_args(parser), since it defines max-source-positions which is duplicated with tasks/speech_to_text.py
40
+ # Transformer
41
+ parser.add_argument(
42
+ "--activation-fn",
43
+ type=str,
44
+ default="relu",
45
+ choices=utils.get_available_activation_fns(),
46
+ help="activation function to use",
47
+ )
48
+ parser.add_argument(
49
+ "--dropout", type=float, metavar="D", help="dropout probability"
50
+ )
51
+ parser.add_argument(
52
+ "--attention-dropout",
53
+ type=float,
54
+ metavar="D",
55
+ help="dropout probability for attention weights",
56
+ )
57
+ parser.add_argument(
58
+ "--activation-dropout",
59
+ "--relu-dropout",
60
+ type=float,
61
+ metavar="D",
62
+ help="dropout probability after activation in FFN.",
63
+ )
64
+ parser.add_argument(
65
+ "--encoder-embed-dim",
66
+ type=int,
67
+ metavar="N",
68
+ help="encoder embedding dimension",
69
+ )
70
+ parser.add_argument(
71
+ "--encoder-ffn-embed-dim",
72
+ type=int,
73
+ metavar="N",
74
+ help="encoder embedding dimension for FFN",
75
+ )
76
+ parser.add_argument(
77
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
78
+ )
79
+ parser.add_argument(
80
+ "--encoder-attention-heads",
81
+ type=int,
82
+ metavar="N",
83
+ help="num encoder attention heads",
84
+ )
85
+ parser.add_argument(
86
+ "--encoder-normalize-before",
87
+ action="store_true",
88
+ help="apply layernorm before each encoder block",
89
+ )
90
+ parser.add_argument(
91
+ "--decoder-embed-dim",
92
+ type=int,
93
+ metavar="N",
94
+ help="decoder embedding dimension",
95
+ )
96
+ parser.add_argument(
97
+ "--decoder-ffn-embed-dim",
98
+ type=int,
99
+ metavar="N",
100
+ help="decoder embedding dimension for FFN",
101
+ )
102
+ parser.add_argument(
103
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
104
+ )
105
+ parser.add_argument(
106
+ "--decoder-attention-heads",
107
+ type=int,
108
+ metavar="N",
109
+ help="num decoder attention heads",
110
+ )
111
+ parser.add_argument(
112
+ "--decoder-normalize-before",
113
+ action="store_true",
114
+ help="apply layernorm before each decoder block",
115
+ )
116
+ parser.add_argument(
117
+ "--share-decoder-input-output-embed",
118
+ action="store_true",
119
+ help="share decoder input and output embeddings",
120
+ )
121
+ parser.add_argument(
122
+ "--layernorm-embedding",
123
+ action="store_true",
124
+ help="add layernorm to embedding",
125
+ )
126
+ parser.add_argument(
127
+ "--no-scale-embedding",
128
+ action="store_true",
129
+ help="if True, dont scale embeddings",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--encoder-learned-pos",
134
+ action="store_true",
135
+ help="use learned positional embeddings",
136
+ )
137
+ parser.add_argument(
138
+ "--decoder-learned-pos",
139
+ action="store_true",
140
+ help="use learned positional embeddings",
141
+ )
142
+
143
+ add_transformer_args(parser)
144
+ SpeechWavTransformerEncoder.add_args(parser)
145
+ parser.add_argument(
146
+ "--load-pretrained-speech-text-encoder",
147
+ type=str,
148
+ default="",
149
+ metavar="EXPR",
150
+ help=""" path to the pretrained speech text encoder from SpeechTextPreTrainModel """,
151
+ )
152
+ parser.add_argument(
153
+ "--load-pretrained-wav2vec-encoder",
154
+ type=str,
155
+ default="",
156
+ metavar="EXPR",
157
+ help=""" path to the pretrained speech text encoder from wav2vec """,
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--load-pretrained-speech-text-decoder",
162
+ type=str,
163
+ default="",
164
+ metavar="EXPR",
165
+ help=""" path to the pretrained speech text decoder from SpeechTextPreTrainModel """,
166
+ )
167
+ parser.add_argument(
168
+ "--load-pretrained-text-decoder",
169
+ type=str,
170
+ default="",
171
+ metavar="EXPR",
172
+ help=""" path to the pretrained text decoder """,
173
+ )
174
+ parser.add_argument(
175
+ "--load-init-encoder",
176
+ type=str,
177
+ default="",
178
+ metavar="EXPR",
179
+ help=""" path to load seed encoder model """,
180
+ )
181
+ parser.add_argument(
182
+ "--load-init-decoder",
183
+ type=str,
184
+ default="",
185
+ metavar="EXPR",
186
+ help=""" path to load seed decoder model """,
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--text-input-cost-ratio",
191
+ type=float,
192
+ default=1.0,
193
+ metavar="V",
194
+ help="text input cost ratio relative to speech input cost",
195
+ )
196
+ parser.add_argument(
197
+ "--enc-grad-mult",
198
+ type=float,
199
+ metavar="V",
200
+ default=1.0,
201
+ help="multiply enc1 and enc2 gradient by V",
202
+ )
203
+ parser.add_argument(
204
+ "--enc2-along-grad-mult",
205
+ type=float,
206
+ metavar="V",
207
+ default=1.0,
208
+ help="multiply enc2 gradient by V if only enc2 is used",
209
+ )
210
+ parser.add_argument(
211
+ "--no-strict-check-pretrain-model",
212
+ action="store_true",
213
+ help="Don't apply strict model check for the pretrained model",
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--stacked-encoder",
218
+ action="store_true",
219
+ help="stack speech and text encoders",
220
+ )
221
+
222
+ @classmethod
223
+ def update_transformer_encoder_cfg(cls, args, update_dict):
224
+ cfg = dict(args._get_kwargs())
225
+ for fkey in update_dict.keys():
226
+ cfg[fkey] = update_dict[fkey]
227
+ cfg.pop("_name", None) # remove keys start with _
228
+ model_args = namedtuple("args", cfg.keys())(*cfg.values())
229
+ return model_args
230
+
231
+ @classmethod
232
+ def build_text_encoder(cls, args, src_dictionary):
233
+ enc_emb = nn.Embedding(
234
+ len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad()
235
+ )
236
+ model_args = cls.update_transformer_encoder_cfg(
237
+ args,
238
+ {
239
+ "encoder_layers": args.text_encoder_layers,
240
+ "max_source_positions": args.max_positions_text,
241
+ },
242
+ )
243
+ text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
244
+ return text_encoder
245
+
246
+ @classmethod
247
+ def build_speech_encoder(cls, args):
248
+ model_args = cls.update_transformer_encoder_cfg(
249
+ args, {"encoder_layers": args.speech_encoder_layers}
250
+ )
251
+ speech_encoder = SpeechWavTransformerEncoder(model_args)
252
+ return speech_encoder
253
+
254
+ @classmethod
255
+ def check_args(cls, condition, is_strict, msg):
256
+ if condition:
257
+ return
258
+ if is_strict:
259
+ raise ValueError(msg)
260
+ logger.warn(msg)
261
+
262
+ @classmethod
263
+ def build_encoder(cls, args, task):
264
+ # text_encoder = cls.build_text_encoder(args, task.source_dictionary )
265
+ text_encoder = cls.build_text_encoder(args, task.src_dict)
266
+ speech_encoder = cls.build_speech_encoder(args)
267
+ if args.load_pretrained_wav2vec_encoder:
268
+ component_pairs = (
269
+ ("feature_extractor", speech_encoder.subsample),
270
+ ("post_extract_proj", speech_encoder.feat_proj),
271
+ ("layer_norm", speech_encoder.feat_layer_norm),
272
+ ("encoder.pos_conv", speech_encoder.embed_positions),
273
+ ("encoder.layers", speech_encoder.layers),
274
+ ("encoder.layer_norm", speech_encoder.layer_norm),
275
+ ("mask_emb", speech_encoder.mask_emb),
276
+ )
277
+ state = cls.load_pretrained_speech_text_components(
278
+ args.load_pretrained_wav2vec_encoder, component_pairs
279
+ )
280
+ cls.check_args(
281
+ args.encoder_normalize_before
282
+ == state["cfg"]["model"]["layer_norm_first"],
283
+ not args.no_strict_check_pretrain_model,
284
+ f"encoder_normalize_before {args.encoder_normalize_before} doesn't match with the pretrained model",
285
+ )
286
+ cls.check_args(
287
+ args.activation_fn == state["cfg"]["model"]["activation_fn"],
288
+ not args.no_strict_check_pretrain_model,
289
+ f"activation_fn {args.activation_fn} doesn't match with the pretrained model",
290
+ )
291
+
292
+ if getattr(args, "stacked_encoder", False):
293
+ if args.encoder_shared_text_layers_from_begin > 0:
294
+ raise ValueError(
295
+ "We can not stack encoders and share encoders at the same time!"
296
+ )
297
+ speech_encoder = StackedSpeechWavTransformerEncoder(
298
+ speech_encoder, text_encoder.layers, text_encoder.layer_norm
299
+ )
300
+ else:
301
+ cls.share_speech_text_encoder(
302
+ speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin
303
+ )
304
+
305
+ cross_attentive_loss_before_last_layer = (
306
+ 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
307
+ )
308
+ encoder = DualInputEncoder(
309
+ args,
310
+ speech_encoder,
311
+ text_encoder,
312
+ task.src_dict,
313
+ cross_attentive_loss_before_last_layer,
314
+ )
315
+ if args.load_pretrained_speech_text_encoder:
316
+ component_pairs = (
317
+ ("encoder.sup_s2s_speech_encoder", encoder.spch_encoder),
318
+ ("encoder.text_encoder", encoder.text_encoder),
319
+ )
320
+ cls.load_pretrained_speech_text_components(
321
+ args.load_pretrained_speech_text_encoder, component_pairs
322
+ )
323
+ if getattr(args, "load_init_encoder", "") != "":
324
+ checkpoint_utils.load_pretrained_component_from_model(
325
+ encoder, args.load_init_encoder
326
+ )
327
+ return encoder
328
+
329
+ @classmethod
330
+ def build_text_decoder(cls, args, tgt_dictionary, dec_emb_share=None):
331
+ dec_emb = (
332
+ nn.Embedding(
333
+ len(tgt_dictionary), args.decoder_embed_dim, tgt_dictionary.pad()
334
+ )
335
+ if dec_emb_share is None
336
+ else dec_emb_share
337
+ )
338
+ text_decoder = TransformerDecoder(args, tgt_dictionary, dec_emb)
339
+ return text_decoder
340
+
341
+ @classmethod
342
+ def build_decoder(cls, args, task):
343
+ text_decoder = cls.build_text_decoder(args, task.target_dictionary)
344
+ compute_cross_attentive_loss = (
345
+ True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
346
+ )
347
+ cross_attentive_loss_without_norm = getattr(
348
+ args, "attentive_cost_without_normalize", False
349
+ )
350
+ cross_attentive_loss_reverse = (
351
+ False # getattr(args, "attentive_cost_reverse", False)
352
+ )
353
+ if getattr(args, "load_pretrained_text_decoder", "") != "":
354
+ checkpoint_utils.load_pretrained_component_from_model(
355
+ text_decoder, args.load_pretrained_text_decoder
356
+ )
357
+
358
+ if args.load_pretrained_speech_text_decoder:
359
+ component_pairs = (("decoder.text_decoder", text_decoder),)
360
+ cls.load_pretrained_speech_text_components(
361
+ args.load_pretrained_speech_text_decoder, component_pairs
362
+ )
363
+
364
+ decoder = TransformerMultiInputDecoder(
365
+ dictionary=task.target_dictionary,
366
+ spch_decoder=text_decoder,
367
+ text_decoder=text_decoder,
368
+ compute_cross_attentive_loss=compute_cross_attentive_loss,
369
+ cross_attentive_loss_with_norm=True
370
+ if not cross_attentive_loss_without_norm
371
+ else False,
372
+ cross_attentive_loss_reverse=cross_attentive_loss_reverse,
373
+ )
374
+ if getattr(args, "load_init_decoder", "") != "":
375
+ checkpoint_utils.load_pretrained_component_from_model(
376
+ decoder, args.load_init_decoder
377
+ )
378
+ return decoder
379
+
380
+ @classmethod
381
+ def load_pretrained_speech_text_components(cls, checkpoint, component_pairs):
382
+ if not PathManager.exists(checkpoint):
383
+ raise IOError("Model file not found: {}".format(checkpoint))
384
+ state = load_checkpoint_to_cpu(checkpoint)
385
+ for component_type, component in component_pairs:
386
+ if isinstance(component, nn.parameter.Parameter):
387
+ component.data.copy_(state["model"][component_type])
388
+ else:
389
+ component_state_dict = OrderedDict()
390
+ for key in state["model"].keys():
391
+ if key.startswith(component_type):
392
+ component_subkey = key[len(component_type) + 1 :]
393
+ component_state_dict[component_subkey] = state["model"][key]
394
+ component.load_state_dict(component_state_dict, strict=True)
395
+ return state
396
+
397
+ @classmethod
398
+ def share_speech_text_encoder(
399
+ cls, speech_encoder, text_encoder, shared_layers_from_begin
400
+ ):
401
+ if shared_layers_from_begin > 0:
402
+ num_text_encoder_layers = len(text_encoder.layers)
403
+ assert len(speech_encoder.layers) >= shared_layers_from_begin
404
+ assert num_text_encoder_layers >= shared_layers_from_begin
405
+ assert len(speech_encoder.layers) >= num_text_encoder_layers
406
+ for i, ly in enumerate(
407
+ speech_encoder.layers[
408
+ -num_text_encoder_layers : -num_text_encoder_layers
409
+ + shared_layers_from_begin
410
+ ]
411
+ ):
412
+ assert isinstance(text_encoder.layers[i], type(ly))
413
+ text_encoder.layers[i] = ly
414
+
415
+
416
+ @register_model_architecture(
417
+ "dual_input_wav_transformer", "dualinputs2twavtransformer_base"
418
+ )
419
+ def dualinputs2twavtransformer_base(args):
420
+ # speech masking
421
+ args.dropout_input = getattr(args, "dropout_input", 0)
422
+ args.dropout_features = getattr(args, "dropout_features", 0)
423
+ args.speech_mask_length = getattr(args, "speech_mask_length", 10)
424
+ args.speech_mask_prob = getattr(args, "speech_mask_prob", 0.65)
425
+ args.speech_mask_selection = getattr(args, "speech_mask_selection", "static")
426
+ args.speech_mask_other = getattr(args, "speech_mask_other", 0)
427
+ args.speech_mask_min_space = getattr(args, "speech_mask_min_space", 1)
428
+ args.speech_no_mask_overlap = getattr(args, "speech_no_mask_overlap", False)
429
+ args.speech_conv_bias = getattr(args, "speech_conv_bias", False)
430
+ args.speech_extractor_mode = getattr(args, "speech_extractor_mode", "default")
431
+ args.no_strict_check_pretrain_model = getattr(
432
+ args, "no_strict_check_pretrain_model", False
433
+ )
434
+
435
+ args.speech_mask_channel_length = getattr(args, "speech_mask_channel_length", 10)
436
+ args.speech_mask_channel_prob = getattr(args, "speech_mask_channel_prob", 0.0)
437
+ args.speech_mask_channel_selection = getattr(
438
+ args, "speech_mask_channel_selection", "static"
439
+ )
440
+ args.speech_mask_channel_other = getattr(args, "speech_mask_channel_other", 0)
441
+ args.speech_mask_channel_min_space = getattr(
442
+ args, "speech_mask_channel_min_space", 1
443
+ )
444
+ args.speech_no_mask_channel_overlap = getattr(
445
+ args, "speech_no_mask_channel_overlap", False
446
+ )
447
+ args.no_scale_feature = getattr(args, "", False)
448
+ args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0) # 0.1
449
+
450
+ # Transformer
451
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
452
+ args.encoder_ffn_embed_dim = getattr(
453
+ args, "encoder_ffn_embed_dim", args.encoder_embed_dim * 4
454
+ )
455
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
456
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
457
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
458
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
459
+
460
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
461
+ args.decoder_ffn_embed_dim = getattr(
462
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
463
+ )
464
+ args.decoder_attention_heads = getattr(
465
+ args, "decoder_attention_heads", args.encoder_attention_heads
466
+ )
467
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
468
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
469
+ args.dropout = getattr(args, "dropout", 0.1)
470
+ args.attention_dropout = getattr(args, "attention_dropout", 0)
471
+ args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
472
+ args.activation_fn = getattr(args, "activation_fn", "relu") # gelu?
473
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
474
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
475
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
476
+ args.share_decoder_input_output_embed = getattr(
477
+ args, "share_decoder_input_output_embed", False
478
+ )
479
+ args.no_token_positional_embeddings = getattr(
480
+ args, "no_token_positional_embeddings", False
481
+ )
482
+ args.adaptive_input = getattr(args, "adaptive_input", False)
483
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
484
+ args.decoder_output_dim = getattr(
485
+ args, "decoder_output_dim", args.decoder_embed_dim
486
+ )
487
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
488
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
489
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
490
+
491
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
492
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
493
+ args.encoder_shared_text_layers_from_begin = getattr(
494
+ args, "encoder_shared_text_layers_from_begin", 6
495
+ )
496
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
497
+
498
+
499
+ @register_model_architecture(
500
+ "dual_input_wav_transformer", "dualinputs2twavtransformer_base_stack"
501
+ )
502
+ def dualinputs2twavtransformer_base_stack(args):
503
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
504
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
505
+ args.encoder_shared_text_layers_from_begin = getattr(
506
+ args, "encoder_shared_text_layers_from_begin", 0
507
+ )
508
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
509
+ args.stacked_encoder = getattr(args, "stacked_encoder", True)
510
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
511
+ dualinputs2twavtransformer_base(args)
512
+
513
+
514
+ @register_model_architecture(
515
+ "dual_input_wav_transformer", "dualinputs2twavtransformer_large"
516
+ )
517
+ def dualinputs2twavtransformer_large(args):
518
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
519
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
520
+ args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 24)
521
+ args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
522
+ args.encoder_shared_text_layers_from_begin = getattr(
523
+ args, "encoder_shared_text_layers_from_begin", 12
524
+ )
525
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
526
+ dualinputs2twavtransformer_base(args)
fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+
8
+ import torch.nn as nn
9
+ from fairseq import checkpoint_utils
10
+ from fairseq import utils
11
+ from fairseq.data.data_utils import lengths_to_padding_mask
12
+ from fairseq.models import (
13
+ register_model,
14
+ register_model_architecture,
15
+ FairseqEncoder,
16
+ )
17
+ from fairseq.models.speech_to_text import Wav2VecEncoderWithAdaptor
18
+ from fairseq.models.speech_to_text.xm_transformer import (
19
+ set_default_adaptor_args,
20
+ set_default_w2v_encoder_args,
21
+ need_finetuning
22
+ )
23
+ from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
24
+ from fairseq.models.wav2vec import TransformerSentenceEncoderLayer
25
+ from fairseq.utils import safe_hasattr
26
+
27
+ from .s2t_dualinputtransformer import (
28
+ DualInputS2TTransformerModel,
29
+ TransformerMultiInputDecoder,
30
+ DualInputEncoder,
31
+ )
32
+
33
+
34
+ class TransformerSentenceEncoderLayerStd(TransformerSentenceEncoderLayer):
35
+ def __init__(self, sent_enc_layer):
36
+ super(TransformerSentenceEncoderLayer, self).__init__()
37
+ self.embedding_dim = sent_enc_layer.embedding_dim
38
+ self.dropout = sent_enc_layer.dropout
39
+ self.activation_dropout = sent_enc_layer.activation_dropout
40
+
41
+ # Initialize blocks
42
+ self.activation_fn = sent_enc_layer.activation_fn
43
+ self.self_attn = sent_enc_layer.self_attn
44
+
45
+ self.dropout1 = sent_enc_layer.dropout1
46
+ self.dropout2 = sent_enc_layer.dropout2
47
+ self.dropout3 = sent_enc_layer.dropout3
48
+
49
+ self.layer_norm_first = sent_enc_layer.layer_norm_first
50
+
51
+ # layer norm associated with the self attention layer
52
+ self.self_attn_layer_norm = sent_enc_layer.self_attn_layer_norm
53
+ self.fc1 = sent_enc_layer.fc1
54
+ self.fc2 = sent_enc_layer.fc2
55
+
56
+ # layer norm associated with the position wise feed-forward NN
57
+ self.final_layer_norm = sent_enc_layer.final_layer_norm
58
+
59
+ def forward(
60
+ self,
61
+ x,
62
+ self_attn_mask=None,
63
+ self_attn_padding_mask=None,
64
+ need_weights=None,
65
+ att_args=None,
66
+ ):
67
+ x, attn = super().forward(
68
+ x, self_attn_mask, self_attn_padding_mask, need_weights, att_args
69
+ )
70
+ return x
71
+
72
+
73
+ # TODO retire SharedEncoder
74
+ class SharedEncoder(FairseqEncoder):
75
+ def __init__(self, wav2vec_enc, mbart_enc, adaptor, shared_layers):
76
+ super().__init__(None)
77
+ self.w2v_encoder = wav2vec_enc
78
+ self.shared_layers = self.w2v_encoder.w2v_model.encoder.layers[-shared_layers:]
79
+ self.w2v_encoder.w2v_model.encoder.layers = (
80
+ self.w2v_encoder.w2v_model.encoder.layers[:-shared_layers]
81
+ )
82
+ self.adaptor = adaptor
83
+ if self.shared_layers[-1].layer_norm_first:
84
+ self.final_layer_norm = mbart_enc.layer_norm
85
+ else:
86
+ mbart_enc.layer_norm = None
87
+ self.final_layer_norm = None
88
+ shared_layer_from = len(mbart_enc.layers) - shared_layers
89
+ if shared_layer_from < 0:
90
+ shared_layer_from = 0
91
+ for layer_id, layer in enumerate(self.shared_layers):
92
+ mbart_enc.layers[
93
+ shared_layer_from + layer_id
94
+ ] = TransformerSentenceEncoderLayerStd(layer)
95
+
96
+ def forward(self, src_tokens, src_lengths=None, **kwargs):
97
+ padding_mask = lengths_to_padding_mask(src_lengths)
98
+ if not padding_mask.any():
99
+ padding_mask = None
100
+
101
+ out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True)
102
+ x = out["encoder_out"]
103
+ enc_padding_mask = None
104
+ if out["encoder_padding_mask"] is not None:
105
+ enc_padding_mask = out["encoder_padding_mask"].transpose(
106
+ 0, 1
107
+ ) # T X B --> B X T
108
+
109
+ x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
110
+ for layer in self.shared_layers:
111
+ x, _ = layer(x, enc_padding_mask)
112
+ if self.final_layer_norm is not None:
113
+ x = self.final_layer_norm(x)
114
+
115
+ return {
116
+ "encoder_out": [x], # T x B x C
117
+ "encoder_padding_mask": [enc_padding_mask]
118
+ if enc_padding_mask is not None
119
+ else [], # B x T
120
+ "encoder_embedding": [], # B x T x C
121
+ "encoder_states": [], # List[T x B x C]
122
+ "src_tokens": [],
123
+ "src_lengths": [],
124
+ }
125
+
126
+
127
+ class StackedWav2VecEncoderWithAdaptor(FairseqEncoder):
128
+ def __init__(
129
+ self,
130
+ wav2vec_enc,
131
+ mbart_enc_layers,
132
+ mbart_layer_norm,
133
+ adaptor,
134
+ drop_w2v_layers=0,
135
+ ):
136
+ super().__init__(None)
137
+ self.w2v_encoder = wav2vec_enc
138
+ self.adaptor = adaptor
139
+ self.mbart_encoder_layers = mbart_enc_layers
140
+ self.final_layer_norm = mbart_layer_norm
141
+ if drop_w2v_layers > 0:
142
+ self.w2v_encoder.w2v_model.encoder.layers = (
143
+ self.w2v_encoder.w2v_model.encoder.layers[:-drop_w2v_layers]
144
+ )
145
+
146
+ def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs):
147
+ padding_mask = lengths_to_padding_mask(src_lengths)
148
+ if not padding_mask.any():
149
+ padding_mask = None
150
+
151
+ out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True)
152
+ x = out["encoder_out"]
153
+ enc_padding_mask = None
154
+ if out["padding_mask"] is not None:
155
+ enc_padding_mask = out["padding_mask"] # B X T
156
+
157
+ x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
158
+ encoder_states = []
159
+ for layer in self.mbart_encoder_layers:
160
+ x = layer(x, enc_padding_mask)
161
+ if return_all_hiddens:
162
+ encoder_states.append(x)
163
+ if self.final_layer_norm is not None:
164
+ x = self.final_layer_norm(x)
165
+
166
+ return {
167
+ "encoder_out": [x], # T x B x C
168
+ "encoder_padding_mask": [enc_padding_mask]
169
+ if enc_padding_mask is not None
170
+ else [], # B x T
171
+ "encoder_embedding": [], # B x T x C
172
+ "encoder_states": encoder_states, # List[T x B x C]
173
+ "src_tokens": [],
174
+ "src_lengths": [],
175
+ }
176
+
177
+ def reorder_encoder_out(self, encoder_out, new_order):
178
+ new_encoder_out = (
179
+ []
180
+ if len(encoder_out["encoder_out"]) == 0
181
+ else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
182
+ )
183
+
184
+ new_encoder_padding_mask = (
185
+ []
186
+ if len(encoder_out["encoder_padding_mask"]) == 0
187
+ else [
188
+ x.index_select(0, new_order)
189
+ for x in encoder_out["encoder_padding_mask"]
190
+ ]
191
+ )
192
+
193
+ new_encoder_embedding = (
194
+ []
195
+ if len(encoder_out["encoder_embedding"]) == 0
196
+ else [
197
+ x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
198
+ ]
199
+ )
200
+
201
+ encoder_states = encoder_out["encoder_states"]
202
+ if len(encoder_states) > 0:
203
+ for idx, state in enumerate(encoder_states):
204
+ encoder_states[idx] = state.index_select(1, new_order)
205
+
206
+ return {
207
+ "encoder_out": new_encoder_out, # T x B x C
208
+ "encoder_padding_mask": new_encoder_padding_mask, # B x T
209
+ "encoder_embedding": new_encoder_embedding, # B x T x C
210
+ "encoder_states": encoder_states, # List[T x B x C]
211
+ "src_tokens": [], # B x T
212
+ "src_lengths": [], # B x 1
213
+ }
214
+
215
+
216
+ # Note:
217
+ # dual input transformer:
218
+ # encoder: wav2vec for speech + mbart encoder for text
219
+ # decoder: mbart decoder for text
220
+ @register_model("dual_input_xm_transformer")
221
+ class DualInputXMTransformerModel(DualInputS2TTransformerModel):
222
+ def __init__(self, encoder, decoder):
223
+ super().__init__(encoder, decoder)
224
+
225
+ @staticmethod
226
+ def add_args(parser):
227
+ """Add model-specific arguments to the parser."""
228
+ # wav2vec encoder
229
+ Wav2VecEncoderWithAdaptor.add_args(parser)
230
+ # add_decoder_args(parser)
231
+ # mbart Transformer
232
+ parser.add_argument(
233
+ "--activation-fn",
234
+ type=str,
235
+ default="relu",
236
+ choices=utils.get_available_activation_fns(),
237
+ help="activation function to use",
238
+ )
239
+
240
+ parser.add_argument(
241
+ "--mbart-dropout", type=float, metavar="D", help="dropout probability"
242
+ )
243
+ parser.add_argument(
244
+ "--mbart-attention-dropout",
245
+ type=float,
246
+ metavar="D",
247
+ help="dropout probability for attention weights",
248
+ )
249
+ parser.add_argument(
250
+ "--mbart-activation-dropout",
251
+ type=float,
252
+ metavar="D",
253
+ help="dropout probability after activation in FFN.",
254
+ )
255
+
256
+ parser.add_argument(
257
+ "--encoder-embed-dim",
258
+ type=int,
259
+ metavar="N",
260
+ help="encoder embedding dimension",
261
+ )
262
+ parser.add_argument(
263
+ "--encoder-ffn-embed-dim",
264
+ type=int,
265
+ metavar="N",
266
+ help="encoder embedding dimension for FFN",
267
+ )
268
+ parser.add_argument(
269
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
270
+ )
271
+ parser.add_argument(
272
+ "--encoder-attention-heads",
273
+ type=int,
274
+ metavar="N",
275
+ help="num encoder attention heads",
276
+ )
277
+ parser.add_argument(
278
+ "--encoder-normalize-before",
279
+ action="store_true",
280
+ help="apply layernorm before each encoder block",
281
+ )
282
+
283
+ parser.add_argument(
284
+ "--decoder-embed-dim",
285
+ type=int,
286
+ metavar="N",
287
+ help="decoder embedding dimension",
288
+ )
289
+ parser.add_argument(
290
+ "--decoder-ffn-embed-dim",
291
+ type=int,
292
+ metavar="N",
293
+ help="decoder embedding dimension for FFN",
294
+ )
295
+ parser.add_argument(
296
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
297
+ )
298
+ parser.add_argument(
299
+ "--decoder-attention-heads",
300
+ type=int,
301
+ metavar="N",
302
+ help="num decoder attention heads",
303
+ )
304
+ parser.add_argument(
305
+ "--decoder-normalize-before",
306
+ action="store_true",
307
+ help="apply layernorm before each decoder block",
308
+ )
309
+ parser.add_argument(
310
+ "--layernorm-embedding",
311
+ action="store_true",
312
+ help="add layernorm to embedding",
313
+ )
314
+ parser.add_argument(
315
+ "--no-scale-embedding",
316
+ action="store_true",
317
+ help="if True, dont scale embeddings",
318
+ )
319
+ parser.add_argument(
320
+ "--load-pretrained-mbart-from",
321
+ type=str,
322
+ metavar="STR",
323
+ help="model to take text encoder decoder weights from (for initialization)",
324
+ )
325
+ # parser.add_argument("--finetune-w2v-params", type=str, metavar="STR",
326
+ # help="comma-separated param strings to finetune.")
327
+ parser.add_argument(
328
+ "--finetune-mbart-decoder-params",
329
+ type=str,
330
+ metavar="STR",
331
+ help="comma-separated param strings to finetune.",
332
+ )
333
+ parser.add_argument(
334
+ "--finetune-mbart-encoder-params",
335
+ type=str,
336
+ metavar="STR",
337
+ help="comma-separated param strings to finetune.",
338
+ )
339
+ parser.add_argument(
340
+ "--skip-encoder-projection",
341
+ action="store_true",
342
+ help="skip the projection layer in encoder",
343
+ )
344
+
345
+ parser.add_argument(
346
+ "--enc-grad-mult",
347
+ type=float,
348
+ metavar="V",
349
+ default=1.0,
350
+ help="multiply enc1 and enc2 gradient by V",
351
+ )
352
+ parser.add_argument(
353
+ "--enc2-along-grad-mult",
354
+ type=float,
355
+ metavar="V",
356
+ default=1.0,
357
+ help="multiply enc2 gradient by V if only enc2 is used",
358
+ )
359
+ parser.add_argument(
360
+ "--text-input-cost-ratio",
361
+ type=float,
362
+ default=1.0,
363
+ metavar="V",
364
+ help="text input cost ratio relative to speech input cost",
365
+ )
366
+ parser.add_argument(
367
+ "--stack-w2v-mbart-encoder",
368
+ action="store_true",
369
+ help="stack w2v and mbart encoder",
370
+ )
371
+ parser.add_argument(
372
+ "--stack-w2v-mbart-nonorm-encoder",
373
+ action="store_true",
374
+ help="stack w2v and mbart encoder",
375
+ )
376
+ parser.add_argument(
377
+ "--no-final-norm-decoder", action="store_true", help="no layer norm"
378
+ )
379
+ parser.add_argument(
380
+ "--drop-w2v-layers",
381
+ type=int,
382
+ default=0,
383
+ metavar="N",
384
+ help="drop w2v encoder layers",
385
+ )
386
+
387
+ parser.add_argument(
388
+ "--share-w2v-text-encoder",
389
+ action="store_true",
390
+ help="share w2v encoder layers with text encoder",
391
+ )
392
+ parser.add_argument(
393
+ "--shared-w2v-layers",
394
+ type=int,
395
+ default=0,
396
+ metavar="N",
397
+ help="shared encoder layers from w2v encoder",
398
+ )
399
+
400
+ @classmethod
401
+ def build_encoder(cls, args, task):
402
+ _args = copy.deepcopy(args)
403
+ _args.dropout = args.mbart_dropout
404
+ _args.attention_dropout = args.mbart_attention_dropout
405
+ _args.activation_dropout = args.mbart_activation_dropout
406
+ _args.max_source_positions = 1024
407
+ enc_emb = nn.Embedding(
408
+ len(task.src_dict), _args.encoder_embed_dim, task.src_dict.pad()
409
+ )
410
+ text_encoder = TransformerEncoder(_args, task.src_dict, enc_emb)
411
+ spch_encoder = Wav2VecEncoderWithAdaptor(args)
412
+ if getattr(args, "load_pretrained_mbart_from", None):
413
+ text_encoder = checkpoint_utils.load_pretrained_component_from_model(
414
+ component=text_encoder, checkpoint=args.load_pretrained_mbart_from
415
+ )
416
+ if getattr(args, "stack_w2v_mbart_encoder", False):
417
+ assert getattr(args, "share_w2v_text_encoder", False) is False
418
+ spch_encoder = StackedWav2VecEncoderWithAdaptor(
419
+ spch_encoder.w2v_encoder,
420
+ text_encoder.layers,
421
+ text_encoder.layer_norm,
422
+ spch_encoder.adaptor,
423
+ args.drop_w2v_layers,
424
+ )
425
+ elif getattr(args, "stack_w2v_mbart_nonorm_encoder", False):
426
+ text_encoder.layer_norm = None
427
+ spch_encoder = StackedWav2VecEncoderWithAdaptor(
428
+ spch_encoder.w2v_encoder,
429
+ text_encoder.layers,
430
+ text_encoder.layer_norm,
431
+ spch_encoder.adaptor,
432
+ args.drop_w2v_layers,
433
+ )
434
+ elif getattr(args, "share_w2v_text_encoder", False):
435
+ spch_encoder = SharedEncoder(
436
+ spch_encoder.w2v_encoder,
437
+ text_encoder,
438
+ spch_encoder.adaptor,
439
+ args.shared_w2v_layers,
440
+ )
441
+
442
+ for k, p in spch_encoder.named_parameters():
443
+ # Freeze pretrained models by default
444
+ if safe_hasattr(
445
+ args, "finetune_w2v_params"
446
+ ) and need_finetuning(args.finetune_w2v_params, k):
447
+ p.requires_grad = True
448
+ else:
449
+ p.requires_grad = False
450
+ for k, p in text_encoder.named_parameters():
451
+ # Freeze pretrained models by default
452
+ if safe_hasattr(
453
+ args, "finetune_mbart_encoder_params"
454
+ ) and need_finetuning(
455
+ args.finetune_mbart_encoder_params, k
456
+ ):
457
+ p.requires_grad = True
458
+ else:
459
+ p.requires_grad = False
460
+ cross_attentive_loss_before_last_layer = (
461
+ 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
462
+ )
463
+ encoder = DualInputEncoder(
464
+ args,
465
+ spch_encoder,
466
+ text_encoder,
467
+ task.src_dict,
468
+ cross_attentive_loss_before_last_layer,
469
+ )
470
+ return encoder
471
+
472
+ @classmethod
473
+ def build_decoder(cls, args, task):
474
+ _args = copy.deepcopy(args)
475
+ _args.dropout = args.mbart_dropout
476
+ _args.attention_dropout = args.mbart_attention_dropout
477
+ _args.activation_dropout = args.mbart_activation_dropout
478
+ _args.max_target_positions = 1024
479
+ dec_emb = nn.Embedding(
480
+ len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad()
481
+ )
482
+ decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb)
483
+ if getattr(args, "load_pretrained_mbart_from", None):
484
+ decoder = checkpoint_utils.load_pretrained_component_from_model(
485
+ component=decoder, checkpoint=args.load_pretrained_mbart_from
486
+ )
487
+ if getattr(args, "no_final_norm_decoder", False):
488
+ decoder.layer_norm = None
489
+ for k, p in decoder.named_parameters():
490
+ # Freeze pretrained models by default
491
+ if safe_hasattr(
492
+ args, "finetune_mbart_decoder_params"
493
+ ) and need_finetuning(
494
+ args.finetune_mbart_decoder_params, k
495
+ ):
496
+ p.requires_grad = True
497
+ else:
498
+ p.requires_grad = False
499
+
500
+ compute_cross_attentive_loss = (
501
+ True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
502
+ )
503
+ cross_attentive_loss_without_norm = getattr(
504
+ args, "attentive_cost_without_normalize", False
505
+ )
506
+ cross_attentive_loss_reverse = (
507
+ False # getattr(args, "attentive_cost_reverse", False)
508
+ )
509
+ decoder = TransformerMultiInputDecoder(
510
+ dictionary=task.target_dictionary,
511
+ spch_decoder=decoder,
512
+ text_decoder=decoder,
513
+ compute_cross_attentive_loss=compute_cross_attentive_loss,
514
+ cross_attentive_loss_with_norm=True
515
+ if not cross_attentive_loss_without_norm
516
+ else False,
517
+ cross_attentive_loss_reverse=cross_attentive_loss_reverse,
518
+ )
519
+ return decoder
520
+
521
+ @classmethod
522
+ def build_model(cls, args, task):
523
+ """Build a new model instance."""
524
+ # make sure that all args are properly defaulted
525
+ # (in case there are any new ones)
526
+ dualinputxmtransformer_base(args)
527
+
528
+ encoder = cls.build_encoder(args, task)
529
+ decoder = cls.build_decoder(args, task)
530
+ return cls(encoder, decoder)
531
+
532
+
533
+ @register_model_architecture("dual_input_xm_transformer", "dualinputxmtransformer_base")
534
+ def dualinputxmtransformer_base(args):
535
+ # wav2vec encoder
536
+ set_default_w2v_encoder_args(args)
537
+ set_default_adaptor_args(args)
538
+
539
+ # mbart model
540
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
541
+ args.encoder_ffn_embed_dim = getattr(
542
+ args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim
543
+ )
544
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
545
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
546
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
547
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
548
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
549
+
550
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
551
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
552
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024)
553
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
554
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
555
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
556
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
557
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
558
+
559
+ args.adaptive_input = getattr(args, "adaptive_input", False)
560
+
561
+ args.mbart_attention_dropout = getattr(args, "mbart_attention_dropout", 0.0)
562
+ args.mbart_activation_dropout = getattr(args, "mbart_activation_dropout", 0.0)
563
+ args.mbart_dropout = getattr(args, "mbart_dropout", 0.1)
564
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
565
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
566
+ args.share_decoder_input_output_embed = getattr(
567
+ args, "share_decoder_input_output_embed", True
568
+ )
569
+ args.no_token_positional_embeddings = getattr(
570
+ args, "no_token_positional_embeddings", False
571
+ )
572
+
573
+ args.decoder_output_dim = getattr(
574
+ args, "decoder_output_dim", args.decoder_embed_dim
575
+ )
576
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
577
+
578
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
579
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
580
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
581
+
582
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
583
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
584
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
fairseq/examples/speech_text_joint_to_text/scripts/convert_model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import re
9
+ from collections import OrderedDict
10
+
11
+ import torch
12
+
13
+ from fairseq.file_io import PathManager
14
+
15
+
16
+ def is_update(param_name, module_name):
17
+ if module_name in param_name:
18
+ return True
19
+ return False
20
+
21
+
22
+ def load_checkpoint(src_cpt):
23
+
24
+ with PathManager.open(src_cpt, "rb") as f:
25
+ state_src = torch.load(
26
+ f,
27
+ map_location=(
28
+ lambda s, _: torch.serialization.default_restore_location(s, "cpu")
29
+ ),
30
+ )
31
+
32
+ return state_src
33
+
34
+
35
+ def save_checkpoint(tgt_cpt, states):
36
+
37
+ with PathManager.open(tgt_cpt, "wb") as f:
38
+ torch.save(
39
+ states,
40
+ f,
41
+ )
42
+
43
+
44
+ # convert the pre-trained model into bart model
45
+ def main():
46
+ parser = argparse.ArgumentParser()
47
+ # fmt: off
48
+ parser.add_argument('--input-model', required=True,
49
+ help='Input checkpoint file path.')
50
+ parser.add_argument('--output-model', required=True,
51
+ help='output checkpoint file path.')
52
+ # fmt: on
53
+ args = parser.parse_args()
54
+ print(args)
55
+
56
+ states = load_checkpoint(args.input_model)
57
+ model = states["model"]
58
+ new_model = OrderedDict()
59
+ for key in model.keys():
60
+ if re.search("^encoder.text_encoder", key):
61
+ new_key = re.sub("encoder.text_encoder", "encoder", key)
62
+ new_model[new_key] = model[key]
63
+ elif re.search("^decoder.text_decoder", key):
64
+ new_key = re.sub("decoder.text_decoder", "decoder", key)
65
+ new_model[new_key] = model[key]
66
+ states["model"] = new_model
67
+ save_checkpoint(args.output_model, states)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import itertools
8
+ import logging
9
+ import re
10
+ import time
11
+
12
+ from g2p_en import G2p
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ FAIL_SENT = "FAILED_SENTENCE"
17
+
18
+
19
+ def parse():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--data-path", type=str, required=True)
22
+ parser.add_argument("--out-path", type=str, required=True)
23
+ parser.add_argument("--lower-case", action="store_true")
24
+ parser.add_argument("--do-filter", action="store_true")
25
+ parser.add_argument("--use-word-start", action="store_true")
26
+ parser.add_argument("--dup-vowel", default=1, type=int)
27
+ parser.add_argument("--dup-consonant", default=1, type=int)
28
+ parser.add_argument("--no-punc", action="store_true")
29
+ parser.add_argument("--reserve-word", type=str, default="")
30
+ parser.add_argument(
31
+ "--reserve-first-column",
32
+ action="store_true",
33
+ help="first column is sentence id",
34
+ )
35
+ ###
36
+ parser.add_argument("--parallel-process-num", default=1, type=int)
37
+ parser.add_argument("--logdir", default="")
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def process_sent(sent, g2p, res_wrds, args):
43
+ sents = pre_process_sent(sent, args.do_filter, args.lower_case, res_wrds)
44
+ pho_seqs = [do_g2p(g2p, s, res_wrds, i == 0) for i, s in enumerate(sents)]
45
+ pho_seq = (
46
+ [FAIL_SENT]
47
+ if [FAIL_SENT] in pho_seqs
48
+ else list(itertools.chain.from_iterable(pho_seqs))
49
+ )
50
+ if args.no_punc:
51
+ pho_seq = remove_punc(pho_seq)
52
+ if args.dup_vowel > 1 or args.dup_consonant > 1:
53
+ pho_seq = dup_pho(pho_seq, args.dup_vowel, args.dup_consonant)
54
+ if args.use_word_start:
55
+ pho_seq = add_word_start(pho_seq)
56
+ return " ".join(pho_seq)
57
+
58
+
59
+ def remove_punc(sent):
60
+ ns = []
61
+ regex = re.compile("[^a-zA-Z0-9 ]")
62
+ for p in sent:
63
+ if (not regex.search(p)) or p == FAIL_SENT:
64
+ if p == " " and (len(ns) == 0 or ns[-1] == " "):
65
+ continue
66
+ ns.append(p)
67
+ return ns
68
+
69
+
70
+ def do_g2p(g2p, sent, res_wrds, is_first_sent):
71
+ if sent in res_wrds:
72
+ pho_seq = [res_wrds[sent]]
73
+ else:
74
+ pho_seq = g2p(sent)
75
+ if not is_first_sent:
76
+ pho_seq = [" "] + pho_seq # add space to separate
77
+ return pho_seq
78
+
79
+
80
+ def pre_process_sent(sent, do_filter, lower_case, res_wrds):
81
+ if do_filter:
82
+ sent = re.sub("-", " ", sent)
83
+ sent = re.sub("—", " ", sent)
84
+ if len(res_wrds) > 0:
85
+ wrds = sent.split()
86
+ wrds = ["SPLIT_ME " + w + " SPLIT_ME" if w in res_wrds else w for w in wrds]
87
+ sents = [x.strip() for x in " ".join(wrds).split("SPLIT_ME") if x.strip() != ""]
88
+ else:
89
+ sents = [sent]
90
+ if lower_case:
91
+ sents = [s.lower() if s not in res_wrds else s for s in sents]
92
+ return sents
93
+
94
+
95
+ def dup_pho(sent, dup_v_num, dup_c_num):
96
+ """
97
+ duplicate phoneme defined as cmudict
98
+ http://www.speech.cs.cmu.edu/cgi-bin/cmudict
99
+ """
100
+ if dup_v_num == 1 and dup_c_num == 1:
101
+ return sent
102
+ ns = []
103
+ for p in sent:
104
+ ns.append(p)
105
+ if re.search(r"\d$", p):
106
+ for i in range(1, dup_v_num):
107
+ ns.append(f"{p}-{i}P")
108
+ elif re.search(r"\w", p):
109
+ for i in range(1, dup_c_num):
110
+ ns.append(f"{p}-{i}P")
111
+ return ns
112
+
113
+
114
+ def add_word_start(sent):
115
+ ns = []
116
+ do_add = True
117
+ ws = "▁"
118
+ for p in sent:
119
+ if do_add:
120
+ p = ws + p
121
+ do_add = False
122
+ if p == " ":
123
+ do_add = True
124
+ else:
125
+ ns.append(p)
126
+ return ns
127
+
128
+
129
+ def load_reserve_word(reserve_word):
130
+ if reserve_word == "":
131
+ return []
132
+ with open(reserve_word, "r") as fp:
133
+ res_wrds = [x.strip().split() for x in fp.readlines() if x.strip() != ""]
134
+ assert sum([0 if len(x) == 2 else 1 for x in res_wrds]) == 0
135
+ res_wrds = dict(res_wrds)
136
+ return res_wrds
137
+
138
+
139
+ def process_sents(sents, args):
140
+ g2p = G2p()
141
+ out_sents = []
142
+ res_wrds = load_reserve_word(args.reserve_word)
143
+ for sent in sents:
144
+ col1 = ""
145
+ if args.reserve_first_column:
146
+ col1, sent = sent.split(None, 1)
147
+ sent = process_sent(sent, g2p, res_wrds, args)
148
+ if args.reserve_first_column and col1 != "":
149
+ sent = f"{col1} {sent}"
150
+ out_sents.append(sent)
151
+ return out_sents
152
+
153
+
154
+ def main():
155
+ args = parse()
156
+ out_sents = []
157
+ with open(args.data_path, "r") as fp:
158
+ sent_list = [x.strip() for x in fp.readlines()]
159
+ if args.parallel_process_num > 1:
160
+ try:
161
+ import submitit
162
+ except ImportError:
163
+ logger.warn(
164
+ "submitit is not found and only one job is used to process the data"
165
+ )
166
+ submitit = None
167
+
168
+ if args.parallel_process_num == 1 or submitit is None:
169
+ out_sents = process_sents(sent_list, args)
170
+ else:
171
+ # process sentences with parallel computation
172
+ lsize = len(sent_list) // args.parallel_process_num + 1
173
+ executor = submitit.AutoExecutor(folder=args.logdir)
174
+ executor.update_parameters(timeout_min=1000, cpus_per_task=4)
175
+ jobs = []
176
+ for i in range(args.parallel_process_num):
177
+ job = executor.submit(
178
+ process_sents, sent_list[lsize * i : lsize * (i + 1)], args
179
+ )
180
+ jobs.append(job)
181
+ is_running = True
182
+ while is_running:
183
+ time.sleep(5)
184
+ is_running = sum([job.done() for job in jobs]) < len(jobs)
185
+ out_sents = list(itertools.chain.from_iterable([job.result() for job in jobs]))
186
+ with open(args.out_path, "w") as fp:
187
+ fp.write("\n".join(out_sents) + "\n")
188
+
189
+
190
+ if __name__ == "__main__":
191
+ main()
fairseq/examples/speech_text_joint_to_text/tasks/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
fairseq/examples/speech_text_joint_to_text/tasks/pair_denoising.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import logging
8
+ import os
9
+ import re
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from examples.speech_text_joint_to_text.data.pair_denoising_dataset import (
15
+ LanguagePairDenoisingDataset,
16
+ )
17
+ from fairseq import utils
18
+ from fairseq.data import (
19
+ ConcatDataset,
20
+ Dictionary,
21
+ LanguagePairDataset,
22
+ ResamplingDataset,
23
+ TransformEosConcatLangPairDataset,
24
+ TransformEosLangPairDataset,
25
+ data_utils,
26
+ indexed_dataset,
27
+ )
28
+ from fairseq.data.encoders.utils import get_whole_word_mask
29
+ from fairseq.tasks import register_task
30
+ from fairseq.tasks.translation import TranslationTask
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def gen_whole_word_mask(args, dictionary):
36
+ def is_beginning_of_word(i):
37
+ if i < dictionary.nspecial:
38
+ # special elements are always considered beginnings
39
+ return True
40
+ tok = dictionary[i]
41
+ if tok.startswith("madeupword"):
42
+ return True
43
+
44
+ if tok in ["<unk>", "<s>", "</s>", "<pad>"]:
45
+ return True
46
+ return tok.startswith("\u2581")
47
+
48
+ if args.use_mask_whole_words:
49
+ mask_whole_words = torch.ByteTensor(
50
+ list(map(is_beginning_of_word, range(len(dictionary))))
51
+ )
52
+ else:
53
+ # it will mask every token as word leading token, since no bpe model is loaded for phoneme tokens
54
+ return get_whole_word_mask(args, dictionary)
55
+ return mask_whole_words
56
+
57
+
58
+ @register_task("paired_denoising")
59
+ class PairedDenoisingTask(TranslationTask):
60
+
61
+ LANG_TAG_TEMPLATE = "<lang:{}>" # Tag for language (target)
62
+
63
+ @staticmethod
64
+ def add_args(parser):
65
+ TranslationTask.add_args(parser)
66
+ # bart setting
67
+ parser.add_argument(
68
+ "--mask",
69
+ default=0.0,
70
+ type=float,
71
+ help="fraction of words/subwords that will be masked",
72
+ )
73
+ parser.add_argument(
74
+ "--mask-random",
75
+ default=0.0,
76
+ type=float,
77
+ help="instead of using [MASK], use random token this often",
78
+ )
79
+ parser.add_argument(
80
+ "--insert",
81
+ default=0.0,
82
+ type=float,
83
+ help="insert this percentage of additional random tokens",
84
+ )
85
+ parser.add_argument(
86
+ "--poisson-lambda",
87
+ default=3.0,
88
+ type=float,
89
+ help="randomly shuffle sentences for this proportion of inputs",
90
+ )
91
+ parser.add_argument(
92
+ "--mask-length",
93
+ default="span-poisson",
94
+ type=str,
95
+ choices=["subword", "word", "span-poisson"],
96
+ help="mask length to choose",
97
+ )
98
+ parser.add_argument(
99
+ "--replace-length",
100
+ default=1,
101
+ type=int,
102
+ help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
103
+ )
104
+
105
+ # multi-lingual
106
+ parser.add_argument(
107
+ "--multilang-sampling-alpha",
108
+ type=float,
109
+ default=1.0,
110
+ help="smoothing alpha for sample ratios across multiple datasets",
111
+ )
112
+ parser.add_argument(
113
+ "--lang-pairs",
114
+ default="",
115
+ metavar="PAIRS",
116
+ help="comma-separated list of language pairs (in training order): phnen-en,phnfr-fr,phnit-it. Do masking",
117
+ )
118
+ parser.add_argument(
119
+ "--lang-pairs-bitext",
120
+ default="",
121
+ metavar="PAIRS",
122
+ help="comma-separated list of language pairs (in training order): en-de,en-fr,de-fr. No masking",
123
+ )
124
+ parser.add_argument("--add-src-lang-token", default=False, action="store_true")
125
+ parser.add_argument("--add-tgt-lang-token", default=False, action="store_true")
126
+ parser.add_argument(
127
+ "--no-whole-word-mask-langs",
128
+ type=str,
129
+ default="",
130
+ metavar="N",
131
+ help="languages without spacing between words dont support whole word masking",
132
+ )
133
+ parser.add_argument(
134
+ "--use-mask-whole-words", default=False, action="store_true"
135
+ )
136
+
137
+ @classmethod
138
+ def setup_task(cls, args, **kwargs):
139
+ """Setup the task."""
140
+ paths = args.data.split(":")
141
+ assert len(paths) > 0
142
+ src_dict = Dictionary.load(
143
+ os.path.join(paths[0], "src_dict.txt")
144
+ ) # assume all languages share a source dictionary
145
+ tgt_dict = Dictionary.load(
146
+ os.path.join(paths[0], "tgt_dict.txt")
147
+ ) # assume all languages share a target dictionary
148
+
149
+ lang_pairs = args.lang_pairs + "," + args.lang_pairs_bitext
150
+ lang_pairs = re.sub(",$", "", re.sub("^,", "", lang_pairs))
151
+ src_langs = [lp.split("-")[0] for lp in lang_pairs.split(",")]
152
+ tgt_langs = [lp.split("-")[1] for lp in lang_pairs.split(",")]
153
+
154
+ if args.add_src_lang_token:
155
+ for lang in src_langs:
156
+ assert (
157
+ src_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
158
+ != src_dict.unk()
159
+ )
160
+ if args.add_tgt_lang_token:
161
+ for lang in tgt_langs:
162
+ assert (
163
+ tgt_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
164
+ != tgt_dict.unk()
165
+ )
166
+
167
+ logger.info("source dictionary: {} types".format(len(src_dict)))
168
+ logger.info("target dictionary: {} types".format(len(tgt_dict)))
169
+ if not hasattr(args, "shuffle_instance"):
170
+ args.shuffle_instance = False
171
+ return cls(args, src_dict, tgt_dict)
172
+
173
+ def __init__(self, args, src_dict, tgt_dict):
174
+ super().__init__(args, src_dict, tgt_dict)
175
+ # check mask token
176
+ self.mask_idx = self.src_dict.index("<mask>")
177
+ assert self.mask_idx != self.src_dict.unk()
178
+ self.lang_pairs = args.lang_pairs
179
+ self.lang_pairs_bitext = args.lang_pairs_bitext
180
+ self.args = args
181
+
182
+ @classmethod
183
+ def language_pair_denoising_dataset(
184
+ cls,
185
+ data_path,
186
+ do_mask,
187
+ split,
188
+ src,
189
+ src_dict,
190
+ tgt,
191
+ tgt_dict,
192
+ mask_idx,
193
+ mask_whole_words,
194
+ seed,
195
+ args,
196
+ dataset_impl,
197
+ combine=False,
198
+ left_pad_source=True,
199
+ left_pad_target=False,
200
+ max_source_positions=1024,
201
+ max_target_positions=1024,
202
+ shuffle=True,
203
+ src_lang_id=None,
204
+ tgt_lang_id=None,
205
+ ):
206
+ def split_exists(split, src, tgt, lang, data_path):
207
+ filename = os.path.join(
208
+ data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
209
+ )
210
+ return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
211
+
212
+ src_datasets = []
213
+ tgt_datasets = []
214
+
215
+ for k in itertools.count():
216
+ split_k = split + (str(k) if k > 0 else "")
217
+
218
+ # infer langcode
219
+ if split_exists(split_k, src, tgt, src, data_path):
220
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
221
+ elif split_exists(split_k, tgt, src, src, data_path):
222
+ prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
223
+ else:
224
+ if k > 0:
225
+ break
226
+ else:
227
+ raise FileNotFoundError(
228
+ "Dataset not found: {} ({})".format(split, data_path)
229
+ )
230
+
231
+ src_dataset = data_utils.load_indexed_dataset(
232
+ prefix + src, src_dict, dataset_impl
233
+ )
234
+ src_datasets.append(src_dataset)
235
+
236
+ tgt_dataset = data_utils.load_indexed_dataset(
237
+ prefix + tgt, tgt_dict, dataset_impl
238
+ )
239
+ if tgt_dataset is not None:
240
+ tgt_datasets.append(tgt_dataset)
241
+
242
+ logger.info(
243
+ "{} {} {}-{} {} examples".format(
244
+ data_path, split_k, src, tgt, len(src_datasets[-1])
245
+ )
246
+ )
247
+
248
+ if not combine:
249
+ break
250
+
251
+ assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
252
+
253
+ if len(src_datasets) == 1:
254
+ src_dataset = src_datasets[0]
255
+ tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
256
+ else:
257
+ sample_ratios = [1] * len(src_datasets)
258
+ src_dataset = ConcatDataset(src_datasets, sample_ratios)
259
+ if len(tgt_datasets) > 0:
260
+ tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
261
+ else:
262
+ tgt_dataset = None
263
+
264
+ eos = None
265
+
266
+ tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
267
+ if not do_mask:
268
+ return LanguagePairDataset(
269
+ src_dataset,
270
+ src_dataset.sizes,
271
+ src_dict,
272
+ tgt_dataset,
273
+ tgt_dataset_sizes,
274
+ tgt_dict,
275
+ left_pad_source=left_pad_source,
276
+ left_pad_target=left_pad_target,
277
+ eos=eos,
278
+ shuffle=shuffle,
279
+ src_lang_id=src_lang_id,
280
+ tgt_lang_id=tgt_lang_id,
281
+ )
282
+
283
+ return LanguagePairDenoisingDataset(
284
+ src_dataset,
285
+ src_dataset.sizes,
286
+ src_dict,
287
+ tgt_dataset,
288
+ tgt_dataset_sizes,
289
+ tgt_dict,
290
+ mask_idx,
291
+ mask_whole_words,
292
+ seed,
293
+ args,
294
+ left_pad_source=left_pad_source,
295
+ left_pad_target=left_pad_target,
296
+ eos=eos,
297
+ shuffle=shuffle,
298
+ src_lang_id=src_lang_id,
299
+ tgt_lang_id=tgt_lang_id,
300
+ )
301
+
302
+ def _get_sample_prob(self, dataset_lens):
303
+ """
304
+ Get smoothed sampling porbability by languages. This helps low resource
305
+ languages by upsampling them.
306
+ """
307
+ prob = dataset_lens / dataset_lens.sum()
308
+ smoothed_prob = prob ** self.args.multilang_sampling_alpha
309
+ smoothed_prob = smoothed_prob / smoothed_prob.sum()
310
+ return smoothed_prob
311
+
312
+ def resample_datasets(self, lang_datasets, lang_pairs_all, epoch):
313
+ # For train subset, additionally up or down sample languages.
314
+ if self.args.multilang_sampling_alpha == 1.0:
315
+ return lang_datasets
316
+
317
+ dataset_lengths = np.array(
318
+ [len(d) for d in lang_datasets],
319
+ dtype=float,
320
+ )
321
+ sample_probs = self._get_sample_prob(dataset_lengths)
322
+ logger.info(
323
+ "Sample probability by language pair: {}".format(
324
+ {
325
+ lp: "{0:.4f}".format(sample_probs[id])
326
+ for id, lp in enumerate(lang_pairs_all)
327
+ }
328
+ )
329
+ )
330
+ size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
331
+ logger.info(
332
+ "Up/Down Sampling ratio by language: {}".format(
333
+ {
334
+ lp: "{0:.2f}".format(size_ratio[id])
335
+ for id, lp in enumerate(lang_pairs_all)
336
+ }
337
+ )
338
+ )
339
+
340
+ resampled_lang_datasets = [
341
+ ResamplingDataset(
342
+ lang_datasets[i],
343
+ size_ratio=size_ratio[i],
344
+ seed=self.args.seed,
345
+ epoch=epoch,
346
+ replace=size_ratio[i] >= 1.0,
347
+ )
348
+ for i, d in enumerate(lang_datasets)
349
+ ]
350
+ return resampled_lang_datasets
351
+
352
+ def load_dataset_only(
353
+ self, split, lang_pairs, do_mask=True, epoch=1, combine=False
354
+ ):
355
+ paths = utils.split_paths(self.args.data)
356
+ assert len(paths) > 0
357
+ data_path = paths[(epoch - 1) % len(paths)]
358
+
359
+ # TODO unk token will be considered as first word too, though it might be an unknown phoneme within a word
360
+ # get_whole_word_mask returns a tensor (size V by 1 ) to indicate if a token is a word start token
361
+ mask_whole_src_words = gen_whole_word_mask(self.args, self.src_dict)
362
+ language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
363
+ lang_datasets = []
364
+ eos_bos = []
365
+ lang_pairs = lang_pairs.split(",") if lang_pairs != "" else []
366
+ assert len(lang_pairs) > 0
367
+ for lp in lang_pairs:
368
+ src, tgt = lp.split("-")
369
+ lang_mask_whole_src_words = (
370
+ mask_whole_src_words
371
+ if src not in language_without_segmentations
372
+ else None
373
+ )
374
+
375
+ end_token = (
376
+ self.source_dictionary.index(
377
+ PairedDenoisingTask.LANG_TAG_TEMPLATE.format(src)
378
+ )
379
+ if self.args.add_src_lang_token
380
+ else None
381
+ )
382
+ bos_token = (
383
+ self.target_dictionary.index(
384
+ PairedDenoisingTask.LANG_TAG_TEMPLATE.format(tgt)
385
+ )
386
+ if self.args.add_tgt_lang_token
387
+ else None
388
+ )
389
+ src_lang_id = None
390
+
391
+ if self.args.add_src_lang_token or self.args.add_tgt_lang_token:
392
+ eos_bos.append((end_token, bos_token))
393
+
394
+ dataset = PairedDenoisingTask.language_pair_denoising_dataset(
395
+ data_path,
396
+ do_mask,
397
+ split,
398
+ src,
399
+ self.source_dictionary,
400
+ tgt,
401
+ self.target_dictionary,
402
+ self.mask_idx,
403
+ lang_mask_whole_src_words,
404
+ self.args.seed,
405
+ self.args,
406
+ self.args.dataset_impl,
407
+ combine=combine,
408
+ left_pad_source=utils.eval_bool(self.args.left_pad_source),
409
+ left_pad_target=utils.eval_bool(self.args.left_pad_target),
410
+ max_source_positions=self.args.max_source_positions,
411
+ max_target_positions=self.args.max_target_positions,
412
+ src_lang_id=src_lang_id,
413
+ )
414
+
415
+ lang_datasets.append(dataset)
416
+
417
+ if len(lang_datasets) == 0:
418
+ return
419
+ elif len(lang_datasets) == 1:
420
+ dataset = lang_datasets[0]
421
+ if self.args.add_src_lang_token or self.args.add_tgt_lang_token:
422
+ end_token, bos_token = eos_bos[0]
423
+ dataset = TransformEosLangPairDataset(
424
+ dataset,
425
+ src_eos=self.source_dictionary.eos(),
426
+ new_src_eos=end_token,
427
+ tgt_bos=self.target_dictionary.eos(),
428
+ new_tgt_bos=bos_token,
429
+ )
430
+ else:
431
+ end_tokens = [item[0] for item in eos_bos if item[0] is not None]
432
+ bos_tokens = [item[1] for item in eos_bos if item[1] is not None]
433
+ lang_datasets = self.resample_datasets(lang_datasets, lang_pairs, epoch)
434
+ dataset = TransformEosConcatLangPairDataset(
435
+ lang_datasets,
436
+ self.source_dictionary.eos(),
437
+ self.target_dictionary.eos(),
438
+ new_src_eos=end_tokens,
439
+ new_tgt_bos=bos_tokens,
440
+ )
441
+ return dataset
442
+
443
+ # split in (train, valid, test, ...)
444
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
445
+ self.datasets[split] = self.load_dataset_only(
446
+ split, self.lang_pairs, epoch=epoch, combine=combine
447
+ )
fairseq/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import logging
6
+ import os
7
+ import re
8
+ from argparse import Namespace
9
+ from pathlib import Path
10
+
11
+ from fairseq.data import ConcatDataset, Dictionary, encoders
12
+ from fairseq.data.audio.multi_modality_dataset import (
13
+ FileAudioDatasetWrapper,
14
+ ModalityDatasetItem,
15
+ MultiModalityDataset,
16
+ )
17
+ from fairseq.data.audio.speech_to_text_joint_dataset import (
18
+ S2TJointDataConfig,
19
+ SpeechToTextJointDatasetCreator,
20
+ )
21
+ from fairseq.data.iterators import GroupedEpochBatchIterator
22
+ from fairseq.tasks import register_task
23
+
24
+ from .pair_denoising import PairedDenoisingTask
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @register_task("speech_text_joint_denoising")
30
+ class SpeechTextJointDenoisingPreTask(PairedDenoisingTask):
31
+ """
32
+ Joint denoising training task for speech and text.
33
+ """
34
+
35
+ SIL_TOKEN = "sil"
36
+
37
+ @classmethod
38
+ def add_args(cls, parser):
39
+ PairedDenoisingTask.add_args(parser)
40
+ # set max tokens and position
41
+ parser.add_argument(
42
+ "--max-text-tokens",
43
+ type=int,
44
+ metavar="N",
45
+ default=1024,
46
+ help="maximum samples for encoder text input ",
47
+ )
48
+ parser.add_argument(
49
+ "--max-speech-tokens",
50
+ type=int,
51
+ metavar="N",
52
+ default=50000,
53
+ help="maximum samples for encoder speech input ",
54
+ )
55
+ parser.add_argument(
56
+ "--max-speech-positions",
57
+ type=int,
58
+ metavar="N",
59
+ default=400,
60
+ help="maximum tokens for per encoder text input ",
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--max-sample-size",
65
+ type=int,
66
+ metavar="N",
67
+ default=32000,
68
+ help="max sample size to crop to for batching (unsupervised speech) ",
69
+ )
70
+ parser.add_argument(
71
+ "--min-sample-size",
72
+ type=int,
73
+ metavar="N",
74
+ default=4000,
75
+ help="min sample size to crop to for batching (unsupervised speech) ",
76
+ )
77
+
78
+ # set mini-batch ratio for different modalities/subtasks
79
+ # s2p
80
+ parser.add_argument(
81
+ "--supervised-speech-sample-ratio",
82
+ default="1",
83
+ type=str,
84
+ metavar="N",
85
+ help="Multiple Ratio for speech dataset with transcripts ",
86
+ )
87
+ # s2t
88
+ parser.add_argument(
89
+ "--supervised-speech-s2s-sample-ratio",
90
+ default="1",
91
+ type=str,
92
+ metavar="N",
93
+ help="Multiple Ratio for speech dataset with transcripts ",
94
+ )
95
+ # ssl
96
+ parser.add_argument(
97
+ "--unsupervised-speech-sample-ratio",
98
+ default="1",
99
+ type=str,
100
+ metavar="N",
101
+ help="Multiple Ratio for speech dataset without transcripts ",
102
+ )
103
+ # t2t with monolingual data (masking)
104
+ parser.add_argument(
105
+ "--text-sample-ratio",
106
+ default="1",
107
+ type=str,
108
+ metavar="N",
109
+ help="Multiple Ratio for text set ",
110
+ )
111
+ # t2t with parallel data (no masking)
112
+ parser.add_argument(
113
+ "--bitext-sample-ratio",
114
+ default="1",
115
+ type=str,
116
+ metavar="N",
117
+ help="Multiple Ratio for text set (bitext) ",
118
+ )
119
+ # train_subset = "train", 'valid' or so
120
+ # parallel data is loaded according to string lang_pairs and lang_pairs_no_mask from args.data
121
+ # (un)supervised speech is loaded from args.(un)sup_speech_{train,valid}_subset
122
+ parser.add_argument(
123
+ "--sup-speech-data", default="", help="path to supervised speech data"
124
+ )
125
+ parser.add_argument(
126
+ "--sup-speech-train-subset",
127
+ default="",
128
+ help="supervised speech training subsets",
129
+ )
130
+ parser.add_argument(
131
+ "--sup-speech-valid-subset",
132
+ default="",
133
+ help="supervised speech validation subsets",
134
+ )
135
+ parser.add_argument(
136
+ "--config-yaml",
137
+ default="config.yaml",
138
+ help="supervised speech configuration yaml file",
139
+ )
140
+ parser.add_argument(
141
+ "--sup-speech-s2s-data", default="", help="path to supervised speech data"
142
+ )
143
+ parser.add_argument(
144
+ "--sup-speech-s2s-train-subset",
145
+ default="",
146
+ help="supervised speech training subsets",
147
+ )
148
+ parser.add_argument(
149
+ "--sup-speech-s2s-valid-subset",
150
+ default="",
151
+ help="supervised speech validation subsets",
152
+ )
153
+ parser.add_argument(
154
+ "--config-s2s-yaml",
155
+ default="config.yaml",
156
+ help="supervised speech configuration yaml file",
157
+ )
158
+ parser.add_argument(
159
+ "--unsup-speech-train-data",
160
+ default="",
161
+ help="path to unsupervised speech training data (tsv)",
162
+ )
163
+ parser.add_argument(
164
+ "--unsup-speech-valid-data",
165
+ default="",
166
+ help="path to unsupervised speech valid data (tsv)",
167
+ )
168
+ parser.add_argument(
169
+ "--sample-rate",
170
+ type=int,
171
+ metavar="N",
172
+ default=16000,
173
+ help="input audio sampling rate",
174
+ )
175
+ parser.add_argument(
176
+ "--no-emb-update-unsup",
177
+ default=False,
178
+ action="store_true",
179
+ help="no update for output embedding during unsupervised_speech mode",
180
+ )
181
+ parser.add_argument("--same-data-update", default=False, action="store_true")
182
+
183
+ # used for sup_speech_ali
184
+ parser.add_argument(
185
+ "--use-sup-speech-ctc",
186
+ default=False,
187
+ action="store_true",
188
+ help="use speech_sup_ctc instead of speech_sup_ali",
189
+ )
190
+
191
+ @classmethod
192
+ def setup_task(cls, args, **kwargs):
193
+ """Setup the task."""
194
+ paths = args.data.split(":")
195
+ assert len(paths) > 0
196
+ src_dict = Dictionary.load(
197
+ os.path.join(paths[0], "src_dict.txt")
198
+ ) # assume all languages share a source dictionary
199
+ tgt_dict = Dictionary.load(
200
+ os.path.join(paths[0], "tgt_dict.txt")
201
+ ) # assume all languages share a target dictionary
202
+
203
+ lang_pairs = args.lang_pairs + "," + args.lang_pairs_bitext
204
+ lang_pairs = re.sub(",$", "", re.sub("^,", "", lang_pairs))
205
+ if lang_pairs != "":
206
+ src_langs = [lp.split("-")[0] for lp in lang_pairs.split(",")]
207
+ tgt_langs = [lp.split("-")[1] for lp in lang_pairs.split(",")]
208
+ else:
209
+ src_langs = []
210
+ tgt_langs = []
211
+
212
+ if args.add_src_lang_token:
213
+ for lang in src_langs:
214
+ assert (
215
+ src_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
216
+ != src_dict.unk()
217
+ )
218
+ if args.add_tgt_lang_token:
219
+ for lang in tgt_langs:
220
+ assert (
221
+ tgt_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
222
+ != tgt_dict.unk()
223
+ )
224
+
225
+ logger.info("source dictionary: {} types".format(len(src_dict)))
226
+ logger.info("target dictionary: {} types".format(len(tgt_dict)))
227
+ if not hasattr(args, "shuffle_instance"):
228
+ args.shuffle_instance = False
229
+ return cls(args, src_dict, tgt_dict)
230
+
231
+ def __init__(self, args, src_dict, tgt_dict):
232
+ super().__init__(args, src_dict, tgt_dict)
233
+ self.data_cfg = S2TJointDataConfig(
234
+ Path(args.sup_speech_data) / args.config_yaml
235
+ )
236
+ logger.info(
237
+ f"load supervised speech data configure from {Path(args.sup_speech_data) / args.config_yaml}"
238
+ )
239
+ self.data_s2s_cfg = (
240
+ S2TJointDataConfig(Path(args.sup_speech_s2s_data) / args.config_s2s_yaml)
241
+ if args.sup_speech_s2s_train_subset != ""
242
+ else None
243
+ )
244
+ if self.data_s2s_cfg is not None:
245
+ logger.info(
246
+ f"load supervised sequece to sequence speech data configure from {Path(args.sup_speech_s2s_data) / args.config_yaml}"
247
+ )
248
+
249
+ def parse_data_ratio(sample_ratio):
250
+ ratios = sample_ratio.split(",")
251
+ if len(ratios) == 1:
252
+ return [float(ratios[0])]
253
+ epoch_ratios = []
254
+ for item in ratios:
255
+ ep, r = item.split(":")
256
+ ep = int(ep)
257
+ r = float(r)
258
+ assert ep > 0 # epoch is 1 based
259
+ assert ep >= len(epoch_ratios)
260
+
261
+ if len(epoch_ratios) == 0:
262
+ epoch_ratios.append(
263
+ r
264
+ ) # epoch_ratios[0] is not used, but we still set it to the first value to make thing simple.
265
+ while len(epoch_ratios) < ep:
266
+ epoch_ratios.append(epoch_ratios[-1])
267
+ epoch_ratios.append(r)
268
+ return epoch_ratios
269
+
270
+ self.sup_ratio = parse_data_ratio(args.supervised_speech_sample_ratio)
271
+ self.sup_s2s_ratio = parse_data_ratio(args.supervised_speech_s2s_sample_ratio)
272
+ self.text_ratio = parse_data_ratio(args.text_sample_ratio)
273
+ self.bitext_ratio = parse_data_ratio(args.bitext_sample_ratio)
274
+ self.unsup_ratio = parse_data_ratio(args.unsupervised_speech_sample_ratio)
275
+ self.sample_mode = None
276
+
277
+ def build_model(self, args):
278
+ args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
279
+ args.input_channels = self.data_cfg.input_channels
280
+ return super().build_model(args)
281
+
282
+ def build_tokenizer(self, data_cfg, msg=""):
283
+ logger.info(f"pre-tokenizer {msg}: {data_cfg.pre_tokenizer}")
284
+ return encoders.build_tokenizer(Namespace(**data_cfg.pre_tokenizer))
285
+
286
+ def build_bpe(self, data_cfg, msg=""):
287
+ logger.info(f"tokenizer {msg}: {data_cfg.bpe_tokenizer}")
288
+ return encoders.build_bpe(Namespace(**data_cfg.bpe_tokenizer))
289
+
290
+ @classmethod
291
+ def resolve_data_type(cls, split, use_sup_speech_ctc):
292
+ if len(split.split("_")) == 1:
293
+ # default case, train or valid
294
+ is_train = split
295
+ dtype = "text"
296
+ else:
297
+ is_train, dtype = split.split("_", 1)
298
+ is_train = True if is_train == "train" else False
299
+ if dtype == "sup_speech":
300
+ dtype = "sup_speech_ctc" if use_sup_speech_ctc else "sup_speech_ali"
301
+ assert dtype in (
302
+ "text",
303
+ "bitext",
304
+ "sup_speech_ali",
305
+ "sup_speech_s2s",
306
+ "unsup_speech",
307
+ "sup_speech_ctc",
308
+ ), f"failed resolving {split} (it resulted into: {dtype} ; is_train={is_train})"
309
+ return is_train, dtype
310
+
311
+ def create_modalitydatasetitem(self, dtype, dataset):
312
+ dsitem = None
313
+ if dtype in ("text", "bitext"):
314
+ dsitem = ModalityDatasetItem(
315
+ dtype,
316
+ dataset,
317
+ (self.args.max_source_positions, self.args.max_target_positions),
318
+ self.args.max_text_tokens,
319
+ self.args.batch_size,
320
+ )
321
+ elif dtype in ("sup_speech_ctc", "sup_speech_ali", "sup_speech_s2s"):
322
+ dsitem = ModalityDatasetItem(
323
+ dtype,
324
+ dataset,
325
+ (self.args.max_speech_positions, self.args.max_target_positions),
326
+ self.args.max_speech_tokens,
327
+ self.args.batch_size,
328
+ )
329
+ elif dtype == "unsup_speech":
330
+ dsitem = ModalityDatasetItem(
331
+ dtype, dataset, 1e8, self.args.max_speech_tokens, self.args.batch_size
332
+ )
333
+ else:
334
+ raise ValueError(f"{dtype} is not supported")
335
+ return dsitem
336
+
337
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
338
+ def _get_sup_src_tgt_dict(src_dict, tgt_dict, use_s2s_sup_decoder):
339
+ if use_s2s_sup_decoder:
340
+ return None, tgt_dict
341
+ # use src_dict as tgt_dict here, since we use source dictionary as target for forcealignment
342
+ return None, src_dict
343
+
344
+ is_train, dtype = self.resolve_data_type(split, self.args.use_sup_speech_ctc)
345
+
346
+ # Note we use --add-tgt-lang-token instead of data_cfg.prepend_tgt_lang_tag_no_change to set target language tag in the text dataset
347
+ # Verify add_tgt_lang_token and prepend_tgt_lang_tag_no_change are same
348
+
349
+ # Note we use --multilang-sampling-alpha instead of data_cfg.sampling_text_alpha to set text data sampling
350
+ if is_train:
351
+ msets = []
352
+ # train split, load everything into one
353
+ if self.lang_pairs != "":
354
+ text_dataset = self.load_dataset_only(
355
+ "train", self.lang_pairs, epoch=epoch, combine=combine
356
+ )
357
+ dsitem = self.create_modalitydatasetitem("text", text_dataset)
358
+ msets.append(dsitem)
359
+ if self.lang_pairs_bitext != "": # load bitext
360
+ bitext_dataset = self.load_dataset_only(
361
+ "train_bitext",
362
+ self.lang_pairs_bitext,
363
+ do_mask=False,
364
+ epoch=epoch,
365
+ combine=combine,
366
+ )
367
+ dsitem = self.create_modalitydatasetitem("bitext", bitext_dataset)
368
+ msets.append(dsitem)
369
+ if self.args.sup_speech_train_subset != "":
370
+ pre_tokenizer = self.build_tokenizer(self.data_cfg)
371
+ bpe_tokenizer = self.build_bpe(self.data_cfg)
372
+
373
+ append_eos = True
374
+ sup_speech_type = "sup_speech_ali"
375
+ if self.args.use_sup_speech_ctc:
376
+ # CTC mode
377
+ sup_speech_type = "sup_speech_ctc"
378
+ append_eos = False # CTC doesn't need eos in the target
379
+
380
+ src_dict, tgt_dict = _get_sup_src_tgt_dict(
381
+ self.src_dict, self.tgt_dict, False
382
+ )
383
+ sup_speech_dataset = SpeechToTextJointDatasetCreator.from_tsv(
384
+ self.args.sup_speech_data,
385
+ self.data_cfg,
386
+ self.args.sup_speech_train_subset,
387
+ tgt_dict=tgt_dict,
388
+ src_dict=src_dict,
389
+ pre_tokenizer=pre_tokenizer,
390
+ bpe_tokenizer=bpe_tokenizer,
391
+ src_pre_tokenizer=None,
392
+ src_bpe_tokenizer=None,
393
+ is_train_split=is_train,
394
+ epoch=epoch,
395
+ seed=self.args.seed,
396
+ append_eos=append_eos,
397
+ )
398
+ dsitem = self.create_modalitydatasetitem(
399
+ sup_speech_type, sup_speech_dataset
400
+ )
401
+ msets.append(dsitem)
402
+
403
+ if self.args.sup_speech_s2s_train_subset != "":
404
+ pre_tokenizer = self.build_tokenizer(self.data_s2s_cfg, msg="(s2s)")
405
+ bpe_tokenizer = self.build_bpe(self.data_s2s_cfg, msg="(s2s)")
406
+
407
+ # make sure self.data_cfg.prepend_tgt_lang_tag_no_change == self.args.add_tgt_lang_token
408
+ src_dict, tgt_dict = _get_sup_src_tgt_dict(
409
+ self.src_dict, self.tgt_dict, True
410
+ )
411
+ sup_speech_s2s_dataset = SpeechToTextJointDatasetCreator.from_tsv(
412
+ self.args.sup_speech_s2s_data,
413
+ self.data_s2s_cfg,
414
+ self.args.sup_speech_s2s_train_subset,
415
+ tgt_dict=tgt_dict,
416
+ src_dict=src_dict,
417
+ pre_tokenizer=pre_tokenizer,
418
+ bpe_tokenizer=bpe_tokenizer,
419
+ src_pre_tokenizer=None,
420
+ src_bpe_tokenizer=None,
421
+ is_train_split=is_train,
422
+ epoch=epoch,
423
+ seed=self.args.seed,
424
+ )
425
+ dsitem = self.create_modalitydatasetitem(
426
+ "sup_speech_s2s", sup_speech_s2s_dataset
427
+ )
428
+ msets.append(dsitem)
429
+ if self.args.unsup_speech_train_data != "":
430
+ unsup_speech_dataset = FileAudioDatasetWrapper(
431
+ self.args.unsup_speech_train_data,
432
+ self.args.sample_rate,
433
+ max_sample_size=self.args.max_sample_size,
434
+ min_sample_size=self.args.min_sample_size,
435
+ normalize=False,
436
+ )
437
+ dsitem = self.create_modalitydatasetitem(
438
+ "unsup_speech", unsup_speech_dataset
439
+ )
440
+ msets.append(dsitem)
441
+
442
+ pre_train_dataset = MultiModalityDataset(msets)
443
+ self.datasets[split] = pre_train_dataset
444
+ else: # validation split, load them for each type of data
445
+ if dtype == "text":
446
+ text_dataset = self.load_dataset_only(
447
+ split, self.lang_pairs, epoch=epoch, combine=combine
448
+ )
449
+ dsitem = self.create_modalitydatasetitem("text", text_dataset)
450
+ self.datasets[split] = MultiModalityDataset([dsitem])
451
+ elif dtype == "bitext":
452
+ bitext_dataset = self.load_dataset_only(
453
+ split,
454
+ self.lang_pairs_bitext,
455
+ do_mask=False,
456
+ epoch=epoch,
457
+ combine=combine,
458
+ )
459
+ dsitem = self.create_modalitydatasetitem("bitext", bitext_dataset)
460
+ self.datasets[split] = MultiModalityDataset([dsitem])
461
+
462
+ elif dtype in ("sup_speech_ctc", "sup_speech_ali"):
463
+ assert self.args.sup_speech_valid_subset != ""
464
+ pre_tokenizer = self.build_tokenizer(self.data_cfg)
465
+ bpe_tokenizer = self.build_bpe(self.data_cfg)
466
+ append_eos = True
467
+ if dtype == "sup_speech_ctc":
468
+ # CTC mode
469
+ append_eos = False # CTC doesn't need eos
470
+ assert self.args.use_sup_speech_ctc
471
+
472
+ datasets = []
473
+ for split_name in self.args.sup_speech_valid_subset.split(","):
474
+ src_dict, tgt_dict = _get_sup_src_tgt_dict(
475
+ self.src_dict, self.tgt_dict, False
476
+ )
477
+ datasets.append(
478
+ SpeechToTextJointDatasetCreator.from_tsv(
479
+ self.args.sup_speech_data,
480
+ self.data_cfg,
481
+ split_name,
482
+ tgt_dict=tgt_dict,
483
+ src_dict=src_dict,
484
+ pre_tokenizer=pre_tokenizer,
485
+ bpe_tokenizer=bpe_tokenizer,
486
+ src_pre_tokenizer=None,
487
+ src_bpe_tokenizer=None,
488
+ is_train_split=is_train,
489
+ epoch=epoch,
490
+ seed=self.args.seed,
491
+ append_eos=append_eos,
492
+ )
493
+ )
494
+
495
+ dset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets)
496
+ dsitem = self.create_modalitydatasetitem(dtype, dset)
497
+ self.datasets[split] = MultiModalityDataset([dsitem])
498
+
499
+ elif dtype == "sup_speech_s2s":
500
+ assert self.args.sup_speech_s2s_valid_subset != ""
501
+ pre_tokenizer = self.build_tokenizer(self.data_s2s_cfg)
502
+ bpe_tokenizer = self.build_bpe(self.data_s2s_cfg)
503
+ datasets = []
504
+ for split_name in self.args.sup_speech_s2s_valid_subset.split(","):
505
+ src_dict, tgt_dict = _get_sup_src_tgt_dict(
506
+ self.src_dict, self.tgt_dict, True
507
+ )
508
+ datasets.append(
509
+ SpeechToTextJointDatasetCreator.from_tsv(
510
+ self.args.sup_speech_s2s_data,
511
+ self.data_s2s_cfg,
512
+ split_name,
513
+ tgt_dict=tgt_dict,
514
+ src_dict=src_dict,
515
+ pre_tokenizer=pre_tokenizer,
516
+ bpe_tokenizer=bpe_tokenizer,
517
+ src_pre_tokenizer=None,
518
+ src_bpe_tokenizer=None,
519
+ is_train_split=is_train,
520
+ epoch=epoch,
521
+ seed=self.args.seed,
522
+ )
523
+ )
524
+
525
+ dset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets)
526
+ dsitem = self.create_modalitydatasetitem("sup_speech_s2s", dset)
527
+ self.datasets[split] = MultiModalityDataset([dsitem])
528
+ elif dtype == "unsup_speech":
529
+ assert self.args.unsup_speech_valid_data != ""
530
+ unsup_speech_dataset = FileAudioDatasetWrapper(
531
+ self.args.unsup_speech_valid_data,
532
+ self.args.sample_rate,
533
+ max_sample_size=self.args.max_sample_size,
534
+ min_sample_size=self.args.min_sample_size,
535
+ normalize=False,
536
+ )
537
+ dsitem = self.create_modalitydatasetitem(
538
+ "unsup_speech", unsup_speech_dataset
539
+ )
540
+ self.datasets[split] = MultiModalityDataset([dsitem])
541
+ else:
542
+ raise ValueError(f"Unsupported type {dtype}")
543
+
544
+ def get_sample_ratio(self, epoch):
545
+ sup_ratio = (
546
+ self.sup_ratio[epoch] if len(self.sup_ratio) > epoch else self.sup_ratio[-1]
547
+ )
548
+ sup_s2s_ratio = (
549
+ self.sup_s2s_ratio[epoch]
550
+ if len(self.sup_s2s_ratio) > epoch
551
+ else self.sup_s2s_ratio[-1]
552
+ )
553
+ unsup_ratio = (
554
+ self.unsup_ratio[epoch]
555
+ if len(self.unsup_ratio) > epoch
556
+ else self.unsup_ratio[-1]
557
+ )
558
+ text_ratio = (
559
+ self.text_ratio[epoch]
560
+ if len(self.text_ratio) > epoch
561
+ else self.text_ratio[-1]
562
+ )
563
+ bitext_ratio = (
564
+ self.bitext_ratio[epoch]
565
+ if len(self.bitext_ratio) > epoch
566
+ else self.bitext_ratio[-1]
567
+ )
568
+ return text_ratio, bitext_ratio, sup_ratio, sup_s2s_ratio, unsup_ratio
569
+
570
+ def get_batch_iterator(
571
+ self,
572
+ dataset,
573
+ max_tokens=None,
574
+ max_sentences=None,
575
+ max_positions=None,
576
+ ignore_invalid_inputs=False,
577
+ required_batch_size_multiple=1,
578
+ seed=1,
579
+ num_shards=1,
580
+ shard_id=0,
581
+ num_workers=0,
582
+ epoch=0,
583
+ data_buffer_size=0,
584
+ disable_iterator_cache=False,
585
+ skip_remainder_batch=False,
586
+ grouped_shuffling=False,
587
+ update_epoch_batch_itr=False,
588
+ ):
589
+
590
+ assert isinstance(dataset, MultiModalityDataset)
591
+ if len(dataset.id_to_mode) == 1:
592
+ max_positions = dataset.max_positions[0]
593
+ max_tokens = dataset.max_tokens[0]
594
+ max_sentences = dataset.max_sentences[0]
595
+ return super().get_batch_iterator(
596
+ dataset,
597
+ max_tokens,
598
+ max_sentences,
599
+ max_positions,
600
+ ignore_invalid_inputs,
601
+ required_batch_size_multiple,
602
+ seed,
603
+ num_shards,
604
+ shard_id,
605
+ num_workers,
606
+ epoch,
607
+ data_buffer_size,
608
+ disable_iterator_cache,
609
+ skip_remainder_batch=skip_remainder_batch,
610
+ )
611
+
612
+ mult_ratio = []
613
+ (
614
+ text_ratio,
615
+ bitext_ratio,
616
+ sup_ratio,
617
+ sup_s2s_ratio,
618
+ unsup_ratio,
619
+ ) = self.get_sample_ratio(epoch)
620
+ for mode in dataset.id_to_mode:
621
+ if mode in ("sup_speech_ctc", "sup_speech_ali"):
622
+ mult_ratio.append(sup_ratio)
623
+ elif mode == "sup_speech_s2s":
624
+ mult_ratio.append(sup_s2s_ratio)
625
+ elif mode == "text":
626
+ mult_ratio.append(text_ratio)
627
+ elif mode == "bitext":
628
+ mult_ratio.append(bitext_ratio)
629
+ elif mode == "unsup_speech":
630
+ mult_ratio.append(unsup_ratio)
631
+
632
+ # initialize the dataset with the correct starting epoch
633
+ dataset.set_epoch(epoch)
634
+
635
+ batch_samplers = dataset.get_batch_samplers(
636
+ mult_ratio, required_batch_size_multiple, seed
637
+ )
638
+
639
+ # return a reusable, sharded iterator
640
+ epoch_iter = GroupedEpochBatchIterator(
641
+ dataset=dataset,
642
+ collate_fn=dataset.collater,
643
+ batch_samplers=batch_samplers,
644
+ seed=seed,
645
+ num_shards=num_shards,
646
+ shard_id=shard_id,
647
+ num_workers=num_workers,
648
+ epoch=epoch,
649
+ mult_rate=max(self.args.update_freq) if self.args.same_data_update else 1,
650
+ buffer_size=data_buffer_size,
651
+ skip_remainder_batch=skip_remainder_batch,
652
+ )
653
+ self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
654
+ return epoch_iter
fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import logging
6
+ import os
7
+ from argparse import Namespace
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ from fairseq.data import (
12
+ encoders,
13
+ Dictionary,
14
+ ResamplingDataset,
15
+ TransformEosLangPairDataset,
16
+ ConcatDataset,
17
+ )
18
+ from fairseq.data.iterators import GroupedEpochBatchIterator
19
+ from fairseq.data.audio.multi_modality_dataset import (
20
+ MultiModalityDataset,
21
+ LangPairMaskDataset,
22
+ ModalityDatasetItem,
23
+ )
24
+ from fairseq.data.audio.speech_to_text_dataset import (
25
+ SpeechToTextDataset,
26
+ SpeechToTextDatasetCreator,
27
+ )
28
+ from fairseq.data.audio.speech_to_text_joint_dataset import (
29
+ S2TJointDataConfig,
30
+ SpeechToTextJointDatasetCreator,
31
+ )
32
+ from fairseq.tasks import register_task
33
+ from fairseq.tasks.speech_to_text import SpeechToTextTask
34
+ from fairseq.tasks.translation import load_langpair_dataset
35
+
36
+ logger = logging.getLogger(__name__)
37
+ LANG_TAG_TEMPLATE = "<lang:{}>"
38
+
39
+
40
+ @register_task("speech_text_joint_to_text")
41
+ class SpeechTextJointToTextTask(SpeechToTextTask):
42
+ """
43
+ Task for joint training speech and text to text.
44
+ """
45
+
46
+ @classmethod
47
+ def add_args(cls, parser):
48
+ """Add task-specific arguments to the parser."""
49
+ super(SpeechTextJointToTextTask, cls).add_args(parser)
50
+ ###
51
+ parser.add_argument(
52
+ "--parallel-text-data",
53
+ default="",
54
+ help="path to parallel text data directory",
55
+ )
56
+ parser.add_argument(
57
+ "--max-tokens-text",
58
+ type=int,
59
+ metavar="N",
60
+ help="maximum tokens for encoder text input ",
61
+ )
62
+ parser.add_argument(
63
+ "--max-positions-text",
64
+ type=int,
65
+ metavar="N",
66
+ default=400,
67
+ help="maximum tokens for per encoder text input ",
68
+ )
69
+ parser.add_argument(
70
+ "--langpairs",
71
+ default=None,
72
+ metavar="S",
73
+ help='language pairs for text training, separated with ","',
74
+ )
75
+ parser.add_argument(
76
+ "--speech-sample-ratio",
77
+ default=1,
78
+ type=float,
79
+ metavar="N",
80
+ help="Multiple Ratio for speech dataset with transcripts ",
81
+ )
82
+ parser.add_argument(
83
+ "--text-sample-ratio",
84
+ default=1,
85
+ type=float,
86
+ metavar="N",
87
+ help="Multiple Ratio for text set ",
88
+ )
89
+ parser.add_argument(
90
+ "--update-mix-data",
91
+ action="store_true",
92
+ help="use mixed data in one update when update-freq > 1",
93
+ )
94
+ parser.add_argument(
95
+ "--load-speech-only", action="store_true", help="load speech data only",
96
+ )
97
+ parser.add_argument(
98
+ "--mask-text-ratio",
99
+ type=float,
100
+ metavar="V",
101
+ default=0.0,
102
+ help="mask V source tokens for text only mode",
103
+ )
104
+ parser.add_argument(
105
+ "--mask-text-type",
106
+ default="random",
107
+ choices=["random", "tail"],
108
+ help="mask text typed",
109
+ )
110
+ parser.add_argument(
111
+ "--noise-token",
112
+ default="",
113
+ help="noise token for masking src text tokens if mask-text-ratio > 0",
114
+ )
115
+ parser.add_argument(
116
+ "--infer-target-lang",
117
+ default="",
118
+ metavar="S",
119
+ help="target language for inference",
120
+ )
121
+
122
+ def __init__(self, args, src_dict, tgt_dict, infer_tgt_lang_id=None):
123
+ super().__init__(args, tgt_dict)
124
+ self.src_dict = src_dict
125
+ self.data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml)
126
+ assert self.tgt_dict.pad() == self.src_dict.pad()
127
+ assert self.tgt_dict.eos() == self.src_dict.eos()
128
+ self.speech_only = args.load_speech_only
129
+ self._infer_tgt_lang_id = infer_tgt_lang_id
130
+
131
+ @classmethod
132
+ def setup_task(cls, args, **kwargs):
133
+ """Setup the task (e.g., load dictionaries)."""
134
+ data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml)
135
+ tgt_dict_path = Path(args.data) / data_cfg.vocab_filename
136
+ src_dict_path = Path(args.data) / data_cfg.src_vocab_filename
137
+ if (not os.path.isfile(src_dict_path)) or (not os.path.isfile(tgt_dict_path)):
138
+ raise FileNotFoundError("Dict not found: {}".format(args.data))
139
+ src_dict = Dictionary.load(src_dict_path.as_posix())
140
+ tgt_dict = Dictionary.load(tgt_dict_path.as_posix())
141
+
142
+ print("| src dictionary: {} types".format(len(src_dict)))
143
+ print("| tgt dictionary: {} types".format(len(tgt_dict)))
144
+
145
+ if args.parallel_text_data != "":
146
+ if not os.path.isabs(args.parallel_text_data):
147
+ args.parallel_text_data = os.path.join(
148
+ args.data, args.parallel_text_data
149
+ )
150
+
151
+ if args.langpairs is None:
152
+ raise Exception(
153
+ "Could not infer language pair, please provide it explicitly"
154
+ )
155
+ infer_tgt_lang_id = None
156
+ if args.infer_target_lang != "" and data_cfg.prepend_tgt_lang_tag_no_change:
157
+ tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format(
158
+ args.infer_target_lang
159
+ )
160
+ infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag)
161
+ assert infer_tgt_lang_id != tgt_dict.unk()
162
+ return cls(args, src_dict, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id)
163
+
164
+ def load_langpair_dataset(
165
+ self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0
166
+ ):
167
+ lang_pairs = []
168
+ text_dataset = None
169
+ split = "train"
170
+ for lp in self.args.langpairs.split(","):
171
+ src, tgt = lp.split("-")
172
+ text_dataset = load_langpair_dataset(
173
+ self.args.parallel_text_data,
174
+ split,
175
+ src,
176
+ self.src_dict,
177
+ tgt,
178
+ self.tgt_dict,
179
+ combine=True,
180
+ dataset_impl=None,
181
+ upsample_primary=1,
182
+ left_pad_source=False,
183
+ left_pad_target=False,
184
+ max_source_positions=self.args.max_positions_text,
185
+ max_target_positions=self.args.max_target_positions,
186
+ load_alignments=False,
187
+ truncate_source=False,
188
+ )
189
+ if prepend_tgt_lang_tag:
190
+ # TODO
191
+ text_dataset = TransformEosLangPairDataset(
192
+ text_dataset,
193
+ src_eos=self.src_dict.eos(),
194
+ tgt_bos=self.tgt_dict.eos(), # 'prev_output_tokens' starts with eos
195
+ new_tgt_bos=self.tgt_dict.index(LANG_TAG_TEMPLATE.format(tgt)),
196
+ )
197
+ lang_pairs.append(text_dataset)
198
+ if len(lang_pairs) > 1:
199
+ if sampling_alpha != 1.0:
200
+ size_ratios = SpeechToTextDatasetCreator.get_size_ratios(
201
+ self.args.langpairs.split(","),
202
+ [len(s) for s in lang_pairs],
203
+ alpha=sampling_alpha,
204
+ )
205
+ lang_pairs = [
206
+ ResamplingDataset(d, size_ratio=r, epoch=epoch, replace=(r >= 1.0))
207
+ for d, r in zip(lang_pairs, size_ratios)
208
+ ]
209
+ return ConcatDataset(lang_pairs)
210
+ return text_dataset
211
+
212
+ def inference_step(
213
+ self, generator, models, sample, prefix_tokens=None, constraints=None
214
+ ):
215
+ with torch.no_grad():
216
+ return generator.generate(
217
+ models,
218
+ sample,
219
+ prefix_tokens=prefix_tokens,
220
+ constraints=constraints,
221
+ bos_token=self._infer_tgt_lang_id,
222
+ )
223
+
224
+ def build_src_tokenizer(self, args):
225
+ logger.info(f"src-pre-tokenizer: {self.data_cfg.src_pre_tokenizer}")
226
+ return encoders.build_tokenizer(Namespace(**self.data_cfg.src_pre_tokenizer))
227
+
228
+ def build_src_bpe(self, args):
229
+ logger.info(f"tokenizer: {self.data_cfg.src_bpe_tokenizer}")
230
+ return encoders.build_bpe(Namespace(**self.data_cfg.src_bpe_tokenizer))
231
+
232
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
233
+ """Load a given dataset split.
234
+
235
+ Args:
236
+ split (str): name of the split (e.g., train, valid, test)
237
+ """
238
+ is_train_split = split.startswith("train")
239
+ pre_tokenizer = self.build_tokenizer(self.args)
240
+ bpe_tokenizer = self.build_bpe(self.args)
241
+ src_pre_tokenizer = self.build_src_tokenizer(self.args)
242
+ src_bpe_tokenizer = self.build_src_bpe(self.args)
243
+ ast_dataset = SpeechToTextJointDatasetCreator.from_tsv(
244
+ self.args.data,
245
+ self.data_cfg,
246
+ split,
247
+ self.tgt_dict,
248
+ src_dict=None if self.speech_only else self.src_dict,
249
+ pre_tokenizer=pre_tokenizer,
250
+ bpe_tokenizer=bpe_tokenizer,
251
+ src_pre_tokenizer=src_pre_tokenizer,
252
+ src_bpe_tokenizer=src_bpe_tokenizer,
253
+ is_train_split=is_train_split,
254
+ epoch=epoch,
255
+ seed=self.args.seed,
256
+ )
257
+ noise_token_id = -1
258
+ text_dataset = None
259
+ if self.args.parallel_text_data != "" and is_train_split:
260
+ text_dataset = self.load_langpair_dataset(
261
+ self.data_cfg.prepend_tgt_lang_tag_no_change, 1.0, epoch=epoch,
262
+ )
263
+ if self.args.mask_text_ratio > 0:
264
+ # add mask
265
+ noise_token_id = (
266
+ self.src_dict.unk()
267
+ if self.args.noise_token == ""
268
+ else self.src_dict.index(self.args.noise_token)
269
+ )
270
+ text_dataset = LangPairMaskDataset(
271
+ text_dataset,
272
+ src_bos=self.src_dict.bos(),
273
+ src_eos=self.src_dict.eos(),
274
+ noise_id=noise_token_id,
275
+ mask_ratio=self.args.mask_text_ratio,
276
+ mask_type=self.args.mask_text_type,
277
+ )
278
+
279
+ if text_dataset is not None:
280
+ mdsets = [
281
+ ModalityDatasetItem(
282
+ "sup_speech",
283
+ ast_dataset,
284
+ (self.args.max_source_positions, self.args.max_target_positions),
285
+ self.args.max_tokens,
286
+ self.args.batch_size,
287
+ ),
288
+ ModalityDatasetItem(
289
+ "text",
290
+ text_dataset,
291
+ (self.args.max_positions_text, self.args.max_target_positions),
292
+ self.args.max_tokens_text
293
+ if self.args.max_tokens_text is not None
294
+ else self.args.max_tokens,
295
+ self.args.batch_size,
296
+ ),
297
+ ]
298
+ ast_dataset = MultiModalityDataset(mdsets)
299
+ self.datasets[split] = ast_dataset
300
+
301
+ @property
302
+ def target_dictionary(self):
303
+ """Return the :class:`~fairseq.data.Dictionary` for the language
304
+ model."""
305
+ return self.tgt_dict
306
+
307
+ @property
308
+ def source_dictionary(self):
309
+ """Return the source :class:`~fairseq.data.Dictionary` (if applicable
310
+ for this task)."""
311
+ return None if self.speech_only else self.src_dict
312
+
313
+ def get_batch_iterator(
314
+ self,
315
+ dataset,
316
+ max_tokens=None,
317
+ max_sentences=None,
318
+ max_positions=None,
319
+ ignore_invalid_inputs=False,
320
+ required_batch_size_multiple=1,
321
+ seed=1,
322
+ num_shards=1,
323
+ shard_id=0,
324
+ num_workers=0,
325
+ epoch=0,
326
+ data_buffer_size=0,
327
+ disable_iterator_cache=False,
328
+ skip_remainder_batch=False,
329
+ grouped_shuffling=False,
330
+ update_epoch_batch_itr=False,
331
+ ):
332
+
333
+ if not isinstance(dataset, MultiModalityDataset):
334
+ return super(SpeechTextJointToTextTask, self).get_batch_iterator(
335
+ dataset,
336
+ max_tokens,
337
+ max_sentences,
338
+ max_positions,
339
+ ignore_invalid_inputs,
340
+ required_batch_size_multiple,
341
+ seed,
342
+ num_shards,
343
+ shard_id,
344
+ num_workers,
345
+ epoch,
346
+ data_buffer_size,
347
+ disable_iterator_cache,
348
+ skip_remainder_batch=skip_remainder_batch,
349
+ update_epoch_batch_itr=update_epoch_batch_itr,
350
+ )
351
+
352
+ mult_ratio = [self.args.speech_sample_ratio, self.args.text_sample_ratio]
353
+ assert len(dataset.datasets) == 2
354
+
355
+ # initialize the dataset with the correct starting epoch
356
+ dataset.set_epoch(epoch)
357
+
358
+ batch_samplers = dataset.get_batch_samplers(
359
+ mult_ratio, required_batch_size_multiple, seed
360
+ )
361
+
362
+ # return a reusable, sharded iterator
363
+ epoch_iter = GroupedEpochBatchIterator(
364
+ dataset=dataset,
365
+ collate_fn=dataset.collater,
366
+ batch_samplers=batch_samplers,
367
+ seed=seed,
368
+ num_shards=num_shards,
369
+ shard_id=shard_id,
370
+ num_workers=num_workers,
371
+ epoch=epoch,
372
+ mult_rate=1 if self.args.update_mix_data else max(self.args.update_freq),
373
+ buffer_size=data_buffer_size,
374
+ skip_remainder_batch=skip_remainder_batch,
375
+ )
376
+ self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
377
+ return epoch_iter
fairseq/examples/speech_to_speech/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Speech to speech translation (S2ST)
2
+
3
+ We provide the implementation and resources for the following work on speech-to-speech translation (S2ST):
4
+
5
+ * [Direct speech-to-speech translation with discrete units (Lee et al. 2021)](docs/direct_s2st_discrete_units.md)
6
+ * [Textless Speech-to-Speech Translation on Real Data (Lee et al. 2021)](docs/textless_s2st_real_data.md)
7
+ * [Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation](docs/enhanced_direct_s2st_discrete_units.md)
fairseq/examples/speech_to_speech/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import unity # noqa
fairseq/examples/speech_to_speech/asr_bleu/README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ASR-BLEU evaluation toolkit
2
+
3
+ This toolkit provides a set of public ASR models used for evaluation of different speech-to-speech translation systems at FAIR. It enables easier score comparisons between different system's outputs.
4
+
5
+ The ASRGenerator wraps different CTC-based ASR models from HuggingFace and fairseq code bases. Torchaudio CTC decoder is built on top of it to decode given audio files.
6
+
7
+ Please see `asr_model_cfgs.json` for a list of languages covered currently.
8
+
9
+ The high-level pipeline is simple by design: given a lang tag, script loads the ASR model, transcribes model's predicted audio, and computes the BLEU score against provided reference translations using sacrebleu.
10
+
11
+ # Dependencies
12
+
13
+ Please see `requirements.txt`.
14
+
15
+ # Usage examples
16
+
17
+ This toolkit have been used with:
18
+
19
+ * Speechmatrix project: https://github.com/facebookresearch/fairseq/tree/ust/examples/speech_matrix.
20
+
21
+ * Hokkien speech-to-speech translation project: https://github.com/facebookresearch/fairseq/tree/ust/examples/hokkien.
22
+
23
+ # Standalone run example
24
+
25
+ High-level example, please substitute arguments per your case:
26
+
27
+ ```bash
28
+ python compute_asr_bleu.py --lang <LANG> \
29
+ --audio_dirpath <PATH_TO_AUDIO_DIR> \
30
+ --reference_path <PATH_TO_REFERENCES_FILE> \
31
+ --reference_format txt
32
+ ```
33
+
34
+ For more details about arguments please see the script argparser help.
fairseq/examples/speech_to_speech/asr_bleu/__init__.py ADDED
File without changes
fairseq/examples/speech_to_speech/asr_bleu/asr_model_cfgs.json ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "en": {
3
+ "oct22": {
4
+ "desc": "Wav2Vec 2.0 Large (LV-60) + Self Training from https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec#pre-trained-models",
5
+ "ckpt_path": "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt",
6
+ "dict_path": "https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt",
7
+ "model_type": "fairseq",
8
+ "lang": "en",
9
+ "post_process": "collapse"
10
+ }
11
+ },
12
+ "hok": {
13
+ "oct22": {
14
+ "desc": "Hokkien ASR model, for details check [TODO add paper link]",
15
+ "ckpt_path": "https://dl.fbaipublicfiles.com/ust_asr/hok/checkpoint_best.pt",
16
+ "dict_path": "https://dl.fbaipublicfiles.com/ust_asr/hok/dict.ltr.txt",
17
+ "model_type": "fairseq",
18
+ "lang": "hok",
19
+ "post_process": "none"
20
+ }
21
+ },
22
+ "es": {
23
+ "oct22": {
24
+ "model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
25
+ "model_type": "hf",
26
+ "lang": "es",
27
+ "post_process": "collapse"
28
+ }
29
+ },
30
+ "fr": {
31
+ "oct22": {
32
+ "model_path": "jonatasgrosman/wav2vec2-large-fr-voxpopuli-french",
33
+ "model_type": "hf",
34
+ "lang": "fr",
35
+ "post_process": "collapse"
36
+ }
37
+ },
38
+ "zh": {
39
+ "oct22": {
40
+ "model_path": "ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt",
41
+ "model_type": "hf",
42
+ "lang": "zh",
43
+ "post_process": "collapse"
44
+ }
45
+ },
46
+ "tr": {
47
+ "oct22": {
48
+ "model_path": "cahya/wav2vec2-large-xlsr-turkish-artificial-cv",
49
+ "model_type": "hf",
50
+ "lang": "tr",
51
+ "post_process": "collapse"
52
+ }
53
+ },
54
+ "ar": {
55
+ "oct22": {
56
+ "model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
57
+ "model_type": "hf",
58
+ "lang": "ar",
59
+ "post_process": "collapse"
60
+ }
61
+ },
62
+ "vi": {
63
+ "oct22": {
64
+ "model_path": "not-tanh/wav2vec2-large-xlsr-53-vietnamese",
65
+ "model_type": "hf",
66
+ "lang": "vi",
67
+ "post_process": "collapse"
68
+ }
69
+ },
70
+ "de": {
71
+ "oct22": {
72
+ "model_path": "jonatasgrosman/wav2vec2-xls-r-1b-german",
73
+ "model_type": "hf",
74
+ "lang": "de",
75
+ "post_process": "collapse"
76
+ }
77
+ },
78
+ "pl": {
79
+ "oct22": {
80
+ "model_path": "jonatasgrosman/wav2vec2-xls-r-1b-polish",
81
+ "model_type": "hf",
82
+ "lang": "pl",
83
+ "post_process": "collapse"
84
+ }
85
+ },
86
+ "it": {
87
+ "oct22": {
88
+ "model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-italian",
89
+ "model_type": "hf",
90
+ "lang": "it",
91
+ "post_process": "collapse"
92
+ }
93
+ },
94
+ "pt": {
95
+ "oct22": {
96
+ "model_path": "jonatasgrosman/wav2vec2-xls-r-1b-portuguese",
97
+ "model_type": "hf",
98
+ "lang": "pt",
99
+ "post_process": "collapse"
100
+ }
101
+ },
102
+ "ro": {
103
+ "oct22": {
104
+ "model_path": "gigant/romanian-wav2vec2",
105
+ "model_type": "hf",
106
+ "lang": "ro",
107
+ "post_process": "collapse"
108
+ }
109
+ },
110
+ "cs": {
111
+ "oct22": {
112
+ "model_path": "comodoro/wav2vec2-xls-r-300m-cs-250",
113
+ "model_type": "hf",
114
+ "lang": "cs",
115
+ "post_process": "collapse"
116
+ }
117
+ },
118
+ "sk": {
119
+ "oct22": {
120
+ "model_path": "anuragshas/wav2vec2-xls-r-300m-sk-cv8-with-lm",
121
+ "model_type": "hf",
122
+ "lang": "sk",
123
+ "post_process": "collapse"
124
+ }
125
+ },
126
+ "sl": {
127
+ "oct22": {
128
+ "model_path": "anuragshas/wav2vec2-xls-r-300m-sl-cv8-with-lm",
129
+ "model_type": "hf",
130
+ "lang": "sl",
131
+ "post_process": "collapse"
132
+ }
133
+ },
134
+ "fi": {
135
+ "oct22": {
136
+ "model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
137
+ "model_type": "hf",
138
+ "lang": "fi",
139
+ "post_process": "collapse"
140
+ }
141
+ },
142
+ "hu": {
143
+ "oct22": {
144
+ "model_path": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
145
+ "model_type": "hf",
146
+ "lang": "hu",
147
+ "post_process": "collapse"
148
+ }
149
+ },
150
+ "et": {
151
+ "oct22": {
152
+ "model_path": "RASMUS/wav2vec2-xlsr-1b-et",
153
+ "model_type": "hf",
154
+ "lang": "et",
155
+ "post_process": "collapse"
156
+ }
157
+ },
158
+ "lt": {
159
+ "oct22": {
160
+ "model_path": "sammy786/wav2vec2-xlsr-lithuanian",
161
+ "model_type": "hf",
162
+ "lang": "lt",
163
+ "post_process": "collapse"
164
+ }
165
+ },
166
+ "nl": {
167
+ "oct22": {
168
+ "model_path": "jonatasgrosman/wav2vec2-xls-r-1b-dutch",
169
+ "model_type": "hf",
170
+ "lang": "nl",
171
+ "post_process": "collapse"
172
+ }
173
+ },
174
+ "lv": {
175
+ "oct22": {
176
+ "model_path": "reach-vb/wav2vec2-large-xls-r-1B-common_voice7-lv-ft",
177
+ "model_type": "hf",
178
+ "lang": "lv",
179
+ "post_process": "collapse"
180
+ }
181
+ },
182
+ "sv": {
183
+ "oct22": {
184
+ "model_path": "marinone94/xls-r-300m-sv-robust",
185
+ "model_type": "hf",
186
+ "lang": "sv",
187
+ "post_process": "collapse"
188
+ }
189
+ },
190
+ "hr": {
191
+ "oct22": {
192
+ "model_path": "classla/wav2vec2-xls-r-parlaspeech-hr",
193
+ "model_type": "hf",
194
+ "lang": "hr",
195
+ "post_process": "collapse"
196
+ }
197
+ }
198
+ }
fairseq/examples/speech_to_speech/asr_bleu/compute_asr_bleu.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List
3
+ import sacrebleu
4
+ import pandas as pd
5
+ from glob import glob
6
+ from pathlib import Path
7
+ from utils import retrieve_asr_config, ASRGenerator
8
+ from tqdm import tqdm
9
+ from argparse import ArgumentParser
10
+
11
+
12
+ def merge_tailo_init_final(text):
13
+ """
14
+ Hokkien ASR hypothesis post-processing.
15
+ """
16
+ sps = text.strip().split()
17
+ results = []
18
+ last_syllable = ""
19
+ for sp in sps:
20
+ if sp == "NULLINIT" or sp == "nullinit":
21
+ continue
22
+ last_syllable += sp
23
+ if sp[-1].isnumeric():
24
+ results.append(last_syllable)
25
+ last_syllable = ""
26
+ if last_syllable != "":
27
+ results.append(last_syllable)
28
+ return " ".join(results)
29
+
30
+
31
+ def remove_tone(text):
32
+ """
33
+ Used for tone-less evaluation of Hokkien
34
+ """
35
+ return " ".join([t[:-1] for t in text.split()])
36
+
37
+
38
+ def extract_audio_for_eval(audio_dirpath: str, audio_format: str):
39
+ if audio_format == "n_pred.wav":
40
+ """
41
+ The assumption here is that 0_pred.wav corresponds to the reference at line position 0 from the reference manifest
42
+ """
43
+ audio_list = []
44
+ audio_fp_list = glob((Path(audio_dirpath) / "*_pred.wav").as_posix())
45
+ audio_fp_list = sorted(
46
+ audio_fp_list, key=lambda x: int(os.path.basename(x).split("_")[0])
47
+ )
48
+ for i in range(len(audio_fp_list)):
49
+ try:
50
+ audio_fp = (Path(audio_dirpath) / f"{i}_pred.wav").as_posix()
51
+ assert (
52
+ audio_fp in audio_fp_list
53
+ ), f"{Path(audio_fp).name} does not exist in {audio_dirpath}"
54
+ except AssertionError:
55
+ # check the audio with random speaker
56
+ audio_fp = Path(audio_dirpath) / f"{i}_spk*_pred.wav"
57
+ audio_fp = glob(
58
+ audio_fp.as_posix()
59
+ ) # resolve audio filepath with random speaker
60
+ assert len(audio_fp) == 1
61
+ audio_fp = audio_fp[0]
62
+
63
+ audio_list.append(audio_fp)
64
+ else:
65
+ raise NotImplementedError
66
+
67
+ return audio_list
68
+
69
+
70
+ def extract_text_for_eval(
71
+ references_filepath: str, reference_format: str, reference_tsv_column: str = None
72
+ ):
73
+ if reference_format == "txt":
74
+ reference_sentences = open(references_filepath, "r").readlines()
75
+ reference_sentences = [l.strip() for l in reference_sentences]
76
+ elif reference_format == "tsv":
77
+ tsv_df = pd.read_csv(references_filepath, sep="\t", quoting=3)
78
+ reference_sentences = tsv_df[reference_tsv_column].to_list()
79
+ reference_sentences = [l.strip() for l in reference_sentences]
80
+ else:
81
+ raise NotImplementedError
82
+
83
+ return reference_sentences
84
+
85
+
86
+ def compose_eval_data(
87
+ audio_dirpath: str,
88
+ audio_format: str,
89
+ references_filepath: str,
90
+ reference_format: str,
91
+ reference_tsv_column: str = None,
92
+ save_manifest_filepath=None,
93
+ ):
94
+ """
95
+ Speech matrix decoding pipeline produces audio with the following mask "N_pred.wav" where N is the order of the corresponding input sample
96
+ """
97
+
98
+ reference_sentences = extract_text_for_eval(
99
+ references_filepath, reference_format, reference_tsv_column
100
+ )
101
+ predicted_audio_fp_list = extract_audio_for_eval(audio_dirpath, audio_format)
102
+ assert len(predicted_audio_fp_list) == len(reference_sentences)
103
+
104
+ audio_text_pairs = [
105
+ (audio, reference)
106
+ for audio, reference in zip(predicted_audio_fp_list, reference_sentences)
107
+ ]
108
+
109
+ tsv_manifest = pd.DataFrame(audio_text_pairs, columns=["prediction", "reference"])
110
+
111
+ if save_manifest_filepath is not None:
112
+ tsv_manifest.to_csv(save_manifest_filepath, sep="\t", quoting=3)
113
+
114
+ return tsv_manifest
115
+
116
+
117
+ def load_eval_data_from_tsv(eval_data_filepath: str):
118
+ """
119
+ We may load the result of `compose_eval_data` directly if needed
120
+ """
121
+ eval_df = pd.from_csv(eval_data_filepath, sep="\t")
122
+
123
+ return eval_df
124
+
125
+
126
+ def run_asr_bleu(args):
127
+
128
+ asr_config = retrieve_asr_config(
129
+ args.lang, args.asr_version, json_path="./asr_model_cfgs.json"
130
+ )
131
+ asr_model = ASRGenerator(asr_config)
132
+
133
+ eval_manifest = compose_eval_data(
134
+ audio_dirpath=args.audio_dirpath,
135
+ audio_format=args.audio_format,
136
+ references_filepath=args.reference_path,
137
+ reference_format=args.reference_format,
138
+ reference_tsv_column=args.reference_tsv_column,
139
+ save_manifest_filepath=None,
140
+ )
141
+
142
+ prediction_transcripts = []
143
+ for _, eval_pair in tqdm(
144
+ eval_manifest.iterrows(),
145
+ desc="Transcribing predictions",
146
+ total=len(eval_manifest),
147
+ ):
148
+ transcription = asr_model.transcribe_audiofile(eval_pair.prediction)
149
+ prediction_transcripts.append(transcription.lower())
150
+
151
+ if args.lang == "hok":
152
+ prediction_transcripts = [
153
+ merge_tailo_init_final(text) for text in prediction_transcripts
154
+ ]
155
+
156
+ references = eval_manifest["reference"].tolist()
157
+ bleu_score = sacrebleu.corpus_bleu(prediction_transcripts, [references])
158
+
159
+ print(bleu_score)
160
+
161
+ return prediction_transcripts, bleu_score
162
+
163
+
164
+ def main():
165
+ parser = ArgumentParser(
166
+ description="This script computes the ASR-BLEU metric between model's generated audio and the text reference sequences."
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--lang",
171
+ help="The target language used to initialize ASR model, see asr_model_cfgs.json for available languages",
172
+ type=str,
173
+ )
174
+ parser.add_argument(
175
+ "--asr_version",
176
+ type=str,
177
+ default="oct22",
178
+ help="For future support we add and extra layer of asr versions. The current most recent version is oct22 meaning October 2022",
179
+ )
180
+ parser.add_argument(
181
+ "--audio_dirpath",
182
+ type=str,
183
+ help="Path to the directory containing the audio predictions from the translation model",
184
+ )
185
+ parser.add_argument(
186
+ "--reference_path",
187
+ type=str,
188
+ help="Path to the file containing reference translations in the form of normalized text (to be compared to ASR predictions",
189
+ )
190
+ parser.add_argument(
191
+ "--reference_format",
192
+ choices=["txt", "tsv"],
193
+ help="Format of reference file. Txt means plain text format where each line represents single reference sequence",
194
+ )
195
+ parser.add_argument(
196
+ "--reference_tsv_column",
197
+ default=None,
198
+ type=str,
199
+ help="If format is tsv, then specify the column name which contains reference sequence",
200
+ )
201
+ parser.add_argument(
202
+ "--audio_format",
203
+ default="n_pred.wav",
204
+ choices=["n_pred.wav"],
205
+ help="Audio format n_pred.wav corresponds to names like 94_pred.wav or 94_spk7_pred.wav where spk7 is the speaker id",
206
+ )
207
+ parser.add_argument(
208
+ "--results_dirpath",
209
+ default=None,
210
+ type=str,
211
+ help="If specified, the resulting BLEU score will be written to this file path as txt file",
212
+ )
213
+ parser.add_argument(
214
+ "--transcripts_path",
215
+ default=None,
216
+ type=str,
217
+ help="If specified, the predicted transcripts will be written to this path as a txt file.",
218
+ )
219
+
220
+ args = parser.parse_args()
221
+
222
+ prediction_transcripts, bleu_score = run_asr_bleu(args)
223
+ result_filename = f"{args.reference_format}_{args.lang}_bleu.txt"
224
+ if args.results_dirpath is not None:
225
+ if not Path(args.results_dirpath).exists():
226
+ Path(args.results_dirpath).mkdir(parents=True)
227
+ with open(Path(args.results_dirpath) / result_filename, "w") as f:
228
+ f.write(bleu_score.format(width=2))
229
+
230
+ if args.transcripts_path is not None:
231
+ with open(args.transcripts_path, "w") as f:
232
+ for transcript in prediction_transcripts:
233
+ f.write(transcript + "\n")
234
+
235
+
236
+ if __name__ == "__main__":
237
+ main()
238
+
239
+
240
+ """
241
+ Example to load Sl audio and references, compute BLEU:
242
+
243
+ export lang=fi; split=vp && python compute_asr_bleu.py --lang $lang --audio_dirpath /checkpoint/hygong/S2S/speech_matrix_release_ckpts/generated_waveform_release/en-$lang/test_$split/checkpoint.pt --audio_format n_pred.wav --reference_path /large_experiments/ust/hygong/S2S/SpeechEncoder/manifests/vp-vp/en-$lang/test_$split.$lang --reference_format txt --results_dirpath ./
244
+ """
fairseq/examples/speech_to_speech/asr_bleu/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fairseq==0.12.2
2
+ pandas==1.4.3
3
+ sacrebleu==2.2.0
4
+ torch==1.12.1
5
+ torchaudio==0.12.1
6
+ tqdm==4.64.0
7
+ transformers==4.21.1
fairseq/examples/speech_to_speech/asr_bleu/utils.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import urllib.request
4
+ from pathlib import Path
5
+
6
+ import fairseq
7
+ import torch
8
+ from fairseq.data.data_utils import lengths_to_padding_mask
9
+ from tqdm import tqdm
10
+
11
+ try:
12
+ import torchaudio
13
+ from torchaudio.models.decoder import ctc_decoder
14
+ except ImportError:
15
+ raise ImportError("Upgrade torchaudio to 0.12 to enable CTC decoding")
16
+
17
+
18
+ class DownloadProgressBar(tqdm):
19
+ """A class to represent a download progress bar"""
20
+
21
+ def update_to(self, b=1, bsize=1, tsize=None) -> None:
22
+ """
23
+ Update the download progress
24
+ """
25
+ if tsize is not None:
26
+ self.total = tsize
27
+ self.update(b * bsize - self.n)
28
+
29
+
30
+ def retrieve_asr_config(lang_key: str, asr_version: str, json_path: str) -> dict:
31
+ """
32
+ Retrieve the asr model configs
33
+
34
+ Args:
35
+ lang_key: the lanuage type as the key name
36
+ json_path: the path of the config json file
37
+
38
+ Returns:
39
+ Dict of all the configs in the json file
40
+ """
41
+
42
+ with open(json_path, "r") as f:
43
+ asr_model_cfgs = json.load(f)
44
+ return asr_model_cfgs[lang_key][asr_version]
45
+
46
+
47
+ class ASRGenerator(object):
48
+ """A class to represent a ASR generator"""
49
+
50
+ def __init__(
51
+ self,
52
+ model_cfg: dict,
53
+ cache_dirpath: str = (Path.home() / ".cache" / "ust_asr").as_posix(),
54
+ ) -> None:
55
+ """
56
+ Construct all the necessary attributes of the ASRGenerator class
57
+
58
+ Args:
59
+ model_cfg: the dict of the asr model config
60
+ cache_dirpath: the default cache path is "Path.home()/.cache/ust_asr"
61
+ """
62
+
63
+ self.cache_dirpath = Path(cache_dirpath) / model_cfg["lang"]
64
+ self.model_cfg = model_cfg
65
+
66
+ self.use_cuda = torch.cuda.is_available()
67
+
68
+ torchaudio.set_audio_backend("sox_io")
69
+
70
+ if self.model_cfg["model_type"] == "hf":
71
+ self.prepare_hf_model(self.model_cfg)
72
+ elif self.model_cfg["model_type"] == "fairseq":
73
+ self.prepare_fairseq_model(self.model_cfg)
74
+ else:
75
+ raise NotImplementedError(
76
+ f"Model type {self.model_cfg['model_type']} is not supported"
77
+ )
78
+
79
+ if self.model_cfg["post_process"] == "collapse":
80
+ self.post_process_fn = lambda hypo: "".join(hypo).replace(
81
+ self.sil_token, " "
82
+ )
83
+ elif self.model_cfg["post_process"] == "none":
84
+ self.post_process_fn = lambda hypo: " ".join(hypo).replace(
85
+ self.sil_token, " "
86
+ )
87
+ else:
88
+ raise NotImplementedError
89
+
90
+ if self.use_cuda:
91
+ self.model.cuda()
92
+ self.model.eval()
93
+
94
+ self.decoder = ctc_decoder(
95
+ lexicon=None,
96
+ tokens=self.tokens,
97
+ lm=None,
98
+ nbest=1,
99
+ beam_size=1,
100
+ beam_size_token=None,
101
+ lm_weight=0.0,
102
+ word_score=0.0,
103
+ unk_score=float("-inf"),
104
+ sil_token=self.sil_token,
105
+ sil_score=0.0,
106
+ log_add=False,
107
+ blank_token=self.blank_token,
108
+ )
109
+
110
+ def prepare_hf_model(self, model_cfg: dict) -> None:
111
+ """
112
+ Prepare the huggingface asr model
113
+
114
+ Args:
115
+ model_cfg: dict with the relevant ASR config
116
+ """
117
+
118
+ def infer_silence_token(vocab: list):
119
+ """
120
+ Different HF checkpoints have different notion of silence token
121
+ such as | or " " (space)
122
+ Important: when adding new HF asr model in, check what silence token it uses
123
+ """
124
+ if "|" in vocab:
125
+ return "|"
126
+ elif " " in vocab:
127
+ return " "
128
+ else:
129
+ raise RuntimeError("Silence token is not found in the vocabulary")
130
+
131
+ try:
132
+ from transformers import (AutoFeatureExtractor, AutoTokenizer,
133
+ Wav2Vec2ForCTC, Wav2Vec2Processor)
134
+ except ImportError:
135
+ raise ImportError("Install transformers to load HF wav2vec model")
136
+
137
+ model_path = model_cfg["model_path"]
138
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_path)
139
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
140
+ self.preprocessor = AutoFeatureExtractor.from_pretrained(model_path)
141
+ self.processor = Wav2Vec2Processor.from_pretrained(model_path)
142
+
143
+ # extra unk tokens are there to make some models work e.g. Finnish ASR has some vocab issue
144
+ vocab_list = [
145
+ self.tokenizer.decoder.get(i, f"{self.tokenizer.unk_token}1")
146
+ for i in range(self.tokenizer.vocab_size)
147
+ ]
148
+
149
+ self.sampling_rate = self.preprocessor.sampling_rate
150
+ self.normalize_input = self.preprocessor.do_normalize
151
+ self.tokens = vocab_list
152
+ self.sil_token = infer_silence_token(vocab_list)
153
+ self.blank_token = self.tokenizer.pad_token
154
+
155
+ def prepare_fairseq_model(self, model_cfg: dict) -> None:
156
+ """
157
+ Prepare the fairseq asr model
158
+
159
+ Args:
160
+ model_cfg: the specific model config dict must have: (1) ckpt_path, (2) dict_path
161
+ """
162
+
163
+ def download_file(url: str, cache_dir: Path):
164
+ download_path = cache_dir / url.split("/")[-1]
165
+ if not (cache_dir / url.split("/")[-1]).exists():
166
+ with DownloadProgressBar(
167
+ unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
168
+ ) as t:
169
+ cache_dir.mkdir(parents=True, exist_ok=True)
170
+ urllib.request.urlretrieve(
171
+ url, filename=download_path.as_posix(), reporthook=t.update_to
172
+ )
173
+ else:
174
+ print(f"'{url}' exists in {cache_dir}")
175
+
176
+ return download_path.as_posix()
177
+
178
+ try:
179
+ ckpt_path = model_cfg["ckpt_path"]
180
+ dict_path = model_cfg["dict_path"]
181
+ except KeyError:
182
+ raise KeyError(
183
+ "Fairseq model cfg must provide (1) ckpt_path, (2) dict_path"
184
+ )
185
+
186
+ if re.search("^https", ckpt_path):
187
+ ckpt_path = download_file(ckpt_path, self.cache_dirpath)
188
+ if re.search("^https", dict_path):
189
+ dict_path = download_file(dict_path, self.cache_dirpath)
190
+
191
+ model, saved_cfg, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
192
+ [ckpt_path],
193
+ arg_overrides={
194
+ "task": "audio_finetuning",
195
+ "data": self.cache_dirpath.as_posix(),
196
+ }, # data must have dict in it
197
+ )
198
+
199
+ dict_lines = open(dict_path, "r").readlines()
200
+ tokens = [l.split()[0] for l in dict_lines]
201
+ # adding default fairseq special tokens
202
+ tokens = ["<s>", "<pad>", "</s>", "<unk>"] + tokens
203
+
204
+ self.model = model[0]
205
+ self.tokens = tokens
206
+
207
+ if "|" in tokens:
208
+ self.sil_token = "|"
209
+ else:
210
+ self.sil_token = tokens[
211
+ 2
212
+ ] # use eos as silence token if | not presented e.g., Hok ASR model
213
+ print(f"Inferring silence token from the dict: {self.sil_token}")
214
+ self.blank_token = self.tokens[0]
215
+
216
+ self.sampling_rate = saved_cfg.task.sample_rate
217
+ self.normalize_input = saved_cfg.task.normalize
218
+
219
+ @torch.inference_mode()
220
+ def load_audiofile(self, audio_path: str) -> torch.Tensor:
221
+ """
222
+ Load the audio files and apply resampling and normalizaion
223
+
224
+ Args:
225
+ audio_path: the audio file path
226
+
227
+ Returns:
228
+ audio_waveform: the audio waveform as a torch.Tensor object
229
+ """
230
+
231
+ audio_waveform, sampling_rate = torchaudio.load(audio_path)
232
+ if audio_waveform.dim == 2:
233
+ audio_waveform = audio_waveform.mean(-1)
234
+ if self.sampling_rate != sampling_rate:
235
+ audio_waveform = torchaudio.functional.resample(
236
+ audio_waveform, sampling_rate, self.sampling_rate
237
+ )
238
+ if self.normalize_input:
239
+ # following fairseq raw audio dataset
240
+ audio_waveform = torch.nn.functional.layer_norm(
241
+ audio_waveform, audio_waveform.shape
242
+ )
243
+
244
+ return audio_waveform
245
+
246
+ @torch.inference_mode()
247
+ def compute_emissions(self, audio_input: torch.Tensor) -> torch.Tensor:
248
+ """
249
+ Compute the emissions for either fairseq or huggingface asr model
250
+
251
+ Args:
252
+ audio_path: the input audio waveform
253
+
254
+ Returns:
255
+ emissions: the logits of the encoded prediction.
256
+ """
257
+
258
+ if self.use_cuda:
259
+ audio_input = audio_input.to("cuda")
260
+ if isinstance(self.model, fairseq.models.wav2vec.wav2vec2_asr.Wav2VecCtc):
261
+ padding_mask = lengths_to_padding_mask(torch.tensor([audio_input.numel()]))
262
+ emissions = self.model.w2v_encoder(audio_input, padding_mask)[
263
+ "encoder_out"
264
+ ].transpose(0, 1)
265
+ else:
266
+ emissions = self.model(audio_input).logits
267
+
268
+ return emissions
269
+
270
+ def decode_emissions(self, emissions: torch.Tensor) -> str:
271
+ """
272
+ Decode the emissions and apply post process functions
273
+
274
+ Args:
275
+ emissions: the input Tensor object
276
+
277
+ Returns:
278
+ hypo: the str as the decoded transcriptions
279
+ """
280
+
281
+ emissions = emissions.cpu()
282
+ results = self.decoder(emissions)
283
+
284
+ # assuming the lexicon-free decoder and working with tokens
285
+ hypo = self.decoder.idxs_to_tokens(results[0][0].tokens)
286
+ hypo = self.post_process_fn(hypo)
287
+
288
+ return hypo
289
+
290
+ def transcribe_audiofile(self, audio_path: str, lower=True) -> str:
291
+ """
292
+ Transcribe the audio into string
293
+
294
+ Args:
295
+ audio_path: the input audio waveform
296
+ lower: the case of the transcriptions with lowercase as the default
297
+
298
+ Returns:
299
+ hypo: the transcription result
300
+ """
301
+
302
+ asr_input = self.load_audiofile(audio_path)
303
+ emissions = self.compute_emissions(asr_input)
304
+ hypo = self.decode_emissions(emissions)
305
+
306
+ return hypo.strip().lower() if lower else hypo.strip()
fairseq/examples/speech_to_speech/benchmarking/README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmarking
2
+
3
+ ## Overview
4
+
5
+ The goal of this framework is to support benchmarking various speech to speech translation(S2ST) models in terms of runtime, max-memory consumption and total number of floating point operations(FLOPS). It is a generic framework and can be easily extended to support any fairseq models. To accurately benchmark the performance, core inference modules are re-implemented based on fairseq_cli/generate.py (core.py/Processing) and examples/speech_to_text/generate_waveform.py(core.py/SpeechGeneration. To ensure that the end to end models and cascaded models are compared fairly, for cascaded models we only consider the performance metrics for model inference at all stages ignoring any intermediate data and io processing consumption. We run all the benchmarking runs on CPU as it is generally used in production environment and also due to lack of good benchmarking library support for GPUs.
6
+
7
+ 1. Runtime: Average time in seconds to run model inference on an example from a given dataset. We use [timeit](https://docs.python.org/3/library/timeit.html) library to measure the runtime.
8
+ 2. Max memory: Maximum memory in MiB averaged over by running the model inference on all examples from the given dataset. We use [memory_profiler](https://pypi.org/project/memory-profiler/) library to gather memory footprints for a code snippet and find the maximum to get the max memory used by the code. For cascaded models, we find the max of all stages to get the overall max_memory footprint.
9
+ 3. FLOPS: We compute the average number of floating point operations needed to run model inference for an example from the given dataset. We use [PAPI library](http://www.bnikolic.co.uk/blog/python/flops/2019/10/01/pytorch-count-flops.html) to benchmark the number of flops.
10
+
11
+ ## CLI Commands
12
+
13
+ ```{python}
14
+ CUBLAS_WORKSPACE_CONFIG=:4096:8 python examples/speech_to_speech/benchmarking/get_metrics.py ‘’ --config $config
15
+ ```
16
+
17
+
18
+ ## Note:
19
+
20
+ 1. The npy dataset is a list of samples saved as a .npy file. Each sample is a dictionary with id, net_input.
21
+ 2. The raw dataset is a list of raw audio paths similar to wav2vec2 input tsv file
22
+
23
+ ```{python}
24
+ sample: {
25
+ "id": xx,
26
+ "net_input": {
27
+ "src_tokens": torch.tensor([]),
28
+ "src_lengths": torch.tensor([])
29
+ }
30
+ }
31
+ ```
fairseq/examples/speech_to_speech/benchmarking/configs/2StageS2ST.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general:
2
+ dataset_path: $npy_dataset
3
+ cpu: True
4
+ model_type: 2StageS2ST
5
+ dataset_size: 1
6
+
7
+ stage1:
8
+ data: $data_bin_stage1
9
+ task: speech_to_text
10
+ path: $checkpoint_stage1
11
+ config_yaml: config.yaml
12
+ max_len_a: 2
13
+ max_len_b: 500
14
+
15
+ stage2:
16
+ data: $data_bin_stage2
17
+ task: text_to_speech
18
+ path: $checkpoint_stage2
19
+ config_yaml: config.yaml
fairseq/examples/speech_to_speech/benchmarking/configs/3StageS2ST.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general:
2
+ dataset_path: $npy_dataset
3
+ cpu: True
4
+ model_type: 3StageS2ST
5
+ max_len_a: 2
6
+ max_len_b: 500
7
+ dataset_size: 1
8
+
9
+ stage1:
10
+ data: $data_bin_stage1
11
+ task: speech_to_text
12
+ path: $checkpoint_stage1
13
+ config_yaml: config.yaml
14
+ max_len_a: 2
15
+ max_len_b: 500
16
+
17
+ stage2:
18
+ data: $data_bin_stage2
19
+ task: translation
20
+ path: $checkpoint_stage2
21
+ config_yaml: config.yaml
22
+
23
+
24
+ stage2:
25
+ data: $data_bin_stage3
26
+ task: text_to_speech
27
+ path: $checkpoint_stage3
28
+ config_yaml: config.yaml
fairseq/examples/speech_to_speech/benchmarking/configs/DirectS2U.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general:
2
+ dataset_path: $npy_dataset_path
3
+ cpu: True
4
+ model_type: S2UT
5
+ dataset_size: 5
6
+ dump_speech_waveforms_dir: $dump_waveforms_dir_path
7
+
8
+ stage1:
9
+ data: $data_bin
10
+ task: speech_to_speech
11
+ path: $checkpoint
12
+ config_yaml: config.yaml
13
+ max_len_b: 100000
14
+ beam: 10
15
+ target_is_code: True
16
+ max_target_positions: 3000
17
+ target_code_size: 100
18
+
19
+ stage2:
20
+ vocoder: $vocoder_path
21
+ vocoder_cfg: $vocoder_cfg_json
22
+ dur_prediction: True
fairseq/examples/speech_to_speech/benchmarking/configs/S2T.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general:
2
+ dataset_path: $npy_dataset
3
+ cpu: True
4
+ model_type: S2T
5
+ dataset_size: 1
6
+
7
+ stage1:
8
+ data: $data_bin
9
+ task: speech_to_text
10
+ path: $checkpoint
11
+ config_yaml: config.yaml
12
+ max_len_a: 2
13
+ max_len_b: 500
fairseq/examples/speech_to_speech/benchmarking/core.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timeit
2
+ import logging
3
+ import torch
4
+ from pypapi import events, papi_high as high
5
+ from memory_profiler import memory_usage
6
+ from torch import nn
7
+ from argparse import Namespace
8
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
9
+ from fairseq.data import data_utils as fairseq_data_utils
10
+ from fairseq import checkpoint_utils, tasks, utils
11
+ from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
12
+ from examples.hubert.simple_kmeans.dump_hubert_feature import HubertFeatureReader
13
+ from examples.hubert.simple_kmeans.dump_km_label import ApplyKmeans
14
+ from fairseq_cli.generate import get_symbols_to_strip_from_output
15
+ import soundfile as sf
16
+ import ast
17
+ import json
18
+
19
+ logging.basicConfig()
20
+ logging.root.setLevel(logging.INFO)
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ torch.manual_seed(1)
26
+ torch.set_deterministic(True)
27
+
28
+
29
+ class BenchmarkingBase(nn.Module):
30
+ def __init__(self):
31
+ nn.Module.__init__(self)
32
+ self.s2x_task = None
33
+
34
+ def warm_up(self, sample, repeat):
35
+ """Warm up the model"""
36
+ for _i in range(repeat):
37
+ self.forward(sample)
38
+ logger.info(f"Model warmed up by running inference {repeat} times")
39
+
40
+ def benchmark_run_time(self, dataset, repeat):
41
+ """Benchmark average runtime for the model by calling benchmark_run_time_single_sample function"""
42
+ logger.info("Starting run time benchmarking")
43
+ time_elapsed = 0
44
+ for i, sample in enumerate(dataset):
45
+ time_elapsed += self.benchmark_run_time_single_sample(sample, repeat=repeat)
46
+ if i % 100 == 0:
47
+ logger.info(f"Benchmarked run time for {i}/{len(dataset)} samples")
48
+ total_time_elapsed = time_elapsed / len(dataset)
49
+ return total_time_elapsed
50
+
51
+ def benchmark_run_time_single_sample(self, sample, repeat):
52
+ """Benchmark average runtime for a single sample using timeit library. Units are seconds"""
53
+ timer = timeit.Timer(lambda: self.forward(sample))
54
+ time_elapsed = timer.timeit(repeat)
55
+ return time_elapsed / repeat
56
+
57
+ def count_flops(
58
+ self,
59
+ dataset,
60
+ repeat,
61
+ ):
62
+ """Use PYPAPI library to count average flops for model inference.
63
+ Note: It only works if the model is being run on cpu"""
64
+ logger.info("Starting flop counter")
65
+ high.start_counters([events.PAPI_DP_OPS])
66
+ for i, sample in enumerate(dataset):
67
+ for _r in range(repeat):
68
+ self.forward(sample)
69
+ if i % 100 == 0:
70
+ logger.info(f"Counted flops for {i}/{len(dataset)} samples")
71
+ flops = high.stop_counters()
72
+ flops = round(flops[0] / (repeat * len(dataset)))
73
+ return flops
74
+
75
+ def max_memory(self, dataset, repeat):
76
+ """Compute average max memory consumed by model inference. Units are MiB"""
77
+ logger.info("Starting memory benchmarking")
78
+ total_memory = 0
79
+ for i, sample in enumerate(dataset):
80
+ for _r in range(repeat):
81
+ total_memory += max(memory_usage((self.forward, (sample,), {})))
82
+ if i % 100 == 0:
83
+ logger.info(f"Benchmarked memory for {i}/{len(dataset)} samples")
84
+ total_memory = total_memory / (repeat * len(dataset))
85
+ return total_memory
86
+
87
+ def gather_all_metrics(self, dataset, repeat):
88
+ run_time = self.benchmark_run_time(dataset, repeat)
89
+ max_memory = self.max_memory(dataset, repeat)
90
+ flops = self.count_flops(dataset, repeat)
91
+
92
+ return run_time, max_memory, flops
93
+
94
+ def dump_final_speech_output(
95
+ self, dataset, output_dir, resample_fn, sample_rate, prefix=None
96
+ ):
97
+
98
+ for i, sample in enumerate(dataset):
99
+ hypo = self.forward(sample)[0]
100
+
101
+ def to_np(x):
102
+ return x.detach().cpu().numpy()
103
+
104
+ try:
105
+ wave_preds = to_np(resample_fn(hypo["waveform"]))
106
+ sf.write(
107
+ f"{output_dir}/{prefix}_{i}_pred.wav",
108
+ wave_preds,
109
+ sample_rate,
110
+ )
111
+ except Exception as e:
112
+ raise Exception(
113
+ f" Encountered {e} - Invalid waveform. Make sure the model outputs a waveform"
114
+ )
115
+
116
+
117
+ class Processing(BenchmarkingBase):
118
+ """Class similar to fairseq_cli/generate.py. Supports ASR, MT and ST model inference"""
119
+
120
+ def __init__(self, args):
121
+ super().__init__()
122
+ self.use_cuda = not getattr(args, "cpu", False)
123
+ self.setUp(args)
124
+ self.training = False
125
+ self.s2x_task = self.task
126
+
127
+ def setUp(self, cfg):
128
+ if isinstance(cfg, Namespace):
129
+ cfg = convert_namespace_to_omegaconf(cfg)
130
+
131
+ self.task = tasks.setup_task(cfg.task)
132
+ self.tgt_dict = self.task.target_dictionary
133
+
134
+ # Load ensemble
135
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
136
+ models, _ = checkpoint_utils.load_model_ensemble(
137
+ utils.split_paths(cfg.common_eval.path),
138
+ arg_overrides={},
139
+ task=self.task,
140
+ suffix=cfg.checkpoint.checkpoint_suffix,
141
+ strict=False,
142
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
143
+ )
144
+ if len(models) > 1:
145
+ raise Exception("Currently loading multiple models is not supported")
146
+ self.model = models[0]
147
+
148
+ # Optimize model for generation
149
+ if cfg.common.fp16:
150
+ self.model.half()
151
+ if self.use_cuda:
152
+ self.model.cuda()
153
+ self.model.prepare_for_inference_(cfg)
154
+
155
+ self.generator = self.task.build_generator(
156
+ [self.model],
157
+ cfg.generation,
158
+ extra_gen_cls_kwargs={},
159
+ )
160
+ # Handle tokenization and BPE
161
+ self.tokenizer = self.task.build_tokenizer(cfg.tokenizer)
162
+ self.bpe = self.task.build_bpe(cfg.bpe)
163
+ self.remove_bpe = cfg.common_eval.post_process
164
+
165
+ def encode_source(self, src):
166
+ """Method to generate source tokens from a string"""
167
+ if self.tokenizer is not None:
168
+ src = self.tokenizer.encode(src)
169
+ if self.bpe is not None:
170
+ src = self.bpe.encode(src)
171
+ src_tokens = self.task.source_dictionary.encode_line(src).long()
172
+ src_lens = src_tokens.size(0)
173
+ return {
174
+ "net_input": {
175
+ "src_tokens": src_tokens.view(1, src_lens),
176
+ "src_lengths": torch.tensor([src_lens]),
177
+ }
178
+ }
179
+
180
+ def decode_target(self, hypos):
181
+ """Method to decode target string from tokens"""
182
+ hypo_str = self.tgt_dict.string(
183
+ hypos[0][0]["tokens"].int().cpu(),
184
+ self.remove_bpe,
185
+ get_symbols_to_strip_from_output(self.generator),
186
+ )
187
+ if self.bpe is not None:
188
+ hypo_str = self.bpe.decode(hypo_str)
189
+ if self.tokenizer is not None:
190
+ hypo_str = self.tokenizer.decode(hypo_str)
191
+ return hypo_str
192
+
193
+ def forward(self, sample):
194
+ hypos = self.task.inference_step(
195
+ self.generator,
196
+ [self.model],
197
+ sample,
198
+ prefix_tokens=None,
199
+ constraints=None,
200
+ )
201
+ return hypos
202
+
203
+
204
+ class GenerateWaveformFromCode(BenchmarkingBase):
205
+ """Class to support waveform generation from code. Currently, vocoder only supports single speaker"""
206
+
207
+ def __init__(self, args):
208
+ super().__init__()
209
+ with open(args.vocoder_cfg) as f:
210
+ vocoder_cfg = json.load(f)
211
+ self.dur_prediction = args.dur_prediction
212
+ self.vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg)
213
+
214
+ def format_units(self, input):
215
+ code = torch.LongTensor(list(map(int, input.strip().split()))).view(1, -1)
216
+ return {"code": code}
217
+
218
+ def generate_vocoder_input(self, dataset):
219
+ return [self.format_units(sample) for sample in dataset]
220
+
221
+ def forward(self, sample):
222
+ return [{"waveform": self.vocoder(sample, self.dur_prediction)}]
223
+
224
+
225
+ class HubertUnitExtractor(BenchmarkingBase):
226
+ def __init__(self, args):
227
+ self.feature_reader = HubertFeatureReader(
228
+ args.hubert_ckpt_path, args.hubert_layer
229
+ )
230
+ self.kmeans = ApplyKmeans(args.hubert_km_path)
231
+
232
+ def forward(self, sample):
233
+ with torch.no_grad():
234
+ feat = []
235
+ for start in range(0, sample.size(1), self.feature_reader.max_chunk):
236
+ x_chunk = sample[:, start : start + self.max_chunk]
237
+ feat_chunk, _ = self.feature_reader.model.extract_features(
238
+ source=x_chunk,
239
+ padding_mask=None,
240
+ mask=False,
241
+ output_layer=self.layer,
242
+ )
243
+ feat.append(feat_chunk)
244
+ torch.cat(feat, 1).squeeze(0)
245
+ return self.kmeans(feat).tolist()
246
+
247
+
248
+ class SpeechGeneration(BenchmarkingBase):
249
+ """Class similar to examples/text_to_speech/generate_waveform.py.
250
+ Supports models with speech generation as end goal (TTS, Direct S2ST models etc)"""
251
+
252
+ def __init__(self, args):
253
+ super().__init__()
254
+ self.use_cuda = not getattr(args, "cpu", False)
255
+ self.setUp(args)
256
+ self.s2x_task = self.task
257
+
258
+ def setUp(self, args):
259
+ if args.task == "speech_to_speech":
260
+ args.normalize_waveform = False
261
+ self.task = tasks.setup_task(args)
262
+ self.pre_tokenizer = self.task.build_tokenizer(args)
263
+ self.bpe_tokenizer = self.task.build_bpe(args)
264
+ try:
265
+ self.src_dict = self.task.src_dict
266
+ except Exception:
267
+ self.src_dict = None
268
+ ensemble, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
269
+ [args.path],
270
+ arg_overrides=ast.literal_eval(args.model_overrides),
271
+ task=self.task,
272
+ strict=False,
273
+ )
274
+ self.model = ensemble[0]
275
+ if self.use_cuda:
276
+ self.model.cuda()
277
+ # criterion.cuda()
278
+ self.model.eval()
279
+ self.generator = self.task.build_generator(
280
+ [self.model],
281
+ args,
282
+ )
283
+
284
+ def processTextInput(self, text):
285
+ """Generate source tokens from text input"""
286
+ if self.pre_tokenizer is not None:
287
+ text = self.pre_tokenizer.encode(text)
288
+ if self.bpe_tokenizer is not None:
289
+ text = self.bpe_tokenizer.encode(text)
290
+ target = self.src_dict.encode_line(
291
+ text, add_if_not_exist=False, append_eos=True
292
+ ).long()
293
+ target = fairseq_data_utils.collate_tokens(
294
+ [target],
295
+ self.src_dict.pad(),
296
+ self.src_dict.eos(),
297
+ left_pad=False,
298
+ move_eos_to_beginning=False,
299
+ )
300
+ src_lengths = torch.tensor([target.size(1)], dtype=torch.long)
301
+ prev_output_tokens = None
302
+ sample = {
303
+ "net_input": {
304
+ "src_tokens": target,
305
+ "src_lengths": src_lengths,
306
+ "prev_output_tokens": prev_output_tokens,
307
+ }
308
+ }
309
+ sample = utils.move_to_cuda(sample) if self.use_cuda else sample
310
+ return sample
311
+
312
+ def forward(self, sample):
313
+ sample["speaker"] = None
314
+ output = self.generator.generate(self.model, sample) # , has_targ=False
315
+ return output
316
+
317
+
318
+ class S2UT(BenchmarkingBase):
319
+ """Class to support S2UT models. Also supports generating waveforms from the units predicted"""
320
+
321
+ def __init__(self, s2u_args, vocoder_args=None):
322
+ super().__init__()
323
+ self.s2u = Processing(s2u_args)
324
+ self.vocoder = None
325
+ if vocoder_args:
326
+ self.vocoder = GenerateWaveformFromCode(vocoder_args)
327
+ self.vocoder_input = None
328
+
329
+ def forward(self, sample):
330
+ s2u_hypos = self.s2u(sample)
331
+ s2u_output = self.s2u.decode_target(s2u_hypos)
332
+ if not self.vocoder:
333
+ return s2u_output
334
+ units = self.vocoder.format_units(s2u_output)
335
+ vocoder_output = self.vocoder(units)
336
+ return vocoder_output
337
+
338
+ def generate_s2u_outputs(self, dataset):
339
+ return [self.s2u.decode_target(self.s2u(sample)) for sample in dataset]
340
+
341
+ def compute_metrics(self, metric_type, dataset, repeat=None):
342
+ """Generic function to compute metrics ignoring the io processing time"""
343
+ if self.vocoder and not self.vocoder_input:
344
+ self.s2u_output = self.generate_s2u_outputs(dataset)
345
+ self.vocoder_input = self.vocoder.generate_vocoder_input(self.s2u_output)
346
+
347
+ s2u_metrics = getattr(self.s2u, metric_type)(
348
+ dataset,
349
+ repeat,
350
+ )
351
+ vocoder_metrics = 0
352
+ if self.vocoder:
353
+ vocoder_metrics = getattr(self.vocoder, metric_type)(
354
+ self.vocoder_input,
355
+ repeat,
356
+ )
357
+ print(
358
+ f"metric_type = {metric_type} s2u_metrics = {s2u_metrics} \t vocoder_metrics = {vocoder_metrics}"
359
+ )
360
+ if metric_type == "max_memory":
361
+ return max(s2u_metrics, vocoder_metrics)
362
+ else:
363
+ return s2u_metrics + vocoder_metrics
364
+
365
+ def benchmark_run_time(self, dataset, repeat):
366
+ return self.compute_metrics("benchmark_run_time", dataset, repeat)
367
+
368
+ def count_flops(self, dataset, repeat):
369
+ return self.compute_metrics("count_flops", dataset, repeat)
370
+
371
+ def max_memory(self, dataset, repeat):
372
+ return self.compute_metrics("max_memory", dataset, repeat)
373
+
374
+
375
+ class Cascaded2StageS2ST(BenchmarkingBase):
376
+ """ST + TTS"""
377
+
378
+ def __init__(self, s2t_args, tts_args):
379
+ super().__init__()
380
+ self.s2t = Processing(s2t_args)
381
+ self.s2x_task = self.s2t.task
382
+ self.tts = SpeechGeneration(tts_args) if tts_args else None
383
+ self.training = False
384
+ self.tts_inputs = None
385
+
386
+ def forward(self, sample):
387
+ if not self.tts:
388
+ raise Exception(
389
+ "Forward function is not callable without tts. Reinitialize the class with tts_args"
390
+ )
391
+ s2t_hypos = self.s2t(sample)
392
+ s2t_output = self.s2t.decode_target(s2t_hypos)
393
+ tts_input = self.tts.processTextInput(s2t_output)
394
+ tts_output = self.tts(tts_input)
395
+ return tts_output
396
+
397
+ def generate_s2t_outputs(self, dataset):
398
+ """Process dataset and generate s2t outputs"""
399
+ return [self.s2t.decode_target(self.s2t(sample)) for sample in dataset]
400
+
401
+ def generate_tts_inputs(self, dataset):
402
+ """Process dataset and generate tts inputs"""
403
+ return [self.tts.processTextInput(sample) for sample in dataset]
404
+
405
+ def compute_metrics(self, metric_type, dataset, repeat=None):
406
+ """Generic function to compute metrics ignoring the io processing time"""
407
+ if not self.tts_inputs:
408
+ s2t_outputs = self.generate_s2t_outputs(dataset)
409
+ self.tts_inputs = self.generate_tts_inputs(s2t_outputs)
410
+
411
+ s2t_metrics = getattr(self.s2t, metric_type)(
412
+ dataset,
413
+ repeat,
414
+ )
415
+
416
+ tts_metrics = getattr(self.tts, metric_type)(
417
+ self.tts_inputs,
418
+ repeat,
419
+ )
420
+ print(
421
+ f"metric_type = {metric_type} s2t_metrics = {s2t_metrics} \t tts_metrics = {tts_metrics}"
422
+ )
423
+ if metric_type == "max_memory":
424
+ return max(s2t_metrics, tts_metrics)
425
+ else:
426
+ return s2t_metrics + tts_metrics
427
+
428
+ def benchmark_run_time(self, dataset, repeat):
429
+ return self.compute_metrics("benchmark_run_time", dataset, repeat)
430
+
431
+ def count_flops(self, dataset, repeat):
432
+ return self.compute_metrics("count_flops", dataset, repeat)
433
+
434
+ def max_memory(self, dataset, repeat):
435
+ return self.compute_metrics("max_memory", dataset, repeat)
436
+
437
+
438
+ class Cascaded3StageS2ST(Cascaded2StageS2ST):
439
+ """ASR + MT + TTS"""
440
+
441
+ def __init__(self, s2t_args, tts_args, mt_args):
442
+ super().__init__(s2t_args, tts_args)
443
+ self.mt = Processing(mt_args)
444
+ self.mt_inputs = []
445
+
446
+ def forward(self, sample):
447
+ s2t_hypos = self.s2t(sample)
448
+ s2t_output = self.s2t.decode_target(s2t_hypos)
449
+ mt_input = self.mt.encode_source(s2t_output)
450
+ mt_hypos = self.mt(mt_input)
451
+ mt_output = self.mt.decode_target(mt_hypos)
452
+ tts_input = self.tts.processTextInput(mt_output)
453
+ tts_output = self.tts(tts_input)
454
+ return tts_output
455
+
456
+ def generate_mt_inputs(self, dataset):
457
+ """Process dataset to generate mt model inputs"""
458
+ return [self.mt.encode_source(sample) for sample in dataset]
459
+
460
+ def generate_mt_outputs(self, dataset):
461
+ """Process dataset to generate mt model outputs"""
462
+ return [self.mt.decode_target(self.mt(sample)) for sample in dataset]
463
+
464
+ def compute_metrics(self, metric_type, dataset, repeat=None):
465
+ """Generic function to compute metrics ignoring the io processing time"""
466
+ if not self.tts_inputs:
467
+ s2t_outputs = self.generate_s2t_outputs(dataset)
468
+ self.mt_inputs = self.generate_mt_inputs(s2t_outputs)
469
+ mt_outputs = self.generate_mt_outputs(self.mt_inputs)
470
+ self.tts_inputs = self.generate_tts_inputs(mt_outputs)
471
+
472
+ s2t_metrics = getattr(self.s2t, metric_type)(
473
+ dataset,
474
+ repeat,
475
+ )
476
+ mt_metrics = getattr(self.mt, metric_type)(self.mt_inputs, repeat)
477
+ tts_metrics = getattr(self.tts, metric_type)(
478
+ self.tts_inputs,
479
+ repeat,
480
+ )
481
+ print(
482
+ f"metric_type = {metric_type} s2t_metrics = {s2t_metrics} \t mt_metrics = {mt_metrics} \t tts_metrics = {tts_metrics}"
483
+ )
484
+ if metric_type == "max_memory":
485
+ return max(s2t_metrics, mt_metrics, tts_metrics)
486
+ else:
487
+ return s2t_metrics + mt_metrics + tts_metrics
fairseq/examples/speech_to_speech/benchmarking/data_utils.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fairseq import tasks
2
+ import numpy as np
3
+ import logging
4
+ import random
5
+ from fairseq import options
6
+ import torch
7
+ import os
8
+ import soundfile as sf
9
+
10
+ from fairseq.data.audio.audio_utils import (
11
+ get_waveform,
12
+ parse_path,
13
+ )
14
+
15
+ logging.basicConfig()
16
+ logging.root.setLevel(logging.INFO)
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ random.seed(1)
21
+ np.random.seed(1)
22
+ random_number_generator = np.random.RandomState(30)
23
+
24
+
25
+ def generate_random_data_sample(T, B=1, D=80):
26
+ """Generate random data sample given the T, B, D values"""
27
+ net_input = {
28
+ "src_tokens": torch.tensor(random_number_generator.randn(B, T, D)).float(),
29
+ "src_lengths": torch.tensor([T]),
30
+ }
31
+ return {"net_input": net_input}
32
+
33
+
34
+ def generate_random_dataset(T_range_min, T_range_max, B=1, D=80, dataset_size=100):
35
+ """Generate random dataset with T values within a given range, B, D"""
36
+ T_values = [random.randint(T_range_min, T_range_max) for i in range(dataset_size)]
37
+ dataset = []
38
+ for t in T_values:
39
+ dataset.append(generate_random_data_sample(t, B, D))
40
+ return dataset, sum(T_values) / dataset_size
41
+
42
+
43
+ def load_dataset_npy(file_name, dataset_size=None):
44
+ """Load dataset from a .npy file."""
45
+ data = np.load(file_name, allow_pickle=True)
46
+ if dataset_size:
47
+ data = data[:dataset_size]
48
+ return data
49
+
50
+
51
+ def load_dataset_raw_to_waveforms(
52
+ file_name,
53
+ dataset_size=None,
54
+ need_waveform=True,
55
+ sample_rate=16000,
56
+ read_using_soundfile=False,
57
+ ):
58
+ """Load raw dataset from w2v tsv file. Optionally get waveforms"""
59
+ data = []
60
+ with open(file_name, "r") as fp:
61
+ lines = fp.readlines()
62
+ data = [
63
+ os.path.join(lines[0].strip(), line.strip().split("\t")[0])
64
+ for line in lines[1:]
65
+ ]
66
+
67
+ if dataset_size:
68
+ data = data[:dataset_size]
69
+
70
+ if not need_waveform:
71
+ return data
72
+
73
+ features = []
74
+ if read_using_soundfile:
75
+ for _i, d in enumerate(data):
76
+ wav = sf.read(d)[0]
77
+ if wav.ndim == 2:
78
+ wav = wav.mean(-1)
79
+ features.append(torch.from_numpy(wav).float().view(1, -1))
80
+ else:
81
+ for i, d in enumerate(data):
82
+ _path, slice_ptr = parse_path(d)
83
+ if len(slice_ptr) == 0:
84
+ feat = get_waveform(
85
+ _path, always_2d=True, output_sample_rate=sample_rate
86
+ )[0]
87
+ features.append(
88
+ {
89
+ "id": i,
90
+ "net_input": {
91
+ "src_tokens": torch.tensor(feat),
92
+ "src_lengths": torch.tensor([feat.shape[1]]),
93
+ },
94
+ }
95
+ )
96
+ else:
97
+ raise Exception("Currently unsupported data format")
98
+ return features
99
+
100
+
101
+ def load_dataset_task(
102
+ args,
103
+ batch_size=1,
104
+ limit_size=None,
105
+ ref_dataset=None,
106
+ ):
107
+ """Loads dataset based on args by creating a task"""
108
+ if not args.data or not args.subset or not args.task:
109
+ raise Exception(
110
+ "Please provide necessary arguments to load the dataset - data, subset and task"
111
+ )
112
+ task = tasks.setup_task(args)
113
+
114
+ task.load_dataset(args.subset)
115
+ if not limit_size:
116
+ limit_size = len(task.dataset(args.subset))
117
+
118
+ iter = task.get_batch_iterator(
119
+ dataset=task.dataset(args.subset), max_sentences=batch_size
120
+ ).next_epoch_itr(shuffle=False)
121
+ dataset = []
122
+ for i, sample in enumerate(iter):
123
+ sample = {
124
+ "id": task.datasets[args.subset].ids[sample["id"].item()],
125
+ "net_input": {
126
+ "src_tokens": sample["net_input"]["src_tokens"],
127
+ "src_lengths": sample["net_input"]["src_lengths"],
128
+ },
129
+ }
130
+ dataset.append(sample)
131
+ if i == limit_size - 1:
132
+ break
133
+
134
+ if ref_dataset:
135
+ try:
136
+ ids = get_ids_from_dataset(ref_dataset)
137
+ except Exception as e:
138
+ raise Exception(f"{e} - Cannot extract ids from reference dataset")
139
+
140
+ filtered_dataset = []
141
+ for sample in dataset:
142
+ if (
143
+ sample["id"] in ids
144
+ or sample["id"][5:] in ids
145
+ or f"dev_{sample['id']}" in ids
146
+ ):
147
+ filtered_dataset.append(sample)
148
+ dataset = filtered_dataset
149
+
150
+ max_len, min_len, avg_len = get_dataset_stats(dataset)
151
+ print(
152
+ f"{args.subset} dataset stats : num_samples={len(dataset)} max_len = {max_len} min_len = {min_len} avg_len = {avg_len}"
153
+ )
154
+
155
+ return dataset
156
+
157
+
158
+ def randomly_sample_subset(dataset, size=500):
159
+ """Randomly sample subset from a dataset"""
160
+ random_indices = [random.randint(0, len(dataset) - 1) for i in range(size)]
161
+ return [dataset[i] for i in random_indices]
162
+
163
+
164
+ def get_short_data_subset(dataset, size=500):
165
+ """Get a subset of desired size by sorting based on src_lengths"""
166
+ return sort_dataset(dataset)[:size]
167
+
168
+
169
+ def get_long_data_subset(dataset, size=500):
170
+ """Get a subset of desired size by sorting based on src_lengths descending"""
171
+ return sort_dataset(dataset, reverse=True)[:size]
172
+
173
+
174
+ def sort_dataset(dataset, reverse=False):
175
+ return sorted(
176
+ dataset, key=lambda x: x["net_input"]["src_lengths"].item(), reverse=reverse
177
+ )
178
+
179
+
180
+ def save_dataset_npy(dataset, file_name):
181
+ """Save a dataset as .npy file"""
182
+ np.save(file_name, dataset)
183
+
184
+
185
+ def get_dataset_stats(dataset):
186
+ """Get stats about dataset based on src_lengths of samples"""
187
+ max_len = 0
188
+ min_len = 100000
189
+ avg_len = 0
190
+ for d in dataset:
191
+ max_len = max(max_len, d["net_input"]["src_lengths"].item())
192
+ min_len = min(min_len, d["net_input"]["src_lengths"].item())
193
+ avg_len += d["net_input"]["src_lengths"].item()
194
+
195
+ return max_len, min_len, avg_len / len(dataset)
196
+
197
+
198
+ def make_parser():
199
+ """
200
+ Additional args:
201
+ 1. Provide the dataset dir path using --data.
202
+ 2. Loading the dataset doesn't require config, provide --config-yaml to apply additional feature transforms
203
+ """
204
+ parser = options.get_speech_generation_parser()
205
+ parser.add_argument(
206
+ "--subset",
207
+ default=None,
208
+ type=str,
209
+ required=True,
210
+ help="Subset to use for dataset generation",
211
+ )
212
+ parser.add_argument(
213
+ "--dataset-save-dir",
214
+ default=None,
215
+ type=str,
216
+ required=False,
217
+ help="Dir path in which the datasets are to be saved",
218
+ )
219
+ parser.add_argument(
220
+ "--ref-dataset",
221
+ default=None,
222
+ type=str,
223
+ required=False,
224
+ help="If provided, the ids in the reference dataset will be used to filter the new dataset generated.",
225
+ )
226
+ parser.add_argument("--dataset-save-token", default="", type=str, required=False)
227
+
228
+ options.add_generation_args(parser)
229
+ return parser
230
+
231
+
232
+ def get_ids_from_dataset(dataset):
233
+ return {sample["id"]: 1 for sample in dataset}
234
+
235
+
236
+ def cli_main():
237
+ parser = make_parser()
238
+ args = options.parse_args_and_arch(parser)
239
+ dataset = load_dataset_task(args)
240
+
241
+ random_dataset = randomly_sample_subset(dataset)
242
+ short_dataset = get_short_data_subset(dataset)
243
+ long_dataset = get_long_data_subset(dataset)
244
+
245
+ if args.dataset_save_token:
246
+ args.dataset_save_token = f"_{args.dataset_save_token}_"
247
+
248
+ if args.dataset_save_dir:
249
+ save_dataset_npy(
250
+ random_dataset,
251
+ f"{args.dataset_save_dir}/random_dataset{args.dataset_save_token}w_ids.npy",
252
+ )
253
+ save_dataset_npy(
254
+ short_dataset,
255
+ f"{args.dataset_save_dir}/short_dataset{args.dataset_save_token}w_ids.npy",
256
+ )
257
+ save_dataset_npy(
258
+ long_dataset,
259
+ f"{args.dataset_save_dir}/long_dataset{args.dataset_save_token}w_ids.npy",
260
+ )
261
+
262
+
263
+ if __name__ == "__main__":
264
+ cli_main()
fairseq/examples/speech_to_speech/benchmarking/get_metrics.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import logging
4
+ from argparse import Namespace
5
+ import yaml
6
+ from fairseq import options
7
+ from examples.speech_to_speech.benchmarking.core import (
8
+ Processing,
9
+ SpeechGeneration,
10
+ Cascaded2StageS2ST,
11
+ Cascaded3StageS2ST,
12
+ S2UT,
13
+ )
14
+ from examples.speech_to_speech.benchmarking.data_utils import (
15
+ load_dataset_npy,
16
+ load_dataset_raw_to_waveforms,
17
+ )
18
+
19
+
20
+ logging.basicConfig()
21
+ logging.root.setLevel(logging.INFO)
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ torch.manual_seed(1)
26
+ torch.set_deterministic(True)
27
+
28
+
29
+ def make_parser():
30
+ """Note: As the names indicate use s2x_args(ex:ST, ASR etc) for models with speech input,
31
+ x2s_args for models with speech output(ex:TTS) and mt_args for translation models (ex: mt, T2U etc).
32
+ For direct S2ST models, use x2s_args to provide model details.
33
+ """
34
+ parser = options.get_speech_generation_parser()
35
+ parser.add_argument("--target-is-code", action="store_true", default=False)
36
+ parser.add_argument("--config", type=str)
37
+ parser.add_argument(
38
+ "--model-type",
39
+ default="S2U",
40
+ choices=["S2S", "TTS", "S2UT", "MT", "S2T", "2StageS2ST", "3StageS2ST"],
41
+ help="Choose one of the models. For model inference implementation, refer to core.py",
42
+ )
43
+ parser.add_argument(
44
+ "--dataset-path",
45
+ type=str,
46
+ help="""File to load dataset from. Assumes dataset is a list of samples.
47
+ Each sample is a dict of format {'net_input':{'src_tokens':torch.tenor(),'src_lengths':torch.tensor()}}""",
48
+ )
49
+ parser.add_argument(
50
+ "--dataset-type",
51
+ type=str,
52
+ default="npy",
53
+ choices=["npy", "raw"],
54
+ help="""Type of input dataset file""",
55
+ )
56
+ parser.add_argument(
57
+ "--read-using-sf",
58
+ type=str,
59
+ default=False,
60
+ help="""If sound file should be used to read the raw dataset""",
61
+ )
62
+ parser.add_argument(
63
+ "--dataset-size",
64
+ default=None,
65
+ type=int,
66
+ help="Dataset size to use for benchmarking",
67
+ )
68
+ parser.add_argument(
69
+ "--dump-speech-waveforms-dir",
70
+ default=None,
71
+ type=str,
72
+ help="Directory to dump the speech waveforms computed on the dataset.",
73
+ )
74
+ parser.add_argument(
75
+ "--dump-waveform-file-prefix",
76
+ default="",
77
+ type=str,
78
+ help="File name prefix for the saved speech waveforms",
79
+ )
80
+ parser.add_argument(
81
+ "--feat-dim", default=80, type=int, help="Input feature dimension"
82
+ )
83
+ parser.add_argument(
84
+ "--target-sr",
85
+ default=16000,
86
+ type=int,
87
+ help="Target sample rate for dumping waveforms",
88
+ )
89
+
90
+ options.add_generation_args(parser)
91
+ options.get_interactive_generation_parser(parser)
92
+ return parser
93
+
94
+
95
+ def cli_main():
96
+ parser = make_parser()
97
+ args = options.parse_args_and_arch(parser)
98
+
99
+ with open(
100
+ args.config,
101
+ "r",
102
+ ) as f:
103
+ config = yaml.load(f, Loader=yaml.FullLoader)
104
+ dict_args = vars(args)
105
+ dict_args.update(config["general"])
106
+ args = Namespace(**dict_args)
107
+
108
+ i = 1
109
+ stage_args = []
110
+ while i <= 3:
111
+ var = f"stage{i}"
112
+ tmp_args = copy.deepcopy(dict_args)
113
+ if var in config:
114
+ tmp_args.update(config[var])
115
+ stage_args.append(Namespace(**tmp_args))
116
+ i += 1
117
+ else:
118
+ break
119
+
120
+ if args.model_type == "S2S" or args.model_type == "TTS":
121
+ model = SpeechGeneration(stage_args[0])
122
+ elif args.model_type == "S2UT":
123
+ model = S2UT(stage_args[0], stage_args[1] if len(stage_args) > 1 else None)
124
+ elif args.model_type == "MT" or args.model_type == "S2T":
125
+ model = Processing(stage_args[0])
126
+ elif args.model_type == "2StageS2ST":
127
+ model = Cascaded2StageS2ST(stage_args[0], stage_args[1])
128
+ elif args.model_type == "3StageS2ST":
129
+ model = Cascaded3StageS2ST(stage_args[0], stage_args[2], stage_args[1])
130
+ else:
131
+ raise Exception(f"Currently unsupported model type {args.model_type}")
132
+
133
+ print(f"Evaluating on dataset - {args.dataset_path}\n")
134
+
135
+ if args.dataset_type == "npy":
136
+ dataset = load_dataset_npy(args.dataset_path, dataset_size=args.dataset_size)
137
+ elif args.dataset_type == "raw":
138
+ dataset = load_dataset_raw_to_waveforms(
139
+ args.dataset_path,
140
+ dataset_size=args.dataset_size,
141
+ read_using_soundfile=args.read_using_sf,
142
+ )
143
+ else:
144
+ raise Exception(f"Invalid dataset type {args.dataset_type}")
145
+
146
+ model.warm_up(sample=dataset[0], repeat=2)
147
+
148
+ run_time, memory, flops = model.gather_all_metrics(dataset, repeat=1)
149
+ print(f"run_time = {run_time}sec \tmemory = {memory}MiB \tflops = {flops}")
150
+
151
+ if args.dump_speech_waveforms_dir:
152
+ model.dump_final_speech_output(
153
+ dataset,
154
+ args.dump_speech_waveforms_dir,
155
+ lambda x: x,
156
+ args.target_sr,
157
+ prefix=args.dump_waveform_file_prefix,
158
+ )
159
+
160
+
161
+ if __name__ == "__main__":
162
+ cli_main()
fairseq/examples/speech_to_speech/docs/data_augmentation.md ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Noise and audio augmentation techniques
2
+
3
+ The noise and data augmentation techniques were written in an effort to understand how augmenatation can affect model robustness and performance in both clean and noisy settings.
4
+
5
+ All transforms discussed in this section are subclasses of `AudioFeatureTransform`, `AudioWaveformTransform`, or `AudioDatasetTransform`. Each `Audio*Transform` has unique interaction with the data. If interested in implemented one's own transforms, it is highly advisable to review the differences (see [Adding your own transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#adding-your-own-transforms)). If only applying the in-built transforms, then one only needs to be mindful that the correct kind of transform is listed in the config (see [Using transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#using-transforms)). These transforms can be applied to instances of `SpeechToTextDataset`.
6
+
7
+ ### Contents
8
+ [In-built transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#in-built-transforms)
9
+
10
+ [Benchmark studies](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#benchmark-studies)
11
+
12
+ [Using transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#using-transforms)
13
+
14
+ [Adding your own transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#adding-your-own-transforms)
15
+
16
+
17
+ ## In-built transforms
18
+ ### 1. Utterance concatenation
19
+ Utterance concatenation is a data augmenation technique introduced as ConcatAug in [Translatotron 2: High-quality direct speech-to-speech translation
20
+ with voice preservation](https://arxiv.org/pdf/2107.08661.pdf).
21
+ With some parameterized probability, samples are concatenated with one other randomly chosen sample from the whole dataset. In the positive (concatenation) case, accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]+source[j]` and `target=target[i]+target[j]`. In the negative (skip concatenation) case, accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]` and `target=target[i]` as usual.
22
+
23
+ **Usage**: `concataugment` is an `AudioDatasetTransform` and has three configurable hyperparameters:
24
+ - `rate`: probability that any single access will result in the positive (concatenation) case. Defaults to 0.25.
25
+ - `max_tokens`: maximum number of tokens allowed for concatenated source sequences. This parameter is meant to limit the length of concatenated samples to avoid out-of-memory errors. Defaults to 300.
26
+ - `attempts`: maximum number of invalid concatenation attempts before defaulting to the negative (skip concatenation) case. This parameter aims to limit excessive time spent trying to find candidate samples that are short enough to concatenate with. Defaults to 5.
27
+
28
+ Please be wary of OOMs while using this augmentation technique; we used smaller batch sizes as a workaround to avoid OOMs. Batch size is determined by update frequency, batch size hyperparameter, and the number of GPU, so you may want to alter these to this end.
29
+
30
+ ### 2. Noise augmentation suite
31
+
32
+ The four noise augmentation methods in this suite adhere to the following principle: with some parameterized probability, samples are overlayed with a noise track. The content of the noise track is specific to the method. Signal-to-noise ratio with which the noise track is overlayed is determined by choosing a value from a random uniform distribution with parameterized endpoints. The first three methods are based off data augmentation methods suggested in Section 3.3 of [X-Vectors: Robust DNN Embeddings for Speaker Recognition](https://danielpovey.com/files/2018_icassp_xvectors.pdf).
33
+
34
+ #### 2.1. Music augmentation
35
+ For music augmentation, the noise track consists of one file uniformly randomly selected from a corpus of music files. The music file is cut to size, including being repeated to fill the original sample length if necessary.
36
+
37
+ **Usage**: `musicaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
38
+ - `samples_path`: path where background music files are saved as audios (.wav files). No default.
39
+ - `rate`: probability that any single access will result in the positive (background music) case. Defaults to 0.25.
40
+ - `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
41
+ - `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
42
+
43
+ #### 2.2. Babble augmentation
44
+ For babble augmentation, the noise track consists of multiple audios uniformly randomly selected from a corpus of speech files. The number of speech audios in the background track is chosen randomly with equal probability between 3 and 7 audios.
45
+
46
+ **Usage**: `babbleaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
47
+ - `samples_path`: path where background speech files are saved as audios (.wav files). No default.
48
+ - `rate`: probability that any single access will result in the positive (background speech) case. Defaults to 0.25.
49
+ - `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
50
+ - `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
51
+
52
+ #### 2.3. Sporadic noise augmentation
53
+ For sporadic noise augmentation, the noise track is mostly silent except for intermittent short clips of noise which are added at roughly a parameterized frequency. These clips are randomly chosen and cut from a corpus of noise files to lengths according to a parameterized Gaussian distribution.
54
+
55
+ **Usage**: `sporadicnoiseaugment` is an `AudioWaveformTransform` and has seven configurable hyperparameters:
56
+ - `samples_path`: path where background noise files are saved as audios (.wav files). No default.
57
+ - `rate`: probability that any single access will result in the positive (add a sporadic noise track) case. Defaults to 0.25.
58
+ - `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
59
+ - `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
60
+ - `noise_rate`: rate in noises per second at which noise clip will be added to the original sample
61
+ - `noise_len_mean`: mean of Gaussian normal distribution from which length of noise clip is chosen
62
+ - `noise_len_std`: standard deviation of Gaussian normal distribution from which length of noise clip is chosen
63
+
64
+ #### 2.4. Background noise augmentation
65
+ For background noise augmentation, the noise track is a single track uniformly randomly selected from a corpus of noise files. The noise file is cut to size, including being repeated to fill the original sample length if necessary.
66
+
67
+ **Usage**: `backgroundnoiseaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
68
+ - `samples_path`: path where background noise files are saved as audios (.wav files). No default.
69
+ - `rate`: probability that any single access will result in the positive (background noise) case. Defaults to 0.25.
70
+ - `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
71
+ - `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
72
+
73
+ ### 3. Mixed babble and background noise augmentation with recognizable source speaker
74
+
75
+ This augmentation technique is based on Algorithm 1 in [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) and is similar to the noise augmentation suite techniques in that it has a background noise track. The noise track consists of either (1) another audio sample from the batch or (2) a background noise track. A key difference is the length of the noise track is chosen from a uniform random distribution between 0 and half of the original sample length.
76
+
77
+ **Usage**: `noisyoverlapaugment` is an `AudioDatasetTransform` and has seven configurable hyperparameters:
78
+ - `noises_path`: path where background noise files are saved as audios (.wav files). No default.
79
+ - `rate`: probability that any single access will result in the positive (background noise) case. Defaults to 0.25.
80
+ - `mixing_noise_rate`: probability that in a positive (background noise) case, the noise track will consist of background noise (rather than babble from the batch). Defaults to 0.1.
81
+ - `noise_snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to -5.
82
+ - `noise_snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
83
+ - `utterance_snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add **another audio from the batch** to the original source. Defaults to -5.
84
+ - `utterance_snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add **another audio from the batch** to the original source. Defaults to 5.
85
+
86
+ ## Benchmark studies
87
+ ### Evaluation on clean data
88
+ Augmentation in training data|Hyperparameters|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
89
+ ---|---|---|---|---|---
90
+ None||3.954|24.984|23.962|24.448
91
+ ConcatAugment|rate = 0.25, max_tokens = 3000, attempts = 5|3.940|25.322|26.124|26.19
92
+ BabbleAugment|rate = 0.25, MUSAN speech, snr_min = (-5), snr_max = 5|3.957|24.226|23.186|22.368|
93
+ BackgroundNoiseAugment|rate = 0.1, MUSAN noises, snr_min = (-10), snr_max = 10|3.955|24.745|23.513|23.819
94
+ MusicAugment|rate = 0.25, MUSAN music, snr_min = 0, snr_max = 20|3.954|25.096|24.301|23.341|
95
+ SporadicNoiseAugment|rate = 0.1, noise_rate = 0.25, MUSAN noises, snr_min = 10, snr_max = 35|3.954|24.924|23.951|23.484|
96
+ MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|as above, except limited rates to sum to 0.25: music (0.074), background (0.029), babble (0.074), sporadic (0.029)|3.953|24.874|23.675|24.249|
97
+ NoisyOverlapAugment|rate = 0.25, mixing_noise_rate = 0.5, MUSAN noises, utterance_snr_min = (-10), utterance_snr_max = 0, noise_snr_min = (-5), noise_snr_max = 20|3.954|24.949|24.015|23.768|
98
+
99
+ ### Evaluation on data with music noise added at SNR = (-5) - 5
100
+ Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
101
+ ---|---|---|---|---
102
+ None|3.954|15.785|21.105|16.944
103
+ ConcatAugment|3.940|17.186|23.255|18.24
104
+ BabbleAugment|3.957|19.158|22.064|17.116
105
+ BackgroundNoiseAugment|3.955|17.777|22.0|17.535|
106
+ MusicAugment|3.954|20.345|23.126|19.433|
107
+ SporadicNoiseAugment|3.954|15.927|21.382|14.736|
108
+ MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|19.724|22.659|17.852|
109
+ NoisyOverlapAugment|3.954|17.49|22.142|17.207|
110
+
111
+ ### Evaluation on data with babble noise added at SNR = (-5) - 5
112
+ Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
113
+ ---|---|---|---|---
114
+ None|3.954|4.092|13.514|5.13
115
+ ConcatAugment|3.940|5.493|15.835|6.893
116
+ BabbleAugment|3.957|16.12|21.097|13.996
117
+ BackgroundNoiseAugment|3.955|4.691|15.784|5.982
118
+ MusicAugment|3.954|8.06|17.764|9.008
119
+ SporadicNoiseAugment|3.954|4.009|13.935|4.814
120
+ MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|14.692|20.882|14.45
121
+ NoisyOverlapAugment|3.954|4.032|16.434|7.284
122
+
123
+ ### Evaluation on data with sporadic noise added at SNR = (-5) - 5
124
+ Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
125
+ ---|---|---|---|---
126
+ None|3.954|23.778|23.745|22.748
127
+ ConcatAugment|3.940|24.239|25.907|25.723
128
+ BabbleAugment|3.957|23.42|23.048|21.076
129
+ BackgroundNoiseAugment|3.955|23.998|23.467|22.494
130
+ MusicAugment|3.954|24.142|24.181|19.143
131
+ SporadicNoiseAugment|3.954|23.97|23.894|22.61
132
+ MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|24.118|23.59|23.717
133
+ NoisyOverlapAugment|3.954|24.265|24.103|23.167
134
+
135
+ ### Evaluation on data with background noise added at SNR = (-5) - 5
136
+ Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
137
+ ---|---|---|---|---
138
+ None|3.954|20.201|22.525|19.66
139
+ ConcatAugment|3.940|20.904|24.706|21.353
140
+ BabbleAugment|3.957|20.687|22.374|18.907
141
+ BackgroundNoiseAugment|3.955|21.574|22.998|20.043
142
+ MusicAugment|3.954|21.65|23.529|19.87
143
+ SporadicNoiseAugment|3.954|20.578|22.577|19.096
144
+ MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|21.811|23.144|20.986
145
+ NoisyOverlapAugment|3.954|21.312|23.153|20.302
146
+
147
+ ### Evaluation on data with all four types of noises added at SNR = (-5) - 5, each applied with prob 0.5
148
+ Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
149
+ ---|---|---|---|---
150
+ None|3.954|10.895|19.319|12.748
151
+ ConcatAugment|3.940|13.517|21.658|15.428
152
+ BabbleAugment|3.957|18.09|21.384|16.018
153
+ BackgroundNoiseAugment|3.955|12.837|20.719|13.933
154
+ MusicAugment|3.954|16.589|21.823|15.927
155
+ SporadicNoiseAugment|3.954|11.238|19.91|13.31
156
+ MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|18.636|21.935|17.845
157
+ NoisyOverlapAugment|3.954|12.829|20.856|15.048
158
+
159
+ ### Evaluation on data with noisy overlap augment
160
+ Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
161
+ ---|---|---|---|---
162
+ None|3.954|21.245|22.24|20.994
163
+ ConcatAugment|3.940|21.611|24.247|23.068
164
+ BabbleAugment|3.957|21.867|21.987|20.099|
165
+ BackgroundNoiseAugment|3.955|21.533|21.806|19.717|
166
+ MusicAugment|3.954|21.823|22.643|20.847|
167
+ SporadicNoiseAugment|3.954|21.373|22.381|20.672|
168
+ MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|22.206|22.414|21.375|
169
+ NoisyOverlapAugment|3.954|23.371|23.396|22.627|
170
+
171
+ ## Using transforms
172
+ Transforms are configurable.
173
+
174
+ 1. Please pay careful attention to the type of transform you are applying.
175
+ - `concataugment` and `noisyoverlapaugment` are instances of `AudioDatasetTransform` and should be listed in the config under `dataset_transforms`.
176
+ - `musicaugment`, `babbleaugment`, `sporadicnoiseaugment`, and `backgroundnoiseaugment` are instances of `AudioWaveformTransform` and should be listed under `waveform_transforms`.
177
+ - Instances of `AudioFeatureTransform` should be listed under `feature_transforms`.
178
+ 2. Feel free to apply these augmentations in different contexts, e.g., you may use a `_train` or `_eval` flag to specify when the transform will be applied. If the dataset at hand contains `train` in its name, those transforms under the `_train` flag will be applied; else, the remaining transforms will be applied.
179
+
180
+ For example, you would add this to your config to apply the musicaugment transform to a training dataset:
181
+ ```yaml
182
+ musicaugment:
183
+ samples_path: ${MUSIC_PATH}
184
+ snr_min: 10
185
+ snr_max: 15
186
+ rate: 0.25
187
+ waveform_transforms:
188
+ _train:
189
+ - musicaugment
190
+ ```
191
+ or add this to apply the concataugment transform:
192
+ ```yaml
193
+ concataugment:
194
+ rate: 0.25
195
+ max_tokens: 3000
196
+ attempts: 5
197
+ dataset_transforms:
198
+ _train:
199
+ - concataugment
200
+ ```
201
+ You may also want to add multiple of one type of transform; here, we add multiple `AudioWaveformTransform`s:
202
+ ```yaml
203
+ musicaugment:
204
+ samples_path: ${MUSIC_PATH}
205
+ snr_min: 5
206
+ snr_max: 20
207
+ rate: 0.25
208
+ backgroundnoiseaugment:
209
+ samples_path: ${NOISES_PATH}
210
+ snr_min: 10
211
+ snr_max: 20
212
+ rate: 0.1
213
+ sporadicnoiseaugment:
214
+ samples_path: ${NOISES_PATH}
215
+ snr_min: 5
216
+ snr_max: 15
217
+ rate: 0.1
218
+ noise_rate: 0.25
219
+ waveform_transforms:
220
+ _train:
221
+ - musicaugment
222
+ - backgroundnoiseaugment
223
+ - sporadicnoiseaugment
224
+ ```
225
+
226
+ ## Adding your own transforms
227
+ Note: We store transform implementations in `fairseq/data/audio/*_transforms` directories. You may refer to these as examples while implementing your own transform.
228
+
229
+ ### Step 1. Picking the right class for your transform
230
+ The integration into SpeechToTextDataset is quite different for each kind of transform, so it is important to understand which one is best suited to your purposes.
231
+
232
+ **Feature transforms**
233
+ `AudioFeatureTransform` is a base class which allows **some transform to be applied to audio spectrograms** in the data loading step. One thing to note is that the source data is either saved as `np.ndarrays` or as audio files, and is to be returned either as features (spectrogram) or waveform. If and only if the data is to be returned as a spectrogram, then `AudioFeatureTransform`s will be applied.
234
+
235
+ **Waveform transforms**
236
+ `AudioWaveformTransform` is a base class which allows some **transform to be applied to waveforms** in the data loading step. As mentioned above, there are two source and return types to data loading for this dataset. If and only if the data is saved in audio file format, then `AudioWaveformTransform`s will be applied, whichever return type is used.
237
+
238
+ **Dataset transforms**
239
+ `AudioDatasetTransform` is a base class for transforms **based on more than one item in a dataset**, ex. concatenation of two random samples in a dataset. Rather than being applied in a consistent way, i.e., to all features or to all waveforms, the integration of a dataset transform is entirely specific. Adding a dataset transform requires actually editing the `fairseq/data/audio/speech_to_text_dataset.py` file.
240
+
241
+ ### Step 2. Setting up your transform (generic to all types of transforms)
242
+ Now that you know which kind of transform you would like to use, we are ready to implement it. This step is generic for all transform types, i.e., `TRANSFORM_TYPE` may be any of `feature`, `waveform`, or `dataset`. We will show how to build utterance concatenation (an `AudioDatasetTransform`) as an example.
243
+
244
+ Import the base class and registration function for your transform.
245
+ ```python
246
+ from fairseq.data.audio.dataset_transforms import (
247
+ AudioDatasetTransform,
248
+ register_audio_dataset_transform
249
+ )
250
+ ```
251
+
252
+ Define the class and register the transform. The name passed into the registration function is how your transform should be named in the config.
253
+ ```python
254
+ @register_audio_dataset_transform("concataugment")
255
+ class ConcatAugment(AudioDatasetTransform):
256
+ ```
257
+
258
+ We are now ready to add the basic important functions to our new class. In this example, `_DEFAULTS` refers to a dictionary with the default hyperparameter values that we defined. `from_config_dict` is called to instantiate the transform given hyperparameters from the config.
259
+ ```python
260
+ @classmethod
261
+ def from_config_dict(cls, config=None):
262
+ _config = {} if config is None else config
263
+ return ConcatAugment(
264
+ _config.get("rate", _DEFAULTS["rate"]),
265
+ _config.get("max_tokens", _DEFAULTS["max_tokens"]),
266
+ _config.get("attempts", _DEFAULTS["attempts"]),
267
+ )
268
+ ```
269
+ We edit the instantiation function `__init__` to track hyperparameters and do any setup work.
270
+ ```python
271
+ def __init__(
272
+ self,
273
+ rate=_DEFAULTS["rate"],
274
+ max_tokens=_DEFAULTS["max_tokens"],
275
+ attempts=_DEFAULTS["attempts"],
276
+ ):
277
+ self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
278
+ ```
279
+ Lastly `__repr__` gives how the transform will be reported in an output log.
280
+ ```python
281
+ def __repr__(self):
282
+ return (
283
+ self.__class__.__name__
284
+ + "("
285
+ + ", ".join(
286
+ [
287
+ f"rate={self.rate}",
288
+ f"max_tokens={self.max_tokens}",
289
+ f"attempts={self.attempts}",
290
+ ]
291
+ )
292
+ + ")"
293
+ )
294
+ ```
295
+
296
+ ### Step 3. Adding the transform logic
297
+ At this point, we are ready to implement the actual transform logic. The flow from here is different for each of the three transforms, so follow the path that is relevant to you.
298
+ ### ...for feature transforms
299
+ The final step is implementing the `__call__` function, which applies the transform logic and **returns** the spectrogram with transform applied. This supports and should take exactly **two arguments**:
300
+ - `self`
301
+ - `x` (np.ndarray): the spectrogram for one source sample. (This is a positional argument, so you can use another parameter name like `spectrogram` instead of `x`.)
302
+
303
+ For example, this is the `__call__` function for GlobalCMVN (cepstral mean and variance normalization).
304
+ ```python
305
+ def __call__(self, x):
306
+ x = np.subtract(x, self.mean)
307
+ x = np.divide(x, self.std)
308
+ return x
309
+
310
+ ```
311
+ ### ...for waveform transforms
312
+ The final step is implementing the `__call__` function, which applies the transform logic. This supports and should take exactly **three arguments**:
313
+ - `self`
314
+ - `source` (numpy.ndarray or torch.Tensor): source audio 2d waveform (channels x length)
315
+ - `sample_rate` (optional, defaults to None): sample rate of `source`
316
+
317
+ `__call__` **returns**:
318
+ - transformed audio waveform
319
+ - sample rate of transformed audio waveform
320
+
321
+ For example, this is the `__call__` function for augmentations in the Noise Augmentation Suite.
322
+ ```python
323
+ def __call__(self, source, sample_rate=None):
324
+ if np.random.random() > self.rate:
325
+ return source
326
+
327
+ noise = self._get_noise(
328
+ source.shape, always_2d=True, use_sample_rate=sample_rate
329
+ )
330
+ return self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)), sample_rate
331
+ ```
332
+
333
+ ### ...for dataset transforms
334
+ Dataset transforms are extremely flexible, and implementation involves directly integrating them into `fairseq/data/audio/speech_to_text_dataset.py` in transform-specific ways.
335
+ There are two basic components: (1) check whether or not this transform is part of this dataset instance using `self.dataset_transforms.has_transform(TRANSFORM_CLS)`, and (2) if so, get the transform using `self.dataset_transforms.get_transform(TRANSFORM_CLS)` & apply it.
336
+ Due to the case-by-case specificity, it is easier to demonstrate this by examples.
337
+
338
+ #### Example: NoisyOverlapAugment
339
+ This transform requires access to multiple items within the same batch at once.
340
+
341
+ **Logic**: We still use the transform classes to keep away the transform logic. For example, `__call__` of `NoisyOverlapAugment` class takes a list of source tokens for items in a mini-batch, applies noise/utterance as dictated by the transform, and returns the list of transformed source tokens for items in the mini-batch.
342
+
343
+ ```python
344
+ def __call__(self, sources):
345
+ for i, source in enumerate(sources):
346
+ if np.random.random() > self.rate:
347
+ continue
348
+
349
+ pri = source.numpy()
350
+
351
+ # ... some transform code omitted
352
+
353
+ pri[s_source : s_source + l] = np.add(
354
+ pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
355
+ )
356
+ sources[i] = torch.from_numpy(pri).float()
357
+
358
+ return sources
359
+ ```
360
+
361
+ **Integration**: The `collater` function for `SpeechToTextDataset` is responsible for preparing a mini-batch for training, so we integrate NOAug through adding a few lines to the top of this function:
362
+ ```python
363
+ def collater(
364
+ self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
365
+ ) -> Dict:
366
+ if len(samples) == 0:
367
+ return {}
368
+ indices = torch.tensor([x.index for x in samples], dtype=torch.long)
369
+
370
+ sources = [x.source for x in samples]
371
+
372
+ # NOAUG INTEGRATION BLOCK
373
+ # (1) Check whether or not this transform is part of this dataset instance
374
+ has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
375
+ # (2) If so, get & apply the transform
376
+ if has_NOAug and self.cfg.use_audio_input:
377
+ NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
378
+ sources = NOAug(sources)
379
+
380
+ frames = _collate_frames(sources, self.cfg.use_audio_input)
381
+ # sort samples by descending number of frames
382
+ n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
383
+ n_frames, order = n_frames.sort(descending=True)
384
+ indices = indices.index_select(0, order)
385
+ frames = frames.index_select(0, order)
386
+
387
+ # ... rest of function
388
+ ```
389
+
390
+ #### Example: ConcatAugment
391
+ This transform requires access to another item within the dataset at once.
392
+
393
+ **Logic**: We abstract the logic for picking indices to concatenate by adding a `find_indices` function to the `ConcatAugment` class, which takes one index in the dataset and finds a compatible second index to concatenate source and target tokens.
394
+ ```python
395
+ def find_indices(self, index: int, n_frames: List[int], n_samples: int):
396
+ # skip conditions: application rate, max_tokens limit exceeded
397
+ if np.random.random() > self.rate:
398
+ return [index]
399
+ if self.max_tokens and n_frames[index] > self.max_tokens:
400
+ return [index]
401
+
402
+ # pick second sample to concatenate
403
+ for _ in range(self.attempts):
404
+ index2 = np.random.randint(0, n_samples)
405
+ if index2 != index and (
406
+ not self.max_tokens
407
+ or n_frames[index] + n_frames[index2] < self.max_tokens
408
+ ):
409
+ return [index, index2]
410
+
411
+ return [index]
412
+ ```
413
+
414
+ **Integration**: `SpeechToTextDataset` uses a custom `__getitem__(self, index)` function (called in the background when you write `dataset[i]`). We edited this function (as well as `_get_source_audio` and `get_tokenized_tgt_text`) to achieve the desired transform effect where accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]+source[j]` and `target=target[i]+target[j]`.
415
+ ```python
416
+ def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
417
+
418
+ # CONCATAUGMENT INTEGRATION BLOCK
419
+ # (1) Check whether or not this transform is part of this dataset instance
420
+ has_concat = self.dataset_transforms.has_transform(ConcatAugment)
421
+ # (2) If so, get & apply the transform
422
+ if has_concat:
423
+ concat = self.dataset_transforms.get_transform(ConcatAugment)
424
+ indices = concat.find_indices(index, self.n_frames, self.n_samples)
425
+
426
+ source = self._get_source_audio(indices if has_concat else index)
427
+ source = self.pack_frames(source)
428
+
429
+ target = None
430
+ if self.tgt_texts is not None:
431
+ tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
432
+ target = self.tgt_dict.encode_line(
433
+
434
+ # ... rest of function
435
+ ```
fairseq/examples/speech_to_speech/docs/direct_s2st_discrete_units.md ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Direct speech-to-speech translation with discrete units
2
+
3
+ We provide the implementation for speech-to-unit translation (S2UT) proposed in "[Direct speech-to-speech translation with discrete units (Lee et al. 2021)](https://arxiv.org/abs/2107.05604)" and also the transformer-based implementation of the speech-to-spectrogram translation (S2SPECT, or transformer-based [Translatotron](https://arxiv.org/abs/1904.06037)) baseline in the paper.
4
+
5
+ ## Pretrained Models
6
+
7
+ ### Unit-based HiFi-GAN Vocoder
8
+ Unit config | Unit size | Vocoder dataset | Model
9
+ |---|---|---|---
10
+ [HuBERT Base, Librispeech](https://github.com/fairinternal/fairseq-py/tree/main/examples/hubert), layer 6 | 100 | [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/hubert_base_100_lj/g_00500000), [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/hubert_base_100_lj/config.json)
11
+
12
+
13
+ ## Data preparation
14
+ ### Target speech
15
+ 0. (optional) To prepare S2S data from a speech-to-text translation (ST) dataset, see [fairseq-S^2](https://github.com/pytorch/fairseq/tree/main/examples/speech_synthesis) for pre-trained TTS models and instructions on how to train and decode TTS models.
16
+ 1. Prepare two folders, `$SRC_AUDIO` and `$TGT_AUDIO`, with `${SPLIT}/${SAMPLE_ID}.wav` for source and target speech under each folder, separately. Note that for S2UT experiments, target audio sampling rate should be in 16,000 Hz, and for S2SPECT experiments, target audio sampling rate is recommended to be in 22,050 Hz.
17
+ 2. To prepare target discrete units for S2UT model training, see [Generative Spoken Language Modeling (speech2unit)](https://github.com/pytorch/fairseq/tree/main/examples/textless_nlp/gslm/speech2unit) for pre-trained k-means models, checkpoints, and instructions on how to decode units from speech. Set the output target unit files (`--out_quantized_file_path`) as `${TGT_AUDIO}/${SPLIT}.txt`. In [Lee et al. 2021](https://arxiv.org/abs/2107.05604), we use 100 units from the sixth layer (`--layer 6`) of the HuBERT Base model.
18
+
19
+ ### Formatting data
20
+ **Speech-to-speech data**
21
+
22
+ _S2UT_
23
+ * Set `--reduce-unit` for training S2UT _reduced_ model
24
+ * Pre-trained vocoder and config (`$VOCODER_CKPT`, `$VOCODER_CFG`) can be downloaded from the **Pretrained Models** section. They are not required if `--eval-inference` is not going to be set during model training.
25
+ ```
26
+ # $SPLIT1, $SPLIT2, etc. are split names such as train, dev, test, etc.
27
+
28
+ python examples/speech_to_speech/preprocessing/prep_s2ut_data.py \
29
+ --source-dir $SRC_AUDIO --target-dir $TGT_AUDIO --data-split $SPLIT1 $SPLIT2 \
30
+ --output-root $DATA_ROOT --reduce-unit \
31
+ --vocoder-checkpoint $VOCODER_CKPT --vocoder-cfg $VOCODER_CFG
32
+ ```
33
+
34
+ _S2SPECT_
35
+ ```
36
+ # $SPLIT1, $SPLIT2, etc. are split names such as train, dev, test, etc.
37
+
38
+ python examples/speech_to_speech/preprocessing/prep_s2spect_data.py \
39
+ --source-dir $SRC_AUDIO --target-dir $TGT_AUDIO --data-split $SPLIT1 $SPLIT2 \
40
+ --output-root $DATA_ROOT
41
+ ```
42
+
43
+ **Multitask data**
44
+ * For each multitask `$TASK_NAME`, prepare `${DATA_ROOT}/${TASK_NAME}/${SPLIT}.tsv` files for each split following the format below: (Two tab separated columns. The sample_ids should match with the sample_ids for the speech-to-speech data in `${DATA_ROOT}/${SPLIT}.tsv`.)
45
+ ```
46
+ id tgt_text
47
+ sample_id_0 token1 token2 token3 ...
48
+ sample_id_1 token1 token2 token3 ...
49
+ ...
50
+ ```
51
+ * For each multitask `$TASK_NAME`, prepare `${DATA_ROOT}/${TASK_NAME}/dict.txt`, a dictionary in fairseq format with all tokens for the targets for `$TASK_NAME`.
52
+ * Create `config_multitask.yaml`. Below is an example of the config used for S2UT _reduced_ with Fisher experiments including two encoder multitasks (`source_letter`, `target_letter`) and one decoder CTC task (`decoder_target_ctc`).
53
+ ```
54
+ source_letter: # $TASK_NAME
55
+ decoder_type: transformer
56
+ dict: ${DATA_ROOT}/source_letter/dict.txt
57
+ data: ${DATA_ROOT}/source_letter
58
+ encoder_layer: 6
59
+ loss_weight: 8.0
60
+ target_letter:
61
+ decoder_type: transformer
62
+ dict: ${DATA_ROOT}/target_letter/dict.txt
63
+ data: ${DATA_ROOT}/target_letter
64
+ encoder_layer: 8
65
+ loss_weight: 8.0
66
+ decoder_target_ctc:
67
+ decoder_type: ctc
68
+ dict: ${DATA_ROOT}/decoder_target_ctc/dict.txt
69
+ data: ${DATA_ROOT}/decoder_target_ctc
70
+ decoder_layer: 3
71
+ loss_weight: 1.6
72
+ ```
73
+
74
+
75
+ ## Training
76
+
77
+ **Speech-to-unit translation (S2UT)**
78
+
79
+ Here's an example for training Fisher S2UT models with 100 discrete units as target:
80
+ ```
81
+ fairseq-train $DATA_ROOT \
82
+ --config-yaml config.yaml --multitask-config-yaml config_multitask.yaml \
83
+ --task speech_to_speech --target-is-code --target-code-size 100 --vocoder code_hifigan \
84
+ --criterion speech_to_unit --label-smoothing 0.2 \
85
+ --arch s2ut_transformer_fisher --share-decoder-input-output-embed \
86
+ --dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \
87
+ --train-subset train --valid-subset dev \
88
+ --save-dir ${MODEL_DIR} \
89
+ --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-init-lr 1e-7 --warmup-updates 10000 \
90
+ --optimizer adam --adam-betas "(0.9,0.98)" --clip-norm 10.0 \
91
+ --max-update 400000 --max-tokens 20000 --max-target-positions 3000 --update-freq 4 \
92
+ --seed 1 --fp16 --num-workers 8
93
+ ```
94
+ * Adjust `--update-freq` accordingly for different #GPUs. In the above we set `--update-freq 4` to simulate training with 4 GPUs.
95
+ * Set `--n-frames-per-step 5` to train an S2UT _stacked_ system with reduction ratio r=5. (Use `$DATA_ROOT` prepared without `--reduce-unit`.)
96
+ * (optional) one can turn on tracking MCD loss during training for checkpoint selection by setting `--eval-inference --eval-args '{"beam": 1, "max_len_a": 1}' --best-checkpoint-metric mcd_loss`. It is recommended to sample a smaller subset as the validation set as MCD loss computation is time-consuming.
97
+
98
+ **Speech-to-spectrogram translation (S2SPECT)**
99
+
100
+ Here's an example for training Fisher S2SPECT models with reduction ratio r=5:
101
+ ```
102
+ fairseq-train $DATA_ROOT \
103
+ --config-yaml config.yaml --multitask-config-yaml config_multitask.yaml \
104
+ --task speech_to_speech --n-frames-per-step 5 \
105
+ --criterion speech_to_spectrogram \
106
+ --arch s2spect_transformer_fisher --decoder-normalize-before \
107
+ --dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \
108
+ --train-subset train --valid-subset dev \
109
+ --save-dir ${MODEL_DIR} \
110
+ --eval-inference --best-checkpoint-metric mcd_loss \
111
+ --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-init-lr 1e-7 --warmup-updates 10000 \
112
+ --optimizer adam --adam-betas "(0.9,0.98)" --clip-norm 10.0 --weight-decay 1e-6 \
113
+ --max-update 400000 --max-tokens 80000 --max-tokens-valid 30000 --required-batch-size-multiple 1 \
114
+ --max-target-positions 3000 --update-freq 16 \
115
+ --seed 1 --fp16 --num-workers 8
116
+ ```
117
+ * Adjust `--update-freq` accordingly for different #GPUs. In the above we set `--update-freq 16` to simulate training with 16 GPUs.
118
+ * We recommend turning on MCD loss during training for the best checkpoint selection.
119
+
120
+ **Unit-based HiFi-GAN vocoder**
121
+
122
+ The vocoder is trained with the [speech-resynthesis repo](https://github.com/facebookresearch/speech-resynthesis). See [here](https://github.com/facebookresearch/speech-resynthesis/tree/main/examples/speech_to_speech_translation) for instructions on how to train the unit-based HiFi-GAN vocoder with duration prediction. The same vocoder can support waveform generation for both _reduced_ unit sequences (with `--dur-prediction` set during inference) and original unit sequences.
123
+
124
+ ## Inference
125
+
126
+ **Speech-to-unit translation (S2UT)**
127
+
128
+ 1. Follow the same inference process as in [fairseq-S2T](https://github.com/pytorch/fairseq/tree/main/examples/speech_to_text) to generate unit sequences (`${RESULTS_PATH}/generate-${GEN_SUBSET}.txt`).
129
+ ```
130
+ fairseq-generate $DATA_ROOT \
131
+ --config-yaml config.yaml --multitask-config-yaml config_multitask.yaml \
132
+ --task speech_to_speech --target-is-code --target-code-size 100 --vocoder code_hifigan \
133
+ --path $MODEL_DIR/checkpoint_best.pt --gen-subset $GEN_SUBSET \
134
+ --max-tokens 50000 \
135
+ --beam 10 --max-len-a 1 \
136
+ --results-path ${RESULTS_PATH}
137
+ ```
138
+ * Set `--beam 1 --n-frames-per-step $r` for decoding with S2UT _stacked_ models.
139
+
140
+ 2. Convert unit sequences to waveform.
141
+ ```
142
+ grep "^D\-" ${RESULTS_PATH}/generate-${GEN_SUBSET}.txt | \
143
+ sed 's/^D-//ig' | sort -nk1 | cut -f3 \
144
+ > ${RESULTS_PATH}/generate-${GEN_SUBSET}.unit
145
+
146
+ python examples/speech_to_speech/generate_waveform_from_code.py \
147
+ --in-code-file ${RESULTS_PATH}/generate-${GEN_SUBSET}.unit \
148
+ --vocoder $VOCODER_CKPT --vocoder-cfg $VOCODER_CFG \
149
+ --results-path ${RESULTS_PATH} --dur-prediction
150
+ ```
151
+ * Set `--dur-prediction` for generating audio for S2UT _reduced_ models.
152
+
153
+
154
+ **Speech-to-spectrogram translation (S2SPECT)**
155
+
156
+ Follow the same inference process as in [fairseq-S^2](https://github.com/pytorch/fairseq/tree/main/examples/speech_synthesis) to generate waveform.
157
+
158
+ ```
159
+ # assume using a default Griffin-Lim vocoder
160
+
161
+ python examples/speech_synthesis/generate_waveform.py $DATA_ROOT \
162
+ --config-yaml config.yaml --multitask-config-yaml config_multitask.yaml \
163
+ --task speech_to_speech --n-frames-per-step 5 \
164
+ --path $MODEL_DIR/checkpoint_best.pt --gen-subset $GEN_SUBSET \
165
+ --max-tokens 50000 \
166
+ --results-path ${RESULTS_PATH} --dump-waveforms --output-sample-rate 16000
167
+ ```
168
+
169
+ In addition to using the default Griffin-Lim vocoder, one can also finetune a HiFi-GAN vocoder for the S2SPECT model by following the instructions in the [HiFi-GAN repo](https://github.com/jik876/hifi-gan).
170
+
171
+ **Multitask decoding**
172
+
173
+ Coming soon.
174
+
175
+ ## Evaluation
176
+
177
+ To evaluate speech translation output, we first apply ASR on the speech output and then compute BLEU score betweent the ASR decoded text and the references using sacreBLEU.
178
+
179
+ **En**
180
+ * ASR: We use the "[Wav2Vec 2.0 Large (LV-60) + Self Training / 960 hours / Libri-Light + Librispeech](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt)" En ASR model open-sourced by the [wav2vec](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec) project. See [instructions](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#evaluating-a-ctc-model) on how to run inference with a wav2vec-based ASR model. The model is also available on [Hugging Face](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
181
+ * Text normalization: We use the text cleaner at [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron) for pre-processing reference English text for ASR BLEU evaluation.
fairseq/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Speech to speech translation (S2ST)
2
+
3
+ We provide the implementation for speech-to-unit translation (S2UT) proposed in [Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation (Popuri et al. 2022)](https://arxiv.org/abs/2204.02967) and the various pretrained models used.
4
+
5
+ ## Pretrained Models
6
+
7
+ ### Unit extraction
8
+
9
+ We used the multilingual HuBERT model open sourced in [Textless S2ST with Real Data](textless_s2st_real_data.md)
10
+
11
+ ### Wav2vec 2.0
12
+
13
+ Language | Block type | Model size | Dataset | Model |
14
+ --- | --- | --- | --- | --- |
15
+ Es | Transformer | BASE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_B.pt) |
16
+ Es | Transformer | LARGE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_L.pt) |
17
+ Es | Conformer | LARGE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/conformer_L.pt) |
18
+ En | Transformer | BASE | Librilight| [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/transformer_B.pt) |
19
+ En | Conformer | LARGE | Librilight | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/conformer_L.pt) |
20
+
21
+ ### Unit mBART
22
+
23
+ Unit size | Dataset | Unit config | Model |
24
+ --- | --- | --- | --- |
25
+ 1000 | [Voxpopuli](https://aclanthology.org/2021.acl-long.80) En, Es unlabelled speech | [mbart_large](https://github.com/pytorch/fairseq/blob/f591cc94caa85098ccf125a4782f91125b6a086d/fairseq/models/bart/model.py#L368) |[ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/unit_mBART/checkpoint.pt) |
26
+
27
+ ## Data preparation
28
+
29
+ 1. To prepare data for S2UT finetuning, follow the steps from [Direct S2ST with Discrete Units](./direct_s2st_discrete_units.md) and format the data in the _S2UT_ format. Note that we use 1000 units from the eleventh layer (`--layer 11`) of the multilingual hubert model linked above instead
30
+ 2. Run
31
+
32
+ ```
33
+ var="id\taudio\tn_frames\ttgt_text\ttgt_n_frames"
34
+ sed -i "1s/.*/$var/" ${SPLIT}.tsv
35
+ ```
36
+
37
+ ## Training
38
+
39
+ **Speech-to-unit translation (S2UT)**
40
+
41
+ Here's an example for finetuning S2UT models with 1000 discrete units as target. You can download the sample [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/config.yaml) file and [vocabulary](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/dict.txt) for Es-En from here:
42
+
43
+ ```
44
+ fairseq-train $DATA_ROOT \
45
+ --config-yaml config.yaml \
46
+ --task speech_to_text --arch xm_transformer\
47
+ --criterion l --label-smoothing 0.2 \
48
+ --share-decoder-input-output-embed --adaptor-n-layers 1 --normalize\
49
+ --dropout 0.1 --attention-dropout 0.1 --relu-dropout 0.1 \
50
+ --train-subset train --valid-subset dev \
51
+ --load-pretrained-decoder-from ${unit_mBART} --w2v-path ${wav2vec2.0} \
52
+ --mask-prob 0.3 --mask-channel-length 32 --mask-channel-prob 0.25\
53
+ --save-dir ${MODEL_DIR} --checkpoint-activations --encoder-proj \
54
+ --lr 0.0005 --dropout 0.1 --attention-dropout 0.1 --lr-scheduler inverse_sqrt\
55
+ --warmup-init-lr 1e-7 --warmup-updates 10000 \
56
+ --optimizer adam --adam-betas "(0.9,0.98)" --clip-norm 10.0 \
57
+ --max-update 20000 --max-tokens 4000 --max-tokens-valid 4000 --max-source-positions 4000 \
58
+ --max-target-positions 4000 --update-freq 120 \
59
+ --seed 1 --fp16 --num-workers 1
60
+ ```
61
+
62
+ * Adjust `--update-freq` accordingly for different #GPUs. In the above we set `--update-freq 15` to simulate training with 120 GPUs.
63
+ * In the above setting we finetune the model end to end, corresponding to the full setup in the paper.
64
+ * To apply LNA-E partial finetuning, add `--finetune-w2v-params layer_norm,self_attn`
65
+ * For LNA-D partial finetuning add `--finetune-decoder-params encoder_attn,layer_norm,self_attn`. To optionally freeze the encoder by k updates, use `--freeze-finetune-updates ${K}`
66
+ * For LNA-E,D partial finetuning add both the above options.
67
+
68
+ **Unit-based HiFi-GAN vocoder**
69
+
70
+ We apply the open-sourced unit-based HiFi-GAN vocoders to convert the predicted unit sequences to waveform. They are open sourced in [Textless S2ST with Real Data](textless_s2st_real_data.md)
71
+
72
+ ## Inference
73
+
74
+ **Speech-to-unit translation (S2UT)**
75
+
76
+ 1. Follow the same inference process as in [fairseq-S2T](https://github.com/pytorch/fairseq/tree/main/examples/speech_to_text) to generate unit sequences (`${RESULTS_PATH}/generate-${GEN_SUBSET}.txt`).
77
+
78
+ ```
79
+ fairseq-generate $DATA_ROOT \
80
+ --config-yaml config.yaml \
81
+ --task speech_to_text \
82
+ --path $MODEL_DIR/checkpoint_best.pt --gen-subset $GEN_SUBSET \
83
+ --max-tokens 10000 --max-source-positions 10000 --max-target-positions 10000\
84
+ --beam 10 --max-len-a 1 --max-len-b 200 \
85
+ --results-path ${RESULTS_PATH}
86
+ ```
87
+
88
+ 2. Convert unit sequences to waveform.
89
+
90
+ ```
91
+ grep "^D\-" ${RESULTS_PATH}/generate-${GEN_SUBSET}.txt | \
92
+ sed 's/^D-//ig' | sort -nk1 | cut -f3 \
93
+ > ${RESULTS_PATH}/generate-${GEN_SUBSET}.unit
94
+
95
+ python examples/speech_to_speech/generate_waveform_from_code.py \
96
+ --in-code-file ${RESULTS_PATH}/generate-${GEN_SUBSET}.unit \
97
+ --vocoder $VOCODER_CKPT --vocoder-cfg $VOCODER_CFG \
98
+ --results-path ${RESULTS_PATH} --dur-prediction
99
+ ```
100
+
101
+ ## Evaluation
102
+
103
+ To evaluate speech translation output, we first apply ASR on the speech output and then compute BLEU score betweent the ASR decoded text and the references using sacreBLEU.
104
+
105
+ * Text normalization: We use the text cleaner at [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron) for pre-processing reference English text for ASR BLEU evaluation. The text cleaner used for Spanish text normalization will be updated here shortly.
106
+ * En ASR: We use the "[Wav2Vec 2.0 Large (LV-60) + Self Training / 960 hours / Libri-Light + Librispeech](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt)" En ASR model open-sourced by the [wav2vec](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec) project. The model is also available on [Hugging Face](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
107
+ * Es ASR: We use the [Wav2Vec2-Large-XLSR-53-Spanish](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) finetuned on spanish Common Voice Es ASR model open-sourced by Jonatasgrosman(<https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish>) on [Hugging Face](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish).
108
+ * See [instructions](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#evaluating-a-ctc-model) on how to run inference with a wav2vec-based ASR model.
109
+
110
+
111
+ ## Finetuned Model Checkpoints
112
+
113
+ ID | En - Es | Es - En |
114
+ | --- | --- | --- |
115
+ **S2UT systems without pre-training**
116
+ S2UT with multitask | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//S2UT_w_multitask.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//S2UT_w_multitask.pt) |
117
+ **S2UT systems with model pre-training**
118
+ w2v2-L | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_only.pt ) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_only.pt) |
119
+ w2v2-L + mBART (LNA-E) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_mbart_LNE.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_mbart_LNE.pt) |
120
+ w2v2-L + mBART (LNA-D) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_mbart_LND.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_mbart_LND.pt) |
121
+ w2v2-L + mBART (LNA-E,D) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_mbart_LNED.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_mbart_LNED.pt) |
122
+ **S2UT systems with model pre-training and data augmentation**
123
+ w2v2-L + mBART (LNA-D) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/en_es//w2v2_mbart_LND_w_ASR.pt) | [checkpoint](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/es_en//w2v2_mbart_LND_w_ASR.pt) |
124
+
125
+ Note: Some of the tasks use speech_to_text_sharded task which is yet to be open sourced. So make sure to override the task to speech_to_text to use those models.
fairseq/examples/speech_to_speech/docs/textless_s2st_real_data.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Textless Speech-to-Speech Translation (S2ST) on Real Data
2
+
3
+ We provide instructions and pre-trained models for the work "[Textless Speech-to-Speech Translation on Real Data (Lee et al. 2021)](https://arxiv.org/abs/2112.08352)".
4
+
5
+ ## Pre-trained Models
6
+
7
+ ### HuBERT
8
+ Model | Pretraining Data | Model | Quantizer
9
+ |---|---|---|---
10
+ mHuBERT Base | [VoxPopuli](https://github.com/facebookresearch/voxpopuli) En, Es, Fr speech from the 100k subset | [download](https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3.pt) | [L11 km1000](https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3_L11_km1000.bin)
11
+
12
+
13
+ ### Unit-based HiFi-GAN vocoder
14
+ Unit config | Unit size | Vocoder language | Dataset | Model
15
+ |---|---|---|---|---
16
+ mHuBERT, layer 11 | 1000 | En | [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000), [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json)
17
+ mHuBERT, layer 11 | 1000 | Es | [CSS10](https://github.com/Kyubyong/css10) | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_es_css10/g_00500000), [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_es_css10/config.json)
18
+ mHuBERT, layer 11 | 1000 | Fr | [CSS10](https://github.com/Kyubyong/css10) | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_fr_css10/g_00500000), [config](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_fr_css10/config.json)
19
+
20
+
21
+ ### Speech normalizer
22
+ Language | Training data | Target unit config | Model
23
+ |---|---|---|---
24
+ En | 10 mins | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/en/en_10min.tar.gz)
25
+ En | 1 hr | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/en/en_1h.tar.gz)
26
+ En | 10 hrs | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/en/en_10h.tar.gz)
27
+ Es | 10 mins | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/es/es_10min.tar.gz)
28
+ Es | 1 hr | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/es/es_1h.tar.gz)
29
+ Es | 10 hrs | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/es/es_10h.tar.gz)
30
+ Fr | 10 mins | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/fr/fr_10min.tar.gz)
31
+ Fr | 1 hr | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/fr/fr_1h.tar.gz)
32
+ Fr | 10 hrs | mHuBERT, layer 11, km1000 | [download](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/speech_normalizer/fr/fr_10h.tar.gz)
33
+
34
+ * Refer to the paper for the details of the training data.
35
+
36
+ ## Inference with Pre-trained Models
37
+
38
+ ### Speech normalizer
39
+ 1. Download the pre-trained models, including the dictionary, to `DATA_DIR`.
40
+ 2. Format the audio data.
41
+ ```bash
42
+ # AUDIO_EXT: audio extension, e.g. wav, flac, etc.
43
+ # Assume all audio files are at ${AUDIO_DIR}/*.${AUDIO_EXT}
44
+
45
+ python examples/speech_to_speech/preprocessing/prep_sn_data.py \
46
+ --audio-dir ${AUDIO_DIR} --ext ${AUIDO_EXT} \
47
+ --data-name ${GEN_SUBSET} --output-dir ${DATA_DIR} \
48
+ --for-inference
49
+ ```
50
+
51
+ 3. Run the speech normalizer and post-process the output.
52
+ ```bash
53
+ mkdir -p ${RESULTS_PATH}
54
+
55
+ python examples/speech_recognition/new/infer.py \
56
+ --config-dir examples/hubert/config/decode/ \
57
+ --config-name infer_viterbi \
58
+ task.data=${DATA_DIR} \
59
+ task.normalize=false \
60
+ common_eval.results_path=${RESULTS_PATH}/log \
61
+ common_eval.path=${DATA_DIR}/checkpoint_best.pt \
62
+ dataset.gen_subset=${GEN_SUBSET} \
63
+ '+task.labels=["unit"]' \
64
+ +decoding.results_path=${RESULTS_PATH} \
65
+ common_eval.post_process=none \
66
+ +dataset.batch_size=1 \
67
+ common_eval.quiet=True
68
+
69
+ # Post-process and generate output at ${RESULTS_PATH}/${GEN_SUBSET}.txt
70
+ python examples/speech_to_speech/preprocessing/prep_sn_output_data.py \
71
+ --in-unit ${RESULTS_PATH}/hypo.units \
72
+ --in-audio ${DATA_DIR}/${GEN_SUBSET}.tsv \
73
+ --output-root ${RESULTS_PATH}
74
+ ```
75
+
76
+
77
+ ### Unit-to-waveform conversion with unit vocoder
78
+ The pre-trained vocoders can support generating audio for both full unit sequences and reduced unit sequences (i.e. duplicating consecutive units removed). Set `--dur-prediction` for generating audio with reduced unit sequences.
79
+ ```bash
80
+ # IN_CODE_FILE contains one unit sequence per line. Units are separated by space.
81
+
82
+ python examples/speech_to_speech/generate_waveform_from_code.py \
83
+ --in-code-file ${IN_CODE_FILE} \
84
+ --vocoder ${VOCODER_CKPT} --vocoder-cfg ${VOCODER_CFG} \
85
+ --results-path ${RESULTS_PATH} --dur-prediction
86
+ ```
87
+
88
+ ## Training new models
89
+ To be updated.
fairseq/examples/speech_to_speech/generate_waveform_from_code.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from pathlib import Path
10
+ import random
11
+ import soundfile as sf
12
+ import torch
13
+
14
+ from tqdm import tqdm
15
+
16
+ from fairseq import utils
17
+ from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
18
+
19
+
20
+ logging.basicConfig()
21
+ logging.root.setLevel(logging.INFO)
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def dump_result(args, sample_id, pred_wav, suffix=""):
27
+ sf.write(
28
+ f"{args.results_path}/{sample_id}{suffix}_pred.wav",
29
+ pred_wav.detach().cpu().numpy(),
30
+ 16000,
31
+ )
32
+
33
+
34
+ def load_code(in_file):
35
+ with open(in_file) as f:
36
+ out = [list(map(int, line.strip().split())) for line in f]
37
+ return out
38
+
39
+
40
+ def main(args):
41
+ logger.info(args)
42
+
43
+ use_cuda = torch.cuda.is_available() and not args.cpu
44
+
45
+ with open(args.vocoder_cfg) as f:
46
+ vocoder_cfg = json.load(f)
47
+ vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg)
48
+ if use_cuda:
49
+ vocoder = vocoder.cuda()
50
+
51
+ multispkr = vocoder.model.multispkr
52
+ if multispkr:
53
+ logger.info("multi-speaker vocoder")
54
+ num_speakers = vocoder_cfg.get(
55
+ "num_speakers", 200
56
+ ) # following the default in codehifigan to set to 200
57
+ assert (
58
+ args.speaker_id < num_speakers
59
+ ), f"invalid --speaker-id ({args.speaker_id}) with total #speakers = {num_speakers}"
60
+
61
+ data = load_code(args.in_code_file)
62
+ Path(args.results_path).mkdir(exist_ok=True, parents=True)
63
+ for i, d in tqdm(enumerate(data), total=len(data)):
64
+ x = {
65
+ "code": torch.LongTensor(d).view(1, -1),
66
+ }
67
+ suffix = ""
68
+ if multispkr:
69
+ spk = (
70
+ random.randint(0, num_speakers - 1)
71
+ if args.speaker_id == -1
72
+ else args.speaker_id
73
+ )
74
+ suffix = f"_spk{spk}"
75
+ x["spkr"] = torch.LongTensor([spk]).view(1, 1)
76
+
77
+ x = utils.move_to_cuda(x) if use_cuda else x
78
+ wav = vocoder(x, args.dur_prediction)
79
+ dump_result(args, i, wav, suffix=suffix)
80
+
81
+
82
+ def cli_main():
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument(
85
+ "--in-code-file", type=str, required=True, help="one unit sequence per line"
86
+ )
87
+ parser.add_argument(
88
+ "--vocoder", type=str, required=True, help="path to the CodeHiFiGAN vocoder"
89
+ )
90
+ parser.add_argument(
91
+ "--vocoder-cfg",
92
+ type=str,
93
+ required=True,
94
+ help="path to the CodeHiFiGAN vocoder config",
95
+ )
96
+ parser.add_argument("--results-path", type=str, required=True)
97
+ parser.add_argument(
98
+ "--dur-prediction",
99
+ action="store_true",
100
+ help="enable duration prediction (for reduced/unique code sequences)",
101
+ )
102
+ parser.add_argument(
103
+ "--speaker-id",
104
+ type=int,
105
+ default=-1,
106
+ help="Speaker id (for vocoder that supports multispeaker). Set to -1 to randomly sample speakers.",
107
+ )
108
+ parser.add_argument("--cpu", action="store_true", help="run on CPU")
109
+
110
+ args = parser.parse_args()
111
+
112
+ main(args)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ cli_main()
fairseq/examples/speech_to_speech/preprocessing/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
fairseq/examples/speech_to_speech/preprocessing/data_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from pathlib import Path
7
+ from typing import List, Optional
8
+
9
+ from examples.speech_to_text.data_utils import S2TDataConfigWriter
10
+
11
+
12
+ def gen_config_yaml(
13
+ manifest_root: Path,
14
+ yaml_filename: str = "config.yaml",
15
+ specaugment_policy: Optional[str] = "lb",
16
+ feature_transform: Optional[List[str]] = None,
17
+ input_channels: Optional[int] = 1,
18
+ input_feat_per_channel: Optional[int] = 80,
19
+ audio_root: str = "",
20
+ vocoder_type: Optional[str] = None,
21
+ vocoder_checkpoint: Optional[str] = None,
22
+ vocoder_cfg: Optional[str] = None,
23
+ extra=None,
24
+ ):
25
+ manifest_root = manifest_root.absolute()
26
+ writer = S2TDataConfigWriter(manifest_root / yaml_filename)
27
+
28
+ if input_channels is not None:
29
+ writer.set_input_channels(input_channels)
30
+ if input_feat_per_channel is not None:
31
+ writer.set_input_feat_per_channel(input_feat_per_channel)
32
+ specaugment_setters = {
33
+ "lb": writer.set_specaugment_lb_policy,
34
+ "ld": writer.set_specaugment_ld_policy,
35
+ "sm": writer.set_specaugment_sm_policy,
36
+ "ss": writer.set_specaugment_ss_policy,
37
+ }
38
+ specaugment_setter = specaugment_setters.get(specaugment_policy, None)
39
+ if specaugment_setter is not None:
40
+ specaugment_setter()
41
+
42
+ if feature_transform is None:
43
+ feature_transform = []
44
+ else:
45
+ writer.set_feature_transforms("*", feature_transform)
46
+
47
+ if specaugment_policy is not None:
48
+ writer.set_feature_transforms("_train", feature_transform + ["specaugment"])
49
+
50
+ if len(audio_root) > 0:
51
+ writer.set_audio_root(audio_root)
52
+
53
+ if (
54
+ vocoder_type is not None
55
+ and vocoder_checkpoint is not None
56
+ and vocoder_cfg is not None
57
+ ):
58
+ writer.set_extra(
59
+ {
60
+ "vocoder": {
61
+ "type": vocoder_type,
62
+ "config": vocoder_cfg,
63
+ "checkpoint": vocoder_checkpoint,
64
+ }
65
+ }
66
+ )
67
+
68
+ if extra is not None:
69
+ writer.set_extra(extra)
70
+ writer.flush()
71
+
72
+
73
+ def load_units(in_file):
74
+ out = {}
75
+ with open(in_file) as f:
76
+ for line in f:
77
+ sample_id, units = line.strip().split("|", 1)
78
+ out[sample_id] = units.split()
79
+
80
+ return out
81
+
82
+
83
+ def process_units(units, reduce=False):
84
+ if not reduce:
85
+ return units
86
+
87
+ out = [u for i, u in enumerate(units) if i == 0 or u != units[i - 1]]
88
+ return out
fairseq/examples/speech_to_speech/preprocessing/prep_s2spect_data.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ from pathlib import Path
11
+ import shutil
12
+ import torchaudio
13
+
14
+ import soundfile as sf
15
+ from tqdm import tqdm
16
+ import pandas as pd
17
+
18
+ from examples.speech_synthesis.data_utils import extract_logmel_spectrogram
19
+ from examples.speech_to_speech.preprocessing.data_utils import gen_config_yaml
20
+ from examples.speech_to_text.data_utils import create_zip, get_zip_manifest, save_df_to_tsv
21
+ from fairseq.data.audio.audio_utils import convert_waveform
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ MANIFEST_COLUMNS = ["id", "src_audio", "src_n_frames", "tgt_audio", "tgt_n_frames"]
27
+
28
+
29
+ def prepare_target_data(args, tgt_audios):
30
+ feature_name = "logmelspec80"
31
+ zip_path = args.output_root / f"{feature_name}.zip"
32
+ if zip_path.exists():
33
+ print(f"{zip_path} exists.")
34
+ return zip_path
35
+
36
+ feature_root = args.output_root / feature_name
37
+ feature_root.mkdir(exist_ok=True)
38
+
39
+ print("Extracting Mel spectrogram features...")
40
+ for tgt_audio in tqdm(tgt_audios):
41
+ sample_id = tgt_audio.stem
42
+ waveform, sample_rate = torchaudio.load(tgt_audio.as_posix())
43
+ waveform, sample_rate = convert_waveform(
44
+ waveform, sample_rate, normalize_volume=args.normalize_volume,
45
+ to_sample_rate=args.sample_rate
46
+ )
47
+ extract_logmel_spectrogram(
48
+ waveform, sample_rate, feature_root / f"{sample_id}.npy",
49
+ win_length=args.win_length, hop_length=args.hop_length,
50
+ n_fft=args.n_fft, n_mels=args.n_mels, f_min=args.f_min,
51
+ f_max=args.f_max
52
+ )
53
+ print("ZIPing features...")
54
+ create_zip(feature_root, zip_path)
55
+ shutil.rmtree(feature_root)
56
+
57
+ return zip_path
58
+
59
+
60
+ def process(args):
61
+ os.makedirs(args.output_root, exist_ok=True)
62
+
63
+ manifest = {}
64
+ tgt_audios = []
65
+ for split in args.data_split:
66
+ print(f"Processing {split}...")
67
+
68
+ manifest[split] = {c: [] for c in MANIFEST_COLUMNS}
69
+ missing_tgt_audios = []
70
+ src_audios = list(args.source_dir.glob(f"{split}/*.wav"))
71
+ for src_audio in tqdm(src_audios):
72
+ sample_id = src_audio.stem
73
+
74
+ tgt_audio = args.target_dir / split / f"{sample_id}.wav"
75
+ if not tgt_audio.is_file():
76
+ missing_tgt_audios.append(sample_id)
77
+ continue
78
+
79
+ tgt_audios.append(tgt_audio)
80
+
81
+ src_n_frames = sf.info(src_audio.as_posix()).frames
82
+ manifest[split]["id"].append(sample_id)
83
+ manifest[split]["src_audio"].append(src_audio.as_posix())
84
+ manifest[split]["src_n_frames"].append(
85
+ src_n_frames // 160
86
+ ) # estimation of 10-ms frame for 16kHz audio
87
+
88
+ print(f"Processed {len(manifest[split]['id'])} samples")
89
+ if len(missing_tgt_audios) > 0:
90
+ print(
91
+ f"{len(missing_tgt_audios)} with missing target data (first 3 examples: {', '.join(missing_tgt_audios[:3])})"
92
+ )
93
+
94
+ # Extract features and pack features into ZIP
95
+ zip_path = prepare_target_data(args, tgt_audios)
96
+
97
+ print("Fetching ZIP manifest...")
98
+ tgt_audio_paths, tgt_audio_lengths = get_zip_manifest(zip_path)
99
+
100
+ print("Generating manifest...")
101
+ for split in args.data_split:
102
+ print(f"Processing {split}...")
103
+
104
+ for sample_id in tqdm(manifest[split]["id"]):
105
+ manifest[split]["tgt_audio"].append(tgt_audio_paths[sample_id])
106
+ manifest[split]["tgt_n_frames"].append(tgt_audio_lengths[sample_id])
107
+
108
+ out_manifest = args.output_root / f"{split}.tsv"
109
+ print(f"Writing manifest to {out_manifest}...")
110
+ save_df_to_tsv(pd.DataFrame.from_dict(manifest[split]), out_manifest)
111
+
112
+ # Generate config YAML
113
+ win_len_t = args.win_length / args.sample_rate
114
+ hop_len_t = args.hop_length / args.sample_rate
115
+ extra = {
116
+ "features": {
117
+ "type": "spectrogram+melscale+log",
118
+ "sample_rate": args.sample_rate,
119
+ "eps": 1e-5, "n_mels": args.n_mels, "n_fft": args.n_fft,
120
+ "window_fn": "hann", "win_length": args.win_length,
121
+ "hop_length": args.hop_length,
122
+ "win_len_t": win_len_t, "hop_len_t": hop_len_t,
123
+ "f_min": args.f_min, "f_max": args.f_max,
124
+ "n_stft": args.n_fft // 2 + 1
125
+ }
126
+ }
127
+ gen_config_yaml(
128
+ args.output_root,
129
+ audio_root=args.output_root.as_posix(),
130
+ specaugment_policy="lb",
131
+ feature_transform=["utterance_cmvn", "delta_deltas"],
132
+ extra=extra,
133
+ )
134
+
135
+
136
+ def main():
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument(
139
+ "--source-dir", required=True, type=Path, help="source audio directory"
140
+ )
141
+ parser.add_argument(
142
+ "--target-dir", required=True, type=Path, help="target audio directory"
143
+ )
144
+ parser.add_argument(
145
+ "--data-split",
146
+ default=["train", "valid", "test"],
147
+ nargs="+",
148
+ help="data split names",
149
+ )
150
+ parser.add_argument(
151
+ "--output-root", required=True, type=Path, help="output directory"
152
+ )
153
+ # target feature related
154
+ parser.add_argument("--win-length", type=int, default=1024)
155
+ parser.add_argument("--hop-length", type=int, default=256)
156
+ parser.add_argument("--n-fft", type=int, default=1024)
157
+ parser.add_argument("--n-mels", type=int, default=80)
158
+ parser.add_argument("--f-min", type=int, default=20)
159
+ parser.add_argument("--f-max", type=int, default=8000)
160
+ parser.add_argument("--sample-rate", type=int, default=22050)
161
+ parser.add_argument("--normalize-volume", "-n", action="store_true")
162
+
163
+ args = parser.parse_args()
164
+
165
+ process(args)
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
fairseq/examples/speech_to_speech/preprocessing/prep_s2ut_data.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ import soundfile as sf
12
+ from tqdm import tqdm
13
+ import pandas as pd
14
+
15
+ from examples.speech_to_speech.preprocessing.data_utils import (
16
+ gen_config_yaml,
17
+ load_units,
18
+ process_units,
19
+ )
20
+ from examples.speech_to_text.data_utils import save_df_to_tsv
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ MANIFEST_COLUMNS = ["id", "src_audio", "src_n_frames", "tgt_audio", "tgt_n_frames"]
25
+
26
+
27
+ def process(args):
28
+ args.output_root.mkdir(exist_ok=True)
29
+
30
+ print("Generating manifest...")
31
+ for split in args.data_split:
32
+ print(f"Processing {split}")
33
+
34
+ # load target units
35
+ target_unit_data = load_units(args.target_dir / f"{split}.txt")
36
+
37
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
38
+ missing_tgt_audios = []
39
+ src_audios = list(args.source_dir.glob(f"{split}/*.wav"))
40
+ for src_audio in tqdm(src_audios):
41
+ sample_id = src_audio.stem
42
+
43
+ if sample_id not in target_unit_data:
44
+ missing_tgt_audios.append(sample_id)
45
+ continue
46
+
47
+ src_n_frames = sf.info(src_audio.as_posix()).frames
48
+ manifest["id"].append(sample_id)
49
+ manifest["src_audio"].append(src_audio.as_posix())
50
+ manifest["src_n_frames"].append(
51
+ src_n_frames // 160
52
+ ) # estimation of 10-ms frame for 16kHz audio
53
+
54
+ target_units = process_units(target_unit_data[sample_id], args.reduce_unit)
55
+ manifest["tgt_audio"].append(" ".join(target_units))
56
+ manifest["tgt_n_frames"].append(len(target_units))
57
+
58
+ print(f"Processed {len(manifest['id'])} samples")
59
+ if len(missing_tgt_audios) > 0:
60
+ print(
61
+ f"{len(missing_tgt_audios)} with missing target data (first 3 examples: {', '.join(missing_tgt_audios[:3])})"
62
+ )
63
+
64
+ out_manifest = args.output_root / f"{split}.tsv"
65
+ print(f"Writing manifest to {out_manifest}...")
66
+ save_df_to_tsv(pd.DataFrame.from_dict(manifest), out_manifest)
67
+
68
+ # Generate config YAML
69
+ gen_config_yaml(
70
+ args.output_root,
71
+ specaugment_policy="lb",
72
+ feature_transform=["utterance_cmvn"],
73
+ vocoder_type="code_hifigan",
74
+ vocoder_checkpoint=args.vocoder_checkpoint,
75
+ vocoder_cfg=args.vocoder_cfg,
76
+ )
77
+
78
+
79
+ def main():
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument(
82
+ "--source-dir", required=True, type=Path, help="source audio directory"
83
+ )
84
+ parser.add_argument(
85
+ "--target-dir", required=True, type=Path, help="target audio directory"
86
+ )
87
+ parser.add_argument(
88
+ "--data-split",
89
+ default=["train", "valid", "test"],
90
+ nargs="+",
91
+ help="data split names",
92
+ )
93
+ parser.add_argument(
94
+ "--output-root", required=True, type=Path, help="output directory"
95
+ )
96
+ parser.add_argument(
97
+ "--reduce-unit",
98
+ action="store_true",
99
+ help="reduce a target unit sequence to a unique unit sequence, i.e. '1 1 1 2 2' -> '1 2'",
100
+ )
101
+ parser.add_argument(
102
+ "--vocoder-checkpoint", default=None, type=str, help="vocoder checkpoint"
103
+ )
104
+ parser.add_argument(
105
+ "--vocoder-cfg", default=None, type=str, help="vocoder config file"
106
+ )
107
+
108
+ args = parser.parse_args()
109
+
110
+ process(args)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
fairseq/examples/speech_to_speech/preprocessing/prep_sn_data.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Adapted from examples/wav2vec/wav2vec_manifest.py
8
+ """
9
+ Data preparation for the speech normalizer
10
+ """
11
+
12
+ import argparse
13
+ import glob
14
+ import os
15
+
16
+ import soundfile
17
+
18
+ from examples.speech_to_speech.preprocessing.data_utils import load_units, process_units
19
+
20
+
21
+ def process(args):
22
+ assert (
23
+ args.for_inference or args.target_unit is not None
24
+ ), "missing --target-unit or --for-inference"
25
+
26
+ if not os.path.exists(args.output_dir):
27
+ os.makedirs(args.output_dir)
28
+
29
+ dir_path = os.path.realpath(args.audio_dir)
30
+ search_path = os.path.join(dir_path, "**/*." + args.ext)
31
+
32
+ if args.target_unit:
33
+ unit_data = load_units(args.target_unit)
34
+
35
+ with open(os.path.join(args.output_dir, f"{args.data_name}.tsv"), "w") as o_t, open(
36
+ os.path.join(args.output_dir, f"{args.data_name}.unit"), "w"
37
+ ) as o_u:
38
+ print(dir_path, file=o_t)
39
+ for fname in glob.iglob(search_path, recursive=True):
40
+ file_path = os.path.realpath(fname)
41
+ frames = soundfile.info(fname).frames
42
+ print(
43
+ "{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=o_t
44
+ )
45
+
46
+ if args.for_inference:
47
+ print("0", file=o_u)
48
+ else:
49
+ sample_id = os.path.basename(file_path)[: -len(args.ext) - 1]
50
+ assert (
51
+ sample_id in unit_data
52
+ ), f'{fname} does not have unit data in {args.target_unit}. Expecting sample_id "{sample_id}".'
53
+ target_units = process_units(unit_data[sample_id], reduce=True)
54
+ print(" ".join(target_units), file=o_u)
55
+
56
+
57
+ def main():
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--audio-dir", required=True, type=str, help="audio directory")
60
+ parser.add_argument("--ext", default="flac", type=str, help="audio extension")
61
+ parser.add_argument(
62
+ "--data-name",
63
+ required=True,
64
+ type=str,
65
+ help="dataset name",
66
+ )
67
+ parser.add_argument(
68
+ "--output-dir", required=True, type=str, help="output directory"
69
+ )
70
+ parser.add_argument(
71
+ "--for-inference",
72
+ action="store_true",
73
+ help="set this if preparing data for running inference with a speech normalizer",
74
+ )
75
+ parser.add_argument(
76
+ "--target-unit",
77
+ default=None,
78
+ type=str,
79
+ help="a file containing unit sequences in the format: sample_id|u1 u2 ...",
80
+ )
81
+
82
+ args = parser.parse_args()
83
+
84
+ process(args)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
fairseq/examples/speech_to_speech/preprocessing/prep_sn_output_data.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ from pathlib import Path
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ def process(args):
14
+ args.output_root.mkdir(exist_ok=True)
15
+
16
+ # load units
17
+ units = {}
18
+ with open(args.in_unit) as f:
19
+ for line in f:
20
+ unit_seq, utt_id = line.strip().rsplit(" ", 1)
21
+ utt_id = int(utt_id[6:-1]) # remove "(None-"
22
+ units[utt_id] = unit_seq
23
+
24
+ with open(args.in_audio) as f, open(
25
+ args.output_root / f"{args.in_audio.stem}.txt", "w"
26
+ ) as o:
27
+ f.readline()
28
+ for i, line in enumerate(tqdm(f.readlines())):
29
+ audio, _ = line.strip().split("\t", 1)
30
+ sample_id = Path(audio).stem
31
+ o.write(f"{sample_id}|{units[i]}\n")
32
+
33
+
34
+ def main():
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument(
37
+ "--in-unit",
38
+ required=True,
39
+ type=Path,
40
+ help="unit file (output from the speech normalizer)",
41
+ )
42
+ parser.add_argument(
43
+ "--in-audio",
44
+ required=True,
45
+ type=Path,
46
+ help="tsv file (input to the normalizer)",
47
+ )
48
+ parser.add_argument(
49
+ "--output-root", required=True, type=Path, help="output directory"
50
+ )
51
+
52
+ args = parser.parse_args()
53
+
54
+ process(args)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
fairseq/examples/speech_to_speech/unity/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import sequence_generator # noqa
7
+ from . import sequence_generator_multi_decoder # noqa
fairseq/examples/speech_to_speech/unity/sequence_generator.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import sys
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from fairseq.sequence_generator import EnsembleModel as EnsembleModelBase
14
+ from fairseq.sequence_generator import SequenceGenerator as SequenceGeneratorBase
15
+
16
+
17
+ class SequenceGenerator(SequenceGeneratorBase):
18
+ def __init__(
19
+ self,
20
+ models,
21
+ tgt_dict,
22
+ beam_size=1,
23
+ max_len_a=0,
24
+ max_len_b=200,
25
+ max_len=0,
26
+ min_len=1,
27
+ normalize_scores=True,
28
+ len_penalty=1.0,
29
+ unk_penalty=0.0,
30
+ temperature=1.0,
31
+ match_source_len=False,
32
+ no_repeat_ngram_size=0,
33
+ search_strategy=None,
34
+ eos=None,
35
+ symbols_to_strip_from_output=None,
36
+ lm_model=None,
37
+ lm_weight=1.0,
38
+ tokens_to_suppress=(),
39
+ ):
40
+ """Generates translations of a given source sentence.
41
+
42
+ Args:
43
+ models (List[~fairseq.models.FairseqModel]): ensemble of models,
44
+ currently support fairseq.models.TransformerModel for scripting
45
+ beam_size (int, optional): beam width (default: 1)
46
+ max_len_a/b (int, optional): generate sequences of maximum length
47
+ ax + b, where x is the source length
48
+ max_len (int, optional): the maximum length of the generated output
49
+ (not including end-of-sentence)
50
+ min_len (int, optional): the minimum length of the generated output
51
+ (not including end-of-sentence)
52
+ normalize_scores (bool, optional): normalize scores by the length
53
+ of the output (default: True)
54
+ len_penalty (float, optional): length penalty, where <1.0 favors
55
+ shorter, >1.0 favors longer sentences (default: 1.0)
56
+ unk_penalty (float, optional): unknown word penalty, where <0
57
+ produces more unks, >0 produces fewer (default: 0.0)
58
+ temperature (float, optional): temperature, where values
59
+ >1.0 produce more uniform samples and values <1.0 produce
60
+ sharper samples (default: 1.0)
61
+ match_source_len (bool, optional): outputs should match the source
62
+ length (default: False)
63
+ """
64
+ super().__init__(
65
+ models=models,
66
+ tgt_dict=tgt_dict,
67
+ beam_size=beam_size,
68
+ max_len_a=max_len_a,
69
+ max_len_b=max_len_b,
70
+ max_len=max_len,
71
+ min_len=min_len,
72
+ normalize_scores=normalize_scores,
73
+ len_penalty=len_penalty,
74
+ unk_penalty=unk_penalty,
75
+ temperature=temperature,
76
+ match_source_len=match_source_len,
77
+ no_repeat_ngram_size=no_repeat_ngram_size,
78
+ search_strategy=search_strategy,
79
+ eos=eos,
80
+ symbols_to_strip_from_output=symbols_to_strip_from_output,
81
+ lm_model=lm_model,
82
+ lm_weight=lm_weight,
83
+ tokens_to_suppress=tokens_to_suppress,
84
+ )
85
+
86
+ if isinstance(models, EnsembleModel):
87
+ self.model = models
88
+ else:
89
+ self.model = EnsembleModel(models)
90
+
91
+ self.model.set_decoder_beam_size(self.beam_size)
92
+ self.model.eval()
93
+
94
+ def _generate(
95
+ self,
96
+ sample: Dict[str, Dict[str, Tensor]],
97
+ prefix_tokens: Optional[Tensor] = None,
98
+ constraints: Optional[Tensor] = None,
99
+ bos_token: Optional[int] = None,
100
+ ):
101
+ net_input = sample["net_input"]
102
+
103
+ if "src_tokens" in net_input:
104
+ src_tokens = net_input["src_tokens"]
105
+ # length of the source text being the character length except EndOfSentence and pad
106
+ # if src_lengths exists in net_input (speech_to_text dataset case), then use it
107
+ if "src_lengths" in net_input:
108
+ src_lengths = net_input["src_lengths"]
109
+ else:
110
+ src_lengths = (
111
+ (src_tokens.ne(self.eos) & src_tokens.ne(self.pad))
112
+ .long()
113
+ .sum(dim=1)
114
+ )
115
+ elif "source" in net_input:
116
+ src_tokens = net_input["source"]
117
+ src_lengths = (
118
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
119
+ if net_input["padding_mask"] is not None
120
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
121
+ )
122
+ elif "features" in net_input:
123
+ src_tokens = net_input["features"]
124
+ src_lengths = (
125
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
126
+ if net_input["padding_mask"] is not None
127
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
128
+ )
129
+ else:
130
+ raise Exception(
131
+ "expected src_tokens or source in net input. input keys: "
132
+ + str(net_input.keys())
133
+ )
134
+
135
+ if constraints is not None and not self.search.supports_constraints:
136
+ raise NotImplementedError(
137
+ "Target-side constraints were provided, but search method doesn't support them"
138
+ )
139
+
140
+ # Initialize constraints, when active
141
+ self.search.init_constraints(constraints, self.beam_size)
142
+
143
+ # compute the encoder output for each beam
144
+ with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
145
+ encoder_outs = self.model.forward_encoder(net_input)
146
+
147
+ finalized = self.generate_decoder(
148
+ encoder_outs,
149
+ src_tokens,
150
+ src_lengths,
151
+ sample,
152
+ prefix_tokens,
153
+ constraints,
154
+ bos_token,
155
+ )
156
+ return finalized
157
+
158
+ def generate_decoder(
159
+ self,
160
+ encoder_outs,
161
+ src_tokens,
162
+ src_lengths,
163
+ sample: Dict[str, Dict[str, Tensor]],
164
+ prefix_tokens: Optional[Tensor] = None,
165
+ constraints: Optional[Tensor] = None,
166
+ bos_token: Optional[int] = None,
167
+ aux_task_name="",
168
+ encoder_outs_aug: Optional[
169
+ Tensor
170
+ ] = None, # an additional/augmented encoder_outs
171
+ ):
172
+ incremental_states = torch.jit.annotate(
173
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
174
+ [
175
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
176
+ for i in range(self.model.models_size)
177
+ ],
178
+ )
179
+
180
+ # bsz: total number of sentences in beam
181
+ # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
182
+ bsz, src_len = src_tokens.size()[:2]
183
+ beam_size = self.beam_size
184
+
185
+ decoder_name = f"{aux_task_name}_decoder" if aux_task_name else "decoder"
186
+
187
+ max_len: int = -1
188
+ if self.match_source_len:
189
+ max_len = src_lengths.max().item()
190
+ else:
191
+ max_len = min(
192
+ int(self.max_len_a * src_len + self.max_len_b),
193
+ self.max_len - 1,
194
+ )
195
+ assert (
196
+ self.min_len <= max_len
197
+ ), "min_len cannot be larger than max_len, please adjust these!"
198
+
199
+ # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
200
+ new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
201
+ new_order = new_order.to(src_tokens.device).long()
202
+ encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
203
+ # ensure encoder_outs is a List.
204
+ assert encoder_outs is not None
205
+ if encoder_outs_aug is not None:
206
+ encoder_outs_aug = self.model.reorder_encoder_out(
207
+ encoder_outs_aug, new_order
208
+ )
209
+
210
+ # initialize buffers
211
+ scores = (
212
+ torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
213
+ ) # +1 for eos; pad is never chosen for scoring
214
+ tokens = (
215
+ torch.zeros(bsz * beam_size, max_len + 2)
216
+ .to(src_tokens)
217
+ .long()
218
+ .fill_(self.pad)
219
+ ) # +2 for eos and pad
220
+ tokens[:, 0] = self.eos if bos_token is None else bos_token
221
+ attn: Optional[Tensor] = None
222
+
223
+ # A list that indicates candidates that should be ignored.
224
+ # For example, suppose we're sampling and have already finalized 2/5
225
+ # samples. Then cands_to_ignore would mark 2 positions as being ignored,
226
+ # so that we only finalize the remaining 3 samples.
227
+ cands_to_ignore = (
228
+ torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
229
+ ) # forward and backward-compatible False mask
230
+
231
+ # list of completed sentences
232
+ finalized = torch.jit.annotate(
233
+ List[List[Dict[str, Tensor]]],
234
+ [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
235
+ ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
236
+
237
+ # a boolean array indicating if the sentence at the index is finished or not
238
+ finished = [False for i in range(bsz)]
239
+ num_remaining_sent = bsz # number of sentences remaining
240
+
241
+ # number of candidate hypos per step
242
+ cand_size = 2 * beam_size # 2 x beam size in case half are EOS
243
+
244
+ # offset arrays for converting between different indexing schemes
245
+ bbsz_offsets = (
246
+ (torch.arange(0, bsz) * beam_size)
247
+ .unsqueeze(1)
248
+ .type_as(tokens)
249
+ .to(src_tokens.device)
250
+ )
251
+ cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
252
+
253
+ reorder_state: Optional[Tensor] = None
254
+ batch_idxs: Optional[Tensor] = None
255
+
256
+ original_batch_idxs: Optional[Tensor] = None
257
+ if "id" in sample and isinstance(sample["id"], Tensor):
258
+ original_batch_idxs = sample["id"]
259
+ else:
260
+ original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
261
+
262
+ for step in range(max_len + 1): # one extra step for EOS marker
263
+ # reorder decoder internal states based on the prev choice of beams
264
+ if reorder_state is not None:
265
+ if batch_idxs is not None:
266
+ # update beam indices to take into account removed sentences
267
+ corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
268
+ batch_idxs
269
+ )
270
+ reorder_state.view(-1, beam_size).add_(
271
+ corr.unsqueeze(-1) * beam_size
272
+ )
273
+ original_batch_idxs = original_batch_idxs[batch_idxs]
274
+ self.model.reorder_incremental_state(
275
+ incremental_states, reorder_state, decoder_name
276
+ )
277
+ encoder_outs = self.model.reorder_encoder_out(
278
+ encoder_outs, reorder_state
279
+ )
280
+ if encoder_outs_aug is not None:
281
+ encoder_outs_aug = self.model.reorder_encoder_out(
282
+ encoder_outs_aug, reorder_state
283
+ )
284
+ with torch.autograd.profiler.record_function(
285
+ "EnsembleModel: forward_decoder"
286
+ ):
287
+ lprobs, avg_attn_scores = self.model.forward_decoder(
288
+ tokens[:, : step + 1],
289
+ encoder_outs,
290
+ incremental_states,
291
+ self.temperature,
292
+ decoder_name=decoder_name,
293
+ encoder_outs_aug=encoder_outs_aug,
294
+ )
295
+
296
+ if self.lm_model is not None and not aux_task_name:
297
+ lm_out = self.lm_model(tokens[:, : step + 1])
298
+ probs = self.lm_model.get_normalized_probs(
299
+ lm_out, log_probs=True, sample=None
300
+ )
301
+ probs = probs[:, -1, :] * self.lm_weight
302
+ lprobs += probs
303
+
304
+ lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
305
+
306
+ lprobs[:, self.pad] = -math.inf # never select pad
307
+ lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
308
+
309
+ # handle max length constraint
310
+ if step >= max_len:
311
+ lprobs[:, : self.eos] = -math.inf
312
+ lprobs[:, self.eos + 1 :] = -math.inf
313
+
314
+ # handle prefix tokens (possibly with different lengths)
315
+ if (
316
+ prefix_tokens is not None
317
+ and step < prefix_tokens.size(1)
318
+ and step < max_len
319
+ ):
320
+ lprobs, tokens, scores = self._prefix_tokens(
321
+ step, lprobs, scores, tokens, prefix_tokens, beam_size
322
+ )
323
+ else:
324
+ if step < self.min_len:
325
+ # minimum length constraint (does not apply if using prefix_tokens)
326
+ lprobs[:, self.eos] = -math.inf
327
+
328
+ if self.token_indices_to_suppress is not None:
329
+ lprobs[:, self.token_indices_to_suppress] = -math.inf
330
+
331
+ # Record attention scores, only support avg_attn_scores is a Tensor
332
+ if avg_attn_scores is not None:
333
+ if attn is None:
334
+ attn = torch.empty(
335
+ bsz * beam_size, avg_attn_scores.size(1), max_len + 2
336
+ ).to(scores)
337
+ attn[:, :, step + 1].copy_(avg_attn_scores)
338
+
339
+ scores = scores.type_as(lprobs)
340
+ eos_bbsz_idx = torch.empty(0).to(
341
+ tokens
342
+ ) # indices of hypothesis ending with eos (finished sentences)
343
+ eos_scores = torch.empty(0).to(
344
+ scores
345
+ ) # scores of hypothesis ending with eos (finished sentences)
346
+
347
+ if self.should_set_src_lengths:
348
+ self.search.set_src_lengths(src_lengths)
349
+
350
+ if self.repeat_ngram_blocker is not None:
351
+ lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
352
+
353
+ # Shape: (batch, cand_size)
354
+ cand_scores, cand_indices, cand_beams = self.search.step(
355
+ step,
356
+ lprobs.view(bsz, -1, self.vocab_size),
357
+ scores.view(bsz, beam_size, -1)[:, :, :step],
358
+ tokens[:, : step + 1],
359
+ original_batch_idxs,
360
+ )
361
+
362
+ # cand_bbsz_idx contains beam indices for the top candidate
363
+ # hypotheses, with a range of values: [0, bsz*beam_size),
364
+ # and dimensions: [bsz, cand_size]
365
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
366
+
367
+ # finalize hypotheses that end in eos
368
+ # Shape of eos_mask: (batch size, beam size)
369
+ eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
370
+ eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
371
+
372
+ # only consider eos when it's among the top beam_size indices
373
+ # Now we know what beam item(s) to finish
374
+ # Shape: 1d list of absolute-numbered
375
+ eos_bbsz_idx = torch.masked_select(
376
+ cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
377
+ )
378
+
379
+ finalized_sents: List[int] = []
380
+ if eos_bbsz_idx.numel() > 0:
381
+ eos_scores = torch.masked_select(
382
+ cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
383
+ )
384
+
385
+ finalized_sents = self.finalize_hypos(
386
+ step,
387
+ eos_bbsz_idx,
388
+ eos_scores,
389
+ tokens,
390
+ scores,
391
+ finalized,
392
+ finished,
393
+ beam_size,
394
+ attn,
395
+ src_lengths,
396
+ max_len,
397
+ )
398
+ num_remaining_sent -= len(finalized_sents)
399
+
400
+ assert num_remaining_sent >= 0
401
+ if num_remaining_sent == 0:
402
+ break
403
+ if self.search.stop_on_max_len and step >= max_len:
404
+ break
405
+ assert step < max_len, f"{step} < {max_len}"
406
+
407
+ # Remove finalized sentences (ones for which {beam_size}
408
+ # finished hypotheses have been generated) from the batch.
409
+ if len(finalized_sents) > 0:
410
+ new_bsz = bsz - len(finalized_sents)
411
+
412
+ # construct batch_idxs which holds indices of batches to keep for the next pass
413
+ batch_mask = torch.ones(
414
+ bsz, dtype=torch.bool, device=cand_indices.device
415
+ )
416
+ batch_mask[finalized_sents] = False
417
+ # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
418
+ batch_idxs = torch.arange(
419
+ bsz, device=cand_indices.device
420
+ ).masked_select(batch_mask)
421
+
422
+ # Choose the subset of the hypothesized constraints that will continue
423
+ self.search.prune_sentences(batch_idxs)
424
+
425
+ eos_mask = eos_mask[batch_idxs]
426
+ cand_beams = cand_beams[batch_idxs]
427
+ bbsz_offsets.resize_(new_bsz, 1)
428
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
429
+ cand_scores = cand_scores[batch_idxs]
430
+ cand_indices = cand_indices[batch_idxs]
431
+
432
+ if prefix_tokens is not None:
433
+ prefix_tokens = prefix_tokens[batch_idxs]
434
+ src_lengths = src_lengths[batch_idxs]
435
+ cands_to_ignore = cands_to_ignore[batch_idxs]
436
+
437
+ scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
438
+ tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
439
+ if attn is not None:
440
+ attn = attn.view(bsz, -1)[batch_idxs].view(
441
+ new_bsz * beam_size, attn.size(1), -1
442
+ )
443
+ bsz = new_bsz
444
+ else:
445
+ batch_idxs = None
446
+
447
+ # Set active_mask so that values > cand_size indicate eos hypos
448
+ # and values < cand_size indicate candidate active hypos.
449
+ # After, the min values per row are the top candidate active hypos
450
+
451
+ # Rewrite the operator since the element wise or is not supported in torchscript.
452
+
453
+ eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
454
+ active_mask = torch.add(
455
+ eos_mask.type_as(cand_offsets) * cand_size,
456
+ cand_offsets[: eos_mask.size(1)],
457
+ )
458
+
459
+ # get the top beam_size active hypotheses, which are just
460
+ # the hypos with the smallest values in active_mask.
461
+ # {active_hypos} indicates which {beam_size} hypotheses
462
+ # from the list of {2 * beam_size} candidates were
463
+ # selected. Shapes: (batch size, beam size)
464
+ new_cands_to_ignore, active_hypos = torch.topk(
465
+ active_mask, k=beam_size, dim=1, largest=False
466
+ )
467
+
468
+ # update cands_to_ignore to ignore any finalized hypos.
469
+ cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
470
+ # Make sure there is at least one active item for each sentence in the batch.
471
+ assert (~cands_to_ignore).any(dim=1).all()
472
+
473
+ # update cands_to_ignore to ignore any finalized hypos
474
+
475
+ # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
476
+ # can be selected more than once).
477
+ active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
478
+ active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
479
+
480
+ active_bbsz_idx = active_bbsz_idx.view(-1)
481
+ active_scores = active_scores.view(-1)
482
+
483
+ # copy tokens and scores for active hypotheses
484
+
485
+ # Set the tokens for each beam (can select the same row more than once)
486
+ tokens[:, : step + 1] = torch.index_select(
487
+ tokens[:, : step + 1], dim=0, index=active_bbsz_idx
488
+ )
489
+ # Select the next token for each of them
490
+ tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
491
+ cand_indices, dim=1, index=active_hypos
492
+ )
493
+ if step > 0:
494
+ scores[:, :step] = torch.index_select(
495
+ scores[:, :step], dim=0, index=active_bbsz_idx
496
+ )
497
+ scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
498
+ cand_scores, dim=1, index=active_hypos
499
+ )
500
+
501
+ # Update constraints based on which candidates were selected for the next beam
502
+ self.search.update_constraints(active_hypos)
503
+
504
+ # copy attention for active hypotheses
505
+ if attn is not None:
506
+ attn[:, :, : step + 2] = torch.index_select(
507
+ attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
508
+ )
509
+
510
+ # reorder incremental state in decoder
511
+ reorder_state = active_bbsz_idx
512
+
513
+ # sort by score descending
514
+ for sent in range(len(finalized)):
515
+ scores = torch.tensor(
516
+ [float(elem["score"].item()) for elem in finalized[sent]]
517
+ )
518
+ _, sorted_scores_indices = torch.sort(scores, descending=True)
519
+ finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
520
+ finalized[sent] = torch.jit.annotate(
521
+ List[Dict[str, Tensor]], finalized[sent]
522
+ )
523
+ return finalized
524
+
525
+
526
+ class EnsembleModel(EnsembleModelBase):
527
+ """A wrapper around an ensemble of models."""
528
+
529
+ def __init__(self, models):
530
+ super().__init__(models)
531
+
532
+ @torch.jit.export
533
+ def forward_decoder(
534
+ self,
535
+ tokens,
536
+ encoder_outs: List[Dict[str, List[Tensor]]],
537
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
538
+ temperature: float = 1.0,
539
+ decoder_name="decoder",
540
+ encoder_outs_aug: List[Dict[str, List[Tensor]]] = None,
541
+ ):
542
+ log_probs = []
543
+ avg_attn: Optional[Tensor] = None
544
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None
545
+ encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None
546
+ for i, model in enumerate(self.models):
547
+ if self.has_encoder():
548
+ encoder_out = encoder_outs[i]
549
+ if encoder_outs_aug is not None:
550
+ encoder_out_aug = encoder_outs_aug[i]
551
+ # decode each model
552
+ if self.has_incremental_states():
553
+ if encoder_out_aug is not None:
554
+ decoder_out = getattr(model, decoder_name).forward(
555
+ tokens,
556
+ encoder_out=encoder_out,
557
+ encoder_out_aug=encoder_out_aug,
558
+ incremental_state=incremental_states[i],
559
+ )
560
+ else:
561
+ decoder_out = getattr(model, decoder_name).forward(
562
+ tokens,
563
+ encoder_out=encoder_out,
564
+ incremental_state=incremental_states[i],
565
+ )
566
+ else:
567
+ if hasattr(model, decoder_name):
568
+ decoder_out = getattr(model, decoder_name).forward(
569
+ tokens, encoder_out=encoder_out
570
+ )
571
+ else:
572
+ decoder_out = model.forward(tokens)
573
+
574
+ attn: Optional[Tensor] = None
575
+ decoder_len = len(decoder_out)
576
+ if decoder_len > 1 and decoder_out[1] is not None:
577
+ if isinstance(decoder_out[1], Tensor):
578
+ attn = decoder_out[1]
579
+ else:
580
+ attn_holder = decoder_out[1]["attn"]
581
+ if isinstance(attn_holder, Tensor):
582
+ attn = attn_holder
583
+ elif attn_holder is not None:
584
+ attn = attn_holder[0]
585
+ if attn is not None:
586
+ attn = attn[:, -1, :]
587
+
588
+ decoder_out_tuple = (
589
+ decoder_out[0][:, -1:, :].div_(temperature),
590
+ None if decoder_len <= 1 else decoder_out[1],
591
+ )
592
+ probs = getattr(model, decoder_name).get_normalized_probs(
593
+ decoder_out_tuple, log_probs=True, sample=None
594
+ )
595
+ probs = probs[:, -1, :]
596
+ if self.models_size == 1:
597
+ return probs, attn
598
+
599
+ log_probs.append(probs)
600
+ if attn is not None:
601
+ if avg_attn is None:
602
+ avg_attn = attn
603
+ else:
604
+ avg_attn.add_(attn)
605
+
606
+ avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
607
+ self.models_size
608
+ )
609
+
610
+ if avg_attn is not None:
611
+ avg_attn.div_(self.models_size)
612
+ return avg_probs, avg_attn
613
+
614
+ @torch.jit.export
615
+ def reorder_incremental_state(
616
+ self,
617
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
618
+ new_order,
619
+ decoder_name="decoder",
620
+ ):
621
+ if not self.has_incremental_states():
622
+ return
623
+ for i, model in enumerate(self.models):
624
+ getattr(model, decoder_name).reorder_incremental_state_scripting(
625
+ incremental_states[i], new_order
626
+ )
fairseq/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Dict, List, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+
12
+ from fairseq import search
13
+
14
+
15
+ class MultiDecoderSequenceGenerator(nn.Module):
16
+ def __init__(
17
+ self,
18
+ models,
19
+ tgt_dict,
20
+ tgt_dict_mt,
21
+ beam_size=1,
22
+ beam_size_mt=1,
23
+ max_len_a=0,
24
+ max_len_b=200,
25
+ max_len_a_mt=0,
26
+ max_len_b_mt=200,
27
+ max_len=0,
28
+ min_len=1,
29
+ normalize_scores=True,
30
+ len_penalty=1.0,
31
+ len_penalty_mt=1.0,
32
+ unk_penalty=0.0,
33
+ temperature=1.0,
34
+ match_source_len=False,
35
+ no_repeat_ngram_size=0,
36
+ eos=None,
37
+ eos_mt=None,
38
+ symbols_to_strip_from_output=None,
39
+ lm_model=None,
40
+ lm_weight=1.0,
41
+ ):
42
+ """Generates translations of a given source sentence.
43
+
44
+ Args:
45
+ models (List[~fairseq.models.FairseqModel]): ensemble of models,
46
+ currently support fairseq.models.TransformerModel for scripting
47
+ beam_size (int, optional): beam width (default: 1)
48
+ max_len_a/b (int, optional): generate sequences of maximum length
49
+ ax + b, where x is the source length for the second pass
50
+ max_len_a_mt/b_mt (int, optional): generate sequences of maximum length
51
+ ax + b, where x is the source length for the first pass
52
+ max_len (int, optional): the maximum length of the generated output
53
+ (not including end-of-sentence)
54
+ min_len (int, optional): the minimum length of the generated output
55
+ (not including end-of-sentence)
56
+ normalize_scores (bool, optional): normalize scores by the length
57
+ of the output (default: True)
58
+ len_penalty (float, optional): length penalty in the second pass, where <1.0 favors
59
+ shorter, >1.0 favors longer sentences (default: 1.0)
60
+ len_penalty (float, optional): length penalty in the first pass, where <1.0 favors
61
+ shorter, >1.0 favors longer sentences (default: 1.0)
62
+ unk_penalty (float, optional): unknown word penalty, where <0
63
+ produces more unks, >0 produces fewer (default: 0.0)
64
+ temperature (float, optional): temperature, where values
65
+ >1.0 produce more uniform samples and values <1.0 produce
66
+ sharper samples (default: 1.0)
67
+ match_source_len (bool, optional): outputs should match the source
68
+ length (default: False)
69
+ """
70
+ super().__init__()
71
+
72
+ from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator
73
+
74
+ self.generator = SequenceGenerator(
75
+ models,
76
+ tgt_dict,
77
+ beam_size=beam_size,
78
+ max_len_a=max_len_a,
79
+ max_len_b=max_len_b,
80
+ max_len=max_len,
81
+ min_len=min_len,
82
+ normalize_scores=normalize_scores,
83
+ len_penalty=len_penalty,
84
+ unk_penalty=unk_penalty,
85
+ temperature=temperature,
86
+ match_source_len=match_source_len,
87
+ no_repeat_ngram_size=no_repeat_ngram_size,
88
+ search_strategy=search.BeamSearch(tgt_dict),
89
+ eos=eos,
90
+ symbols_to_strip_from_output=symbols_to_strip_from_output,
91
+ lm_model=lm_model,
92
+ lm_weight=lm_weight,
93
+ )
94
+ self.eos = self.generator.eos
95
+
96
+ self.generator_mt = SequenceGenerator(
97
+ models,
98
+ tgt_dict_mt,
99
+ beam_size=beam_size_mt,
100
+ max_len_a=max_len_a_mt,
101
+ max_len_b=max_len_b_mt,
102
+ max_len=max_len,
103
+ min_len=min_len,
104
+ normalize_scores=normalize_scores,
105
+ len_penalty=len_penalty_mt,
106
+ unk_penalty=unk_penalty,
107
+ temperature=temperature,
108
+ match_source_len=match_source_len,
109
+ no_repeat_ngram_size=no_repeat_ngram_size,
110
+ search_strategy=search.BeamSearch(tgt_dict_mt),
111
+ eos=eos_mt,
112
+ symbols_to_strip_from_output=symbols_to_strip_from_output,
113
+ )
114
+
115
+ @torch.no_grad()
116
+ def generate(
117
+ self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
118
+ ) -> List[List[Dict[str, Tensor]]]:
119
+ """Generate translations. Match the api of other fairseq generators.
120
+
121
+ Args:
122
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
123
+ sample (dict): batch
124
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
125
+ with these tokens
126
+ constraints (torch.LongTensor, optional): force decoder to include
127
+ the list of constraints
128
+ bos_token (int, optional): beginning of sentence token
129
+ (default: self.eos)
130
+ """
131
+ return self._generate(sample, **kwargs)
132
+
133
+ def _generate(
134
+ self,
135
+ sample: Dict[str, Dict[str, Tensor]],
136
+ prefix_tokens: Optional[Tensor] = None,
137
+ constraints: Optional[Tensor] = None,
138
+ bos_token: Optional[int] = None,
139
+ ):
140
+ net_input = sample["net_input"]
141
+
142
+ if "src_tokens" in net_input:
143
+ src_tokens = net_input["src_tokens"]
144
+ # length of the source text being the character length except EndOfSentence and pad
145
+ # if src_lengths exists in net_input (speech_to_text dataset case), then use it
146
+ if "src_lengths" in net_input:
147
+ src_lengths = net_input["src_lengths"]
148
+ else:
149
+ src_lengths = (
150
+ (
151
+ src_tokens.ne(self.generator.eos)
152
+ & src_tokens.ne(self.generator.pad)
153
+ )
154
+ .long()
155
+ .sum(dim=1)
156
+ )
157
+ else:
158
+ raise Exception(
159
+ "expected src_tokens or source in net input. input keys: "
160
+ + str(net_input.keys())
161
+ )
162
+
163
+ if constraints is not None and not self.generator.search.supports_constraints:
164
+ raise NotImplementedError(
165
+ "Target-side constraints were provided, but search method doesn't support them"
166
+ )
167
+
168
+ # Initialize constraints, when active
169
+ self.generator.search.init_constraints(constraints, self.generator.beam_size)
170
+ self.generator_mt.search.init_constraints(
171
+ constraints, self.generator_mt.beam_size
172
+ )
173
+
174
+ # compute the encoder output for each beam
175
+ with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
176
+ encoder_outs = self.generator.model.forward_encoder(net_input)
177
+
178
+ single_model = self.generator.model.single_model
179
+ mt_decoder = getattr(single_model, f"{single_model.mt_task_name}_decoder")
180
+
181
+ # 1. MT decoder
182
+ finalized_mt = self.generator_mt.generate_decoder(
183
+ encoder_outs,
184
+ src_tokens,
185
+ src_lengths,
186
+ sample,
187
+ prefix_tokens,
188
+ constraints,
189
+ bos_token,
190
+ aux_task_name=single_model.mt_task_name,
191
+ )
192
+
193
+ # extract decoder output corresponding to the best hypothesis
194
+ max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt])
195
+ prev_output_tokens_mt = (
196
+ src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len)
197
+ .fill_(mt_decoder.padding_idx)
198
+ .int()
199
+ ) # B x T
200
+ for i, hypo in enumerate(finalized_mt):
201
+ i_beam = 0
202
+ tmp = hypo[i_beam]["tokens"].int() # hyp + eos
203
+ prev_output_tokens_mt[i, 0] = self.generator_mt.eos
204
+ if tmp[-1] == self.generator_mt.eos:
205
+ tmp = tmp[:-1]
206
+ prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp
207
+
208
+ text = "".join([self.generator_mt.tgt_dict[c] for c in tmp])
209
+ text = text.replace("_", " ")
210
+ text = text.replace("▁", " ")
211
+ text = text.replace("<unk>", " ")
212
+ text = text.replace("<s>", "")
213
+ text = text.replace("</s>", "")
214
+ if len(text) > 0 and text[0] == " ":
215
+ text = text[1:]
216
+ sample_id = sample["id"].tolist()[i]
217
+ print("{} (None-{})".format(text, sample_id))
218
+
219
+ x = mt_decoder(
220
+ prev_output_tokens_mt,
221
+ encoder_out=encoder_outs[0],
222
+ features_only=True,
223
+ )[0].transpose(0, 1)
224
+
225
+ if getattr(single_model, "proj", None) is not None:
226
+ x = single_model.proj(x)
227
+
228
+ mt_decoder_padding_mask = None
229
+ if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
230
+ mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
231
+
232
+ # 2. T2U encoder
233
+ if getattr(single_model, "synthesizer_encoder", None) is not None:
234
+ t2u_encoder_out = single_model.synthesizer_encoder(
235
+ x,
236
+ mt_decoder_padding_mask,
237
+ )
238
+ else:
239
+ t2u_encoder_out = {
240
+ "encoder_out": [x], # T x B x C
241
+ "encoder_padding_mask": [mt_decoder_padding_mask]
242
+ if mt_decoder_padding_mask is not None
243
+ else [], # B x T
244
+ "encoder_embedding": [],
245
+ "encoder_states": [],
246
+ "src_tokens": [],
247
+ "src_lengths": [],
248
+ }
249
+
250
+ if getattr(single_model, "t2u_augmented_cross_attn", False):
251
+ encoder_outs_aug = [t2u_encoder_out]
252
+ else:
253
+ encoder_outs = [t2u_encoder_out]
254
+ encoder_outs_aug = None
255
+
256
+ # 3. T2U decoder
257
+ finalized = self.generator.generate_decoder(
258
+ encoder_outs,
259
+ src_tokens,
260
+ src_lengths,
261
+ sample,
262
+ prefix_tokens,
263
+ constraints,
264
+ bos_token,
265
+ encoder_outs_aug=encoder_outs_aug,
266
+ )
267
+ return finalized
fairseq/examples/speech_to_text/README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Speech-to-Text (S2T) Modeling
2
+
3
+ [https://www.aclweb.org/anthology/2020.aacl-demo.6](https://www.aclweb.org/anthology/2020.aacl-demo.6.pdf)
4
+
5
+ Speech recognition (ASR) and speech-to-text translation (ST) with fairseq.
6
+
7
+ ## Data Preparation
8
+ S2T modeling data consists of source speech features, target text and other optional information
9
+ (source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files
10
+ to store these information. Each data field is represented by a column in the TSV file.
11
+
12
+ Unlike text token embeddings, speech features (e.g. log mel-scale filter banks) are usually fixed
13
+ during model training and can be pre-computed. The manifest file contains the path to
14
+ either the feature file in NumPy format or the WAV/FLAC audio file. For the latter,
15
+ features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed
16
+ into uncompressed ZIP files (then accessed via byte offset and length) to improve I/O performance.
17
+
18
+ Fairseq S2T also employs a YAML file for data related configurations: tokenizer type and dictionary path
19
+ for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment,
20
+ temperature-based resampling, etc.
21
+
22
+ ## Model Training
23
+ Fairseq S2T uses the unified `fairseq-train` interface for model training. It requires arguments `--task speech_to_text`,
24
+ `--arch <model architecture in fairseq.models.speech_to_text.*>` and `--config-yaml <config YAML filename>`.
25
+
26
+ ## Inference & Evaluation
27
+ Fairseq S2T uses the unified `fairseq-generate`/`fairseq-interactive` interface for inference and evaluation. It
28
+ requires arguments `--task speech_to_text` and `--config-yaml <config YAML filename>`. The interactive console takes
29
+ audio paths (one per line) as inputs.
30
+
31
+
32
+ ## Examples
33
+ - [Speech Recognition (ASR) on LibriSpeech](docs/librispeech_example.md)
34
+
35
+ - [Speech-to-Text Translation (ST) on MuST-C](docs/mustc_example.md)
36
+
37
+ - [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md)
38
+
39
+ - [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md)
40
+ - [Simultaneous Speech-to-Text Translation (SimulST) on MuST-C](docs/simulst_mustc_example.md)
41
+
42
+ ## Updates
43
+ - 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples:
44
+ [ASR (LibriSpeech)](docs/librispeech_example.md#interactive-decoding)
45
+ and [ST (CoVoST 2)](docs/covost_example.md#interactive-decoding).
46
+ - 01/08/2021: Several fixes for S2T Transformer model, inference-time de-tokenization, scorer configuration and data
47
+ preparation scripts. We also add pre-trained models to the examples and revise the instructions.
48
+ Breaking changes: the data preparation scripts now extract filterbank features without CMVN. CMVN is instead applied
49
+ on-the-fly (defined in the config YAML).
50
+
51
+ ## What's Next
52
+ - We are migrating the old fairseq [ASR example](../speech_recognition) into this S2T framework and
53
+ merging the features from both sides.
54
+ - The following papers also base their experiments on fairseq S2T. We are adding more examples for replication.
55
+ - [Improving Cross-Lingual Transfer Learning for End-to-End Speech Recognition with Speech Translation (Wang et al., 2020)](https://arxiv.org/abs/2006.05474)
56
+ - [Self-Supervised Representations Improve End-to-End Speech Translation (Wu et al., 2020)](https://arxiv.org/abs/2006.12124)
57
+ - [Self-Training for End-to-End Speech Translation (Pino et al., 2020)](https://arxiv.org/abs/2006.02490)
58
+ - [CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus (Wang et al., 2020)](https://arxiv.org/abs/2002.01320)
59
+ - [Harnessing Indirect Training Data for End-to-End Automatic Speech Translation: Tricks of the Trade (Pino et al., 2019)](https://arxiv.org/abs/1909.06515)
60
+
61
+ ## Citation
62
+ Please cite as:
63
+ ```
64
+ @inproceedings{wang2020fairseqs2t,
65
+ title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
66
+ author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
67
+ booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
68
+ year = {2020},
69
+ }
70
+
71
+ @inproceedings{ott2019fairseq,
72
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
73
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
74
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
75
+ year = {2019},
76
+ }
77
+ ```
fairseq/examples/speech_to_text/data_utils.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from pathlib import Path
8
+ import zipfile
9
+ from functools import reduce
10
+ from multiprocessing import cpu_count
11
+ from typing import Any, Dict, List, Optional, Union
12
+ import io
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import sentencepiece as sp
17
+ from fairseq.data.audio.audio_utils import (
18
+ convert_waveform, _get_kaldi_fbank, _get_torchaudio_fbank, is_npy_data,
19
+ is_sf_audio_data
20
+ )
21
+ import torch
22
+ import soundfile as sf
23
+ from tqdm import tqdm
24
+
25
+
26
+ UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
27
+ BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
28
+ EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
29
+ PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
30
+
31
+
32
+ def gen_vocab(
33
+ input_path: Path, output_path_prefix: Path, model_type="bpe",
34
+ vocab_size=1000, special_symbols: Optional[List[str]] = None
35
+ ):
36
+ # Train SentencePiece Model
37
+ arguments = [
38
+ f"--input={input_path.as_posix()}",
39
+ f"--model_prefix={output_path_prefix.as_posix()}",
40
+ f"--model_type={model_type}",
41
+ f"--vocab_size={vocab_size}",
42
+ "--character_coverage=1.0",
43
+ f"--num_threads={cpu_count()}",
44
+ f"--unk_id={UNK_TOKEN_ID}",
45
+ f"--bos_id={BOS_TOKEN_ID}",
46
+ f"--eos_id={EOS_TOKEN_ID}",
47
+ f"--pad_id={PAD_TOKEN_ID}",
48
+ ]
49
+ if special_symbols is not None:
50
+ _special_symbols = ",".join(special_symbols)
51
+ arguments.append(f"--user_defined_symbols={_special_symbols}")
52
+ sp.SentencePieceTrainer.Train(" ".join(arguments))
53
+ # Export fairseq dictionary
54
+ spm = sp.SentencePieceProcessor()
55
+ spm.Load(output_path_prefix.as_posix() + ".model")
56
+ vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
57
+ assert (
58
+ vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
59
+ and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
60
+ and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
61
+ and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
62
+ )
63
+ vocab = {
64
+ i: s
65
+ for i, s in vocab.items()
66
+ if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
67
+ }
68
+ with open(output_path_prefix.as_posix() + ".txt", "w") as f_out:
69
+ for _, s in sorted(vocab.items(), key=lambda x: x[0]):
70
+ f_out.write(f"{s} 1\n")
71
+
72
+
73
+ def extract_fbank_features(
74
+ waveform: torch.FloatTensor,
75
+ sample_rate: int,
76
+ output_path: Optional[Path] = None,
77
+ n_mel_bins: int = 80,
78
+ overwrite: bool = False,
79
+ ):
80
+ if output_path is not None and output_path.is_file() and not overwrite:
81
+ return
82
+
83
+ _waveform, _ = convert_waveform(waveform, sample_rate, to_mono=True)
84
+ # Kaldi compliance: 16-bit signed integers
85
+ _waveform = _waveform * (2 ** 15)
86
+ _waveform = _waveform.numpy()
87
+
88
+ features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
89
+ if features is None:
90
+ features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
91
+ if features is None:
92
+ raise ImportError(
93
+ "Please install pyKaldi or torchaudio to enable fbank feature extraction"
94
+ )
95
+
96
+ if output_path is not None:
97
+ np.save(output_path.as_posix(), features)
98
+ return features
99
+
100
+
101
+ def create_zip(data_root: Path, zip_path: Path):
102
+ paths = list(data_root.glob("*.npy"))
103
+ paths.extend(data_root.glob("*.flac"))
104
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
105
+ for path in tqdm(paths):
106
+ f.write(path, arcname=path.name)
107
+
108
+
109
+ def get_zip_manifest(
110
+ zip_path: Path, zip_root: Optional[Path] = None, is_audio=False
111
+ ):
112
+ _zip_path = Path.joinpath(zip_root or Path(""), zip_path)
113
+ with zipfile.ZipFile(_zip_path, mode="r") as f:
114
+ info = f.infolist()
115
+ paths, lengths = {}, {}
116
+ for i in tqdm(info):
117
+ utt_id = Path(i.filename).stem
118
+ offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
119
+ paths[utt_id] = f"{zip_path.as_posix()}:{offset}:{file_size}"
120
+ with open(_zip_path, "rb") as f:
121
+ f.seek(offset)
122
+ byte_data = f.read(file_size)
123
+ assert len(byte_data) > 1
124
+ if is_audio:
125
+ assert is_sf_audio_data(byte_data), i
126
+ else:
127
+ assert is_npy_data(byte_data), i
128
+ byte_data_fp = io.BytesIO(byte_data)
129
+ if is_audio:
130
+ lengths[utt_id] = sf.info(byte_data_fp).frames
131
+ else:
132
+ lengths[utt_id] = np.load(byte_data_fp).shape[0]
133
+ return paths, lengths
134
+
135
+
136
+ def gen_config_yaml(
137
+ manifest_root: Path,
138
+ spm_filename: Optional[str] = None,
139
+ vocab_name: Optional[str] = None,
140
+ yaml_filename: str = "config.yaml",
141
+ specaugment_policy: Optional[str] = "lb",
142
+ prepend_tgt_lang_tag: bool = False,
143
+ sampling_alpha: Optional[float] = None,
144
+ input_channels: Optional[int] = 1,
145
+ input_feat_per_channel: Optional[int] = 80,
146
+ audio_root: str = "",
147
+ cmvn_type: str = "utterance",
148
+ gcmvn_path: Optional[Path] = None,
149
+ extra=None
150
+ ):
151
+ manifest_root = manifest_root.absolute()
152
+ writer = S2TDataConfigWriter(manifest_root / yaml_filename)
153
+ assert spm_filename is not None or vocab_name is not None
154
+ vocab_name = spm_filename.replace(".model", ".txt") if vocab_name is None \
155
+ else vocab_name
156
+ writer.set_vocab_filename(vocab_name)
157
+ if input_channels is not None:
158
+ writer.set_input_channels(input_channels)
159
+ if input_feat_per_channel is not None:
160
+ writer.set_input_feat_per_channel(input_feat_per_channel)
161
+ specaugment_setters = {
162
+ "lb": writer.set_specaugment_lb_policy,
163
+ "ld": writer.set_specaugment_ld_policy,
164
+ "sm": writer.set_specaugment_sm_policy,
165
+ "ss": writer.set_specaugment_ss_policy,
166
+ }
167
+ specaugment_setter = specaugment_setters.get(specaugment_policy, None)
168
+ if specaugment_setter is not None:
169
+ specaugment_setter()
170
+ if spm_filename is not None:
171
+ writer.set_bpe_tokenizer(
172
+ {
173
+ "bpe": "sentencepiece",
174
+ "sentencepiece_model": (manifest_root / spm_filename).as_posix(),
175
+ }
176
+ )
177
+ if prepend_tgt_lang_tag:
178
+ writer.set_prepend_tgt_lang_tag(True)
179
+ if sampling_alpha is not None:
180
+ writer.set_sampling_alpha(sampling_alpha)
181
+
182
+ if cmvn_type not in ["global", "utterance"]:
183
+ raise NotImplementedError
184
+
185
+ if specaugment_policy is not None:
186
+ writer.set_feature_transforms(
187
+ "_train", [f"{cmvn_type}_cmvn", "specaugment"]
188
+ )
189
+ writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"])
190
+
191
+ if cmvn_type == "global":
192
+ if gcmvn_path is None:
193
+ raise ValueError("Please provide path of global cmvn file.")
194
+ else:
195
+ writer.set_global_cmvn(gcmvn_path.as_posix())
196
+
197
+ if len(audio_root) > 0:
198
+ writer.set_audio_root(audio_root)
199
+
200
+ if extra is not None:
201
+ writer.set_extra(extra)
202
+ writer.flush()
203
+
204
+
205
+ def load_df_from_tsv(path: Union[str, Path]) -> pd.DataFrame:
206
+ _path = path if isinstance(path, str) else path.as_posix()
207
+ return pd.read_csv(
208
+ _path,
209
+ sep="\t",
210
+ header=0,
211
+ encoding="utf-8",
212
+ escapechar="\\",
213
+ quoting=csv.QUOTE_NONE,
214
+ na_filter=False,
215
+ )
216
+
217
+
218
+ def save_df_to_tsv(dataframe, path: Union[str, Path]):
219
+ _path = path if isinstance(path, str) else path.as_posix()
220
+ dataframe.to_csv(
221
+ _path,
222
+ sep="\t",
223
+ header=True,
224
+ index=False,
225
+ encoding="utf-8",
226
+ escapechar="\\",
227
+ quoting=csv.QUOTE_NONE,
228
+ )
229
+
230
+
231
+ def load_tsv_to_dicts(path: Union[str, Path]) -> List[dict]:
232
+ with open(path, "r") as f:
233
+ reader = csv.DictReader(
234
+ f,
235
+ delimiter="\t",
236
+ quotechar=None,
237
+ doublequote=False,
238
+ lineterminator="\n",
239
+ quoting=csv.QUOTE_NONE,
240
+ )
241
+ rows = [dict(e) for e in reader]
242
+ return rows
243
+
244
+
245
+ def filter_manifest_df(
246
+ df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
247
+ ):
248
+ filters = {
249
+ "no speech": df["audio"] == "",
250
+ f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
251
+ "empty sentence": df["tgt_text"] == "",
252
+ }
253
+ if is_train_split:
254
+ filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
255
+ if extra_filters is not None:
256
+ filters.update(extra_filters)
257
+ invalid = reduce(lambda x, y: x | y, filters.values())
258
+ valid = ~invalid
259
+ print(
260
+ "| "
261
+ + ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
262
+ + f", total {invalid.sum()} filtered, {valid.sum()} remained."
263
+ )
264
+ return df[valid]
265
+
266
+
267
+ def cal_gcmvn_stats(features_list):
268
+ features = np.concatenate(features_list)
269
+ square_sums = (features ** 2).sum(axis=0)
270
+ mean = features.mean(axis=0)
271
+ features = np.subtract(features, mean)
272
+ var = square_sums / features.shape[0] - mean ** 2
273
+ std = np.sqrt(np.maximum(var, 1e-8))
274
+ return {"mean": mean.astype("float32"), "std": std.astype("float32")}
275
+
276
+
277
+ class S2TDataConfigWriter(object):
278
+ DEFAULT_VOCAB_FILENAME = "dict.txt"
279
+ DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
280
+ DEFAULT_INPUT_CHANNELS = 1
281
+
282
+ def __init__(self, yaml_path: Path):
283
+ try:
284
+ import yaml
285
+ except ImportError:
286
+ print("Please install PyYAML for S2T data config YAML files")
287
+ self.yaml = yaml
288
+ self.yaml_path = yaml_path
289
+ self.config = {}
290
+
291
+ def flush(self):
292
+ with open(self.yaml_path, "w") as f:
293
+ self.yaml.dump(self.config, f)
294
+
295
+ def set_audio_root(self, audio_root=""):
296
+ self.config["audio_root"] = audio_root
297
+
298
+ def set_vocab_filename(self, vocab_filename: str = "dict.txt"):
299
+ self.config["vocab_filename"] = vocab_filename
300
+
301
+ def set_specaugment(
302
+ self,
303
+ time_wrap_w: int,
304
+ freq_mask_n: int,
305
+ freq_mask_f: int,
306
+ time_mask_n: int,
307
+ time_mask_t: int,
308
+ time_mask_p: float,
309
+ ):
310
+ self.config["specaugment"] = {
311
+ "time_wrap_W": time_wrap_w,
312
+ "freq_mask_N": freq_mask_n,
313
+ "freq_mask_F": freq_mask_f,
314
+ "time_mask_N": time_mask_n,
315
+ "time_mask_T": time_mask_t,
316
+ "time_mask_p": time_mask_p,
317
+ }
318
+
319
+ def set_specaugment_lb_policy(self):
320
+ self.set_specaugment(
321
+ time_wrap_w=0,
322
+ freq_mask_n=1,
323
+ freq_mask_f=27,
324
+ time_mask_n=1,
325
+ time_mask_t=100,
326
+ time_mask_p=1.0,
327
+ )
328
+
329
+ def set_specaugment_ld_policy(self):
330
+ self.set_specaugment(
331
+ time_wrap_w=0,
332
+ freq_mask_n=2,
333
+ freq_mask_f=27,
334
+ time_mask_n=2,
335
+ time_mask_t=100,
336
+ time_mask_p=1.0,
337
+ )
338
+
339
+ def set_specaugment_sm_policy(self):
340
+ self.set_specaugment(
341
+ time_wrap_w=0,
342
+ freq_mask_n=2,
343
+ freq_mask_f=15,
344
+ time_mask_n=2,
345
+ time_mask_t=70,
346
+ time_mask_p=0.2,
347
+ )
348
+
349
+ def set_specaugment_ss_policy(self):
350
+ self.set_specaugment(
351
+ time_wrap_w=0,
352
+ freq_mask_n=2,
353
+ freq_mask_f=27,
354
+ time_mask_n=2,
355
+ time_mask_t=70,
356
+ time_mask_p=0.2,
357
+ )
358
+
359
+ def set_input_channels(self, input_channels: int = 1):
360
+ self.config["input_channels"] = input_channels
361
+
362
+ def set_input_feat_per_channel(self, input_feat_per_channel: int = 80):
363
+ self.config["input_feat_per_channel"] = input_feat_per_channel
364
+
365
+ def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
366
+ self.config["bpe_tokenizer"] = bpe_tokenizer
367
+
368
+ def set_global_cmvn(self, stats_npz_path: str):
369
+ self.config["global_cmvn"] = {"stats_npz_path": stats_npz_path}
370
+
371
+ def set_feature_transforms(self, split: str, transforms: List[str]):
372
+ if "transforms" not in self.config:
373
+ self.config["transforms"] = {}
374
+ self.config["transforms"][split] = transforms
375
+
376
+ def set_prepend_tgt_lang_tag(self, flag: bool = True):
377
+ self.config["prepend_tgt_lang_tag"] = flag
378
+
379
+ def set_sampling_alpha(self, sampling_alpha: float = 1.0):
380
+ self.config["sampling_alpha"] = sampling_alpha
381
+
382
+ def set_extra(self, data):
383
+ self.config.update(data)
fairseq/examples/speech_to_text/docs/covost_example.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[Back]](..)
2
+
3
+ # S2T Example: ST on CoVoST
4
+
5
+ We replicate the experiments in
6
+ [CoVoST 2 and Massively Multilingual Speech-to-Text Translation (Wang et al., 2020)](https://arxiv.org/abs/2007.10310).
7
+
8
+ ## Data Preparation
9
+
10
+ [Download](https://commonvoice.mozilla.org/en/datasets) and unpack Common Voice v4 to a path
11
+ `${COVOST_ROOT}/${SOURCE_LANG_ID}`, then preprocess it with
12
+
13
+ ```bash
14
+ # additional Python packages for S2T data processing/model training
15
+ pip install pandas torchaudio sentencepiece
16
+
17
+ # En ASR
18
+ python examples/speech_to_text/prep_covost_data.py \
19
+ --data-root ${COVOST_ROOT} --vocab-type char --src-lang en
20
+ # ST
21
+ python examples/speech_to_text/prep_covost_data.py \
22
+ --data-root ${COVOST_ROOT} --vocab-type char \
23
+ --src-lang fr --tgt-lang en
24
+ ```
25
+
26
+ The generated files (manifest, features, vocabulary and data configuration) will be added to
27
+ `${COVOST_ROOT}/${SOURCE_LANG_ID}`.
28
+
29
+ Download our vocabulary files if you want to use our pre-trained models:
30
+
31
+ - ASR: [En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_vocab_char.zip)
32
+ - ST: [Fr-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_vocab_char.zip), [De-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_vocab_char.zip), [Es-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_vocab_char.zip), [Ca-En](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_vocab_char.zip), [En-De](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip), [En-Ca](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_vocab_char.zip), [En-Fa](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_vocab_char.zip), [En-Et](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_vocab_char.zip)
33
+
34
+ ## ASR
35
+
36
+ #### Training
37
+
38
+ We train an En ASR model for encoder pre-training some of the ST models.
39
+
40
+ ```bash
41
+ fairseq-train ${COVOST_ROOT}/en \
42
+ --config-yaml config_asr_en.yaml --train-subset train_asr_en --valid-subset dev_asr_en \
43
+ --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 50000 --max-update 60000 \
44
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
45
+ --report-accuracy --arch s2t_transformer_s --dropout 0.15 --optimizer adam --lr 2e-3 \
46
+ --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
47
+ --attn-type None --pos-enc-type ${POS_ENC_TYPE}
48
+ ```
49
+
50
+ where `ASR_SAVE_DIR` is the checkpoint root path and `POS_ENC_TYPE` refers to positional encoding to be used in the conformer encoder.
51
+ Set it to `abs`, `rope` or `rel_pos` to use the absolute positional encoding, rotary positional encoding or relative positional encoding in the conformer layer respectively.
52
+ Transformer encoder only supports absolute positional encoding and by default, the transformer encoder will be used.
53
+ To switch to conformer, set `--attn-type espnet` and `--POS_ENC_TYPE`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
54
+
55
+ #### Inference & Evaluation
56
+
57
+ ```bash
58
+ CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
59
+ python scripts/average_checkpoints.py \
60
+ --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \
61
+ --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
62
+ fairseq-generate ${COVOST_ROOT}/en \
63
+ --config-yaml config_asr_en.yaml --gen-subset test_asr_en --task speech_to_text \
64
+ --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
65
+ --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
66
+ ```
67
+
68
+ #### Results
69
+
70
+ | --arch | --pos-enc-type | Params | En | Model |
71
+ |---|---|---|---|---|
72
+ | s2t_transformer_s | - | 31M | 25.6 | [Download](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_asr_transformer_s.pt) |
73
+ | s2t_conformer | rel_pos | 42.9M | 23.18| [Download](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_asr/rel_pos_asr_checkpoint_best.pt) |
74
+ | s2t_conformer | rope | 42.1M | 23.8| [Download](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_asr/rope_pos_asr_checkpoint_best.pt) |
75
+ | s2t_conformer | abs | 42.1M | 23.8| [Download](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_asr/abs_asr_checkpoint_best.pt) |
76
+
77
+ ## ST
78
+
79
+ #### Training
80
+
81
+ Fr-En as example:
82
+
83
+ ```bash
84
+ fairseq-train ${COVOST_ROOT}/fr \
85
+ --config-yaml config_st_fr_en.yaml --train-subset train_st_fr_en --valid-subset dev_st_fr_en \
86
+ --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-update 30000 --max-tokens 40000 \ # --max-tokens 50000 for en-*
87
+ --task speech_to_text --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --report-accuracy \
88
+ --arch s2t_transformer_s --encoder-freezing-updates 1000 --optimizer adam --lr 2e-3 \
89
+ --lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
90
+ --attn-type None --pos-enc-type ${POS_ENC_TYPE} \
91
+ --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
92
+ ```
93
+
94
+ where `ST_SAVE_DIR` is the checkpoint root path and `POS_ENC_TYPE` refers to positional encoding to be used in the conformer encoder.
95
+ Set it to `abs`, `rope` or `rel_pos` to use the absolute positional encoding, rotary positional encoding or relative positional encoding in the conformer layer respectively.
96
+ Transformer encoder only supports absolute positional encoding and by default, the transformer encoder will be used.
97
+ To switch to conformer, set `--attn-type espnet` and `--POS_ENC_TYPE`. Optionally load the pre-trained En ASR encoder for faster training and better
98
+ performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
99
+ You may want to update it accordingly when using more than 1 GPU.
100
+
101
+ #### Inference & Evaluation
102
+
103
+ Average the last 10 checkpoints and evaluate on test split:
104
+
105
+ ```bash
106
+ CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
107
+ python scripts/average_checkpoints.py \
108
+ --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \
109
+ --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
110
+ fairseq-generate ${COVOST_ROOT}/fr \
111
+ --config-yaml config_st_fr_en.yaml --gen-subset test_st_fr_en --task speech_to_text \
112
+ --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \
113
+ --max-tokens 50000 --beam 5 --scoring sacrebleu
114
+ ```
115
+
116
+ ## Interactive Decoding
117
+
118
+ Launch the interactive console via
119
+
120
+ ```bash
121
+ fairseq-interactive ${COVOST_ROOT}/fr --config-yaml config_st_fr_en.yaml \
122
+ --task speech_to_text --path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \
123
+ --max-tokens 50000 --beam 5
124
+ ```
125
+
126
+ Type in WAV/FLAC/OGG audio paths (one per line) after the prompt.
127
+
128
+ #### Results
129
+
130
+ | --arch | --pos-enc-type | Params | ASR PT | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et | Model |
131
+ |---|---|---|---|---|---|---|---|---|---|---|---|---|
132
+ | s2t_transformer | - | 31M | Yes | [27.2](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_fr_en_st_transformer_s.pt) | [17.7](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_de_en_st_transformer_s.pt) | [23.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_es_en_st_transformer_s.pt) | [19.3](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_ca_en_st_transformer_s.pt) | [16.1](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_transformer_s.pt) | [21.6](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_ca_st_transformer_s.pt) | [12.9](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_fa_st_transformer_s.pt) | [12.8](https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_et_st_transformer_s.pt) | (<-Download) |
133
+ | s2t_conformer | rel_pos | 42.9M | No | [28.32](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [18.21](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [25.98](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [21.13](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [20.37](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [25.89](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [15.59](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | [14.49](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/rel_pos_from_scratch_avg_last_10_checkpoint.pt) | (<-Download) |
134
+ | s2t_conformer | rel_pos | 42.9M | Yes| [27.15](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [18.22](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [25.14](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [21.68](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [20.35](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [25.92](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [15.76](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | [16.52](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/rel_pos_asr_pt_avg_last_10_checkpoint.pt) | (<-Download) |
135
+ | s2t_conformer | rope | 42.1M | No | [27.61](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/rope_from_scratch_avg_last_10_checkpoint.pt) | [17.6](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/rope_from_scratch_avg_last_10_checkpoint.pt) | [24.91](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/rope_from_scratch_avg_last_10_checkpoint.pt) | [20.78](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/rope_from_scratch_avg_last_10_checkpoint.pt) | [19.7](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rope_from_scratch_avg_last_10_checkpoint.pt) | [25.13](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/rope_from_scratch_avg_last_10_checkpoint.pt) | [15.22](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/rope_from_scratch_avg_last_10_checkpoint.pt) | [15.87](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/rope_from_scratch_avg_last_10_checkpoint.pt) | (<-Download) |
136
+ | s2t_conformer | rope | 42.1M | Yes | [26.99](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/rope_asr_pt_avg_last_10_checkpoint.pt) | [17.71](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/rope_asr_pt_avg_last_10_checkpoint.pt) | [24.24](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/rope_asr_pt_avg_last_10_checkpoint.pt) | [21.24](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/rope_asr_pt_avg_last_10_checkpoint.pt) | [19.9](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rope_asr_pt_avg_last_10_checkpoint.pt) | [25.25](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/rope_asr_pt_avg_last_10_checkpoint.pt) | [15.58](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/rope_asr_pt_avg_last_10_checkpoint.pt) | [15.97](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/rope_asr_pt_avg_last_10_checkpoint.pt) | (<-Download) |
137
+ | s2t_conformer | abs | 42.1M | No | [27.45](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/abs_from_scratch_avg_last_10_checkpoint.pt) | [17.25](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/abs_from_scratch_avg_last_10_checkpoint.pt) | [25.01](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/abs_from_scratch_avg_last_10_checkpoint.pt) | [20.26](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/abs_from_scratch_avg_last_10_checkpoint.pt) | [19.86](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/abs_from_scratch_avg_last_10_checkpoint.pt) | [25.25](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/abs_from_scratch_avg_last_10_checkpoint.pt) | [15.46](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/abs_from_scratch_avg_last_10_checkpoint.pt) | [15.81](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/abs_from_scratch_avg_last_10_checkpoint.pt) | (<-Download) |
138
+ | s2t_conforme | abs | 42.1M | Yes| [26.52](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/fr_en/abs_asr_pt_avg_last_10_checkpoint.pt) | [17.37](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/de_en/abs_asr_pt_avg_last_10_checkpoint.pt) | [25.40](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/es_en/abs_asr_pt_avg_last_10_checkpoint.pt) | [20.45](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/ca_en/abs_asr_pt_avg_last_10_checkpoint.pt) | [19.57](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/abs_asr_pt_avg_last_10_checkpoint.pt) | [25.40](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_ca/abs_asr_pt_avg_last_10_checkpoint.pt) | [15.17](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_fa/abs_asr_pt_avg_last_10_checkpoint.pt) | [15.83](https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_et/abs_asr_pt_avg_last_10_checkpoint.pt) | (<-Download) |
139
+
140
+ [[Back]](..)