PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
010952f
·
verified ·
1 Parent(s): 6789f6f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fairseq/docs/fairseq.gif +3 -0
  3. fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh +17 -0
  4. fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh +26 -0
  5. fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh +15 -0
  6. fairseq/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh +20 -0
  7. fairseq/examples/data2vec/scripts/text/glue.py +34 -0
  8. fairseq/examples/data2vec/scripts/text/glue_lr.py +143 -0
  9. fairseq/examples/data2vec/tasks/audio_classification.py +167 -0
  10. fairseq/examples/data2vec/tasks/image_classification.py +129 -0
  11. fairseq/examples/data2vec/tasks/image_pretraining.py +110 -0
  12. fairseq/examples/data2vec/tasks/mae_image_pretraining.py +119 -0
  13. fairseq/examples/emotion_conversion/emotion_models/__init__.py +0 -0
  14. fairseq/examples/emotion_conversion/emotion_models/duration_predictor.py +243 -0
  15. fairseq/examples/emotion_conversion/emotion_models/duration_predictor.yaml +48 -0
  16. fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.py +559 -0
  17. fairseq/examples/emotion_conversion/emotion_models/utils.py +78 -0
  18. fairseq/examples/emotion_conversion/fairseq_models/__init__.py +226 -0
  19. fairseq/examples/emotion_conversion/preprocess/__init__.py +0 -0
  20. fairseq/examples/emotion_conversion/preprocess/build_hifigan_manifest.py +38 -0
  21. fairseq/examples/emotion_conversion/preprocess/build_translation_manifests.py +258 -0
  22. fairseq/examples/emotion_conversion/preprocess/create_core_manifest.py +91 -0
  23. fairseq/examples/emotion_conversion/preprocess/extract_f0.py +57 -0
  24. fairseq/examples/emotion_conversion/preprocess/process_km.py +40 -0
  25. fairseq/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py +70 -0
  26. fairseq/examples/emotion_conversion/preprocess/split_km.py +50 -0
  27. fairseq/examples/emotion_conversion/preprocess/split_km_tsv.py +65 -0
  28. fairseq/examples/fast_noisy_channel/README.md +345 -0
  29. fairseq/examples/fast_noisy_channel/__init__.py +8 -0
  30. fairseq/examples/fast_noisy_channel/noisy_channel_beam_search.py +71 -0
  31. fairseq/examples/fast_noisy_channel/noisy_channel_sequence_generator.py +842 -0
  32. fairseq/examples/fast_noisy_channel/noisy_channel_translation.py +127 -0
  33. fairseq/examples/flores101/README.md +223 -0
  34. fairseq/examples/flores101/flores_logo.png +0 -0
  35. fairseq/examples/fully_sharded_data_parallel/README.md +177 -0
  36. fairseq/examples/gottbert/README.md +64 -0
  37. fairseq/examples/hubert/README.md +116 -0
  38. fairseq/examples/hubert/config/decode/ax_sweep/ngram.yaml +33 -0
  39. fairseq/examples/hubert/config/decode/ax_sweep/transformer.yaml +33 -0
  40. fairseq/examples/hubert/config/decode/infer_fsqlm.yaml +36 -0
  41. fairseq/examples/hubert/config/decode/infer_kenlm.yaml +36 -0
  42. fairseq/examples/hubert/config/decode/infer_viterbi.yaml +29 -0
  43. fairseq/examples/hubert/config/decode/run/submitit_slurm.yaml +17 -0
  44. fairseq/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml +17 -0
  45. fairseq/examples/hubert/config/finetune/base_10h.yaml +100 -0
  46. fairseq/examples/hubert/config/finetune/ckpt/it1.yaml +7 -0
  47. fairseq/examples/hubert/config/finetune/lm/ls_4gram.yaml +7 -0
  48. fairseq/examples/hubert/config/finetune/run/submitit_reg.yaml +20 -0
  49. fairseq/examples/hubert/config/pretrain/data/iter1.yaml +8 -0
  50. fairseq/examples/hubert/config/pretrain/data/iter2.yaml +8 -0
.gitattributes CHANGED
@@ -37,3 +37,4 @@ fairseq/examples/MMPT/vlm.png filter=lfs diff=lfs merge=lfs -text
37
  fairseq/examples/MMPT/videoclip.png filter=lfs diff=lfs merge=lfs -text
38
  fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
39
  fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
37
  fairseq/examples/MMPT/videoclip.png filter=lfs diff=lfs merge=lfs -text
38
  fairseq/alignment_train_cuda_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
39
  fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
+ fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
fairseq/docs/fairseq.gif ADDED

Git LFS Details

  • SHA256: b7551e9682c816fca1fa00458f3b657177c8d90d2e87db31e42197cb3ae80fca
  • Pointer size: 132 Bytes
  • Size of remote file: 2.66 MB
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -eu
4
+
5
+ job_id="$1"
6
+ task_id="$2"
7
+ dir="$3"
8
+
9
+ echo "job_id: $job_id, task_id: $task_id, dir: $dir"
10
+
11
+ mkdir -p "$dir/log"
12
+ sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
13
+ sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
14
+ sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
15
+ sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
16
+
17
+ sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env zsh
2
+
3
+ dir="$1"
4
+ cp="$dir/checkpoints/checkpoint_last.pt"
5
+
6
+ echo "dir: $dir"
7
+
8
+ declare -A tasks
9
+ tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
10
+ tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
11
+ tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
12
+ tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
13
+ tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
14
+
15
+ lrs=(5e-6 8e-6 1e-5 2e-5)
16
+
17
+ for task data_path in ${(kv)tasks}; do
18
+ for lr in $lrs; do
19
+ echo $lr $task
20
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
21
+ python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
22
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
23
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" \
24
+ model._name=roberta_large
25
+ done
26
+ done
fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -eu
4
+
5
+ dir="$1"
6
+
7
+ echo "dir: $dir"
8
+
9
+ mkdir -p "$dir/log"
10
+ sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
11
+ sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
12
+ sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
13
+ sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
14
+
15
+ sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir
fairseq/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env zsh
2
+
3
+ dir="$1"
4
+ cp="$dir/checkpoints/checkpoint_last.pt"
5
+
6
+ echo "dir: $dir"
7
+
8
+ declare -A tasks
9
+ tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
10
+ tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
11
+
12
+ lrs="5e-6 1e-5 2e-5 5e-5 1e-4 2e-4 5e-4 1e-3"
13
+
14
+ for task data_path in ${(kv)tasks}; do
15
+ for lr in $(echo "$lrs"); do
16
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
17
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
18
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_sweep/$task/lr_$lr" "optimization.lr=[${lr}]" &
19
+ done
20
+ done
fairseq/examples/data2vec/scripts/text/glue.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from valids import parser, main as valids_main
2
+ import os.path as osp
3
+
4
+
5
+ args = parser.parse_args()
6
+ args.target = "valid_accuracy"
7
+ args.best_biggest = True
8
+ args.best = True
9
+ args.last = 0
10
+ args.path_contains = None
11
+
12
+ res = valids_main(args, print_output=False)
13
+
14
+ grouped = {}
15
+ for k, v in res.items():
16
+ k = osp.dirname(k)
17
+ run = osp.dirname(k)
18
+ task = osp.basename(k)
19
+ val = v["valid_accuracy"]
20
+
21
+ if run not in grouped:
22
+ grouped[run] = {}
23
+
24
+ grouped[run][task] = val
25
+
26
+ for run, tasks in grouped.items():
27
+ print(run)
28
+ avg = sum(float(v) for v in tasks.values()) / len(tasks)
29
+ avg_norte = sum(float(v) for k,v in tasks.items() if k != 'rte') / (len(tasks) -1)
30
+ try:
31
+ print(f"{tasks['cola']}\t{tasks['qnli']}\t{tasks['mrpc']}\t{tasks['rte']}\t{tasks['sst_2']}\t{avg:.2f}\t{avg_norte:.2f}")
32
+ except:
33
+ print(tasks)
34
+ print()
fairseq/examples/data2vec/scripts/text/glue_lr.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import re
3
+ from collections import defaultdict
4
+
5
+ from valids import parser, main as valids_main
6
+
7
+
8
+ TASK_TO_METRIC = {
9
+ "cola": "mcc",
10
+ "qnli": "accuracy",
11
+ "mrpc": "acc_and_f1",
12
+ "rte": "accuracy",
13
+ "sst_2": "accuracy",
14
+ "mnli": "accuracy",
15
+ "qqp": "acc_and_f1",
16
+ "sts_b": "pearson_and_spearman",
17
+ }
18
+ TASKS = ["cola", "qnli", "mrpc", "rte", "sst_2", "mnli", "qqp", "sts_b"]
19
+
20
+
21
+ def get_best_stat_str(task_vals, show_subdir):
22
+ task_to_best_val = {}
23
+ task_to_best_dir = {}
24
+ for task, subdir_to_val in task_vals.items():
25
+ task_to_best_val[task] = max(subdir_to_val.values())
26
+ task_to_best_dir[task] = max(subdir_to_val.keys(), key=lambda x: subdir_to_val[x])
27
+
28
+ # import pdb; pdb.set_trace()
29
+ N1 = len(task_to_best_val)
30
+ N2 = len([k for k in task_to_best_val if k != "rte"])
31
+ avg1 = sum(task_to_best_val.values()) / N1
32
+ avg2 = sum(v for task, v in task_to_best_val.items() if task != "rte") / N2
33
+
34
+ try:
35
+ msg = ""
36
+ for task in TASKS:
37
+ dir = task_to_best_dir.get(task, 'null')
38
+ val = task_to_best_val.get(task, -100)
39
+ msg += f"({dir}, {val})\t" if show_subdir else f"{val}\t"
40
+ msg += f"{avg1:.2f}\t{avg2:.2f}"
41
+ except Exception as e:
42
+ msg = str(e)
43
+ msg += str(sorted(task_vals.items()))
44
+ return msg
45
+
46
+ def get_all_stat_str(task_vals):
47
+ msg = ""
48
+ for task in [task for task in TASKS if task in task_vals]:
49
+ msg += f"=== {task}\n"
50
+ for subdir in sorted(task_vals[task].keys()):
51
+ msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
52
+ return msg
53
+
54
+ def get_tabular_stat_str(task_vals):
55
+ """assume subdir is <param>/run_*/0"""
56
+ msg = ""
57
+ for task in [task for task in TASKS if task in task_vals]:
58
+ msg += f"=== {task}\n"
59
+ param_to_runs = defaultdict(dict)
60
+ for subdir in task_vals[task]:
61
+ match = re.match("(.*)/(run_.*)/0", subdir)
62
+ assert match, "subdir"
63
+ param, run = match.groups()
64
+ param_to_runs[param][run] = task_vals[task][subdir]
65
+ params = sorted(param_to_runs, key=lambda x: float(x))
66
+ runs = sorted(set(run for runs in param_to_runs.values() for run in runs))
67
+ msg += ("runs:" + "\t".join(runs) + "\n")
68
+ msg += ("params:" + "\t".join(params) + "\n")
69
+ for param in params:
70
+ msg += "\t".join([str(param_to_runs[param].get(run, None)) for run in runs])
71
+ msg += "\n"
72
+ # for subdir in sorted(task_vals[task].keys()):
73
+ # msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
74
+ return msg
75
+
76
+
77
+
78
+ def main():
79
+ parser.add_argument("--show_glue", action="store_true", help="show glue metric for each task instead of accuracy")
80
+ parser.add_argument("--print_mode", default="best", help="best|all|tabular")
81
+ parser.add_argument("--show_subdir", action="store_true", help="print the subdir that has the best results for each run")
82
+ parser.add_argument("--override_target", default="valid_accuracy", help="override target")
83
+
84
+ args = parser.parse_args()
85
+ args.target = args.override_target
86
+ args.best_biggest = True
87
+ args.best = True
88
+ args.last = 0
89
+ args.path_contains = None
90
+
91
+ res = valids_main(args, print_output=False)
92
+ grouped_acc = {}
93
+ grouped_met = {} # use official metric for each task
94
+ for path, v in res.items():
95
+ path = "/".join([args.base, path])
96
+ path = re.sub("//*", "/", path)
97
+ match = re.match("(.*)finetune[^/]*/([^/]*)/(.*)", path)
98
+ if not match:
99
+ continue
100
+ run, task, subdir = match.groups()
101
+
102
+ if run not in grouped_acc:
103
+ grouped_acc[run] = {}
104
+ grouped_met[run] = {}
105
+ if task not in grouped_acc[run]:
106
+ grouped_acc[run][task] = {}
107
+ grouped_met[run][task] = {}
108
+
109
+ if v is not None:
110
+ grouped_acc[run][task][subdir] = float(v.get("valid_accuracy", -100))
111
+ grouped_met[run][task][subdir] = float(v.get(f"valid_{TASK_TO_METRIC[task]}", -100))
112
+ else:
113
+ print(f"{path} has None return")
114
+
115
+ header = "\t".join(TASKS)
116
+ for run in sorted(grouped_acc):
117
+ print(run)
118
+ if args.print_mode == "all":
119
+ if args.show_glue:
120
+ print("===== GLUE =====")
121
+ print(get_all_stat_str(grouped_met[run]))
122
+ else:
123
+ print("===== ACC =====")
124
+ print(get_all_stat_str(grouped_acc[run]))
125
+ elif args.print_mode == "best":
126
+ print(f" {header}")
127
+ if args.show_glue:
128
+ print(f"GLEU: {get_best_stat_str(grouped_met[run], args.show_subdir)}")
129
+ else:
130
+ print(f"ACC: {get_best_stat_str(grouped_acc[run], args.show_subdir)}")
131
+ elif args.print_mode == "tabular":
132
+ if args.show_glue:
133
+ print("===== GLUE =====")
134
+ print(get_tabular_stat_str(grouped_met[run]))
135
+ else:
136
+ print("===== ACC =====")
137
+ print(get_tabular_stat_str(grouped_acc[run]))
138
+ else:
139
+ raise ValueError(args.print_mode)
140
+ print()
141
+
142
+ if __name__ == "__main__":
143
+ main()
fairseq/examples/data2vec/tasks/audio_classification.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import logging
9
+ import os
10
+ import numpy as np
11
+ import math
12
+ import torch
13
+
14
+ from sklearn import metrics as sklearn_metrics
15
+ from dataclasses import dataclass
16
+
17
+ from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
18
+ from fairseq.tasks import register_task
19
+ from fairseq.logging import metrics
20
+
21
+ from ..data.add_class_target_dataset import AddClassTargetDataset
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class AudioClassificationConfig(AudioPretrainingConfig):
29
+ label_descriptors: str = "label_descriptors.csv"
30
+ labels: str = "lbl"
31
+
32
+
33
+ @register_task("audio_classification", dataclass=AudioClassificationConfig)
34
+ class AudioClassificationTask(AudioPretrainingTask):
35
+ """ """
36
+
37
+ cfg: AudioClassificationConfig
38
+
39
+ def __init__(
40
+ self,
41
+ cfg: AudioClassificationConfig,
42
+ ):
43
+ super().__init__(cfg)
44
+
45
+ self.state.add_factory("labels", self.load_labels)
46
+
47
+ def load_labels(self):
48
+ labels = {}
49
+ path = os.path.join(self.cfg.data, self.cfg.label_descriptors)
50
+ with open(path, "r") as ldf:
51
+ for line in ldf:
52
+ if line.strip() == "":
53
+ continue
54
+ items = line.split(",")
55
+ idx = items[0]
56
+ lbl = items[1]
57
+ assert lbl not in labels, lbl
58
+ labels[lbl] = idx
59
+ return labels
60
+
61
+ @property
62
+ def labels(self):
63
+ return self.state.labels
64
+
65
+ def load_dataset(
66
+ self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs
67
+ ):
68
+ super().load_dataset(split, task_cfg, **kwargs)
69
+
70
+ task_cfg = task_cfg or self.cfg
71
+
72
+ data_path = self.cfg.data
73
+ label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
74
+ skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
75
+ labels = []
76
+ with open(label_path, "r") as f:
77
+ for i, line in enumerate(f):
78
+ if i not in skipped_indices:
79
+ lbl_items = line.rstrip().split("\t")
80
+ labels.append([int(x) for x in lbl_items[2].split(",")])
81
+
82
+ assert len(labels) == len(self.datasets[split]), (
83
+ f"labels length ({len(labels)}) and dataset length "
84
+ f"({len(self.datasets[split])}) do not match"
85
+ )
86
+
87
+ self.datasets[split] = AddClassTargetDataset(
88
+ self.datasets[split],
89
+ labels,
90
+ multi_class=True,
91
+ add_to_input=True,
92
+ num_classes=len(self.labels),
93
+ )
94
+
95
+ def calculate_stats(self, output, target):
96
+
97
+ classes_num = target.shape[-1]
98
+ stats = []
99
+
100
+ # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
101
+ # acc = sklearn_metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
102
+
103
+ # Class-wise statistics
104
+ for k in range(classes_num):
105
+ # Average precision
106
+ avg_precision = sklearn_metrics.average_precision_score(
107
+ target[:, k], output[:, k], average=None
108
+ )
109
+
110
+ dict = {
111
+ "AP": avg_precision,
112
+ }
113
+
114
+ # # AUC
115
+ # try:
116
+ # auc = sklearn_metrics.roc_auc_score(target[:, k], output[:, k], average=None)
117
+ # except:
118
+ # auc = 0
119
+ #
120
+ # # Precisions, recalls
121
+ # (precisions, recalls, thresholds) = sklearn_metrics.precision_recall_curve(
122
+ # target[:, k], output[:, k]
123
+ # )
124
+ #
125
+ # # FPR, TPR
126
+ # (fpr, tpr, thresholds) = sklearn_metrics.roc_curve(target[:, k], output[:, k])
127
+ #
128
+ # save_every_steps = 1000 # Sample statistics to reduce size
129
+ # dict = {
130
+ # "precisions": precisions[0::save_every_steps],
131
+ # "recalls": recalls[0::save_every_steps],
132
+ # "AP": avg_precision,
133
+ # "fpr": fpr[0::save_every_steps],
134
+ # "fnr": 1.0 - tpr[0::save_every_steps],
135
+ # "auc": auc,
136
+ # # note acc is not class-wise, this is just to keep consistent with other metrics
137
+ # "acc": acc,
138
+ # }
139
+ stats.append(dict)
140
+
141
+ return stats
142
+
143
+ def valid_step(self, sample, model, criterion):
144
+ loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
145
+ return loss, sample_size, logging_output
146
+
147
+ def reduce_metrics(self, logging_outputs, criterion):
148
+ super().reduce_metrics(logging_outputs, criterion)
149
+ if "_predictions" in logging_outputs[0]:
150
+ metrics.log_concat_tensor(
151
+ "_predictions",
152
+ torch.cat([l["_predictions"].cpu() for l in logging_outputs], dim=0),
153
+ )
154
+ metrics.log_concat_tensor(
155
+ "_targets",
156
+ torch.cat([l["_targets"].cpu() for l in logging_outputs], dim=0),
157
+ )
158
+
159
+ def compute_stats(meters):
160
+ if meters["_predictions"].tensor.shape[0] < 100:
161
+ return 0
162
+ stats = self.calculate_stats(
163
+ meters["_predictions"].tensor, meters["_targets"].tensor
164
+ )
165
+ return np.nanmean([stat["AP"] for stat in stats])
166
+
167
+ metrics.log_derived("mAP", compute_stats)
fairseq/examples/data2vec/tasks/image_classification.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import os.path as osp
9
+ import logging
10
+
11
+ from dataclasses import dataclass
12
+ import torch
13
+ from torchvision import transforms
14
+
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from fairseq.tasks import register_task
17
+ from fairseq.logging import metrics
18
+
19
+ try:
20
+ from ..data import ImageDataset
21
+ except:
22
+ import sys
23
+
24
+ sys.path.append("..")
25
+ from data import ImageDataset
26
+
27
+ from .image_pretraining import (
28
+ ImagePretrainingConfig,
29
+ ImagePretrainingTask,
30
+ IMG_EXTENSIONS,
31
+ )
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class ImageClassificationConfig(ImagePretrainingConfig):
38
+ pass
39
+
40
+
41
+ @register_task("image_classification", dataclass=ImageClassificationConfig)
42
+ class ImageClassificationTask(ImagePretrainingTask):
43
+
44
+ cfg: ImageClassificationConfig
45
+
46
+ @classmethod
47
+ def setup_task(cls, cfg: ImageClassificationConfig, **kwargs):
48
+ return cls(cfg)
49
+
50
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
51
+ data_path = self.cfg.data
52
+ cfg = task_cfg or self.cfg
53
+
54
+ path_with_split = osp.join(data_path, split)
55
+ if osp.exists(path_with_split):
56
+ data_path = path_with_split
57
+
58
+ from timm.data import create_transform
59
+
60
+ if split == "train":
61
+ # this should always dispatch to transforms_imagenet_train
62
+ transform = create_transform(
63
+ input_size=cfg.input_size,
64
+ is_training=True,
65
+ auto_augment="rand-m9-mstd0.5-inc1",
66
+ interpolation="bicubic",
67
+ re_prob=0.25,
68
+ re_mode="pixel",
69
+ re_count=1,
70
+ mean=cfg.normalization_mean,
71
+ std=cfg.normalization_std,
72
+ )
73
+ if not cfg.input_size > 32:
74
+ transform.transforms[0] = transforms.RandomCrop(
75
+ cfg.input_size, padding=4
76
+ )
77
+ else:
78
+ t = []
79
+ if cfg.input_size > 32:
80
+ crop_pct = 1
81
+ if cfg.input_size < 384:
82
+ crop_pct = 224 / 256
83
+ size = int(cfg.input_size / crop_pct)
84
+ t.append(
85
+ transforms.Resize(
86
+ size, interpolation=3
87
+ ), # to maintain same ratio w.r.t. 224 images
88
+ )
89
+ t.append(transforms.CenterCrop(cfg.input_size))
90
+
91
+ t.append(transforms.ToTensor())
92
+ t.append(
93
+ transforms.Normalize(cfg.normalization_mean, cfg.normalization_std)
94
+ )
95
+ transform = transforms.Compose(t)
96
+ logger.info(transform)
97
+
98
+ self.datasets[split] = ImageDataset(
99
+ root=data_path,
100
+ extensions=IMG_EXTENSIONS,
101
+ load_classes=True,
102
+ transform=transform,
103
+ )
104
+ for k in self.datasets.keys():
105
+ if k != split:
106
+ assert self.datasets[k].classes == self.datasets[split].classes
107
+
108
+ def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
109
+ model = super().build_model(model_cfg, from_checkpoint)
110
+
111
+ actualized_cfg = getattr(model, "cfg", None)
112
+ if actualized_cfg is not None:
113
+ if hasattr(actualized_cfg, "pretrained_model_args"):
114
+ model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
115
+
116
+ return model
117
+
118
+ def reduce_metrics(self, logging_outputs, criterion):
119
+ super().reduce_metrics(logging_outputs, criterion)
120
+
121
+ if "correct" in logging_outputs[0]:
122
+ zero = torch.scalar_tensor(0.0)
123
+ correct = sum(log.get("correct", zero) for log in logging_outputs)
124
+ metrics.log_scalar_sum("_correct", correct)
125
+
126
+ metrics.log_derived(
127
+ "accuracy",
128
+ lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
129
+ )
fairseq/examples/data2vec/tasks/image_pretraining.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import logging
9
+ import sys
10
+ import os.path as osp
11
+
12
+ from dataclasses import dataclass, field
13
+ from typing import List
14
+ from omegaconf import MISSING
15
+
16
+ import torch
17
+ from torchvision import transforms
18
+
19
+ from fairseq.dataclass import FairseqDataclass
20
+ from fairseq.tasks import FairseqTask, register_task
21
+
22
+ try:
23
+ from ..data import ImageDataset
24
+ except:
25
+ sys.path.append("..")
26
+ from data import ImageDataset
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ IMG_EXTENSIONS = {
31
+ ".jpg",
32
+ ".jpeg",
33
+ ".png",
34
+ ".ppm",
35
+ ".bmp",
36
+ ".pgm",
37
+ ".tif",
38
+ ".tiff",
39
+ ".webp",
40
+ }
41
+
42
+
43
+ @dataclass
44
+ class ImagePretrainingConfig(FairseqDataclass):
45
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
46
+ input_size: int = 224
47
+ normalization_mean: List[float] = (0.485, 0.456, 0.406)
48
+ normalization_std: List[float] = (0.229, 0.224, 0.225)
49
+
50
+
51
+ @register_task("image_pretraining", dataclass=ImagePretrainingConfig)
52
+ class ImagePretrainingTask(FairseqTask):
53
+ """ """
54
+
55
+ cfg: ImagePretrainingConfig
56
+
57
+ @classmethod
58
+ def setup_task(cls, cfg: ImagePretrainingConfig, **kwargs):
59
+ """Setup the task (e.g., load dictionaries).
60
+
61
+ Args:
62
+ cfg (AudioPretrainingConfig): configuration of this task
63
+ """
64
+
65
+ return cls(cfg)
66
+
67
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
68
+ data_path = self.cfg.data
69
+ cfg = task_cfg or self.cfg
70
+
71
+ path_with_split = osp.join(data_path, split)
72
+ if osp.exists(path_with_split):
73
+ data_path = path_with_split
74
+
75
+ transform = transforms.Compose(
76
+ [
77
+ transforms.ColorJitter(0.4, 0.4, 0.4),
78
+ transforms.RandomHorizontalFlip(p=0.5),
79
+ transforms.RandomResizedCrop(
80
+ size=cfg.input_size,
81
+ interpolation=transforms.InterpolationMode.BICUBIC,
82
+ ),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize(
85
+ mean=torch.tensor(cfg.normalization_mean),
86
+ std=torch.tensor(cfg.normalization_std),
87
+ ),
88
+ ]
89
+ )
90
+
91
+ logger.info(transform)
92
+
93
+ self.datasets[split] = ImageDataset(
94
+ root=data_path,
95
+ extensions=IMG_EXTENSIONS,
96
+ load_classes=False,
97
+ transform=transform,
98
+ )
99
+
100
+ @property
101
+ def source_dictionary(self):
102
+ return None
103
+
104
+ @property
105
+ def target_dictionary(self):
106
+ return None
107
+
108
+ def max_positions(self):
109
+ """Maximum input length supported by the encoder."""
110
+ return sys.maxsize, sys.maxsize
fairseq/examples/data2vec/tasks/mae_image_pretraining.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import logging
9
+ import sys
10
+
11
+ from typing import Optional, List
12
+ from dataclasses import dataclass, field
13
+ from omegaconf import MISSING, II
14
+
15
+ from fairseq.data import SubsampleDataset
16
+ from fairseq.dataclass import FairseqDataclass
17
+ from fairseq.tasks import FairseqTask, register_task
18
+
19
+ try:
20
+ from ..data import MaeImageDataset
21
+ except:
22
+ sys.path.append("..")
23
+ from data import MaeImageDataset
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class ImageMaskingConfig:
30
+ patch_size: int = II("model.modalities.image.patch_size")
31
+ mask_prob: float = II("model.modalities.image.mask_prob")
32
+ mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust")
33
+ mask_length: int = II("model.modalities.image.mask_length")
34
+ inverse_mask: bool = II("model.modalities.image.inverse_mask")
35
+ mask_dropout: float = II("model.modalities.image.mask_dropout")
36
+ clone_batch: int = II("model.clone_batch")
37
+ expand_adjacent: bool = False
38
+ non_overlapping: bool = False
39
+
40
+
41
+ @dataclass
42
+ class MaeImagePretrainingConfig(FairseqDataclass):
43
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
44
+ multi_data: Optional[List[str]] = None
45
+ input_size: int = 224
46
+ local_cache_path: Optional[str] = None
47
+ key: str = "imgs"
48
+
49
+ beit_transforms: bool = False
50
+ target_transform: bool = False
51
+ no_transform: bool = False
52
+
53
+ rebuild_batches: bool = True
54
+
55
+ precompute_mask_config: Optional[ImageMaskingConfig] = None
56
+
57
+ subsample: float = 1
58
+ seed: int = II("common.seed")
59
+ dataset_type: str = "imagefolder"
60
+
61
+
62
+ @register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig)
63
+ class MaeImagePretrainingTask(FairseqTask):
64
+ """ """
65
+
66
+ cfg: MaeImagePretrainingConfig
67
+
68
+ @classmethod
69
+ def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs):
70
+ """Setup the task (e.g., load dictionaries).
71
+
72
+ Args:
73
+ cfg (AudioPretrainingConfig): configuration of this task
74
+ """
75
+
76
+ return cls(cfg)
77
+
78
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
79
+ data_path = self.cfg.data
80
+ cfg = task_cfg or self.cfg
81
+
82
+ compute_mask = cfg.precompute_mask_config is not None
83
+ mask_args = {}
84
+ if compute_mask:
85
+ mask_args = cfg.precompute_mask_config
86
+
87
+ self.datasets[split] = MaeImageDataset(
88
+ root=data_path if cfg.multi_data is None else cfg.multi_data,
89
+ split=split,
90
+ input_size=cfg.input_size,
91
+ local_cache_path=cfg.local_cache_path,
92
+ key=cfg.key,
93
+ beit_transforms=cfg.beit_transforms,
94
+ target_transform=cfg.target_transform,
95
+ no_transform=cfg.no_transform,
96
+ compute_mask=compute_mask,
97
+ dataset_type=cfg.dataset_type,
98
+ **mask_args,
99
+ )
100
+
101
+ if cfg.subsample < 1:
102
+ self.datasets[split] = SubsampleDataset(
103
+ self.datasets[split],
104
+ cfg.subsample,
105
+ shuffle=True,
106
+ seed=cfg.seed,
107
+ )
108
+
109
+ @property
110
+ def source_dictionary(self):
111
+ return None
112
+
113
+ @property
114
+ def target_dictionary(self):
115
+ return None
116
+
117
+ def max_positions(self):
118
+ """Maximum input length supported by the encoder."""
119
+ return sys.maxsize, sys.maxsize
fairseq/examples/emotion_conversion/emotion_models/__init__.py ADDED
File without changes
fairseq/examples/emotion_conversion/emotion_models/duration_predictor.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import hydra
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops.layers.torch import Rearrange
9
+ from torch.utils.data import DataLoader, Dataset
10
+
11
+ from .utils import Accuracy
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def save_ckpt(model, path, model_class):
17
+ ckpt = {
18
+ "state_dict": model.state_dict(),
19
+ "padding_token": model.padding_token,
20
+ "model_class": model_class,
21
+ }
22
+ torch.save(ckpt, path)
23
+
24
+
25
+ def load_ckpt(path):
26
+ ckpt = torch.load(path)
27
+ ckpt["model_class"]["_target_"] = "emotion_models.duration_predictor.CnnPredictor"
28
+ model = hydra.utils.instantiate(ckpt["model_class"])
29
+ model.load_state_dict(ckpt["state_dict"])
30
+ model.padding_token = ckpt["padding_token"]
31
+ model = model.cpu()
32
+ model.eval()
33
+ return model
34
+
35
+
36
+ class Collator:
37
+ def __init__(self, padding_idx):
38
+ self.padding_idx = padding_idx
39
+
40
+ def __call__(self, batch):
41
+ x = [item[0] for item in batch]
42
+ lengths = [len(item) for item in x]
43
+ x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.padding_idx)
44
+ y = [item[1] for item in batch]
45
+ y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=self.padding_idx)
46
+ mask = (x != self.padding_idx)
47
+ return x, y, mask, lengths
48
+
49
+
50
+ class Predictor(nn.Module):
51
+ def __init__(self, n_tokens, emb_dim):
52
+ super(Predictor, self).__init__()
53
+ self.n_tokens = n_tokens
54
+ self.emb_dim = emb_dim
55
+ self.padding_token = n_tokens
56
+ # add 1 extra embedding for padding token, set the padding index to be the last token
57
+ # (tokens from the clustering start at index 0)
58
+ self.emb = nn.Embedding(n_tokens + 1, emb_dim, padding_idx=self.padding_token)
59
+
60
+ def inflate_input(self, batch):
61
+ """ get a sequence of tokens, predict their durations
62
+ and inflate them accordingly """
63
+ batch_durs = self.forward(batch)
64
+ batch_durs = torch.exp(batch_durs) - 1
65
+ batch_durs = batch_durs.round()
66
+ output = []
67
+ for seq, durs in zip(batch, batch_durs):
68
+ inflated_seq = []
69
+ for token, n in zip(seq, durs):
70
+ if token == self.padding_token:
71
+ break
72
+ n = int(n.item())
73
+ token = int(token.item())
74
+ inflated_seq.extend([token for _ in range(n)])
75
+ output.append(inflated_seq)
76
+ output = torch.LongTensor(output)
77
+ return output
78
+
79
+
80
+ class CnnPredictor(Predictor):
81
+ def __init__(self, n_tokens, emb_dim, channels, kernel, output_dim, dropout, n_layers):
82
+ super(CnnPredictor, self).__init__(n_tokens=n_tokens, emb_dim=emb_dim)
83
+ layers = [
84
+ Rearrange("b t c -> b c t"),
85
+ nn.Conv1d(emb_dim, channels, kernel_size=kernel, padding=(kernel - 1) // 2),
86
+ Rearrange("b c t -> b t c"),
87
+ nn.ReLU(),
88
+ nn.LayerNorm(channels),
89
+ nn.Dropout(dropout),
90
+ ]
91
+ for _ in range(n_layers-1):
92
+ layers += [
93
+ Rearrange("b t c -> b c t"),
94
+ nn.Conv1d(channels, channels, kernel_size=kernel, padding=(kernel - 1) // 2),
95
+ Rearrange("b c t -> b t c"),
96
+ nn.ReLU(),
97
+ nn.LayerNorm(channels),
98
+ nn.Dropout(dropout),
99
+ ]
100
+ self.conv_layer = nn.Sequential(*layers)
101
+ self.proj = nn.Linear(channels, output_dim)
102
+
103
+ def forward(self, x):
104
+ x = self.emb(x)
105
+ x = self.conv_layer(x)
106
+ x = self.proj(x)
107
+ x = x.squeeze(-1)
108
+ return x
109
+
110
+
111
+ def l2_log_loss(input, target):
112
+ return F.mse_loss(
113
+ input=input.float(),
114
+ target=torch.log(target.float() + 1),
115
+ reduce=False
116
+ )
117
+
118
+
119
+ class DurationDataset(Dataset):
120
+ def __init__(self, tsv_path, km_path, substring=""):
121
+ lines = open(tsv_path, "r").readlines()
122
+ self.root, self.tsv = lines[0], lines[1:]
123
+ self.km = open(km_path, "r").readlines()
124
+ logger.info(f"loaded {len(self.km)} files")
125
+
126
+ if substring != "":
127
+ tsv, km = [], []
128
+ for tsv_line, km_line in zip(self.tsv, self.km):
129
+ if substring.lower() in tsv_line.lower():
130
+ tsv.append(tsv_line)
131
+ km.append(km_line)
132
+ self.tsv, self.km = tsv, km
133
+ logger.info(f"after filtering: {len(self.km)} files")
134
+
135
+ def __len__(self):
136
+ return len(self.km)
137
+
138
+ def __getitem__(self, i):
139
+ x = self.km[i]
140
+ x = x.split(" ")
141
+ x = list(map(int, x))
142
+
143
+ y = []
144
+ xd = []
145
+ count = 1
146
+ for x1, x2 in zip(x[:-1], x[1:]):
147
+ if x1 == x2:
148
+ count += 1
149
+ continue
150
+ else:
151
+ y.append(count)
152
+ xd.append(x1)
153
+ count = 1
154
+
155
+ xd = torch.LongTensor(xd)
156
+ y = torch.LongTensor(y)
157
+ return xd, y
158
+
159
+
160
+ def train(cfg):
161
+ device = "cuda:0"
162
+ model = hydra.utils.instantiate(cfg[cfg.model]).to(device)
163
+ optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())
164
+ # add 1 extra embedding for padding token, set the padding index to be the last token
165
+ # (tokens from the clustering start at index 0)
166
+ collate_fn = Collator(padding_idx=model.padding_token)
167
+ logger.info(f"data: {cfg.train_tsv}")
168
+ train_ds = DurationDataset(cfg.train_tsv, cfg.train_km, substring=cfg.substring)
169
+ valid_ds = DurationDataset(cfg.valid_tsv, cfg.valid_km, substring=cfg.substring)
170
+ train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)
171
+ valid_dl = DataLoader(valid_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)
172
+
173
+ best_loss = float("inf")
174
+ for epoch in range(cfg.epochs):
175
+ train_loss, train_loss_scaled = train_epoch(model, train_dl, l2_log_loss, optimizer, device)
176
+ valid_loss, valid_loss_scaled, *acc = valid_epoch(model, valid_dl, l2_log_loss, device)
177
+ acc0, acc1, acc2, acc3 = acc
178
+ if valid_loss_scaled < best_loss:
179
+ path = f"{os.getcwd()}/{cfg.substring}.ckpt"
180
+ save_ckpt(model, path, cfg[cfg.model])
181
+ best_loss = valid_loss_scaled
182
+ logger.info(f"saved checkpoint: {path}")
183
+ logger.info(f"[epoch {epoch}] train loss: {train_loss:.3f}, train scaled: {train_loss_scaled:.3f}")
184
+ logger.info(f"[epoch {epoch}] valid loss: {valid_loss:.3f}, valid scaled: {valid_loss_scaled:.3f}")
185
+ logger.info(f"acc: {acc0,acc1,acc2,acc3}")
186
+
187
+
188
+ def train_epoch(model, loader, criterion, optimizer, device):
189
+ model.train()
190
+ epoch_loss = 0
191
+ epoch_loss_scaled = 0
192
+ for x, y, mask, _ in loader:
193
+ x, y, mask = x.to(device), y.to(device), mask.to(device)
194
+ yhat = model(x)
195
+ loss = criterion(yhat, y) * mask
196
+ loss = torch.mean(loss)
197
+ loss.backward()
198
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
199
+ optimizer.step()
200
+ epoch_loss += loss.item()
201
+ # get normal scale loss
202
+ yhat_scaled = torch.exp(yhat) - 1
203
+ yhat_scaled = torch.round(yhat_scaled)
204
+ scaled_loss = torch.mean(torch.abs(yhat_scaled - y) * mask)
205
+ epoch_loss_scaled += scaled_loss.item()
206
+ return epoch_loss / len(loader), epoch_loss_scaled / len(loader)
207
+
208
+
209
+ def valid_epoch(model, loader, criterion, device):
210
+ model.eval()
211
+ epoch_loss = 0
212
+ epoch_loss_scaled = 0
213
+ acc = Accuracy()
214
+ for x, y, mask, _ in loader:
215
+ x, y, mask = x.to(device), y.to(device), mask.to(device)
216
+ yhat = model(x)
217
+ loss = criterion(yhat, y) * mask
218
+ loss = torch.mean(loss)
219
+ epoch_loss += loss.item()
220
+ # get normal scale loss
221
+ yhat_scaled = torch.exp(yhat) - 1
222
+ yhat_scaled = torch.round(yhat_scaled)
223
+ scaled_loss = torch.sum(torch.abs(yhat_scaled - y) * mask) / mask.sum()
224
+ acc.update(yhat_scaled[mask].view(-1).float(), y[mask].view(-1).float())
225
+ epoch_loss_scaled += scaled_loss.item()
226
+ logger.info(f"example y: {y[0, :10].tolist()}")
227
+ logger.info(f"example yhat: {yhat_scaled[0, :10].tolist()}")
228
+ acc0 = acc.acc(tol=0)
229
+ acc1 = acc.acc(tol=1)
230
+ acc2 = acc.acc(tol=2)
231
+ acc3 = acc.acc(tol=3)
232
+ logger.info(f"accs: {acc0,acc1,acc2,acc3}")
233
+ return epoch_loss / len(loader), epoch_loss_scaled / len(loader), acc0, acc1, acc2, acc3
234
+
235
+
236
+ @hydra.main(config_path=".", config_name="duration_predictor.yaml")
237
+ def main(cfg):
238
+ logger.info(f"{cfg}")
239
+ train(cfg)
240
+
241
+
242
+ if __name__ == "__main__":
243
+ main()
fairseq/examples/emotion_conversion/emotion_models/duration_predictor.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_tsv: "<your-processed-data>/denoising/emov/train.tsv"
2
+ train_km: "<your-processed-data>/denoising/emov/train.km"
3
+ valid_tsv: "<your-processed-data>/denoising/emov/valid.tsv"
4
+ valid_km: "<your-processed-data>/denoising/emov/valid.km"
5
+
6
+ n_tokens: 200
7
+ batch_size: 32
8
+ lr: 0.0001
9
+ epochs: 300
10
+ model: "cnn"
11
+ substring: ""
12
+
13
+ rnn:
14
+ _target_: emotion_models.duration_predictor.RnnPredictor
15
+ n_tokens: ${n_tokens}
16
+ emb_dim: 128
17
+ rnn_hidden: 128
18
+ output_dim: 1
19
+ dropout: 0
20
+ n_layers: 1
21
+
22
+ optimizer:
23
+ _target_: torch.optim.Adam
24
+ lr: ${lr}
25
+ betas: [0.9, 0.98]
26
+ eps: 0.000000001
27
+ weight_decay: 0
28
+
29
+ cnn:
30
+ _target_: emotion_models.duration_predictor.CnnPredictor
31
+ n_tokens: ${n_tokens}
32
+ emb_dim: 128
33
+ channels: 256
34
+ kernel: 3
35
+ output_dim: 1
36
+ dropout: 0.5
37
+ n_layers: 1
38
+
39
+ hydra:
40
+ run:
41
+ dir: /checkpoint/felixkreuk/experiments/duration_predictor/${hydra.job.override_dirname}
42
+ job:
43
+ config:
44
+ # configuration for the ${hydra.job.override_dirname} runtime variable
45
+ override_dirname:
46
+ kv_sep: '='
47
+ item_sep: ','
48
+ exclude_keys: ['train_tsv', 'train_km', 'valid_tsv', 'valid_km']
fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import sys
5
+ from collections import defaultdict
6
+
7
+ import hydra
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from einops.layers.torch import Rearrange
14
+ from scipy.io.wavfile import read
15
+ from scipy.ndimage import gaussian_filter1d
16
+ from torch.utils.data import DataLoader, Dataset
17
+ from tqdm import tqdm
18
+
19
+ dir_path = os.path.dirname(__file__)
20
+ resynth_path = os.path.dirname(dir_path) + "/speech-resynthesis"
21
+ sys.path.append(resynth_path)
22
+ from dataset import parse_speaker, parse_style
23
+ from .utils import F0Stat
24
+
25
+ MAX_WAV_VALUE = 32768.0
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def quantize_f0(speaker_to_f0, nbins, normalize, log):
30
+ f0_all = []
31
+ for speaker, f0 in speaker_to_f0.items():
32
+ f0 = f0.raw_data
33
+ if log:
34
+ f0 = f0.log()
35
+ mean = speaker_to_f0[speaker].mean_log if log else speaker_to_f0[speaker].mean
36
+ std = speaker_to_f0[speaker].std_log if log else speaker_to_f0[speaker].std
37
+ if normalize == "mean":
38
+ f0 = f0 - mean
39
+ elif normalize == "meanstd":
40
+ f0 = (f0 - mean) / std
41
+ f0_all.extend(f0.tolist())
42
+
43
+ hist, bin_x = np.histogram(f0_all, 100000)
44
+ cum_hist = np.cumsum(hist) / len(f0_all) * 100
45
+
46
+ bin_offset = []
47
+ bin_size = 100 / nbins
48
+ threshold = bin_size
49
+ for i in range(nbins - 1):
50
+ index = (np.abs(cum_hist - threshold)).argmin()
51
+ bin_offset.append(bin_x[index])
52
+ threshold += bin_size
53
+ bins = np.array(bin_offset)
54
+ bins = torch.FloatTensor(bins)
55
+
56
+ return bins
57
+
58
+
59
+ def save_ckpt(model, path, model_class, f0_min, f0_max, f0_bins, speaker_stats):
60
+ ckpt = {
61
+ "state_dict": model.state_dict(),
62
+ "padding_token": model.padding_token,
63
+ "model_class": model_class,
64
+ "speaker_stats": speaker_stats,
65
+ "f0_min": f0_min,
66
+ "f0_max": f0_max,
67
+ "f0_bins": f0_bins,
68
+ }
69
+ torch.save(ckpt, path)
70
+
71
+
72
+ def load_ckpt(path):
73
+ ckpt = torch.load(path)
74
+ ckpt["model_class"]["_target_"] = "emotion_models.pitch_predictor.CnnPredictor"
75
+ model = hydra.utils.instantiate(ckpt["model_class"])
76
+ model.load_state_dict(ckpt["state_dict"])
77
+ model.setup_f0_stats(
78
+ ckpt["f0_min"],
79
+ ckpt["f0_max"],
80
+ ckpt["f0_bins"],
81
+ ckpt["speaker_stats"],
82
+ )
83
+ return model
84
+
85
+
86
+ def freq2bin(f0, f0_min, f0_max, bins):
87
+ f0 = f0.clone()
88
+ f0[f0 < f0_min] = f0_min
89
+ f0[f0 > f0_max] = f0_max
90
+ f0 = torch.bucketize(f0, bins)
91
+ return f0
92
+
93
+
94
+ def bin2freq(x, f0_min, f0_max, bins, mode):
95
+ n_bins = len(bins) + 1
96
+ assert x.shape[-1] == n_bins
97
+ bins = torch.cat([torch.tensor([f0_min]), bins]).to(x.device)
98
+ if mode == "mean":
99
+ f0 = (x * bins).sum(-1, keepdims=True) / x.sum(-1, keepdims=True)
100
+ elif mode == "argmax":
101
+ idx = F.one_hot(x.argmax(-1), num_classes=n_bins)
102
+ f0 = (idx * bins).sum(-1, keepdims=True)
103
+ else:
104
+ raise NotImplementedError()
105
+ return f0[..., 0]
106
+
107
+
108
+ def load_wav(full_path):
109
+ sampling_rate, data = read(full_path)
110
+ return data, sampling_rate
111
+
112
+
113
+ def l1_loss(input, target):
114
+ return F.l1_loss(input=input.float(), target=target.float(), reduce=False)
115
+
116
+
117
+ def l2_loss(input, target):
118
+ return F.mse_loss(input=input.float(), target=target.float(), reduce=False)
119
+
120
+
121
+ class Collator:
122
+ def __init__(self, padding_idx):
123
+ self.padding_idx = padding_idx
124
+
125
+ def __call__(self, batch):
126
+ tokens = [item[0] for item in batch]
127
+ lengths = [len(item) for item in tokens]
128
+ tokens = torch.nn.utils.rnn.pad_sequence(
129
+ tokens, batch_first=True, padding_value=self.padding_idx
130
+ )
131
+ f0 = [item[1] for item in batch]
132
+ f0 = torch.nn.utils.rnn.pad_sequence(
133
+ f0, batch_first=True, padding_value=self.padding_idx
134
+ )
135
+ f0_raw = [item[2] for item in batch]
136
+ f0_raw = torch.nn.utils.rnn.pad_sequence(
137
+ f0_raw, batch_first=True, padding_value=self.padding_idx
138
+ )
139
+ spk = [item[3] for item in batch]
140
+ spk = torch.LongTensor(spk)
141
+ gst = [item[4] for item in batch]
142
+ gst = torch.LongTensor(gst)
143
+ mask = tokens != self.padding_idx
144
+ return tokens, f0, f0_raw, spk, gst, mask, lengths
145
+
146
+
147
+ class CnnPredictor(nn.Module):
148
+ def __init__(
149
+ self,
150
+ n_tokens,
151
+ emb_dim,
152
+ channels,
153
+ kernel,
154
+ dropout,
155
+ n_layers,
156
+ spk_emb,
157
+ gst_emb,
158
+ n_bins,
159
+ f0_pred,
160
+ f0_log,
161
+ f0_norm,
162
+ ):
163
+ super(CnnPredictor, self).__init__()
164
+ self.n_tokens = n_tokens
165
+ self.emb_dim = emb_dim
166
+ self.f0_log = f0_log
167
+ self.f0_pred = f0_pred
168
+ self.padding_token = n_tokens
169
+ self.f0_norm = f0_norm
170
+ # add 1 extra embedding for padding token, set the padding index to be the last token
171
+ # (tokens from the clustering start at index 0)
172
+ self.token_emb = nn.Embedding(
173
+ n_tokens + 1, emb_dim, padding_idx=self.padding_token
174
+ )
175
+
176
+ self.spk_emb = spk_emb
177
+ self.gst_emb = nn.Embedding(20, gst_emb)
178
+ self.setup = False
179
+
180
+ feats = emb_dim + gst_emb
181
+ # feats = emb_dim + gst_emb + (256 if spk_emb else 0)
182
+ layers = [
183
+ nn.Sequential(
184
+ Rearrange("b t c -> b c t"),
185
+ nn.Conv1d(
186
+ feats, channels, kernel_size=kernel, padding=(kernel - 1) // 2
187
+ ),
188
+ Rearrange("b c t -> b t c"),
189
+ nn.ReLU(),
190
+ nn.LayerNorm(channels),
191
+ nn.Dropout(dropout),
192
+ )
193
+ ]
194
+ for _ in range(n_layers - 1):
195
+ layers += [
196
+ nn.Sequential(
197
+ Rearrange("b t c -> b c t"),
198
+ nn.Conv1d(
199
+ channels,
200
+ channels,
201
+ kernel_size=kernel,
202
+ padding=(kernel - 1) // 2,
203
+ ),
204
+ Rearrange("b c t -> b t c"),
205
+ nn.ReLU(),
206
+ nn.LayerNorm(channels),
207
+ nn.Dropout(dropout),
208
+ )
209
+ ]
210
+ self.conv_layer = nn.ModuleList(layers)
211
+ self.proj = nn.Linear(channels, n_bins)
212
+
213
+ def forward(self, x, gst=None):
214
+ x = self.token_emb(x)
215
+ feats = [x]
216
+
217
+ if gst is not None:
218
+ gst = self.gst_emb(gst)
219
+ gst = rearrange(gst, "b c -> b c 1")
220
+ gst = F.interpolate(gst, x.shape[1])
221
+ gst = rearrange(gst, "b c t -> b t c")
222
+ feats.append(gst)
223
+
224
+ x = torch.cat(feats, dim=-1)
225
+
226
+ for i, conv in enumerate(self.conv_layer):
227
+ if i != 0:
228
+ x = conv(x) + x
229
+ else:
230
+ x = conv(x)
231
+
232
+ x = self.proj(x)
233
+ x = x.squeeze(-1)
234
+
235
+ if self.f0_pred == "mean":
236
+ x = torch.sigmoid(x)
237
+ elif self.f0_pred == "argmax":
238
+ x = torch.softmax(x, dim=-1)
239
+ else:
240
+ raise NotImplementedError
241
+ return x
242
+
243
+ def setup_f0_stats(self, f0_min, f0_max, f0_bins, speaker_stats):
244
+ self.f0_min = f0_min
245
+ self.f0_max = f0_max
246
+ self.f0_bins = f0_bins
247
+ self.speaker_stats = speaker_stats
248
+ self.setup = True
249
+
250
+ def inference(self, x, spk_id=None, gst=None):
251
+ assert (
252
+ self.setup == True
253
+ ), "make sure that `setup_f0_stats` was called before inference!"
254
+ probs = self(x, gst)
255
+ f0 = bin2freq(probs, self.f0_min, self.f0_max, self.f0_bins, self.f0_pred)
256
+ for i in range(f0.shape[0]):
257
+ mean = (
258
+ self.speaker_stats[spk_id[i].item()].mean_log
259
+ if self.f0_log
260
+ else self.speaker_stats[spk_id[i].item()].mean
261
+ )
262
+ std = (
263
+ self.speaker_stats[spk_id[i].item()].std_log
264
+ if self.f0_log
265
+ else self.speaker_stats[spk_id[i].item()].std
266
+ )
267
+ if self.f0_norm == "mean":
268
+ f0[i] = f0[i] + mean
269
+ if self.f0_norm == "meanstd":
270
+ f0[i] = (f0[i] * std) + mean
271
+ if self.f0_log:
272
+ f0 = f0.exp()
273
+ return f0
274
+
275
+
276
+ class PitchDataset(Dataset):
277
+ def __init__(
278
+ self,
279
+ tsv_path,
280
+ km_path,
281
+ substring,
282
+ spk,
283
+ spk2id,
284
+ gst,
285
+ gst2id,
286
+ f0_bins,
287
+ f0_bin_type,
288
+ f0_smoothing,
289
+ f0_norm,
290
+ f0_log,
291
+ ):
292
+ lines = open(tsv_path, "r").readlines()
293
+ self.root, self.tsv = lines[0], lines[1:]
294
+ self.root = self.root.strip()
295
+ self.km = open(km_path, "r").readlines()
296
+ print(f"loaded {len(self.km)} files")
297
+
298
+ self.spk = spk
299
+ self.spk2id = spk2id
300
+ self.gst = gst
301
+ self.gst2id = gst2id
302
+
303
+ self.f0_bins = f0_bins
304
+ self.f0_smoothing = f0_smoothing
305
+ self.f0_norm = f0_norm
306
+ self.f0_log = f0_log
307
+
308
+ if substring != "":
309
+ tsv, km = [], []
310
+ for tsv_line, km_line in zip(self.tsv, self.km):
311
+ if substring.lower() in tsv_line.lower():
312
+ tsv.append(tsv_line)
313
+ km.append(km_line)
314
+ self.tsv, self.km = tsv, km
315
+ print(f"after filtering: {len(self.km)} files")
316
+
317
+ self.speaker_stats = self._compute_f0_stats()
318
+ self.f0_min, self.f0_max = self._compute_f0_minmax()
319
+ if f0_bin_type == "adaptive":
320
+ self.f0_bins = quantize_f0(
321
+ self.speaker_stats, self.f0_bins, self.f0_norm, self.f0_log
322
+ )
323
+ elif f0_bin_type == "uniform":
324
+ self.f0_bins = torch.linspace(self.f0_min, self.f0_max, self.f0_bins + 1)[
325
+ 1:-1
326
+ ]
327
+ else:
328
+ raise NotImplementedError
329
+ print(f"f0 min: {self.f0_min}, f0 max: {self.f0_max}")
330
+ print(f"bins: {self.f0_bins} (shape: {self.f0_bins.shape})")
331
+
332
+ def __len__(self):
333
+ return len(self.km)
334
+
335
+ def _load_f0(self, tsv_line):
336
+ tsv_line = tsv_line.split("\t")[0]
337
+ f0 = self.root + "/" + tsv_line.replace(".wav", ".yaapt.f0.npy")
338
+ f0 = np.load(f0)
339
+ f0 = torch.FloatTensor(f0)
340
+ return f0
341
+
342
+ def _preprocess_f0(self, f0, spk):
343
+ mask = f0 != -999999 # process all frames
344
+ # mask = (f0 != 0) # only process voiced frames
345
+ mean = (
346
+ self.speaker_stats[spk].mean_log
347
+ if self.f0_log
348
+ else self.speaker_stats[spk].mean
349
+ )
350
+ std = (
351
+ self.speaker_stats[spk].std_log
352
+ if self.f0_log
353
+ else self.speaker_stats[spk].std
354
+ )
355
+ if self.f0_log:
356
+ f0[f0 == 0] = 1e-5
357
+ f0[mask] = f0[mask].log()
358
+ if self.f0_norm == "mean":
359
+ f0[mask] = f0[mask] - mean
360
+ if self.f0_norm == "meanstd":
361
+ f0[mask] = (f0[mask] - mean) / std
362
+ return f0
363
+
364
+ def _compute_f0_minmax(self):
365
+ f0_min, f0_max = float("inf"), -float("inf")
366
+ for tsv_line in tqdm(self.tsv, desc="computing f0 minmax"):
367
+ spk = self.spk2id[parse_speaker(tsv_line, self.spk)]
368
+ f0 = self._load_f0(tsv_line)
369
+ f0 = self._preprocess_f0(f0, spk)
370
+ f0_min = min(f0_min, f0.min().item())
371
+ f0_max = max(f0_max, f0.max().item())
372
+ return f0_min, f0_max
373
+
374
+ def _compute_f0_stats(self):
375
+ from functools import partial
376
+
377
+ speaker_stats = defaultdict(partial(F0Stat, True))
378
+ for tsv_line in tqdm(self.tsv, desc="computing speaker stats"):
379
+ spk = self.spk2id[parse_speaker(tsv_line, self.spk)]
380
+ f0 = self._load_f0(tsv_line)
381
+ mask = f0 != 0
382
+ f0 = f0[mask] # compute stats only on voiced parts
383
+ speaker_stats[spk].update(f0)
384
+ return speaker_stats
385
+
386
+ def __getitem__(self, i):
387
+ x = self.km[i]
388
+ x = x.split(" ")
389
+ x = list(map(int, x))
390
+ x = torch.LongTensor(x)
391
+
392
+ gst = parse_style(self.tsv[i], self.gst)
393
+ gst = self.gst2id[gst]
394
+ spk = parse_speaker(self.tsv[i], self.spk)
395
+ spk = self.spk2id[spk]
396
+
397
+ f0_raw = self._load_f0(self.tsv[i])
398
+ f0 = self._preprocess_f0(f0_raw.clone(), spk)
399
+
400
+ f0 = F.interpolate(f0.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0]
401
+ f0_raw = F.interpolate(f0_raw.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0]
402
+
403
+ f0 = freq2bin(f0, f0_min=self.f0_min, f0_max=self.f0_max, bins=self.f0_bins)
404
+ f0 = F.one_hot(f0.long(), num_classes=len(self.f0_bins) + 1).float()
405
+ if self.f0_smoothing > 0:
406
+ f0 = torch.tensor(
407
+ gaussian_filter1d(f0.float().numpy(), sigma=self.f0_smoothing)
408
+ )
409
+ return x, f0, f0_raw, spk, gst
410
+
411
+
412
+ def train(cfg):
413
+ device = "cuda:0"
414
+ # add 1 extra embedding for padding token, set the padding index to be the last token
415
+ # (tokens from the clustering start at index 0)
416
+ padding_token = cfg.n_tokens
417
+ collate_fn = Collator(padding_idx=padding_token)
418
+ train_ds = PitchDataset(
419
+ cfg.train_tsv,
420
+ cfg.train_km,
421
+ substring=cfg.substring,
422
+ spk=cfg.spk,
423
+ spk2id=cfg.spk2id,
424
+ gst=cfg.gst,
425
+ gst2id=cfg.gst2id,
426
+ f0_bins=cfg.f0_bins,
427
+ f0_bin_type=cfg.f0_bin_type,
428
+ f0_smoothing=cfg.f0_smoothing,
429
+ f0_norm=cfg.f0_norm,
430
+ f0_log=cfg.f0_log,
431
+ )
432
+ valid_ds = PitchDataset(
433
+ cfg.valid_tsv,
434
+ cfg.valid_km,
435
+ substring=cfg.substring,
436
+ spk=cfg.spk,
437
+ spk2id=cfg.spk2id,
438
+ gst=cfg.gst,
439
+ gst2id=cfg.gst2id,
440
+ f0_bins=cfg.f0_bins,
441
+ f0_bin_type=cfg.f0_bin_type,
442
+ f0_smoothing=cfg.f0_smoothing,
443
+ f0_norm=cfg.f0_norm,
444
+ f0_log=cfg.f0_log,
445
+ )
446
+ train_dl = DataLoader(
447
+ train_ds,
448
+ num_workers=0,
449
+ batch_size=cfg.batch_size,
450
+ shuffle=True,
451
+ collate_fn=collate_fn,
452
+ )
453
+ valid_dl = DataLoader(
454
+ valid_ds, num_workers=0, batch_size=16, shuffle=False, collate_fn=collate_fn
455
+ )
456
+
457
+ f0_min = train_ds.f0_min
458
+ f0_max = train_ds.f0_max
459
+ f0_bins = train_ds.f0_bins
460
+ speaker_stats = train_ds.speaker_stats
461
+
462
+ model = hydra.utils.instantiate(cfg["model"]).to(device)
463
+ model.setup_f0_stats(f0_min, f0_max, f0_bins, speaker_stats)
464
+
465
+ optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())
466
+
467
+ best_loss = float("inf")
468
+ for epoch in range(cfg.epochs):
469
+ train_loss, train_l2_loss, train_l2_voiced_loss = run_epoch(
470
+ model, train_dl, optimizer, device, cfg, mode="train"
471
+ )
472
+ valid_loss, valid_l2_loss, valid_l2_voiced_loss = run_epoch(
473
+ model, valid_dl, None, device, cfg, mode="valid"
474
+ )
475
+ print(
476
+ f"[epoch {epoch}] train loss: {train_loss:.3f}, l2 loss: {train_l2_loss:.3f}, l2 voiced loss: {train_l2_voiced_loss:.3f}"
477
+ )
478
+ print(
479
+ f"[epoch {epoch}] valid loss: {valid_loss:.3f}, l2 loss: {valid_l2_loss:.3f}, l2 voiced loss: {valid_l2_voiced_loss:.3f}"
480
+ )
481
+ if valid_l2_voiced_loss < best_loss:
482
+ path = f"{os.getcwd()}/pitch_predictor.ckpt"
483
+ save_ckpt(model, path, cfg["model"], f0_min, f0_max, f0_bins, speaker_stats)
484
+ best_loss = valid_l2_voiced_loss
485
+ print(f"saved checkpoint: {path}")
486
+ print(f"[epoch {epoch}] best loss: {best_loss:.3f}")
487
+
488
+
489
+ def run_epoch(model, loader, optimizer, device, cfg, mode):
490
+ if mode == "train":
491
+ model.train()
492
+ else:
493
+ model.eval()
494
+
495
+ epoch_loss = 0
496
+ l1 = 0
497
+ l1_voiced = 0
498
+ for x, f0_bin, f0_raw, spk_id, gst, mask, _ in tqdm(loader):
499
+ x, f0_bin, f0_raw, spk_id, gst, mask = (
500
+ x.to(device),
501
+ f0_bin.to(device),
502
+ f0_raw.to(device),
503
+ spk_id.to(device),
504
+ gst.to(device),
505
+ mask.to(device),
506
+ )
507
+ b, t, n_bins = f0_bin.shape
508
+ yhat = model(x, gst)
509
+ nonzero_mask = (f0_raw != 0).logical_and(mask)
510
+ yhat_raw = model.inference(x, spk_id, gst)
511
+ expanded_mask = mask.unsqueeze(-1).expand(-1, -1, n_bins)
512
+ if cfg.f0_pred == "mean":
513
+ loss = F.binary_cross_entropy(
514
+ yhat[expanded_mask], f0_bin[expanded_mask]
515
+ ).mean()
516
+ elif cfg.f0_pred == "argmax":
517
+ loss = F.cross_entropy(
518
+ rearrange(yhat, "b t d -> (b t) d"),
519
+ rearrange(f0_bin.argmax(-1), "b t -> (b t)"),
520
+ reduce=False,
521
+ )
522
+ loss = rearrange(loss, "(b t) -> b t", b=b, t=t)
523
+ loss = (loss * mask).sum() / mask.float().sum()
524
+ else:
525
+ raise NotImplementedError
526
+ l1 += F.l1_loss(yhat_raw[mask], f0_raw[mask]).item()
527
+ l1_voiced += F.l1_loss(yhat_raw[nonzero_mask], f0_raw[nonzero_mask]).item()
528
+ epoch_loss += loss.item()
529
+
530
+ if mode == "train":
531
+ loss.backward()
532
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
533
+ optimizer.step()
534
+
535
+ print(f"{mode} example y: {f0_bin.argmax(-1)[0, 50:60].tolist()}")
536
+ print(f"{mode} example yhat: {yhat.argmax(-1)[0, 50:60].tolist()}")
537
+ print(f"{mode} example y: {f0_raw[0, 50:60].round().tolist()}")
538
+ print(f"{mode} example yhat: {yhat_raw[0, 50:60].round().tolist()}")
539
+ return epoch_loss / len(loader), l1 / len(loader), l1_voiced / len(loader)
540
+
541
+
542
+ @hydra.main(config_path=dir_path, config_name="pitch_predictor.yaml")
543
+ def main(cfg):
544
+ np.random.seed(1)
545
+ random.seed(1)
546
+ torch.manual_seed(1)
547
+ from hydra.core.hydra_config import HydraConfig
548
+
549
+ overrides = {
550
+ x.split("=")[0]: x.split("=")[1]
551
+ for x in HydraConfig.get().overrides.task
552
+ if "/" not in x
553
+ }
554
+ print(f"{cfg}")
555
+ train(cfg)
556
+
557
+
558
+ if __name__ == "__main__":
559
+ main()
fairseq/examples/emotion_conversion/emotion_models/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Stat:
5
+ def __init__(self, keep_raw=False):
6
+ self.x = 0.0
7
+ self.x2 = 0.0
8
+ self.z = 0.0 # z = logx
9
+ self.z2 = 0.0
10
+ self.n = 0.0
11
+ self.u = 0.0
12
+ self.keep_raw = keep_raw
13
+ self.raw = []
14
+
15
+ def update(self, new_x):
16
+ new_z = new_x.log()
17
+
18
+ self.x += new_x.sum()
19
+ self.x2 += (new_x**2).sum()
20
+ self.z += new_z.sum()
21
+ self.z2 += (new_z**2).sum()
22
+ self.n += len(new_x)
23
+ self.u += 1
24
+
25
+ if self.keep_raw:
26
+ self.raw.append(new_x)
27
+
28
+ @property
29
+ def mean(self):
30
+ return self.x / self.n
31
+
32
+ @property
33
+ def std(self):
34
+ return (self.x2 / self.n - self.mean**2) ** 0.5
35
+
36
+ @property
37
+ def mean_log(self):
38
+ return self.z / self.n
39
+
40
+ @property
41
+ def std_log(self):
42
+ return (self.z2 / self.n - self.mean_log**2) ** 0.5
43
+
44
+ @property
45
+ def n_frms(self):
46
+ return self.n
47
+
48
+ @property
49
+ def n_utts(self):
50
+ return self.u
51
+
52
+ @property
53
+ def raw_data(self):
54
+ assert self.keep_raw, "does not support storing raw data!"
55
+ return torch.cat(self.raw)
56
+
57
+
58
+ class F0Stat(Stat):
59
+ def update(self, new_x):
60
+ # assume unvoiced frames are 0 and consider only voiced frames
61
+ if new_x is not None:
62
+ super().update(new_x[new_x != 0])
63
+
64
+
65
+ class Accuracy:
66
+ def __init__(self):
67
+ self.y, self.yhat = [], []
68
+
69
+ def update(self, yhat, y):
70
+ self.yhat.append(yhat)
71
+ self.y.append(y)
72
+
73
+ def acc(self, tol):
74
+ yhat = torch.cat(self.yhat)
75
+ y = torch.cat(self.y)
76
+ acc = torch.abs(yhat - y) <= tol
77
+ acc = acc.float().mean().item()
78
+ return acc
fairseq/examples/emotion_conversion/fairseq_models/__init__.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq import utils
7
+ from fairseq.models import (
8
+ FairseqMultiModel,
9
+ register_model,
10
+ register_model_architecture,
11
+ )
12
+ from fairseq.models.transformer import (
13
+ Embedding,
14
+ base_architecture,
15
+ )
16
+ from fairseq.models.multilingual_transformer import (
17
+ MultilingualTransformerModel,
18
+ base_multilingual_architecture,
19
+ )
20
+ from fairseq.utils import safe_hasattr
21
+ from collections import OrderedDict
22
+
23
+
24
+ @register_model("multilingual_transformer_from_mbart")
25
+ class MultilingualTransformerModelFromMbart(MultilingualTransformerModel):
26
+ @classmethod
27
+ def build_model(cls, args, task):
28
+ """Build a new model instance."""
29
+ from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
30
+
31
+ assert isinstance(task, MultilingualTranslationTask)
32
+
33
+ # make sure all arguments are present in older models
34
+ base_multilingual_architecture(args)
35
+
36
+ if not safe_hasattr(args, "max_source_positions"):
37
+ args.max_source_positions = 1024
38
+ if not safe_hasattr(args, "max_target_positions"):
39
+ args.max_target_positions = 1024
40
+
41
+ src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs]
42
+ tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs]
43
+
44
+ if args.share_encoders:
45
+ args.share_encoder_embeddings = True
46
+ if args.share_decoders:
47
+ args.share_decoder_embeddings = True
48
+
49
+ def build_embedding(dictionary, embed_dim, path=None):
50
+ num_embeddings = len(dictionary)
51
+ padding_idx = dictionary.pad()
52
+ emb = Embedding(num_embeddings, embed_dim, padding_idx)
53
+ # if provided, load from preloaded dictionaries
54
+ if path:
55
+ embed_dict = utils.parse_embedding(path)
56
+ utils.load_embedding(embed_dict, dictionary, emb)
57
+ return emb
58
+
59
+ # build shared embeddings (if applicable)
60
+ shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
61
+ if args.share_all_embeddings:
62
+ if args.encoder_embed_dim != args.decoder_embed_dim:
63
+ raise ValueError(
64
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
65
+ )
66
+ if args.decoder_embed_path and (
67
+ args.decoder_embed_path != args.encoder_embed_path
68
+ ):
69
+ raise ValueError(
70
+ "--share-all-embeddings not compatible with --decoder-embed-path"
71
+ )
72
+ shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
73
+ dicts=task.dicts,
74
+ langs=task.langs,
75
+ embed_dim=args.encoder_embed_dim,
76
+ build_embedding=build_embedding,
77
+ pretrained_embed_path=args.encoder_embed_path,
78
+ )
79
+ shared_decoder_embed_tokens = shared_encoder_embed_tokens
80
+ args.share_decoder_input_output_embed = True
81
+ else:
82
+ if args.share_encoder_embeddings:
83
+ shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
84
+ dicts=task.dicts,
85
+ langs=src_langs,
86
+ embed_dim=args.encoder_embed_dim,
87
+ build_embedding=build_embedding,
88
+ pretrained_embed_path=args.encoder_embed_path,
89
+ )
90
+ if args.share_decoder_embeddings:
91
+ shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings(
92
+ dicts=task.dicts,
93
+ langs=tgt_langs,
94
+ embed_dim=args.decoder_embed_dim,
95
+ build_embedding=build_embedding,
96
+ pretrained_embed_path=args.decoder_embed_path,
97
+ )
98
+
99
+ # encoders/decoders for each language
100
+ lang_encoders, lang_decoders = {}, {}
101
+
102
+ def get_encoder(lang):
103
+ if lang not in lang_encoders:
104
+ if shared_encoder_embed_tokens is not None:
105
+ encoder_embed_tokens = shared_encoder_embed_tokens
106
+ else:
107
+ encoder_embed_tokens = build_embedding(
108
+ task.dicts[lang],
109
+ args.encoder_embed_dim,
110
+ args.encoder_embed_path,
111
+ )
112
+ lang_encoders[lang] = MultilingualTransformerModel._get_module_class(
113
+ True, args, task.dicts[lang], encoder_embed_tokens, src_langs
114
+ )
115
+ return lang_encoders[lang]
116
+
117
+ def get_decoder(lang):
118
+ if lang not in lang_decoders:
119
+ if shared_decoder_embed_tokens is not None:
120
+ decoder_embed_tokens = shared_decoder_embed_tokens
121
+ else:
122
+ decoder_embed_tokens = build_embedding(
123
+ task.dicts[lang],
124
+ args.decoder_embed_dim,
125
+ args.decoder_embed_path,
126
+ )
127
+ lang_decoders[lang] = MultilingualTransformerModel._get_module_class(
128
+ False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs
129
+ )
130
+ return lang_decoders[lang]
131
+
132
+ # shared encoders/decoders (if applicable)
133
+ shared_encoder, shared_decoder = None, None
134
+ if args.share_encoders:
135
+ shared_encoder = get_encoder(src_langs[0])
136
+ if args.share_decoders:
137
+ shared_decoder = get_decoder(tgt_langs[0])
138
+
139
+ encoders, decoders = OrderedDict(), OrderedDict()
140
+ for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
141
+ encoders[lang_pair] = (
142
+ shared_encoder if shared_encoder is not None else get_encoder(src)
143
+ )
144
+ decoders[lang_pair] = (
145
+ shared_decoder if shared_decoder is not None else get_decoder(tgt)
146
+ )
147
+
148
+ return MultilingualTransformerModelFromMbart(encoders, decoders)
149
+
150
+ def load_state_dict(self, state_dict, strict=True, model_cfg=None):
151
+ state_dict_subset = state_dict.copy()
152
+ lang_pairs = set([x.split(".")[1] for x in state_dict.keys()])
153
+ finetune_mode = not any("neutral" in lp for lp in lang_pairs)
154
+
155
+ if finetune_mode:
156
+ # load a pre-trained mBART/BART model
157
+ # we need this code because mBART/BART are not of type FairseqMultiModel but FairseqModel
158
+ # so we hackishly load the weights by replicating them for all lang pairs
159
+ print("loading pre-trained BART")
160
+ self_state_dict = self.state_dict()
161
+ for k, v in state_dict.items():
162
+ for lang_pair in self.models:
163
+ new_key = k if "models." in k else f"models.{lang_pair}.{k}"
164
+ # print(new_key)
165
+ if self_state_dict[new_key].shape == v.shape:
166
+ state_dict_subset[new_key] = v
167
+ elif any(
168
+ w in k
169
+ for w in [
170
+ "encoder.embed_tokens.weight",
171
+ "decoder.embed_tokens.weight",
172
+ "decoder.output_projection.weight",
173
+ ]
174
+ ):
175
+ # why vocab_size - 5? because there are `vocab_size` tokens from the language
176
+ # and 5 additional tokens in the denoising task: eos,bos,pad,unk,mask.
177
+ # but in the translation task there are only `vocab_size` + 4 (no mask).
178
+ print(
179
+ f"{k}: {self_state_dict[new_key].shape} != {v.shape}",
180
+ end="",
181
+ flush=True,
182
+ )
183
+ vocab_size = v.shape[0] - 5
184
+ state_dict_subset[new_key] = self_state_dict[new_key]
185
+ state_dict_subset[new_key] = v[: vocab_size + 4]
186
+ print(f" => fixed by using first {vocab_size + 4} dims")
187
+ else:
188
+ raise ValueError("unable to load model due to mimatched dims!")
189
+ del state_dict_subset[k]
190
+ else:
191
+ print("loading pre-trained emotion translation model")
192
+ for k, _ in state_dict.items():
193
+ assert k.startswith("models.")
194
+ lang_pair = k.split(".")[1]
195
+ if lang_pair not in self.models:
196
+ del state_dict_subset[k]
197
+
198
+ super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg)
199
+
200
+
201
+ @register_model_architecture("transformer", "transformer_small")
202
+ def transformer_small(args):
203
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
204
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512)
205
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
206
+ args.encoder_layers = getattr(args, "encoder_layers", 3)
207
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
208
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512)
209
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
210
+ args.decoder_layers = getattr(args, "decoder_layers", 3)
211
+ base_architecture(args)
212
+
213
+
214
+ @register_model_architecture(
215
+ "multilingual_transformer_from_mbart", "multilingual_small"
216
+ )
217
+ def multilingual_small(args):
218
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
219
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512)
220
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
221
+ args.encoder_layers = getattr(args, "encoder_layers", 3)
222
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
223
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 512)
224
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
225
+ args.decoder_layers = getattr(args, "decoder_layers", 3)
226
+ base_multilingual_architecture(args)
fairseq/examples/emotion_conversion/preprocess/__init__.py ADDED
File without changes
fairseq/examples/emotion_conversion/preprocess/build_hifigan_manifest.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import argparse
3
+ import json
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(description="example: python create_hifigan_manifest.py --tsv /checkpoint/felixkreuk/datasets/vctk/splits/vctk_16khz/train.tsv --km /checkpoint/felixkreuk/experiments/hubert/hubert_feats/vctk_16khz_km_100/train.km --km_type hubert_100km > ~/tmp/tmp_mani.txt")
7
+ parser.add_argument("--tsv", required=True, help="path to fairseq tsv file")
8
+ parser.add_argument("--km", required=True, help="path to a km file generated by HuBERT clustering")
9
+ parser.add_argument("--km_type", required=True, help="name of the codes in the output json (for example: 'cpc_100km')")
10
+ args = parser.parse_args()
11
+
12
+ km_lines = open(args.km, "r").readlines()
13
+ tsv_lines = open(args.tsv, "r").readlines()
14
+ assert len(km_lines) == len(tsv_lines) - 1, "tsv and km files are not of the same length!"
15
+
16
+ wav_root = tsv_lines[0].strip()
17
+ tsv_lines = tsv_lines[1:]
18
+
19
+ for tsv_line, km_line in zip(tsv_lines, km_lines):
20
+ tsv_line, km_line = tsv_line.strip(), km_line.strip()
21
+ wav_basename, wav_num_frames = tsv_line.split("\t")
22
+ wav_path = wav_root + "/" + wav_basename
23
+ wav_info = torchaudio.info(wav_path)
24
+ assert int(wav_num_frames) == wav_info.num_frames, "tsv duration and actual duration don't match!"
25
+ wav_duration = wav_info.num_frames / wav_info.sample_rate
26
+ manifest_line = {"audio": wav_path, "duration": wav_duration, args.km_type: km_line}
27
+ print(json.dumps(manifest_line))
28
+
29
+ if __name__ == "__main__":
30
+ """
31
+ usage:
32
+ python create_hifigan_manifest.py \
33
+ --tsv /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/valid.tsv \
34
+ --km /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/hubert_km_100/valid.km \
35
+ --km_type hubert \
36
+ > /checkpoint/felixkreuk/datasets/vctk/manifests/vctk_16khz/hubert_km_100/hifigan_valid_manifest.txt
37
+ """
38
+ main()
fairseq/examples/emotion_conversion/preprocess/build_translation_manifests.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import argparse
3
+ from collections import defaultdict, Counter
4
+ from itertools import combinations, product, groupby
5
+ from pathlib import Path
6
+ import os
7
+ from sklearn.utils import shuffle
8
+ import numpy as np
9
+ import random
10
+ from shutil import copy
11
+ from subprocess import check_call
12
+
13
+ np.random.seed(42)
14
+ random.seed(42)
15
+
16
+
17
+ def get_fname(s):
18
+ return s.split("\t")[0]
19
+
20
+ def get_emotion(s):
21
+ return get_fname(s).split("_")[0].split("/")[1].lower()
22
+
23
+ def get_utt_id(s):
24
+ return get_fname(s).split(".")[0].split("_")[-1]
25
+
26
+ def dedup(seq):
27
+ """ >> remove_repetitions("1 2 2 3 100 2 2 1")
28
+ '1 2 3 100 2 1' """
29
+ seq = seq.strip().split(" ")
30
+ result = seq[:1]
31
+ reps = []
32
+ rep_counter = 1
33
+ for k in seq[1:]:
34
+ if k != result[-1]:
35
+ result += [k]
36
+ reps += [rep_counter]
37
+ rep_counter = 1
38
+ else:
39
+ rep_counter += 1
40
+ reps += [rep_counter]
41
+ assert len(reps) == len(result) and sum(reps) == len(seq)
42
+ return " ".join(result) + "\n" #, reps
43
+
44
+ def remove_under_k(seq, k):
45
+ """ remove tokens that repeat less then k times in a row
46
+ >> remove_under_k("a a a a b c c c", 1) ==> a a a a c c c """
47
+ seq = seq.strip().split(" ")
48
+ result = []
49
+
50
+ freqs = [(k,len(list(g))) for k, g in groupby(seq)]
51
+ for c, f in freqs:
52
+ if f > k:
53
+ result += [c for _ in range(f)]
54
+ return " ".join(result) + "\n" #, reps
55
+
56
+
57
+ def call(cmd):
58
+ print(cmd)
59
+ check_call(cmd, shell=True)
60
+
61
+
62
+ def denoising_preprocess(path, lang, dict):
63
+ bin = 'fairseq-preprocess'
64
+ cmd = [
65
+ bin,
66
+ f'--trainpref {path}/train.{lang} --validpref {path}/valid.{lang} --testpref {path}/test.{lang}',
67
+ f'--destdir {path}/tokenized/{lang}',
68
+ '--only-source',
69
+ '--task multilingual_denoising',
70
+ '--workers 40',
71
+ ]
72
+ if dict != "":
73
+ cmd += [f'--srcdict {dict}']
74
+ cmd = " ".join(cmd)
75
+ call(cmd)
76
+
77
+
78
+ def translation_preprocess(path, src_lang, trg_lang, dict, only_train=False):
79
+ bin = 'fairseq-preprocess'
80
+ cmd = [
81
+ bin,
82
+ f'--source-lang {src_lang} --target-lang {trg_lang}',
83
+ f'--trainpref {path}/train',
84
+ f'--destdir {path}/tokenized',
85
+ '--workers 40',
86
+ ]
87
+ if not only_train:
88
+ cmd += [f'--validpref {path}/valid --testpref {path}/test']
89
+ if dict != "":
90
+ cmd += [
91
+ f'--srcdict {dict}',
92
+ f'--tgtdict {dict}',
93
+ ]
94
+ cmd = " ".join(cmd)
95
+ call(cmd)
96
+
97
+
98
+ def load_tsv_km(tsv_path, km_path):
99
+ assert tsv_path.exists() and km_path.exists()
100
+ tsv_lines = open(tsv_path, "r").readlines()
101
+ root, tsv_lines = tsv_lines[0], tsv_lines[1:]
102
+ km_lines = open(km_path, "r").readlines()
103
+ assert len(tsv_lines) == len(km_lines), ".tsv and .km should be the same length!"
104
+ return root, tsv_lines, km_lines
105
+
106
+
107
+ def main():
108
+ desc = """
109
+ this script takes as input .tsv and .km files for EMOV dataset, and a pairs of emotions.
110
+ it generates parallel .tsv and .km files for these emotions. for exmaple:
111
+ ❯ python build_emov_translation_manifests.py \
112
+ /checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/train.tsv \
113
+ /checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/emov_16khz_km_100/train.km \
114
+ ~/tmp/emov_pairs \
115
+ --src-emotion amused --trg-emotion neutral \
116
+ --dedup --shuffle --cross-speaker --dry-run
117
+ """
118
+ parser = argparse.ArgumentParser(description=desc)
119
+ parser.add_argument("data", type=Path, help="path to a dir containing .tsv and .km files containing emov dataset")
120
+ parser.add_argument("output_path", type=Path, help="output directory with the manifests will be created")
121
+ parser.add_argument("-cs", "--cross-speaker", action='store_true', help="if set then translation will occur also between speakers, meaning the same sentence can be translated between different speakers (default: false)")
122
+ parser.add_argument("-dd", "--dedup", action='store_true', help="remove repeated tokens (example: 'aaabc=>abc')")
123
+ parser.add_argument("-sh", "--shuffle", action='store_true', help="shuffle the data")
124
+ parser.add_argument("-ae", "--autoencode", action='store_true', help="include training pairs from the same emotion (this includes examples of the same sentence uttered by different people and examples where the src and trg are the exact same seq)")
125
+ parser.add_argument("-dr", "--dry-run", action='store_true', help="don't write anything to disk")
126
+ parser.add_argument("-zs", "--zero-shot", action='store_true', help="if true, the denoising task will train on the same splits as the translation task (split by utterance id). if false, the denoising task will train on randomly sampled splits (not split by utterance id)")
127
+ parser.add_argument("--km-ext", default="km", help="")
128
+ parser.add_argument("--dict", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/fairseq.dict.txt", help="")
129
+ args = parser.parse_args()
130
+ SPEAKERS = ["bea", "jenie", "josh", "sam", "SAME"]
131
+ EMOTIONS = ['neutral', 'amused', 'angry', 'disgusted', 'sleepy']
132
+
133
+ suffix = ""
134
+ if args.cross_speaker: suffix += "_cross-speaker"
135
+ if args.dedup: suffix += "_dedup"
136
+ translation_suffix = ""
137
+ if args.autoencode: translation_suffix += "_autoencode"
138
+ denoising_suffix = ""
139
+ denoising_suffix += "_zeroshot" if args.zero_shot else "_nonzeroshot"
140
+
141
+ translation_dir = Path(args.output_path) / ("emov_multilingual_translation" + suffix + translation_suffix)
142
+ os.makedirs(translation_dir, exist_ok=True)
143
+ denoising_dir = Path(args.output_path) / ("emov_multilingual_denoising" + suffix + denoising_suffix)
144
+ os.makedirs(denoising_dir, exist_ok=True)
145
+
146
+ denoising_data = [p.name for p in (args.data / "denoising").glob("*") if "emov" not in p.name]
147
+
148
+ for split in ["train", "valid", "test"]:
149
+ root, tsv_lines, km_lines = load_tsv_km(
150
+ tsv_path = args.data / "denoising" / "emov" / f"{split}.tsv",
151
+ km_path = args.data / "denoising" / "emov" / f"{split}.{args.km_ext}"
152
+ )
153
+
154
+ # generate data for the multilingual denoising task
155
+ for EMOTION in EMOTIONS:
156
+ print("---")
157
+ print(split)
158
+ print(f"denoising: {EMOTION}")
159
+ emotion_tsv, emotion_km = [], []
160
+ for tsv_line, km_line in zip(tsv_lines, km_lines):
161
+ if EMOTION.lower() in tsv_line.lower():
162
+ km_line = km_line if not args.dedup else dedup(km_line)
163
+ emotion_tsv.append(tsv_line)
164
+ emotion_km.append(km_line)
165
+ print(f"{len(emotion_km)} samples")
166
+ open(denoising_dir / f"files.{split}.{EMOTION}", "w").writelines([root] + emotion_tsv)
167
+ open(denoising_dir / f"{split}.{EMOTION}", "w").writelines(emotion_km)
168
+
169
+ for data in denoising_data:
170
+ with open(args.data / "denoising" / data / f"{split}.{args.km_ext}", "r") as f1:
171
+ with open(denoising_dir / f"{split}.{data}", "w") as f2:
172
+ f2.writelines([l if not args.dedup else dedup(l) for l in f1.readlines()])
173
+
174
+ # start of translation preprocessing
175
+ root, tsv_lines, km_lines = load_tsv_km(
176
+ tsv_path = args.data / "translation" / f"{split}.tsv",
177
+ km_path = args.data / "translation" / f"{split}.{args.km_ext}"
178
+ )
179
+
180
+ # generate data for the multilingual translation task
181
+ for SRC_EMOTION in EMOTIONS:
182
+ TRG_EMOTIONS = EMOTIONS if args.autoencode else set(EMOTIONS) - set([SRC_EMOTION])
183
+ for TRG_EMOTION in TRG_EMOTIONS:
184
+ # when translating back to the same emotion - we dont want these emotion
185
+ # pairs to be part of the validation/test sets (because its not really emotion conversino)
186
+ # if SRC_EMOTION == TRG_EMOTION and split in ["valid", "test"]: continue
187
+ print("---")
188
+ print(split)
189
+ print(f"src emotions: {SRC_EMOTION}\ntrg emotions: {TRG_EMOTION}")
190
+
191
+ # create a dictionary with the following structure:
192
+ # output[SPEAKER][UTT_ID] = list with indexes of line from the tsv file
193
+ # that match the speaker and utterance id. for exmaple:
194
+ # output = {'sam': {'0493': [875, 1608, 1822], ...}, ...}
195
+ # meaning, for speaker 'sam', utterance id '0493', the indexes in tsv_lines
196
+ # are 875, 1608, 1822
197
+ spkr2utts = defaultdict(lambda: defaultdict(list))
198
+ for i, tsv_line in enumerate(tsv_lines):
199
+ speaker = tsv_line.split("/")[0]
200
+ if args.cross_speaker: speaker = "SAME"
201
+ assert speaker in SPEAKERS, "unknown speaker! make sure the .tsv contains EMOV data"
202
+ utt_id = get_utt_id(tsv_line)
203
+ spkr2utts[speaker][utt_id].append(i)
204
+
205
+ # create a tsv and km files with all the combinations for translation
206
+ src_tsv, trg_tsv, src_km, trg_km = [], [], [], []
207
+ for speaker, utt_ids in spkr2utts.items():
208
+ for utt_id, indices in utt_ids.items():
209
+ # generate all pairs
210
+ pairs = [(x,y) for x in indices for y in indices]
211
+ # self-translation
212
+ if SRC_EMOTION == TRG_EMOTION:
213
+ pairs = [(x,y) for (x,y) in pairs if x == y]
214
+ # filter according to src and trg emotions
215
+ pairs = [(x,y) for (x,y) in pairs
216
+ if get_emotion(tsv_lines[x]) == SRC_EMOTION and get_emotion(tsv_lines[y]) == TRG_EMOTION]
217
+
218
+ for idx1, idx2 in pairs:
219
+ assert get_utt_id(tsv_lines[idx1]) == get_utt_id(tsv_lines[idx2])
220
+ src_tsv.append(tsv_lines[idx1])
221
+ trg_tsv.append(tsv_lines[idx2])
222
+ km_line_idx1 = km_lines[idx1]
223
+ km_line_idx2 = km_lines[idx2]
224
+ km_line_idx1 = km_line_idx1 if not args.dedup else dedup(km_line_idx1)
225
+ km_line_idx2 = km_line_idx2 if not args.dedup else dedup(km_line_idx2)
226
+ src_km.append(km_line_idx1)
227
+ trg_km.append(km_line_idx2)
228
+ assert len(src_tsv) == len(trg_tsv) == len(src_km) == len(trg_km)
229
+ print(f"{len(src_tsv)} pairs")
230
+
231
+ if len(src_tsv) == 0:
232
+ raise Exception("ERROR: generated 0 pairs!")
233
+
234
+ if args.dry_run: continue
235
+
236
+ # create files
237
+ os.makedirs(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}", exist_ok=True)
238
+ open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"files.{split}.{SRC_EMOTION}", "w").writelines([root] + src_tsv)
239
+ open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"files.{split}.{TRG_EMOTION}", "w").writelines([root] + trg_tsv)
240
+ open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"{split}.{SRC_EMOTION}", "w").writelines(src_km)
241
+ open(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}" / f"{split}.{TRG_EMOTION}", "w").writelines(trg_km)
242
+
243
+
244
+ # fairseq-preprocess the denoising data
245
+ for EMOTION in EMOTIONS + denoising_data:
246
+ denoising_preprocess(denoising_dir, EMOTION, args.dict)
247
+ os.system(f"cp {args.dict} {denoising_dir}/tokenized/dict.txt")
248
+
249
+ # fairseq-preprocess the translation data
250
+ os.makedirs(translation_dir / "tokenized", exist_ok=True)
251
+ for SRC_EMOTION in EMOTIONS:
252
+ TRG_EMOTIONS = EMOTIONS if args.autoencode else set(EMOTIONS) - set([SRC_EMOTION])
253
+ for TRG_EMOTION in TRG_EMOTIONS:
254
+ translation_preprocess(translation_dir / f"{SRC_EMOTION}-{TRG_EMOTION}", SRC_EMOTION, TRG_EMOTION, args.dict)#, only_train=SRC_EMOTION==TRG_EMOTION)
255
+ os.system(f"cp -rf {translation_dir}/**/tokenized/* {translation_dir}/tokenized")
256
+
257
+ if __name__ == "__main__":
258
+ main()
fairseq/examples/emotion_conversion/preprocess/create_core_manifest.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ import sys
4
+ import subprocess
5
+ import argparse
6
+ from datetime import datetime
7
+ import logging
8
+
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format='%(asctime)s [%(levelname)s] %(message)s',
12
+ handlers=[logging.FileHandler('debug.log'), logging.StreamHandler()]
13
+ )
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def verify_dict_size(km, dict):
18
+ logger.info(f"verifying: {km}")
19
+ dict_size = len(open(dict, "r").readlines())
20
+ km_vocab = set(open(km, "r").read().replace("\n", " ").split(" "))
21
+ if "" in km_vocab: km_vocab.remove("")
22
+ km_vocab_size = len(km_vocab)
23
+ return dict_size == km_vocab_size
24
+
25
+
26
+ def verify_files_exist(l):
27
+ for f in l:
28
+ if not f.exists():
29
+ logging.error(f"{f} doesn't exist!")
30
+ return False
31
+ return True
32
+
33
+
34
+ def run_cmd(cmd, print_output=True):
35
+ try:
36
+ out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True, shell=True)
37
+ if print_output:
38
+ logger.info(f"command output:\n{out}")
39
+ return out
40
+ except subprocess.CalledProcessError as grepexc:
41
+ logger.info(f"error executing command!:\n{cmd}")
42
+ logger.info(grepexc.output)
43
+
44
+ def main():
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--tsv", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/data.tsv", type=Path)
47
+ parser.add_argument("--emov-km", required=True, type=Path)
48
+ parser.add_argument("--km", nargs='+', required=True, type=Path)
49
+ parser.add_argument("--seed", type=int, default=1)
50
+ parser.add_argument("--dict", default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz/fairseq.dict.txt")
51
+ parser.add_argument("--manifests-dir", type=Path, default="/checkpoint/felixkreuk/datasets/emov/manifests/emov_16khz")
52
+ args = parser.parse_args()
53
+
54
+ manifests_dir = args.manifests_dir
55
+ date = datetime.now().strftime('%d%m%y')
56
+ outdir = manifests_dir / f"{date}"
57
+
58
+ # verify input and create folders
59
+ all_kms = args.km + [args.emov_km]
60
+ assert verify_files_exist(all_kms), "make sure the km dir contains: train-clean-all.km, blizzard2013.km, data.km"
61
+ for codes in all_kms:
62
+ assert verify_dict_size(codes, args.dict), "dict argument doesn't match the vocabulary of the km file!"
63
+ assert not outdir.exists(), "data dir already exists!"
64
+ outdir.mkdir(parents=True, exist_ok=True)
65
+
66
+ logger.info("generating denoising split (emov)")
67
+ run_cmd(f"python preprocess/split_km_tsv.py {args.tsv} {args.emov_km} --destdir {outdir}/denoising/emov -sh --seed {args.seed}")
68
+ for codes in args.km:
69
+ codes_name = os.path.basename(codes)
70
+ run_cmd(f"python preprocess/split_km.py {codes} --destdir {outdir}/denoising/{codes_name} -sh --seed {args.seed}")
71
+
72
+ logger.info("generating translation split")
73
+ run_cmd(f"python preprocess/split_emov_km_tsv_by_uttid.py {args.tsv} {args.emov_km} --destdir {outdir}/translation --seed {args.seed}")
74
+
75
+ emov_code_name = os.path.basename(args.emov_km)
76
+ logger.info("generating hifigan split")
77
+ run_cmd(
78
+ f"mkdir -p {outdir}/hifigan &&"
79
+ f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/train.tsv --km {outdir}/denoising/emov/train.km > {outdir}/hifigan/train.txt &&"
80
+ f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/valid.tsv --km {outdir}/denoising/emov/valid.km > {outdir}/hifigan/valid.txt &&"
81
+ f"python preprocess/build_hifigan_manifest.py --km_type hubert --tsv {outdir}/denoising/emov/test.tsv --km {outdir}/denoising/emov/test.km > {outdir}/hifigan/test.txt"
82
+ )
83
+
84
+ logger.info("generating fairseq manifests")
85
+ run_cmd(f"python preprocess/build_translation_manifests.py {outdir} {outdir}/fairseq-data -dd -cs --dict {args.dict}")
86
+
87
+ logger.info(f"finished processing data at:\n{outdir}")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()
fairseq/examples/emotion_conversion/preprocess/extract_f0.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from tqdm import tqdm
3
+ from multiprocessing import Manager, Pool
4
+
5
+ from scipy.io.wavfile import read
6
+ from librosa.util import normalize
7
+ import numpy as np
8
+ import amfm_decompy.pYAAPT as pYAAPT
9
+ import amfm_decompy.basic_tools as basic
10
+
11
+ MAX_WAV_VALUE = 32768.0
12
+
13
+ parser = argparse.ArgumentParser(description="")
14
+ parser.add_argument("tsv", help="")
15
+ parser.add_argument("--extractor", choices=["crepe", "pyaapt"], default="pyaapt", help="")
16
+ parser.add_argument("--interp", action="store_true", help="")
17
+ parser.add_argument("--n_workers", type=int, default=40, help="")
18
+ args = parser.parse_args()
19
+
20
+ tsv_lines = open(args.tsv, "r").readlines()
21
+ root, tsv_lines = tsv_lines[0].strip(), tsv_lines[1:]
22
+
23
+
24
+ def extract_f0(tsv_line):
25
+ wav_path, _ = tsv_line.split("\t")
26
+ wav_path = root.strip() + "/" + wav_path
27
+ sr, wav = read(wav_path)
28
+ wav = wav / MAX_WAV_VALUE
29
+ wav = normalize(wav) * 0.95
30
+
31
+ if args.extractor == "pyaapt":
32
+ frame_length = 20.0
33
+ pad = int(frame_length / 1000 * sr) // 2
34
+ wav = np.pad(wav.squeeze(), (pad, pad), "constant", constant_values=0)
35
+ signal = basic.SignalObj(wav, sr)
36
+ pitch = pYAAPT.yaapt(
37
+ signal,
38
+ **{
39
+ 'frame_length': frame_length,
40
+ 'frame_space': 5.0,
41
+ 'nccf_thresh1': 0.25,
42
+ 'tda_frame_length': 25.0
43
+ })
44
+ pitch = pitch.samp_interp[None, None, :] if args.interp else pitch.samp_values[None, None, :]
45
+ pitch = pitch[0, 0]
46
+ f0_path = wav_path.replace(".wav", ".yaapt")
47
+ f0_path += ".interp.f0" if args.interp else ".f0"
48
+ np.save(f0_path, pitch)
49
+
50
+
51
+ def main():
52
+ with Pool(args.n_workers) as p:
53
+ r = list(tqdm(p.imap(extract_f0, tsv_lines), total=len(tsv_lines)))
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
fairseq/examples/emotion_conversion/preprocess/process_km.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ from tqdm import tqdm
4
+ from build_emov_translation_manifests import dedup, remove_under_k
5
+
6
+
7
+ if __name__ == "__main__":
8
+ """
9
+ this is a standalone script to process a km file
10
+ specifically, to dedup or remove tokens that repeat less
11
+ than k times in a row
12
+ """
13
+ parser = argparse.ArgumentParser(description="")
14
+ parser.add_argument("km", type=str, help="path to km file")
15
+ parser.add_argument("--dedup", action='store_true')
16
+ parser.add_argument("--remove-under-k", type=int, default=0)
17
+ parser.add_argument("--output", default=None)
18
+ args = parser.parse_args()
19
+
20
+ if not args.dedup and args.remove_under_k == 0:
21
+ print("nothing to do! quitting...")
22
+ sys.exit(0)
23
+
24
+ km = open(args.km, "r").readlines()
25
+ out = []
26
+ for line in tqdm(km):
27
+ if args.remove_under_k > 0:
28
+ line = remove_under_k(line, args.remove_under_k)
29
+ if args.dedup:
30
+ line = dedup(line)
31
+ out.append(line)
32
+
33
+ path = args.km if args.output is None else args.output
34
+ if args.remove_under_k > 0:
35
+ path = path.replace(".km", f"-k{args.remove_under_k}.km")
36
+ if args.dedup:
37
+ path = path.replace(".km", f"-deduped.km")
38
+
39
+ open(path, "w").writelines(out)
40
+ print(f"written to {path}")
fairseq/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ import sys
4
+ import argparse
5
+ import random
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from sklearn.model_selection import train_test_split
9
+ from build_translation_manifests import get_utt_id
10
+
11
+
12
+ def train_val_test_split(tsv_lines, km_lines, valid_percent, test_percent, seed=42):
13
+ utt_ids = list(sorted(set([get_utt_id(x) for x in tsv_lines])))
14
+ utt_ids, valid_utt_ids, _, _ = train_test_split(utt_ids, utt_ids, test_size=valid_percent, shuffle=True, random_state=seed)
15
+ train_utt_ids, test_utt_ids, _, _ = train_test_split(utt_ids, utt_ids, test_size=test_percent, shuffle=True, random_state=seed)
16
+
17
+ train_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in train_utt_ids]
18
+ valid_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in valid_utt_ids]
19
+ test_idx = [i for i, line in enumerate(tsv_lines) if get_utt_id(line) in test_utt_ids]
20
+
21
+ train_tsv, train_km = [tsv_lines[i] for i in train_idx], [km_lines[i] for i in train_idx]
22
+ valid_tsv, valid_km = [tsv_lines[i] for i in valid_idx], [km_lines[i] for i in valid_idx]
23
+ test_tsv, test_km = [tsv_lines[i] for i in test_idx], [km_lines[i] for i in test_idx]
24
+
25
+ print(f"train {len(train_km)}")
26
+ print(f"valid {len(valid_km)}")
27
+ print(f"test {len(test_km)}")
28
+
29
+ return train_tsv, train_km, valid_tsv, valid_km, test_tsv, test_km
30
+
31
+
32
+ if __name__ == "__main__":
33
+ """
34
+ this is a standalone script to process a km file
35
+ specifically, to dedup or remove tokens that repeat less
36
+ than k times in a row
37
+ """
38
+ parser = argparse.ArgumentParser(description="")
39
+ parser.add_argument("tsv", type=str, help="path to tsv file")
40
+ parser.add_argument("km", type=str, help="path to km file")
41
+ parser.add_argument("--destdir", required=True, type=str)
42
+ parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
43
+ parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
44
+ parser.add_argument("--seed", type=int, default=42, help="")
45
+ args = parser.parse_args()
46
+
47
+ np.random.seed(args.seed)
48
+ random.seed(args.seed)
49
+
50
+ os.makedirs(args.destdir, exist_ok=True)
51
+ km = open(args.km, "r").readlines()
52
+ tsv = open(args.tsv, "r").readlines()
53
+ root, tsv = tsv[0], tsv[1:]
54
+
55
+ assert args.tsv.endswith(".tsv") and args.km.endswith(".km")
56
+ assert len(tsv) == len(km)
57
+
58
+ train_tsv, train_km, valid_tsv, valid_km, test_tsv, test_km = train_val_test_split(tsv, km, args.valid_percent, args.test_percent, args.seed)
59
+
60
+ assert len(train_tsv) + len(valid_tsv) + len(test_tsv) == len(tsv)
61
+ assert len(train_tsv) == len(train_km) and len(valid_tsv) == len(valid_km) and len(test_tsv) == len(test_km)
62
+
63
+ dir = Path(args.destdir)
64
+ open(dir / f"train.tsv", "w").writelines([root] + train_tsv)
65
+ open(dir / f"valid.tsv", "w").writelines([root] + valid_tsv)
66
+ open(dir / f"test.tsv", "w").writelines([root] + test_tsv)
67
+ open(dir / f"train.km", "w").writelines(train_km)
68
+ open(dir / f"valid.km", "w").writelines(valid_km)
69
+ open(dir / f"test.km", "w").writelines(test_km)
70
+ print("done")
fairseq/examples/emotion_conversion/preprocess/split_km.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ import argparse
4
+ import random
5
+ import numpy as np
6
+ from sklearn.utils import shuffle
7
+
8
+
9
+ if __name__ == "__main__":
10
+ """
11
+ this is a standalone script to process a km file
12
+ specifically, to dedup or remove tokens that repeat less
13
+ than k times in a row
14
+ """
15
+ parser = argparse.ArgumentParser(description="")
16
+ parser.add_argument("km", type=str, help="path to km file")
17
+ parser.add_argument("--destdir", required=True, type=str)
18
+ parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
19
+ parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
20
+ parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file")
21
+ parser.add_argument("--seed", type=int, default=42, help="")
22
+ args = parser.parse_args()
23
+
24
+ np.random.seed(args.seed)
25
+ random.seed(args.seed)
26
+
27
+ os.makedirs(args.destdir, exist_ok=True)
28
+ km = open(args.km, "r").readlines()
29
+
30
+ if args.shuffle:
31
+ km = shuffle(km)
32
+ print(f"shuffled")
33
+
34
+ N = len(km)
35
+ N_tt = int(N * args.test_percent)
36
+ N_cv = int(N * args.valid_percent)
37
+ N_tr = N - N_tt - N_cv
38
+
39
+ train_km = km[:N_tr]
40
+ valid_km = km[N_tr:N_tr + N_cv]
41
+ test_km = km[N_tr + N_cv:]
42
+
43
+ dir = Path(args.destdir)
44
+ open(dir / f"train.km", "w").writelines(train_km)
45
+ open(dir / f"valid.km", "w").writelines(valid_km)
46
+ open(dir / f"test.km", "w").writelines(test_km)
47
+ print(f"train: {len(train_km)}")
48
+ print(f"valid: {len(valid_km)}")
49
+ print(f"test: {len(test_km)}")
50
+ print("done")
fairseq/examples/emotion_conversion/preprocess/split_km_tsv.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ import argparse
4
+ import random
5
+ import numpy as np
6
+ from sklearn.utils import shuffle
7
+
8
+
9
+ if __name__ == "__main__":
10
+ """
11
+ this is a standalone script to process a km file
12
+ specifically, to dedup or remove tokens that repeat less
13
+ than k times in a row
14
+ """
15
+ parser = argparse.ArgumentParser(description="")
16
+ parser.add_argument("tsv", type=str, help="path to tsv file")
17
+ parser.add_argument("km", type=str, help="path to km file")
18
+ parser.add_argument("--destdir", required=True, type=str)
19
+ parser.add_argument("--valid-percent", type=float, default=0.05, help="percent to allocate to validation set")
20
+ parser.add_argument("--test-percent", type=float, default=0.05, help="percent to allocate to test set")
21
+ parser.add_argument("-sh", "--shuffle", action="store_true", help="path to km file")
22
+ parser.add_argument("--seed", type=int, default=42, help="")
23
+ args = parser.parse_args()
24
+
25
+ np.random.seed(args.seed)
26
+ random.seed(args.seed)
27
+
28
+ os.makedirs(args.destdir, exist_ok=True)
29
+ km = open(args.km, "r").readlines()
30
+ tsv = open(args.tsv, "r").readlines()
31
+ root, tsv = tsv[0], tsv[1:]
32
+
33
+ assert args.tsv.endswith(".tsv") and args.km.endswith(".km")
34
+ assert len(tsv) == len(km)
35
+
36
+ if args.shuffle:
37
+ tsv, km = shuffle(tsv, km)
38
+ print(f"shuffled")
39
+
40
+ N = len(tsv)
41
+ N_tt = int(N * args.test_percent)
42
+ N_cv = int(N * args.valid_percent)
43
+ N_tr = N - N_tt - N_cv
44
+
45
+ train_tsv = tsv[:N_tr]
46
+ valid_tsv = tsv[N_tr:N_tr + N_cv]
47
+ test_tsv = tsv[N_tr + N_cv:]
48
+ train_km = km[:N_tr]
49
+ valid_km = km[N_tr:N_tr + N_cv]
50
+ test_km = km[N_tr + N_cv:]
51
+
52
+ assert len(train_tsv) + len(valid_tsv) + len(test_tsv) == len(tsv)
53
+ assert len(train_tsv) == len(train_km) and len(valid_tsv) == len(valid_km) and len(test_tsv) == len(test_km)
54
+
55
+ dir = Path(args.destdir)
56
+ open(dir / f"train.tsv", "w").writelines([root] + train_tsv)
57
+ open(dir / f"valid.tsv", "w").writelines([root] + valid_tsv)
58
+ open(dir / f"test.tsv", "w").writelines([root] + test_tsv)
59
+ open(dir / f"train.km", "w").writelines(train_km)
60
+ open(dir / f"valid.km", "w").writelines(valid_km)
61
+ open(dir / f"test.km", "w").writelines(test_km)
62
+ print(f"train: {len(train_km)}")
63
+ print(f"valid: {len(valid_km)}")
64
+ print(f"test: {len(test_km)}")
65
+ print("done")
fairseq/examples/fast_noisy_channel/README.md ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling
2
+
3
+ ## Introduction
4
+ - [Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) introduce a simple and effective noisy channel modeling approach for neural machine translation. However, the noisy channel online decoding approach introduced in this paper is too slow to be practical.
5
+ - To address this, [Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 simple approximations to make this approach very fast and practical without much loss in accuracy.
6
+ - This README provides intructions on how to run online decoding or generation with the noisy channel modeling approach, including ways to make it very fast without much loss in accuracy.
7
+
8
+ ## Noisy Channel Modeling
9
+
10
+ [Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) applies the Bayes Rule to predict `P(y|x)`, the probability of the target `y` given the source `x`.
11
+ ```P(y|x) = P(x|y) * P(y) / P(x)```
12
+ - `P(x|y)` predicts the source `x` given the target `y` and is referred to as the **channel model**
13
+ - `P(y)` is a **language model** over the target `y`
14
+ - `P(x)` is generally not modeled since it is constant for all `y`.
15
+
16
+ We use Transformer models to parameterize the direct model `P(y|x)`, the channel model `P(x|y)` and the language model `P(y)`.
17
+
18
+ During online decoding with beam search, we generate the top `K2` candidates per beam and score them with the following linear combination of the channel model, the language model as well as the direct model scores.
19
+
20
+ ```(1 / t) * log(P(y|x) + (1 / s) * ( λ1 * log(P(x|y)) + λ2 * log(P(y) ) )```
21
+ - `t` - Target Prefix Length
22
+ - `s` - Source Length
23
+ - `λ1` - Channel Model Weight
24
+ - `λ2` - Language Model Weight
25
+
26
+ The top `beam_size` candidates based on the above combined scores are chosen to continue the beams in beam search. In beam search with a direct model alone, the scores from the direct model `P(y|x)` are used to choose the top candidates in beam search.
27
+
28
+ This framework provides a great way to utlize strong target language models trained on large amounts of unlabeled data. Language models can prefer targets unrelated to the source, so we also need a channel model whose role is to ensure that the target preferred by the language model also translates back to the source.
29
+
30
+ ### Training Translation Models and Language Models
31
+
32
+ For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/translation)
33
+
34
+ For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model)
35
+
36
+ ### Generation with Language Model for German-English translation with fairseq
37
+
38
+ Here are instructions to generate using a direct model and a target-side language model.
39
+
40
+ Note:
41
+ - Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
42
+ - Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
43
+
44
+ ```sh
45
+ binarized_data=data_dir/binarized
46
+ direct_model=de_en_seed4.pt
47
+ lm_model=en_lm.pt
48
+ lm_data=lm_data
49
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
50
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
51
+ mkdir -p ${lm_data}
52
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
53
+
54
+ k2=10
55
+ lenpen=0.16
56
+ lm_wt=0.14
57
+ fairseq-generate ${binarized_data} \
58
+ --user-dir examples/fast_noisy_channel \
59
+ --beam 5 \
60
+ --path ${direct_model} \
61
+ --lm-model ${lm_model} \
62
+ --lm-data ${lm_data} \
63
+ --k2 ${k2} \
64
+ --combine-method lm_only \
65
+ --task noisy_channel_translation \
66
+ --lenpen ${lenpen} \
67
+ --lm-wt ${lm_wt} \
68
+ --gen-subset valid \
69
+ --remove-bpe \
70
+ --fp16 \
71
+ --batch-size 10
72
+ ```
73
+ ### Noisy Channel Generation for German-English translation with fairseq
74
+
75
+ Here are instructions for noisy channel generation with a direct model, channel model and language model as explained in section [Noisy Channel Modeling](#noisy-channel-modeling).
76
+
77
+ Note:
78
+ - Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
79
+ - Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
80
+
81
+ ```sh
82
+ binarized_data=data_dir/binarized
83
+ direct_model=de_en_seed4.pt
84
+ lm_model=en_lm.pt
85
+ lm_data=lm_data
86
+ ch_model=en_de.big.seed4.pt
87
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
88
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
89
+ mkdir -p ${lm_data}
90
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
91
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt -O ${ch_model}
92
+
93
+ k2=10
94
+ lenpen=0.21
95
+ lm_wt=0.50
96
+ bw_wt=0.30
97
+ fairseq-generate ${binarized_data} \
98
+ --user-dir examples/fast_noisy_channel \
99
+ --beam 5 \
100
+ --path ${direct_model} \
101
+ --lm-model ${lm_model} \
102
+ --lm-data ${lm_data} \
103
+ --channel-model ${ch_model} \
104
+ --k2 ${k2} \
105
+ --combine-method noisy_channel \
106
+ --task noisy_channel_translation \
107
+ --lenpen ${lenpen} \
108
+ --lm-wt ${lm_wt} \
109
+ --ch-wt ${bw_wt} \
110
+ --gen-subset test \
111
+ --remove-bpe \
112
+ --fp16 \
113
+ --batch-size 1
114
+ ```
115
+ ## Fast Noisy Channel Modeling
116
+
117
+ [Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 approximations that speed up online noisy channel decoding -
118
+ - Smaller channel models (`Tranformer Base` with 1 encoder and decoder layer each vs. `Transformer Big`)
119
+ - This involves training a channel model that is possibly smaller and less accurate in terms of BLEU than a channel model of the same size as the direct model.
120
+ - Since the role of the channel model is mainly to assign low scores to generations from the language model if they don't translate back to the source, we may not need the most accurate channel model for this purpose.
121
+ - Smaller output vocabulary size for the channel model (~30,000 -> ~1000)
122
+ - The channel model doesn't need to score the full output vocabulary, it just needs to score the source tokens, which are completely known.
123
+ - This is specified using the arguments `--channel-scoring-type src_vocab --top-k-vocab 500`
124
+ - This means that the output vocabulary for the channel model will be the source tokens for all examples in the batch and the top-K most frequent tokens in the vocabulary
125
+ - This reduces the memory consumption needed to store channel model scores significantly
126
+ - Smaller number of candidates (`k2`) scored per beam
127
+ - This is specified by reducing the argument `--k2`
128
+
129
+
130
+ ### Fast Noisy Channel Generation for German-English translation with fairseq
131
+
132
+ Here are instructions for **fast** noisy channel generation with a direct model, channel model and language model as explained in section [Fast Noisy Channel Modeling](#fast-noisy-channel-modeling). The main differences are that we use a smaller channel model, reduce `--k2`, set `--channel-scoring-type src_vocab --top-k-vocab 500` and increase the `--batch-size`.
133
+
134
+ Note:
135
+ - Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
136
+ - Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
137
+
138
+ ```sh
139
+ binarized_data=data_dir/binarized
140
+ direct_model=de_en_seed4.pt
141
+ lm_model=en_lm.pt
142
+ lm_data=lm_data
143
+ small_ch_model=en_de.base_1_1.seed4.pt
144
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
145
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
146
+ mkdir -p ${lm_data}
147
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
148
+ wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt -O ${small_ch_model}
149
+
150
+ k2=3
151
+ lenpen=0.23
152
+ lm_wt=0.58
153
+ bw_wt=0.26
154
+ fairseq-generate ${binarized_data} \
155
+ --user-dir examples/fast_noisy_channel \
156
+ --beam 5 \
157
+ --path ${direct_model} \
158
+ --lm-model ${lm_model} \
159
+ --lm-data ${lm_data} \
160
+ --channel-model ${small_ch_model} \
161
+ --k2 ${k2} \
162
+ --combine-method noisy_channel \
163
+ --task noisy_channel_translation \
164
+ --lenpen ${lenpen} \
165
+ --lm-wt ${lm_wt} \
166
+ --ch-wt ${bw_wt} \
167
+ --gen-subset test \
168
+ --remove-bpe \
169
+ --fp16 \
170
+ --batch-size 50 \
171
+ --channel-scoring-type src_vocab --top-k-vocab 500
172
+ ```
173
+
174
+ ## Test Data Preprocessing
175
+
176
+ For preprocessing and binarizing the test sets for Romanian-English and German-English translation, we use the following script -
177
+
178
+ ```sh
179
+ FAIRSEQ=/path/to/fairseq
180
+ cd $FAIRSEQ
181
+ SCRIPTS=$FAIRSEQ/mosesdecoder/scripts
182
+ if [ ! -d "${SCRIPTS}" ]; then
183
+ echo 'Cloning Moses github repository (for tokenization scripts)...'
184
+ git clone https://github.com/moses-smt/mosesdecoder.git
185
+ fi
186
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
187
+ NORMALIZE=$SCRIPTS/tokenizer/normalize-punctuation.perl
188
+
189
+ s=de
190
+ t=en
191
+ test=wmt18
192
+
193
+ mkdir -p data_dir
194
+
195
+ # Tokenization
196
+ if [ $s == "ro" ] ; then
197
+ # Note: Get normalise-romanian.py and remove-diacritics.py from
198
+ # https://github.com/rsennrich/wmt16-scripts/tree/master/preprocess
199
+ sacrebleu -t $test -l $s-$t --echo src | \
200
+ $NORMALIZE -l $s | \
201
+ python normalise-romanian.py | \
202
+ python remove-diacritics.py | \
203
+ $TOKENIZER -l $s -a -q > data_dir/$test.$s-$t.$s
204
+ else
205
+ sacrebleu -t $test -l $s-$t --echo src | perl $NORMALIZE -l $s | perl $TOKENIZER -threads 8 -a -l $s > data_dir/$test.$s-$t.$s
206
+ fi
207
+
208
+ sacrebleu -t $test -l $s-$t --echo ref | perl $NORMALIZE -l $t | perl $TOKENIZER -threads 8 -a -l $t > data_dir/$test.$s-$t.$t
209
+
210
+
211
+ # Applying BPE
212
+ src_bpe_code=/path/to/source/language/bpe/code
213
+ tgt_bpe_code=/path/to/target/language/bpe/code
214
+ src_dict=/path/to/source/language/dict
215
+ tgt_dict=/path/to/target/language/dict
216
+
217
+ FASTBPE=$FAIRSEQ/fastBPE
218
+ if [ ! -d "${FASTBPE}" ] ; then
219
+ git clone https://github.com/glample/fastBPE.git
220
+ # Follow compilation instructions at https://github.com/glample/fastBPE
221
+ g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
222
+ fi
223
+
224
+ ${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${src_bpe_code}
225
+ ${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${tgt_bpe_code}
226
+
227
+ fairseq-preprocess -s $s -t $t \
228
+ --testpref data_dir/bpe.$test.$s-$t \
229
+ --destdir data_dir/binarized \
230
+ --srcdict ${src_dict} \
231
+ --tgtdict ${tgt_dict}
232
+ ```
233
+
234
+ ## Calculating BLEU
235
+
236
+ ```sh
237
+ DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
238
+ cat ${generation_output} | grep -P "^H" | sort -V | cut -f 3- | $DETOKENIZER -l $t -q -a | sacrebleu -t $test -l $s-$t
239
+ ```
240
+
241
+
242
+ ## Romanian-English Translation
243
+
244
+ The direct and channel models are trained using bitext data (WMT16) combined with backtranslated data (The monolingual data used for backtranslation comes from http://data.statmt.org/rsennrich/wmt16_backtranslations/ (Sennrich et al., 2016c))
245
+
246
+ The backtranslated data is generated using an ensemble of 3 English-Romanian models trained on bitext training data (WMT16) with unrestricted sampling.
247
+
248
+ ### BPE Codes and Dictionary
249
+
250
+ We learn a joint BPE vocabulary of 18K types on the bitext training data which is used for both the source and target.
251
+ ||Path|
252
+ |----------|------|
253
+ | BPE Code | [joint_bpe_18k](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/bpe_18k) |
254
+ | Dictionary | [dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/dict) |
255
+
256
+ ### Direct Models
257
+ For Ro-En with backtranslation, the direct and channel models use a Transformer-Big architecture.
258
+
259
+ | Seed | Model |
260
+ |----|----|
261
+ | 2 | [ro_en_seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed2.pt)
262
+ | 4 | [ro_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed4.pt)
263
+ | 6 | [ro_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed6.pt)
264
+
265
+ ### Channel Models
266
+ For channel models, we follow the same steps as for the direct models. But backtranslated data is generated in the opposite direction using [this Romanian monolingual data](http://data.statmt.org/rsennrich/wmt16_backtranslations/).
267
+ The best lenpen, LM weight and CH weight are obtained by sweeping over the validation set (wmt16/dev) using beam 5.
268
+ | Model Size | Lenpen | LM Weight | CH Weight | Seed 2 | Seed 4 | Seed 6 |
269
+ |----|----|----|----|----|----|----|
270
+ | `big` | 0.84 | 0.64 | 0.56 | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) |
271
+ | `base_1_1` | 0.63 | 0.40 | 0.37 | [base_1_1.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed2.pt) | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed6.pt) |
272
+
273
+ ### Language Model
274
+ The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
275
+ | | Path |
276
+ |----|----|
277
+ | `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/transformer_lm.pt) |
278
+ | `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/lm_dict)
279
+
280
+ ## German-English Translation
281
+
282
+ ### BPE Codes and Dictionaries
283
+
284
+ | | Path|
285
+ |----------|------|
286
+ | Source BPE Code | [de_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_bpe_code_24K) |
287
+ | Target BPE Code | [en_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_bpe_code_24K)
288
+ | Source Dictionary | [de_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_dict) |
289
+ | Target Dictionary | [en_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_dict) |
290
+
291
+ ### Direct Models
292
+ We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
293
+ We use the Transformer-Big architecture for the direct model.
294
+
295
+ | Seed | Model |
296
+ |:----:|----|
297
+ | 4 | [de_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt)
298
+ | 5 | [de_en_seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed5.pt)
299
+ | 6 | [de_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed6.pt)
300
+
301
+ ### Channel Models
302
+
303
+ We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
304
+
305
+ | Model Size | Seed 4 | Seed 5 | Seed 6 |
306
+ |----|----|----|----|
307
+ | `big` | [big.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt) | [big.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed5.pt) | [big.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed6.pt) |
308
+ | `big_1_1` | [big_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed4.pt) | [big_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed5.pt) | [big_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed6.pt) |
309
+ | `base` | [base.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed4.pt) | [base.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed5.pt) | [base.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed6.pt) |
310
+ | `base_1_1` | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed5.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed6.pt) |
311
+ | `half` | [half.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed4.pt) | [half.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed5.pt) | [half.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed6.pt) |
312
+ | `half_1_1` | [half_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed4.pt) | [half_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed5.pt) | [half_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed6.pt) |
313
+ | `quarter` | [quarter.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed4.pt) | [quarter.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed5.pt) | [quarter.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed6.pt) |
314
+ | `quarter_1_1` | [quarter_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed4.pt) | [quarter_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed5.pt) | [quarter_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed6.pt) |
315
+ | `8th` | [8th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed4.pt) | [8th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed5.pt) | [8th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed6.pt) |
316
+ | `8th_1_1` | [8th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed4.pt) | [8th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed5.pt) | [8th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed6.pt) |
317
+ | `16th` | [16th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed4.pt) | [16th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed5.pt) | [16th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed6.pt) |
318
+ | `16th_1_1` | [16th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed4.pt) | [16th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed5.pt) | [16th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed6.pt) |
319
+
320
+ ### Language Model
321
+ The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
322
+ | | Path |
323
+ |----|----|
324
+ | `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt) |
325
+ | `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/)
326
+
327
+
328
+ ## Citation
329
+
330
+ ```bibtex
331
+ @inproceedings{bhosale2020language,
332
+ title={Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling},
333
+ author={Shruti Bhosale and Kyra Yee and Sergey Edunov and Michael Auli},
334
+ booktitle={Proceedings of the Fifth Conference on Machine Translation (WMT)},
335
+ year={2020},
336
+ }
337
+
338
+ @inproceedings{yee2019simple,
339
+ title={Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
340
+ author={Yee, Kyra and Dauphin, Yann and Auli, Michael},
341
+ booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
342
+ pages={5700--5705},
343
+ year={2019}
344
+ }
345
+ ```
fairseq/examples/fast_noisy_channel/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import noisy_channel_translation # noqa
7
+ from . import noisy_channel_sequence_generator # noqa
8
+ from . import noisy_channel_beam_search # noqa
fairseq/examples/fast_noisy_channel/noisy_channel_beam_search.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq.search import Search
8
+
9
+
10
+ class NoisyChannelBeamSearch(Search):
11
+
12
+ def __init__(self, tgt_dict):
13
+ super().__init__(tgt_dict)
14
+ self.fw_scores_buf = None
15
+ self.lm_scores_buf = None
16
+
17
+ def _init_buffers(self, t):
18
+ # super()._init_buffers(t)
19
+ if self.fw_scores_buf is None:
20
+ self.scores_buf = t.new()
21
+ self.indices_buf = torch.LongTensor().to(device=t.device)
22
+ self.beams_buf = torch.LongTensor().to(device=t.device)
23
+ self.fw_scores_buf = t.new()
24
+ self.lm_scores_buf = t.new()
25
+
26
+ def combine_fw_bw(self, combine_method, fw_cum, bw, step):
27
+ if combine_method == "noisy_channel":
28
+ fw_norm = fw_cum.div(step + 1)
29
+ lprobs = bw + fw_norm
30
+ elif combine_method == "lm_only":
31
+ lprobs = bw + fw_cum
32
+
33
+ return lprobs
34
+
35
+ def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method):
36
+ self._init_buffers(fw_lprobs)
37
+ bsz, beam_size, vocab_size = fw_lprobs.size()
38
+
39
+ if step == 0:
40
+ # at the first step all hypotheses are equally likely, so use
41
+ # only the first beam
42
+ fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous()
43
+ bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous()
44
+ # nothing to add since we are at the first step
45
+ fw_lprobs_cum = fw_lprobs
46
+
47
+ else:
48
+ # make probs contain cumulative scores for each hypothesis
49
+ raw_scores = (scores[:, :, step - 1].unsqueeze(-1))
50
+ fw_lprobs_cum = (fw_lprobs.add(raw_scores))
51
+
52
+ combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step)
53
+
54
+ # choose the top k according to the combined noisy channel model score
55
+ torch.topk(
56
+ combined_lprobs.view(bsz, -1),
57
+ k=min(
58
+ # Take the best 2 x beam_size predictions. We'll choose the first
59
+ # beam_size of these which don't predict eos to continue with.
60
+ beam_size * 2,
61
+ combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
62
+ ),
63
+ out=(self.scores_buf, self.indices_buf),
64
+ )
65
+ # save corresponding fw and lm scores
66
+ self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf)
67
+ self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf)
68
+ # Project back into relative indices and beams
69
+ self.beams_buf = self.indices_buf // vocab_size
70
+ self.indices_buf.fmod_(vocab_size)
71
+ return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf
fairseq/examples/fast_noisy_channel/noisy_channel_sequence_generator.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Dict, List, Optional
7
+
8
+ import math
9
+ import numpy as np
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+
15
+ from .noisy_channel_beam_search import NoisyChannelBeamSearch
16
+ from fairseq.sequence_generator import EnsembleModel
17
+
18
+
19
+ class NoisyChannelSequenceGenerator(object):
20
+ def __init__(
21
+ self,
22
+ combine_method,
23
+ tgt_dict,
24
+ src_dict=None,
25
+ beam_size=1,
26
+ max_len_a=0,
27
+ max_len_b=200,
28
+ min_len=1,
29
+ len_penalty=1.0,
30
+ unk_penalty=0.0,
31
+ retain_dropout=False,
32
+ temperature=1.0,
33
+ match_source_len=False,
34
+ no_repeat_ngram_size=0,
35
+ normalize_scores=True,
36
+ channel_models=None,
37
+ k2=10,
38
+ ch_weight=1.0,
39
+ channel_scoring_type='log_norm',
40
+ top_k_vocab=0,
41
+ lm_models=None,
42
+ lm_dict=None,
43
+ lm_weight=1.0,
44
+ normalize_lm_scores_by_tgt_len=False,
45
+ ):
46
+ """Generates translations of a given source sentence,
47
+ using beam search with noisy channel decoding.
48
+
49
+ Args:
50
+ combine_method (string, optional): Method to combine direct, LM and
51
+ channel model scores (default: None)
52
+ tgt_dict (~fairseq.data.Dictionary): target dictionary
53
+ src_dict (~fairseq.data.Dictionary): source dictionary
54
+ beam_size (int, optional): beam width (default: 1)
55
+ max_len_a/b (int, optional): generate sequences of maximum length
56
+ ax + b, where x is the source length
57
+ min_len (int, optional): the minimum length of the generated output
58
+ (not including end-of-sentence)
59
+ len_penalty (float, optional): length penalty, where <1.0 favors
60
+ shorter, >1.0 favors longer sentences (default: 1.0)
61
+ unk_penalty (float, optional): unknown word penalty, where <0
62
+ produces more unks, >0 produces fewer (default: 0.0)
63
+ retain_dropout (bool, optional): use dropout when generating
64
+ (default: False)
65
+ temperature (float, optional): temperature, where values
66
+ >1.0 produce more uniform samples and values <1.0 produce
67
+ sharper samples (default: 1.0)
68
+ match_source_len (bool, optional): outputs should match the source
69
+ length (default: False)
70
+ no_repeat_ngram_size (int, optional): Size of n-grams that we avoid
71
+ repeating in the generation (default: 0)
72
+ normalize_scores (bool, optional): normalize scores by the length
73
+ of the output (default: True)
74
+ channel_models (List[~fairseq.models.FairseqModel]): ensemble of models
75
+ translating from the target to the source
76
+ k2 (int, optional): Top K2 candidates to score per beam at each step (default:10)
77
+ ch_weight (int, optional): Weight associated with the channel model score
78
+ assuming that the direct model score has weight 1.0 (default: 1.0)
79
+ channel_scoring_type (str, optional): String specifying how to score
80
+ the channel model (default: 'log_norm')
81
+ top_k_vocab (int, optional): If `channel_scoring_type` is `'src_vocab'` or
82
+ `'src_vocab_batched'`, then this parameter specifies the number of
83
+ most frequent tokens to include in the channel model output vocabulary,
84
+ in addition to the source tokens in the input batch (default: 0)
85
+ lm_models (List[~fairseq.models.FairseqModel]): ensemble of models
86
+ generating text in the target language
87
+ lm_dict (~fairseq.data.Dictionary): LM Model dictionary
88
+ lm_weight (int, optional): Weight associated with the LM model score
89
+ assuming that the direct model score has weight 1.0 (default: 1.0)
90
+ normalize_lm_scores_by_tgt_len (bool, optional): Should we normalize LM scores
91
+ by the target length? By default, we normalize the combination of
92
+ LM and channel model scores by the source length
93
+ """
94
+ self.pad = tgt_dict.pad()
95
+ self.unk = tgt_dict.unk()
96
+ self.eos = tgt_dict.eos()
97
+ self.vocab_size = len(tgt_dict)
98
+ self.beam_size = beam_size
99
+ # the max beam size is the dictionary size - 1, since we never select pad
100
+ self.beam_size = min(beam_size, self.vocab_size - 1)
101
+ self.max_len_a = max_len_a
102
+ self.max_len_b = max_len_b
103
+ self.min_len = min_len
104
+ self.normalize_scores = normalize_scores
105
+ self.len_penalty = len_penalty
106
+ self.unk_penalty = unk_penalty
107
+ self.retain_dropout = retain_dropout
108
+ self.temperature = temperature
109
+ self.match_source_len = match_source_len
110
+ self.no_repeat_ngram_size = no_repeat_ngram_size
111
+ self.channel_models = channel_models
112
+ self.src_dict = src_dict
113
+ self.tgt_dict = tgt_dict
114
+ self.combine_method = combine_method
115
+ self.k2 = k2
116
+ self.ch_weight = ch_weight
117
+ self.channel_scoring_type = channel_scoring_type
118
+ self.top_k_vocab = top_k_vocab
119
+ self.lm_models = lm_models
120
+ self.lm_dict = lm_dict
121
+ self.lm_weight = lm_weight
122
+ self.log_softmax_fn = torch.nn.LogSoftmax(dim=1)
123
+ self.normalize_lm_scores_by_tgt_len = normalize_lm_scores_by_tgt_len
124
+
125
+ self.share_tgt_dict = (self.lm_dict == self.tgt_dict)
126
+ self.tgt_to_lm = make_dict2dict(tgt_dict, lm_dict)
127
+
128
+ self.ch_scoring_bsz = 3072
129
+
130
+ assert temperature > 0, '--temperature must be greater than 0'
131
+
132
+ self.search = NoisyChannelBeamSearch(tgt_dict)
133
+
134
+ @torch.no_grad()
135
+ def generate(
136
+ self,
137
+ models,
138
+ sample,
139
+ prefix_tokens=None,
140
+ bos_token=None,
141
+ **kwargs
142
+ ):
143
+ """Generate a batch of translations.
144
+ Args:
145
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
146
+ sample (dict): batch
147
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
148
+ with these tokens
149
+ """
150
+ model = EnsembleModel(models)
151
+ incremental_states = torch.jit.annotate(
152
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
153
+ [
154
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
155
+ for i in range(model.models_size)
156
+ ],
157
+ )
158
+ if not self.retain_dropout:
159
+ model.eval()
160
+
161
+ # model.forward normally channels prev_output_tokens into the decoder
162
+ # separately, but SequenceGenerator directly calls model.encoder
163
+ encoder_input = {
164
+ k: v for k, v in sample['net_input'].items()
165
+ if k != 'prev_output_tokens'
166
+ }
167
+ src_tokens = encoder_input['src_tokens']
168
+ src_lengths_no_eos = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
169
+ input_size = src_tokens.size()
170
+ # batch dimension goes first followed by source lengths
171
+ bsz = input_size[0]
172
+ src_len = input_size[1]
173
+ beam_size = self.beam_size
174
+
175
+ if self.match_source_len:
176
+ max_len = src_lengths_no_eos.max().item()
177
+ else:
178
+ max_len = min(
179
+ int(self.max_len_a * src_len + self.max_len_b),
180
+ # exclude the EOS marker
181
+ model.max_decoder_positions() - 1,
182
+ )
183
+
184
+ # compute the encoder output for each beam
185
+ encoder_outs = model.forward_encoder(encoder_input)
186
+ new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
187
+ new_order = new_order.to(src_tokens.device).long()
188
+ encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
189
+
190
+ src_lengths = encoder_input['src_lengths']
191
+ # initialize buffers
192
+ scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
193
+ lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0)
194
+
195
+ scores_buf = scores.clone()
196
+ tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
197
+ tokens_buf = tokens.clone()
198
+ tokens[:, 0] = self.eos if bos_token is None else bos_token
199
+
200
+ # reorder source tokens so they may be used as a reference in generating P(S|T)
201
+ src_tokens = reorder_all_tokens(src_tokens, src_lengths, self.src_dict.eos_index)
202
+
203
+ src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len)
204
+ src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view(bsz*beam_size, -1)
205
+
206
+ attn, attn_buf = None, None
207
+ nonpad_idxs = None
208
+
209
+ # The cands_to_ignore indicates candidates that should be ignored.
210
+ # For example, suppose we're sampling and have already finalized 2/5
211
+ # samples. Then the cands_to_ignore would mark 2 positions as being ignored,
212
+ # so that we only finalize the remaining 3 samples.
213
+ cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask
214
+
215
+ # list of completed sentences
216
+ finalized = [[] for i in range(bsz)]
217
+ finished = [False for i in range(bsz)]
218
+ num_remaining_sent = bsz
219
+
220
+ # number of candidate hypos per step
221
+ cand_size = 2 * beam_size # 2 x beam size in case half are EOS
222
+
223
+ # offset arrays for converting between different indexing schemes
224
+ bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
225
+ cand_offsets = torch.arange(0, cand_size).type_as(tokens)
226
+
227
+ # helper function for allocating buffers on the fly
228
+ buffers = {}
229
+
230
+ def buffer(name, type_of=tokens): # noqa
231
+ if name not in buffers:
232
+ buffers[name] = type_of.new()
233
+ return buffers[name]
234
+
235
+ def is_finished(sent, step, unfin_idx):
236
+ """
237
+ Check whether we've finished generation for a given sentence, by
238
+ comparing the worst score among finalized hypotheses to the best
239
+ possible score among unfinalized hypotheses.
240
+ """
241
+ assert len(finalized[sent]) <= beam_size
242
+ if len(finalized[sent]) == beam_size:
243
+ return True
244
+ return False
245
+
246
+ def finalize_hypos(step, bbsz_idx, eos_scores, combined_noisy_channel_eos_scores):
247
+ """
248
+ Finalize the given hypotheses at this step, while keeping the total
249
+ number of finalized hypotheses per sentence <= beam_size.
250
+
251
+ Note: the input must be in the desired finalization order, so that
252
+ hypotheses that appear earlier in the input are preferred to those
253
+ that appear later.
254
+
255
+ Args:
256
+ step: current time step
257
+ bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
258
+ indicating which hypotheses to finalize
259
+ eos_scores: A vector of the same size as bbsz_idx containing
260
+ fw scores for each hypothesis
261
+ combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing
262
+ combined noisy channel scores for each hypothesis
263
+ """
264
+ assert bbsz_idx.numel() == eos_scores.numel()
265
+
266
+ # clone relevant token and attention tensors
267
+ tokens_clone = tokens.index_select(0, bbsz_idx)
268
+ tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
269
+ assert not tokens_clone.eq(self.eos).any()
270
+ tokens_clone[:, step] = self.eos
271
+ attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
272
+
273
+ # compute scores per token position
274
+ pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
275
+ pos_scores[:, step] = eos_scores
276
+ # convert from cumulative to per-position scores
277
+ pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
278
+
279
+ # normalize sentence-level scores
280
+ if self.normalize_scores:
281
+ combined_noisy_channel_eos_scores /= (step + 1) ** self.len_penalty
282
+
283
+ cum_unfin = []
284
+ prev = 0
285
+ for f in finished:
286
+ if f:
287
+ prev += 1
288
+ else:
289
+ cum_unfin.append(prev)
290
+
291
+ sents_seen = set()
292
+ for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), combined_noisy_channel_eos_scores.tolist())):
293
+ unfin_idx = idx // beam_size
294
+ sent = unfin_idx + cum_unfin[unfin_idx]
295
+
296
+ sents_seen.add((sent, unfin_idx))
297
+
298
+ if self.match_source_len and step > src_lengths_no_eos[unfin_idx]:
299
+ score = -math.inf
300
+
301
+ def get_hypo():
302
+
303
+ if attn_clone is not None:
304
+ # remove padding tokens from attn scores
305
+ hypo_attn = attn_clone[i][nonpad_idxs[sent]]
306
+ _, alignment = hypo_attn.max(dim=0)
307
+ else:
308
+ hypo_attn = None
309
+ alignment = None
310
+
311
+ return {
312
+ 'tokens': tokens_clone[i],
313
+ 'score': score,
314
+ 'attention': hypo_attn, # src_len x tgt_len
315
+ 'alignment': alignment,
316
+ 'positional_scores': pos_scores[i],
317
+ }
318
+
319
+ if len(finalized[sent]) < beam_size:
320
+ finalized[sent].append(get_hypo())
321
+
322
+ newly_finished = []
323
+ for sent, unfin_idx in sents_seen:
324
+ # check termination conditions for this sentence
325
+ if not finished[sent] and is_finished(sent, step, unfin_idx):
326
+ finished[sent] = True
327
+ newly_finished.append(unfin_idx)
328
+ return newly_finished
329
+
330
+ def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens, k):
331
+ """Rescore the top k hypothesis from each beam using noisy channel modeling
332
+ Returns:
333
+ new_fw_lprobs: the direct model probabilities after pruning the top k
334
+ new_ch_lm_lprobs: the combined channel and language model probabilities
335
+ new_lm_lprobs: the language model probabilities after pruning the top k
336
+ """
337
+ with torch.no_grad():
338
+ lprobs_size = lprobs.size()
339
+ if prefix_tokens is not None and step < prefix_tokens.size(1):
340
+ probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
341
+ cand_scores = torch.gather(
342
+ probs_slice, dim=1,
343
+ index=prefix_tokens[:, step].view(-1, 1).data
344
+ ).expand(-1, beam_size).contiguous().view(bsz*beam_size, 1)
345
+ cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, beam_size).data.contiguous().view(bsz*beam_size, 1)
346
+
347
+ # need to calculate and save fw and lm probs for prefix tokens
348
+ fw_top_k = cand_scores
349
+ fw_top_k_idx = cand_indices
350
+ k = 1
351
+ else:
352
+ # take the top k best words for every sentence in batch*beam
353
+ fw_top_k, fw_top_k_idx = torch.topk(lprobs.view(beam_size*bsz, -1), k=k)
354
+ eos_idx = torch.nonzero(fw_top_k_idx.view(bsz*beam_size*k, -1) == self.eos)[:, 0]
355
+ ch_scores = fw_top_k.new_full((beam_size*bsz*k, ), 0)
356
+ src_size = torch.sum(src_tokens[:, :] != self.src_dict.pad_index, dim=1, keepdim=True, dtype=fw_top_k.dtype)
357
+
358
+ if self.combine_method != "lm_only":
359
+ temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
360
+ not_padding = temp_src_tokens_full[:, 1:] != self.src_dict.pad_index
361
+ cur_tgt_size = step+2
362
+
363
+ # add eos to all candidate sentences except those that already end in eos
364
+ eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1)
365
+ eos_tokens[eos_idx] = self.tgt_dict.pad_index
366
+
367
+ if step == 0:
368
+ channel_input = torch.cat((fw_top_k_idx.view(-1, 1), eos_tokens), 1)
369
+ else:
370
+ # move eos from beginning to end of target sentence
371
+ channel_input = torch.cat((tokens[:, 1:step + 1].repeat(1, k).view(-1, step), fw_top_k_idx.view(-1, 1), eos_tokens), 1)
372
+
373
+ ch_input_lengths = torch.tensor(np.full(channel_input.size(0), cur_tgt_size))
374
+ ch_input_lengths[eos_idx] = cur_tgt_size-1
375
+ if self.channel_scoring_type == "unnormalized":
376
+ ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
377
+ ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
378
+ del ch_encoder_output
379
+ ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target(ch_decoder_output, target_ids=temp_src_tokens_full[:, 1:])
380
+ ch_intermed_scores = ch_intermed_scores.float()
381
+ ch_intermed_scores *= not_padding.float()
382
+ ch_scores = torch.sum(ch_intermed_scores, dim=1)
383
+ elif self.channel_scoring_type == "k2_separate":
384
+ for k_idx in range(k):
385
+ k_eos_tokens = eos_tokens[k_idx::k, :]
386
+ if step == 0:
387
+ k_ch_input = torch.cat((fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
388
+ else:
389
+ # move eos from beginning to end of target sentence
390
+ k_ch_input = torch.cat((tokens[:, 1:step + 1], fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
391
+ k_ch_input_lengths = ch_input_lengths[k_idx::k]
392
+ k_ch_output = channel_model(k_ch_input, k_ch_input_lengths, src_tokens)
393
+ k_ch_lprobs = channel_model.get_normalized_probs(k_ch_output, log_probs=True)
394
+ k_ch_intermed_scores = torch.gather(k_ch_lprobs[:, :-1, :], 2, src_tokens[:, 1:].unsqueeze(2)).squeeze(2)
395
+ k_ch_intermed_scores *= not_padding.float()
396
+ ch_scores[k_idx::k] = torch.sum(k_ch_intermed_scores, dim=1)
397
+ elif self.channel_scoring_type == "src_vocab":
398
+ ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
399
+ ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
400
+
401
+ del ch_encoder_output
402
+ ch_lprobs = normalized_scores_with_batch_vocab(
403
+ channel_model.decoder,
404
+ ch_decoder_output, src_tokens, k, bsz, beam_size,
405
+ self.src_dict.pad_index, top_k=self.top_k_vocab)
406
+ ch_scores = torch.sum(ch_lprobs, dim=1)
407
+ elif self.channel_scoring_type == "src_vocab_batched":
408
+ ch_bsz_size = temp_src_tokens_full.shape[0]
409
+ ch_lprobs_list = [None] * len(range(0, ch_bsz_size, self.ch_scoring_bsz))
410
+ for i, start_idx in enumerate(range(0, ch_bsz_size, self.ch_scoring_bsz)):
411
+ end_idx = min(start_idx + self.ch_scoring_bsz, ch_bsz_size)
412
+ temp_src_tokens_full_batch = temp_src_tokens_full[start_idx:end_idx, :]
413
+ channel_input_batch = channel_input[start_idx:end_idx, :]
414
+ ch_input_lengths_batch = ch_input_lengths[start_idx:end_idx]
415
+ ch_encoder_output_batch = channel_model.encoder(channel_input_batch, src_lengths=ch_input_lengths_batch)
416
+ ch_decoder_output_batch, _ = channel_model.decoder(temp_src_tokens_full_batch, encoder_out=ch_encoder_output_batch, features_only=True)
417
+ ch_lprobs_list[i] = normalized_scores_with_batch_vocab(
418
+ channel_model.decoder,
419
+ ch_decoder_output_batch, src_tokens, k, bsz, beam_size,
420
+ self.src_dict.pad_index, top_k=self.top_k_vocab,
421
+ start_idx=start_idx, end_idx=end_idx)
422
+ ch_lprobs = torch.cat(ch_lprobs_list, dim=0)
423
+ ch_scores = torch.sum(ch_lprobs, dim=1)
424
+ else:
425
+ ch_output = channel_model(channel_input, ch_input_lengths, temp_src_tokens_full)
426
+ ch_lprobs = channel_model.get_normalized_probs(ch_output, log_probs=True)
427
+ ch_intermed_scores = torch.gather(ch_lprobs[:, :-1, :], 2, temp_src_tokens_full[:, 1:].unsqueeze(2)).squeeze().view(bsz*beam_size*k, -1)
428
+ ch_intermed_scores *= not_padding.float()
429
+ ch_scores = torch.sum(ch_intermed_scores, dim=1)
430
+
431
+ else:
432
+ cur_tgt_size = 0
433
+ ch_scores = ch_scores.view(bsz*beam_size, k)
434
+ expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze(1).expand(-1, k).flatten()
435
+
436
+ if self.share_tgt_dict:
437
+ lm_scores = get_lm_scores(lm, tokens[:, :step + 1].view(-1, step+1), lm_incremental_states, fw_top_k_idx.view(-1, 1), torch.tensor(np.full(tokens.size(0), step+1)), k)
438
+ else:
439
+ new_lm_input = dict2dict(tokens[:, :step + 1].view(-1, step+1), self.tgt_to_lm)
440
+ new_cands = dict2dict(fw_top_k_idx.view(-1, 1), self.tgt_to_lm)
441
+ lm_scores = get_lm_scores(lm, new_lm_input, lm_incremental_states, new_cands, torch.tensor(np.full(tokens.size(0), step+1)), k)
442
+
443
+ lm_scores.add_(expanded_lm_prefix_scores)
444
+ ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores, lm_scores, src_size, cur_tgt_size)
445
+ # initialize all as min value
446
+ new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
447
+ new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
448
+ new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
449
+ new_fw_lprobs[:, self.pad] = -math.inf
450
+ new_ch_lm_lprobs[:, self.pad] = -math.inf
451
+ new_lm_lprobs[:, self.pad] = -math.inf
452
+
453
+ new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k)
454
+ new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores)
455
+ new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k))
456
+ return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs
457
+
458
+ def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size, tgt_size):
459
+ if self.channel_scoring_type == "unnormalized":
460
+ ch_scores = self.log_softmax_fn(
461
+ ch_scores.view(-1, self.beam_size * self.k2)
462
+ ).view(ch_scores.shape)
463
+ ch_scores = ch_scores * self.ch_weight
464
+ lm_scores1 = lm_scores1 * self.lm_weight
465
+
466
+ if combine_type == "lm_only":
467
+ # log P(T|S) + log P(T)
468
+ ch_scores = lm_scores1.view(ch_scores.size())
469
+ elif combine_type == "noisy_channel":
470
+ # 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T)
471
+ if self.normalize_lm_scores_by_tgt_len:
472
+ ch_scores.div_(src_size)
473
+ lm_scores_norm = lm_scores1.view(ch_scores.size()).div(tgt_size)
474
+ ch_scores.add_(lm_scores_norm)
475
+ # 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T)
476
+ else:
477
+ ch_scores.add_(lm_scores1.view(ch_scores.size()))
478
+ ch_scores.div_(src_size)
479
+
480
+ return ch_scores
481
+
482
+ if self.channel_models is not None:
483
+ channel_model = self.channel_models[0] # assume only one channel_model model
484
+ else:
485
+ channel_model = None
486
+
487
+ lm = EnsembleModel(self.lm_models)
488
+ lm_incremental_states = torch.jit.annotate(
489
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
490
+ [
491
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
492
+ for i in range(lm.models_size)
493
+ ],
494
+ )
495
+
496
+ reorder_state = None
497
+ batch_idxs = None
498
+ for step in range(max_len + 1): # one extra step for EOS marker
499
+ # reorder decoder internal states based on the prev choice of beams
500
+ if reorder_state is not None:
501
+ if batch_idxs is not None:
502
+ # update beam indices to take into account removed sentences
503
+ corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
504
+ reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
505
+ model.reorder_incremental_state(incremental_states, reorder_state)
506
+ encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)
507
+
508
+ lm.reorder_incremental_state(lm_incremental_states, reorder_state)
509
+
510
+ fw_lprobs, avg_attn_scores = model.forward_decoder(
511
+ tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature,
512
+ )
513
+
514
+ fw_lprobs[:, self.pad] = -math.inf # never select pad
515
+ fw_lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
516
+ fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring(fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2)
517
+
518
+ # handle min and max length constraints
519
+ if step >= max_len:
520
+ fw_lprobs[:, :self.eos] = -math.inf
521
+ fw_lprobs[:, self.eos + 1:] = -math.inf
522
+ elif step < self.min_len:
523
+ fw_lprobs[:, self.eos] = -math.inf
524
+
525
+ # handle prefix tokens (possibly with different lengths)
526
+ if prefix_tokens is not None and step < prefix_tokens.size(1):
527
+ prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
528
+ prefix_mask = prefix_toks.ne(self.pad)
529
+
530
+ prefix_fw_lprobs = fw_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
531
+ fw_lprobs[prefix_mask] = -math.inf
532
+ fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_(
533
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_fw_lprobs
534
+ )
535
+
536
+ prefix_ch_lm_lprobs = ch_lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
537
+ ch_lm_lprobs[prefix_mask] = -math.inf
538
+ ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_(
539
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_ch_lm_lprobs
540
+ )
541
+
542
+ prefix_lm_lprobs = lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
543
+ lm_lprobs[prefix_mask] = -math.inf
544
+ lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_(
545
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lm_lprobs
546
+ )
547
+
548
+ # if prefix includes eos, then we should make sure tokens and
549
+ # scores are the same across all beams
550
+ eos_mask = prefix_toks.eq(self.eos)
551
+ if eos_mask.any():
552
+ # validate that the first beam matches the prefix
553
+ first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
554
+ eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
555
+ target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
556
+ assert (first_beam == target_prefix).all()
557
+
558
+ def replicate_first_beam(tensor, mask):
559
+ tensor = tensor.view(-1, beam_size, tensor.size(-1))
560
+ tensor[mask] = tensor[mask][:, :1, :]
561
+ return tensor.view(-1, tensor.size(-1))
562
+
563
+ # copy tokens, scores and lprobs from the first beam to all beams
564
+ tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
565
+ scores = replicate_first_beam(scores, eos_mask_batch_dim)
566
+
567
+ fw_lprobs = replicate_first_beam(fw_lprobs, eos_mask_batch_dim)
568
+ ch_lm_lprobs = replicate_first_beam(ch_lm_lprobs, eos_mask_batch_dim)
569
+ lm_lprobs = replicate_first_beam(lm_lprobs, eos_mask_batch_dim)
570
+
571
+ if self.no_repeat_ngram_size > 0:
572
+ # for each beam and batch sentence, generate a list of previous ngrams
573
+ gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
574
+ for bbsz_idx in range(bsz * beam_size):
575
+ gen_tokens = tokens[bbsz_idx].tolist()
576
+ for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
577
+ gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
578
+ gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
579
+
580
+ # Record attention scores
581
+ if avg_attn_scores is not None:
582
+ if attn is None:
583
+ attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
584
+ attn_buf = attn.clone()
585
+ nonpad_idxs = src_tokens.ne(self.pad)
586
+ attn[:, :, step + 1].copy_(avg_attn_scores)
587
+
588
+ scores = scores.type_as(fw_lprobs)
589
+ scores_buf = scores_buf.type_as(fw_lprobs)
590
+
591
+ self.search.set_src_lengths(src_lengths_no_eos)
592
+
593
+ if self.no_repeat_ngram_size > 0:
594
+ def calculate_banned_tokens(bbsz_idx):
595
+ # before decoding the next token, prevent decoding of ngrams that have already appeared
596
+ ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
597
+ return gen_ngrams[bbsz_idx].get(ngram_index, [])
598
+
599
+ if step + 2 - self.no_repeat_ngram_size >= 0:
600
+ # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
601
+ banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
602
+ else:
603
+ banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
604
+
605
+ for bbsz_idx in range(bsz * beam_size):
606
+ fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
607
+
608
+ combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step(
609
+ step,
610
+ fw_lprobs.view(bsz, -1, self.vocab_size),
611
+ scores.view(bsz, beam_size, -1)[:, :, :step], ch_lm_lprobs.view(bsz, -1, self.vocab_size),
612
+ lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method
613
+ )
614
+
615
+ # cand_bbsz_idx contains beam indices for the top candidate
616
+ # hypotheses, with a range of values: [0, bsz*beam_size),
617
+ # and dimensions: [bsz, cand_size]
618
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
619
+
620
+ # finalize hypotheses that end in eos (except for candidates to be ignored)
621
+ eos_mask = cand_indices.eq(self.eos)
622
+ eos_mask[:, :beam_size] &= ~cands_to_ignore
623
+
624
+ # only consider eos when it's among the top beam_size indices
625
+ eos_bbsz_idx = torch.masked_select(
626
+ cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
627
+ )
628
+
629
+ finalized_sents = set()
630
+ if eos_bbsz_idx.numel() > 0:
631
+ eos_scores = torch.masked_select(
632
+ fw_lprobs_top_k[:, :beam_size], mask=eos_mask[:, :beam_size]
633
+ )
634
+ combined_noisy_channel_eos_scores = torch.masked_select(
635
+ combined_noisy_channel_scores[:, :beam_size],
636
+ mask=eos_mask[:, :beam_size],
637
+ )
638
+
639
+ # finalize hypo using channel model score
640
+ finalized_sents = finalize_hypos(
641
+ step, eos_bbsz_idx, eos_scores, combined_noisy_channel_eos_scores)
642
+
643
+ num_remaining_sent -= len(finalized_sents)
644
+
645
+ assert num_remaining_sent >= 0
646
+ if num_remaining_sent == 0:
647
+ break
648
+
649
+ if len(finalized_sents) > 0:
650
+ new_bsz = bsz - len(finalized_sents)
651
+
652
+ # construct batch_idxs which holds indices of batches to keep for the next pass
653
+ batch_mask = cand_indices.new_ones(bsz)
654
+ batch_mask[cand_indices.new(finalized_sents)] = 0
655
+ batch_idxs = torch.nonzero(batch_mask).squeeze(-1)
656
+
657
+ eos_mask = eos_mask[batch_idxs]
658
+ cand_beams = cand_beams[batch_idxs]
659
+ bbsz_offsets.resize_(new_bsz, 1)
660
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
661
+
662
+ lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs]
663
+
664
+ fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs]
665
+ cand_indices = cand_indices[batch_idxs]
666
+ if prefix_tokens is not None:
667
+ prefix_tokens = prefix_tokens[batch_idxs]
668
+ src_lengths_no_eos = src_lengths_no_eos[batch_idxs]
669
+ cands_to_ignore = cands_to_ignore[batch_idxs]
670
+
671
+ scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
672
+ scores_buf.resize_as_(scores)
673
+ tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
674
+ tokens_buf.resize_as_(tokens)
675
+ src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
676
+ src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
677
+ lm_prefix_scores = lm_prefix_scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1).squeeze()
678
+
679
+ if attn is not None:
680
+ attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
681
+ attn_buf.resize_as_(attn)
682
+ bsz = new_bsz
683
+ else:
684
+ batch_idxs = None
685
+
686
+ # Set active_mask so that values > cand_size indicate eos or
687
+ # ignored hypos and values < cand_size indicate candidate
688
+ # active hypos. After this, the min values per row are the top
689
+ # candidate active hypos.
690
+ eos_mask[:, :beam_size] |= cands_to_ignore
691
+ active_mask = torch.add(
692
+ eos_mask.type_as(cand_offsets) * cand_size,
693
+ cand_offsets[: eos_mask.size(1)],
694
+ )
695
+
696
+ # get the top beam_size active hypotheses, which are just the hypos
697
+ # with the smallest values in active_mask
698
+ active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer('new_cands_to_ignore')
699
+ torch.topk(
700
+ active_mask, k=beam_size, dim=1, largest=False,
701
+ out=(new_cands_to_ignore, active_hypos)
702
+ )
703
+
704
+ # update cands_to_ignore to ignore any finalized hypos
705
+ cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
706
+ assert (~cands_to_ignore).any(dim=1).all()
707
+
708
+ active_bbsz_idx = buffer('active_bbsz_idx')
709
+ torch.gather(
710
+ cand_bbsz_idx, dim=1, index=active_hypos,
711
+ out=active_bbsz_idx,
712
+ )
713
+ active_scores = torch.gather(
714
+ fw_lprobs_top_k, dim=1, index=active_hypos,
715
+ out=scores[:, step].view(bsz, beam_size),
716
+ )
717
+
718
+ active_bbsz_idx = active_bbsz_idx.view(-1)
719
+ active_scores = active_scores.view(-1)
720
+
721
+ # copy tokens and scores for active hypotheses
722
+ torch.index_select(
723
+ tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
724
+ out=tokens_buf[:, :step + 1],
725
+ )
726
+ torch.gather(
727
+ cand_indices, dim=1, index=active_hypos,
728
+ out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
729
+ )
730
+ if step > 0:
731
+ torch.index_select(
732
+ scores[:, :step], dim=0, index=active_bbsz_idx,
733
+ out=scores_buf[:, :step],
734
+ )
735
+ torch.gather(
736
+ fw_lprobs_top_k, dim=1, index=active_hypos,
737
+ out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
738
+ )
739
+ torch.gather(
740
+ lm_lprobs_top_k, dim=1, index=active_hypos,
741
+ out=lm_prefix_scores.view(bsz, beam_size)
742
+ )
743
+
744
+ # copy attention for active hypotheses
745
+ if attn is not None:
746
+ torch.index_select(
747
+ attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
748
+ out=attn_buf[:, :, :step + 2],
749
+ )
750
+
751
+ # swap buffers
752
+ tokens, tokens_buf = tokens_buf, tokens
753
+ scores, scores_buf = scores_buf, scores
754
+ if attn is not None:
755
+ attn, attn_buf = attn_buf, attn
756
+
757
+ # reorder incremental state in decoder
758
+ reorder_state = active_bbsz_idx
759
+
760
+ # sort by score descending
761
+ for sent in range(len(finalized)):
762
+ finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
763
+
764
+ return finalized
765
+
766
+
767
+ def get_lm_scores(model, input_tokens, incremental_states, cand_tokens, input_len, k):
768
+ with torch.no_grad():
769
+ lm_lprobs, avg_attn_scores = model.forward_decoder(
770
+ input_tokens, encoder_outs=None, incremental_states=incremental_states,
771
+ )
772
+
773
+ lm_lprobs_size = lm_lprobs.size(0)
774
+ probs_next_wrd = torch.gather(lm_lprobs.repeat(1, k).view(lm_lprobs_size*k, -1), 1, cand_tokens).squeeze().view(-1)
775
+
776
+ return probs_next_wrd
777
+
778
+
779
+ def make_dict2dict(old_dict, new_dict):
780
+ dict2dict_map = {}
781
+ for sym in old_dict.symbols:
782
+ dict2dict_map[old_dict.index(sym)] = new_dict.index(sym)
783
+ return dict2dict_map
784
+
785
+
786
+ def dict2dict(tokens, dict2dict_map):
787
+ if tokens.device == torch.device('cpu'):
788
+ tokens_tmp = tokens
789
+ else:
790
+ tokens_tmp = tokens.cpu()
791
+ return tokens_tmp.map_(
792
+ tokens_tmp,
793
+ lambda _, val, dict2dict_map=dict2dict_map : dict2dict_map[float(val)]
794
+ ).to(tokens.device)
795
+
796
+
797
+ def reorder_tokens(tokens, lengths, eos):
798
+ # reorder source tokens so they may be used as reference for P(S|T)
799
+ return torch.cat((tokens.new([eos]), tokens[-lengths:-1], tokens[:-lengths]), 0)
800
+
801
+
802
+ def reorder_all_tokens(tokens, lengths, eos):
803
+ # used to reorder src tokens from [<pad> <w1> <w2> .. <eos>] to [<eos> <w1> <w2>...<pad>]
804
+ # so source tokens can be used to predict P(S|T)
805
+ return torch.stack([reorder_tokens(token, length, eos) for token, length in zip(tokens, lengths)])
806
+
807
+
808
+ def normalized_scores_with_batch_vocab(
809
+ model_decoder, features, target_ids, k, bsz, beam_size,
810
+ pad_idx, top_k=0, vocab_size_meter=None, start_idx=None,
811
+ end_idx=None, **kwargs):
812
+ """
813
+ Get normalized probabilities (or log probs) from a net's output
814
+ w.r.t. vocab consisting of target IDs in the batch
815
+ """
816
+ if model_decoder.adaptive_softmax is None:
817
+ weight = model_decoder.output_projection.weight
818
+ vocab_ids = torch.unique(
819
+ torch.cat(
820
+ (torch.unique(target_ids), torch.arange(top_k, device=target_ids.device))
821
+ )
822
+ )
823
+ id_map = dict(zip(vocab_ids.tolist(), range(len(vocab_ids))))
824
+ mapped_target_ids = target_ids.cpu().apply_(
825
+ lambda x, id_map=id_map: id_map[x]
826
+ ).to(target_ids.device)
827
+ expanded_target_ids = mapped_target_ids[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
828
+ if start_idx is not None and end_idx is not None:
829
+ expanded_target_ids = expanded_target_ids[start_idx:end_idx, :]
830
+ logits = F.linear(features, weight[vocab_ids, :])
831
+ log_softmax = F.log_softmax(logits, dim=-1, dtype=torch.float32)
832
+ intermed_scores = torch.gather(
833
+ log_softmax[:, :-1, :],
834
+ 2,
835
+ expanded_target_ids[:, 1:].unsqueeze(2),
836
+ ).squeeze()
837
+ not_padding = expanded_target_ids[:, 1:] != pad_idx
838
+ intermed_scores *= not_padding.float()
839
+ return intermed_scores
840
+ else:
841
+ raise ValueError("adaptive softmax doesn't work with " +
842
+ "`normalized_scores_with_batch_vocab()`")
fairseq/examples/fast_noisy_channel/noisy_channel_translation.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.tasks.translation import TranslationTask
7
+ from fairseq.tasks.language_modeling import LanguageModelingTask
8
+ from fairseq import checkpoint_utils
9
+ import argparse
10
+ from fairseq.tasks import register_task
11
+ import torch
12
+
13
+
14
+ @register_task("noisy_channel_translation")
15
+ class NoisyChannelTranslation(TranslationTask):
16
+ """
17
+ Rescore the top k candidates from each beam using noisy channel modeling
18
+ """
19
+
20
+ @staticmethod
21
+ def add_args(parser):
22
+ """Add task-specific arguments to the parser."""
23
+ TranslationTask.add_args(parser)
24
+ # fmt: off
25
+ parser.add_argument('--channel-model', metavar='FILE',
26
+ help='path to P(S|T) model. P(S|T) and P(T|S) must share source and target dictionaries.')
27
+ parser.add_argument('--combine-method', default='lm_only',
28
+ choices=['lm_only', 'noisy_channel'],
29
+ help="""method for combining direct and channel model scores.
30
+ lm_only: decode with P(T|S)P(T)
31
+ noisy_channel: decode with 1/t P(T|S) + 1/s(P(S|T)P(T))""")
32
+ parser.add_argument('--normalize-lm-scores-by-tgt-len', action='store_true', default=False,
33
+ help='normalize lm score by target length instead of source length')
34
+ parser.add_argument('--channel-scoring-type', default='log_norm', choices=['unnormalized', 'log_norm', 'k2_separate', 'src_vocab', 'src_vocab_batched'],
35
+ help="Normalize bw scores with log softmax or return bw scores without log softmax")
36
+ parser.add_argument('--top-k-vocab', default=0, type=int,
37
+ help='top k vocab IDs to use with `src_vocab` in channel model scoring')
38
+ parser.add_argument('--k2', default=50, type=int,
39
+ help='the top k2 candidates to rescore with the noisy channel model for each beam')
40
+ parser.add_argument('--ch-wt', default=1, type=float,
41
+ help='weight for the channel model')
42
+ parser.add_argument('--lm-model', metavar='FILE',
43
+ help='path to lm model file, to model P(T). P(T) must share the same vocab as the direct model on the target side')
44
+ parser.add_argument('--lm-data', metavar='FILE',
45
+ help='path to lm model training data for target language, used to properly load LM with correct dictionary')
46
+ parser.add_argument('--lm-wt', default=1, type=float,
47
+ help='the weight of the lm in joint decoding')
48
+ # fmt: on
49
+
50
+ def build_generator(
51
+ self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
52
+ ):
53
+ if getattr(args, "score_reference", False):
54
+ raise NotImplementedError()
55
+ else:
56
+ from .noisy_channel_sequence_generator import NoisyChannelSequenceGenerator
57
+ use_cuda = torch.cuda.is_available() and not self.args.cpu
58
+ assert self.args.lm_model is not None, '--lm-model required for noisy channel generation!'
59
+ assert self.args.lm_data is not None, '--lm-data required for noisy channel generation to map between LM and bitext vocabs'
60
+ if self.args.channel_model is not None:
61
+ import copy
62
+ ch_args_task = copy.deepcopy(self.args)
63
+ tmp = ch_args_task.source_lang
64
+ ch_args_task.source_lang = ch_args_task.target_lang
65
+ ch_args_task.target_lang = tmp
66
+ ch_args_task._name = 'translation'
67
+ channel_task = TranslationTask.setup_task(ch_args_task)
68
+
69
+ arg_dict = {}
70
+ arg_dict['task'] = 'language_modeling'
71
+ arg_dict['sample_break_mode'] = 'eos'
72
+ arg_dict['data'] = self.args.lm_data
73
+ arg_dict['output_dictionary_size'] = -1
74
+ lm_args = argparse.Namespace(**arg_dict)
75
+ lm_task = LanguageModelingTask.setup_task(lm_args)
76
+ lm_dict = lm_task.output_dictionary
77
+
78
+ if self.args.channel_model is not None:
79
+ channel_models, _ = checkpoint_utils.load_model_ensemble(self.args.channel_model.split(':'), task=channel_task)
80
+
81
+ for model in channel_models:
82
+ model.make_generation_fast_(
83
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
84
+ need_attn=args.print_alignment,
85
+ )
86
+ if self.args.fp16:
87
+ model.half()
88
+ if use_cuda:
89
+ model.cuda()
90
+ else:
91
+ channel_models = None
92
+
93
+ lm_models, _ = checkpoint_utils.load_model_ensemble(self.args.lm_model.split(':'), task=lm_task)
94
+
95
+ for model in lm_models:
96
+ model.make_generation_fast_(
97
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
98
+ need_attn=args.print_alignment,
99
+ )
100
+ if self.args.fp16:
101
+ model.half()
102
+ if use_cuda:
103
+ model.cuda()
104
+ return NoisyChannelSequenceGenerator(
105
+ combine_method=self.args.combine_method,
106
+ tgt_dict=self.target_dictionary,
107
+ src_dict=self.source_dictionary,
108
+ beam_size=getattr(args, 'beam', 5),
109
+ max_len_a=getattr(args, 'max_len_a', 0),
110
+ max_len_b=getattr(args, 'max_len_b', 200),
111
+ min_len=getattr(args, 'min_len', 1),
112
+ len_penalty=getattr(args, 'lenpen', 1),
113
+ unk_penalty=getattr(args, 'unkpen', 0),
114
+ temperature=getattr(args, 'temperature', 1.),
115
+ match_source_len=getattr(args, 'match_source_len', False),
116
+ no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
117
+ normalize_scores=(not getattr(args, 'unnormalized', False)),
118
+ channel_models=channel_models,
119
+ k2=getattr(self.args, 'k2', 50),
120
+ ch_weight=getattr(self.args, 'ch_wt', 1),
121
+ channel_scoring_type=self.args.channel_scoring_type,
122
+ top_k_vocab=self.args.top_k_vocab,
123
+ lm_models=lm_models,
124
+ lm_dict=lm_dict,
125
+ lm_weight=getattr(self.args, 'lm_wt', 1),
126
+ normalize_lm_scores_by_tgt_len=getattr(self.args, 'normalize_lm_scores_by_tgt_len', False),
127
+ )
fairseq/examples/flores101/README.md ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="flores_logo.png" width="500">
3
+ </p>
4
+
5
+ # Flores101: Large-Scale Multilingual Machine Translation
6
+
7
+ ## Introduction
8
+
9
+ Baseline pretrained models for small and large tracks of WMT 21 Large-Scale Multilingual Machine Translation competition.
10
+
11
+ Flores Task at WMT 21: http://www.statmt.org/wmt21/large-scale-multilingual-translation-task.html
12
+
13
+ Flores announement blog post: https://ai.facebook.com/blog/flores-researchers-kick-off-multilingual-translation-challenge-at-wmt-and-call-for-compute-grants/
14
+
15
+
16
+
17
+ ## Pretrained models
18
+
19
+ Model | Num layers | Embed dimension | FFN dimension| Vocab Size | #params | Download
20
+ ---|---|---|---|---|---|---
21
+ `flores101_mm100_615M` | 12 | 1024 | 4096 | 256,000 | 615M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz
22
+ `flores101_mm100_175M` | 6 | 512 | 2048 | 256,000 | 175M | https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz
23
+
24
+
25
+ These models are trained similar to [M2M-100](https://arxiv.org/abs/2010.11125) with additional support for the languages that are part of the WMT Large-Scale Multilingual Machine Translation track. Full list of languages can be found at the bottom.
26
+
27
+
28
+ ## Example Generation code
29
+
30
+ ### Download model, sentencepiece vocab
31
+
32
+ ```bash
33
+ fairseq=/path/to/fairseq
34
+ cd $fairseq
35
+
36
+ # Download 615M param model.
37
+ wget https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz
38
+
39
+ # Extract
40
+ tar -xvzf flores101_mm100_615M.tar.gz
41
+ ```
42
+
43
+ ### Encode using our SentencePiece Model
44
+ Note: Install SentencePiece from [here](https://github.com/google/sentencepiece)
45
+
46
+
47
+ ```bash
48
+ fairseq=/path/to/fairseq
49
+ cd $fairseq
50
+
51
+ # Download example dataset From German to French
52
+ sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de
53
+ sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr
54
+
55
+ for lang in de fr ; do
56
+ python scripts/spm_encode.py \
57
+ --model flores101_mm100_615M/sentencepiece.bpe.model \
58
+ --output_format=piece \
59
+ --inputs=raw_input.de-fr.${lang} \
60
+ --outputs=spm.de-fr.${lang}
61
+ done
62
+ ```
63
+
64
+ ### Binarization
65
+
66
+ ```bash
67
+ fairseq-preprocess \
68
+ --source-lang de --target-lang fr \
69
+ --testpref spm.de-fr \
70
+ --thresholdsrc 0 --thresholdtgt 0 \
71
+ --destdir data_bin \
72
+ --srcdict flores101_mm100_615M/dict.txt --tgtdict flores101_mm100_615M/dict.txt
73
+ ```
74
+
75
+ ### Generation
76
+
77
+
78
+ ```bash
79
+ fairseq-generate \
80
+ data_bin \
81
+ --batch-size 1 \
82
+ --path flores101_mm100_615M/model.pt \
83
+ --fixed-dictionary flores101_mm100_615M/dict.txt \
84
+ -s de -t fr \
85
+ --remove-bpe 'sentencepiece' \
86
+ --beam 5 \
87
+ --task translation_multi_simple_epoch \
88
+ --lang-pairs flores101_mm100_615M/language_pairs.txt \
89
+ --decoder-langtok --encoder-langtok src \
90
+ --gen-subset test \
91
+ --fp16 \
92
+ --dataset-impl mmap \
93
+ --distributed-world-size 1 --distributed-no-spawn
94
+ ```
95
+
96
+ ### Supported Languages and lang code
97
+
98
+ Language | lang code
99
+ ---|---
100
+ Akrikaans | af
101
+ Amharic | am
102
+ Arabic | ar
103
+ Assamese | as
104
+ Asturian | ast
105
+ Aymara | ay
106
+ Azerbaijani | az
107
+ Bashkir | ba
108
+ Belarusian | be
109
+ Bulgarian | bg
110
+ Bengali | bn
111
+ Breton | br
112
+ Bosnian | bs
113
+ Catalan | ca
114
+ Cebuano | ceb
115
+ Chokwe | cjk
116
+ Czech | cs
117
+ Welsh | cy
118
+ Danish | da
119
+ German | de
120
+ Dyula| dyu
121
+ Greek | el
122
+ English | en
123
+ Spanish | es
124
+ Estonian | et
125
+ Persian | fa
126
+ Fulah | ff
127
+ Finnish | fi
128
+ French | fr
129
+ Western Frisian | fy
130
+ Irish | ga
131
+ Scottish Gaelic | gd
132
+ Galician | gl
133
+ Gujarati | gu
134
+ Hausa | ha
135
+ Hebrew | he
136
+ Hindi | hi
137
+ Croatian | hr
138
+ Haitian Creole | ht
139
+ Hungarian | hu
140
+ Armenian | hy
141
+ Indonesian | id
142
+ Igbo | ig
143
+ Iloko | ilo
144
+ Icelandic | is
145
+ Italian | it
146
+ Japanese | ja
147
+ Javanese | jv
148
+ Georgian | ka
149
+ Kachin | kac
150
+ Kamba | kam
151
+ Kabuverdianu | kea
152
+ Kongo | kg
153
+ Kazakh | kk
154
+ Central Khmer | km
155
+ Kimbundu | kmb
156
+ Northern Kurdish | kmr
157
+ Kannada | kn
158
+ Korean | ko
159
+ Kurdish | ku
160
+ Kyrgyz | ky
161
+ Luxembourgish | lb
162
+ Ganda | lg
163
+ Lingala | ln
164
+ Lao | lo
165
+ Lithuanian | lt
166
+ Luo | luo
167
+ Latvian | lv
168
+ Malagasy | mg
169
+ Maori | mi
170
+ Macedonian | mk
171
+ Malayalam | ml
172
+ Mongolian | mn
173
+ Marathi | mr
174
+ Malay | ms
175
+ Maltese | mt
176
+ Burmese | my
177
+ Nepali | ne
178
+ Dutch | nl
179
+ Norwegian | no
180
+ Northern Sotho | ns
181
+ Nyanja | ny
182
+ Occitan | oc
183
+ Oromo | om
184
+ Oriya | or
185
+ Punjabi | pa
186
+ Polish | pl
187
+ Pashto | ps
188
+ Portuguese | pt
189
+ Quechua | qu
190
+ Romanian | ro
191
+ Russian | ru
192
+ Sindhi | sd
193
+ Shan | shn
194
+ Sinhala | si
195
+ Slovak | sk
196
+ Slovenian | sl
197
+ Shona | sn
198
+ Somali | so
199
+ Albanian | sq
200
+ Serbian | sr
201
+ Swati | ss
202
+ Sundanese | su
203
+ Swedish | sv
204
+ Swahili | sw
205
+ Tamil | ta
206
+ Telugu | te
207
+ Tajik | tg
208
+ Thai | th
209
+ Tigrinya | ti
210
+ Tagalog | tl
211
+ Tswana | tn
212
+ Turkish | tr
213
+ Ukrainian | uk
214
+ Umbundu | umb
215
+ Urdu | ur
216
+ Uzbek | uz
217
+ Vietnamese | vi
218
+ Wolof | wo
219
+ Xhosa | xh
220
+ Yiddish | yi
221
+ Yoruba | yo
222
+ Chinese| zh
223
+ Zulu | zu
fairseq/examples/flores101/flores_logo.png ADDED
fairseq/examples/fully_sharded_data_parallel/README.md ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fully Sharded Data Parallel (FSDP)
2
+
3
+ ## Overview
4
+ Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and
5
+ [Google](https://arxiv.org/abs/2004.13336) has shown that data parallel
6
+ training can be made significantly more efficient by sharding the model
7
+ parameters and optimizer state across data parallel workers. These ideas are
8
+ encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper provided
9
+ by [fairscale](https://github.com/facebookresearch/fairscale/).
10
+
11
+ Compared to PyTorch DDP:
12
+ * FSDP produces identical results as PyTorch DDP (it's still synchronous data parallel training)
13
+ * FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
14
+ * FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the communication can be overlapped with the forward pass
15
+ * FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs
16
+
17
+ FSDP is fully supported in fairseq via the following new arguments:
18
+ * `--ddp-backend=fully_sharded`: enables full sharding via FSDP
19
+ * `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`)
20
+ * `--no-reshard-after-forward`: increases training speed for large models (1B+ params) and is similar to ZeRO stage 2
21
+ * other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal
22
+
23
+ <details><summary>Limitations</summary><p>
24
+
25
+ FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP):
26
+ * while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.)
27
+ * FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported
28
+
29
+ See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
30
+ explanation of these and other limitations.
31
+
32
+ </p></details>
33
+
34
+ <details><summary>How it works</summary><p>
35
+
36
+ <img width="800" alt="Fully Sharded Data Parallel" src="https://user-images.githubusercontent.com/231798/110406775-c2de0000-8050-11eb-9718-fbfc4510a76a.png">
37
+
38
+ See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
39
+ explanation of how FSDP works.
40
+
41
+ </p></details>
42
+
43
+ ## Example usage
44
+
45
+ The following examples illustrate how to train a very large language model with
46
+ 13 billion parameters on 1 GPU by offloading parameters and optimizer states to
47
+ CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs.
48
+
49
+ These examples use the WikiText-103 dataset for demonstration purposes, but
50
+ in practice a much larger dataset will be needed to achieve good results.
51
+ Follow the [instructions here](https://github.com/pytorch/fairseq/blob/main/examples/roberta/README.pretraining.md#1-preprocess-the-data)
52
+ to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary.
53
+
54
+ ### 13B params on 1 V100 GPU (with CPU offloading)
55
+
56
+ The following command trains a 13B parameter GPT-3 model on a single V100 GPU
57
+ using the `--cpu-offload` feature to offload parameters and optimizer states to
58
+ CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the
59
+ `--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)),
60
+ which further saves memory in exchange for a small increase in computation.
61
+
62
+ **Requirements:**
63
+ - Install the latest master version of fairscale: `pip install git+https://github.com/facebookresearch/fairscale.git@master`
64
+ - You'll need 32GB of GPU memory and ~256GB of system memory to train the 13B param model.
65
+ - If you have less system memory, the 6.7B param model can be trained with ~128GB of system memory, just set `--arch transformer_lm_gpt3_6_7`
66
+ - We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command.
67
+
68
+ **Notes:**
69
+ - The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow.
70
+ - The `--cpu-offload` feature requires training in mixed precision (`--fp16`).
71
+ - Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading.
72
+ - The example command below stops training after 10 steps (`--max-update 10`) and does not save checkpoints (`--no-save`).
73
+
74
+ ```bash
75
+ OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \
76
+ fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
77
+ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
78
+ --cpu-offload --checkpoint-activations \
79
+ --task language_modeling --tokens-per-sample 2048 --batch-size 8 \
80
+ --arch transformer_lm_gpt3_13 \
81
+ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
82
+ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
83
+ --max-update 10 --no-save --log-format json --log-interval 1
84
+ ```
85
+
86
+ <details><summary>Example output</summary><p>
87
+
88
+ ```
89
+ (...)
90
+ 2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
91
+ (...)
92
+ 2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs)
93
+ 2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
94
+ (...)
95
+ Adam Optimizer #0 is created with AVX2 arithmetic capability.
96
+ Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
97
+ (...)
98
+ 2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"}
99
+ 2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"}
100
+ 2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
101
+ 2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
102
+ 2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"}
103
+ 2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"}
104
+ 2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"}
105
+ 2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"}
106
+ 2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"}
107
+ 2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"}
108
+ 2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"}
109
+ 2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"}
110
+ 2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
111
+ 2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset
112
+ 2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"}
113
+ 2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
114
+ 2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"}
115
+ 2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds
116
+ ```
117
+
118
+ </p></details>
119
+
120
+ ### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding)
121
+
122
+ FSDP can also shard the parameters and optimizer states across multiple GPUs,
123
+ reducing memory requirements significantly. On 8 x 32GB GPUs, sharding enables
124
+ training the same 13B parameter model *without offloading the parameters to
125
+ CPU*. However, without CPU offloading we'd only be able to fit a batch size of
126
+ 1 per GPU, which would cause training speed to suffer.
127
+
128
+ We obtain the best performance on 8 GPUs by combining full sharding and CPU
129
+ offloading. The following command trains the same 13B parameter GPT-3 model as
130
+ before on 8 x 32GB V100 GPUs; training speed increases superlinearly from ~310
131
+ words per second to ~3200 words per second.
132
+
133
+ ```bash
134
+ OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
135
+ fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
136
+ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
137
+ --cpu-offload --checkpoint-activations \
138
+ --task language_modeling --tokens-per-sample 2048 --batch-size 8 \
139
+ --arch transformer_lm_gpt3_13 \
140
+ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
141
+ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
142
+ --max-update 10 --no-save --log-format json --log-interval 1
143
+ ```
144
+
145
+ <details><summary>Example output</summary><p>
146
+
147
+ ```
148
+ (...)
149
+ 2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
150
+ (...)
151
+ 2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs)
152
+ 2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
153
+ (...)
154
+ Adam Optimizer #0 is created with AVX2 arithmetic capability.
155
+ Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
156
+ (...)
157
+ 2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"}
158
+ 2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"}
159
+ 2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
160
+ 2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
161
+ 2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"}
162
+ 2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"}
163
+ 2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"}
164
+ 2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"}
165
+ 2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"}
166
+ 2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"}
167
+ 2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"}
168
+ 2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"}
169
+ 2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
170
+ 2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset
171
+ 2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"}
172
+ 2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
173
+ 2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"}
174
+ 2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds
175
+ ```
176
+
177
+ </p></details>
fairseq/examples/gottbert/README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GottBERT: a pure German language model
2
+
3
+ ## Introduction
4
+
5
+ [GottBERT](http://arxiv.org/abs/2012.02110) is a pretrained language model trained on 145GB of German text based on RoBERTa.
6
+
7
+ ## Example usage
8
+
9
+ ### fairseq
10
+ ##### Load GottBERT from torch.hub (PyTorch >= 1.1):
11
+ ```python
12
+ import torch
13
+ gottbert = torch.hub.load('pytorch/fairseq', 'gottbert-base')
14
+ gottbert.eval() # disable dropout (or leave in train mode to finetune)
15
+ ```
16
+
17
+ ##### Load GottBERT (for PyTorch 1.0 or custom models):
18
+ ```python
19
+ # Download gottbert model
20
+ wget https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz
21
+ tar -xzvf gottbert.tar.gz
22
+
23
+ # Load the model in fairseq
24
+ from fairseq.models.roberta import GottbertModel
25
+ gottbert = GottbertModel.from_pretrained('/path/to/gottbert')
26
+ gottbert.eval() # disable dropout (or leave in train mode to finetune)
27
+ ```
28
+
29
+ ##### Filling masks:
30
+ ```python
31
+ masked_line = 'Gott ist <mask> ! :)'
32
+ gottbert.fill_mask(masked_line, topk=3)
33
+ # [('Gott ist gut ! :)', 0.3642110526561737, ' gut'),
34
+ # ('Gott ist überall ! :)', 0.06009674072265625, ' überall'),
35
+ # ('Gott ist großartig ! :)', 0.0370681993663311, ' großartig')]
36
+ ```
37
+
38
+ ##### Extract features from GottBERT
39
+
40
+ ```python
41
+ # Extract the last layer's features
42
+ line = "Der erste Schluck aus dem Becher der Naturwissenschaft macht atheistisch , aber auf dem Grunde des Bechers wartet Gott !"
43
+ tokens = gottbert.encode(line)
44
+ last_layer_features = gottbert.extract_features(tokens)
45
+ assert last_layer_features.size() == torch.Size([1, 27, 768])
46
+
47
+ # Extract all layer's features (layer 0 is the embedding layer)
48
+ all_layers = gottbert.extract_features(tokens, return_all_hiddens=True)
49
+ assert len(all_layers) == 13
50
+ assert torch.all(all_layers[-1] == last_layer_features)
51
+ ```
52
+ ## Citation
53
+ If you use our work, please cite:
54
+
55
+ ```bibtex
56
+ @misc{scheible2020gottbert,
57
+ title={GottBERT: a pure German Language Model},
58
+ author={Raphael Scheible and Fabian Thomczyk and Patric Tippmann and Victor Jaravine and Martin Boeker},
59
+ year={2020},
60
+ eprint={2012.02110},
61
+ archivePrefix={arXiv},
62
+ primaryClass={cs.CL}
63
+ }
64
+ ```
fairseq/examples/hubert/README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuBERT
2
+
3
+ ## Pre-trained and fine-tuned (ASR) models
4
+ Model | Pretraining Data | Finetuning Dataset | Model | Quantizer
5
+ |---|---|---|---|---
6
+ HuBERT Base (~95M params) | [Librispeech](http://www.openslr.org/12) 960 hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) | [L9 km500](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin)
7
+ HuBERT Large (~316M params) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt)
8
+ HuBERT Extra Large (~1B params) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | No finetuning (Pretrained Model) | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k.pt)
9
+ HuBERT Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k_finetune_ls960.pt)
10
+ HuBERT Extra Large | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k_finetune_ls960.pt)
11
+
12
+ ## Load a model
13
+ ```
14
+ ckpt_path = "/path/to/the/checkpoint.pt"
15
+ models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
16
+ model = models[0]
17
+ ```
18
+
19
+ ## Train a new model
20
+
21
+ ### Data preparation
22
+
23
+ Follow the steps in `./simple_kmeans` to create:
24
+ - `{train,valid}.tsv` waveform list files
25
+ - `{train,valid}.km` frame-aligned pseudo label files.
26
+ - `dict.km.txt` a dummy dictionary
27
+ The `label_rate` is the same as the feature frame rate used for clustering,
28
+ which is 100Hz for MFCC features and 50Hz for HuBERT features by default.
29
+
30
+ ### Pre-train a HuBERT model
31
+
32
+ Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km`
33
+ are saved at `/path/to/labels`, and the label rate is 100Hz.
34
+
35
+ To train a base model (12 layer transformer), run:
36
+ ```sh
37
+ $ python fairseq_cli/hydra_train.py \
38
+ --config-dir /path/to/fairseq-py/examples/hubert/config/pretrain \
39
+ --config-name hubert_base_librispeech \
40
+ task.data=/path/to/data task.label_dir=/path/to/labels task.labels='["km"]' model.label_rate=100
41
+ ```
42
+
43
+ ### Fine-tune a HuBERT model with a CTC loss
44
+
45
+ Suppose `{train,valid}.tsv` are saved at `/path/to/data`, and their
46
+ corresponding character transcripts `{train,valid}.ltr` are saved at
47
+ `/path/to/trans`.
48
+
49
+ To fine-tune a pre-trained HuBERT model at `/path/to/checkpoint`, run
50
+ ```sh
51
+ $ python fairseq_cli/hydra_train.py \
52
+ --config-dir /path/to/fairseq-py/examples/hubert/config/finetune \
53
+ --config-name base_10h \
54
+ task.data=/path/to/data task.label_dir=/path/to/trans \
55
+ model.w2v_path=/path/to/checkpoint
56
+ ```
57
+
58
+ ### Decode a HuBERT model
59
+
60
+ Suppose the `test.tsv` and `test.ltr` are the waveform list and transcripts of
61
+ the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is
62
+ saved at `/path/to/checkpoint`. We support three decoding modes:
63
+ - Viterbi decoding: greedy decoding without a language model
64
+ - KenLM decoding: decoding with an arpa-format KenLM n-gram language model
65
+ - Fairseq-LM deocding: decoding with a Fairseq neural language model
66
+
67
+
68
+ #### Viterbi decoding
69
+
70
+ `task.normalize` needs to be consistent with the value used during fine-tuning.
71
+ Decoding results will be saved at
72
+ `/path/to/experiment/directory/decode/viterbi/test`.
73
+
74
+ ```sh
75
+ $ python examples/speech_recognition/new/infer.py \
76
+ --config-dir /path/to/fairseq-py/examples/hubert/config/decode \
77
+ --config-name infer_viterbi \
78
+ task.data=/path/to/data \
79
+ task.normalize=[true|false] \
80
+ decoding.exp_dir=/path/to/experiment/directory \
81
+ common_eval.path=/path/to/checkpoint
82
+ dataset.gen_subset=test \
83
+ ```
84
+
85
+ #### KenLM / Fairseq-LM decoding
86
+
87
+ Suppose the pronunciation lexicon and the n-gram LM are saved at
88
+ `/path/to/lexicon` and `/path/to/arpa`, respectively. Decoding results will be
89
+ saved at `/path/to/experiment/directory/decode/kenlm/test`.
90
+
91
+ ```sh
92
+ $ python examples/speech_recognition/new/infer.py \
93
+ --config-dir /path/to/fairseq-py/examples/hubert/config/decode \
94
+ --config-name infer_kenlm \
95
+ task.data=/path/to/data \
96
+ task.normalize=[true|false] \
97
+ decoding.exp_dir=/path/to/experiment/directory \
98
+ common_eval.path=/path/to/checkpoint
99
+ dataset.gen_subset=test \
100
+ decoding.decoder.lexicon=/path/to/lexicon \
101
+ decoding.decoder.lmpath=/path/to/arpa
102
+ ```
103
+
104
+ The command above uses the default decoding hyperparameter, which can be found
105
+ in `examples/speech_recognition/hydra/decoder.py`. These parameters can be
106
+ configured from the command line. For example, to search with a beam size of
107
+ 500, we can append the command above with `decoding.decoder.beam=500`.
108
+ Important parameters include:
109
+ - decoding.decoder.beam
110
+ - decoding.decoder.beamthreshold
111
+ - decoding.decoder.lmweight
112
+ - decoding.decoder.wordscore
113
+ - decoding.decoder.silweight
114
+
115
+ To decode with a Fairseq LM, use `--config-name infer_fsqlm` instead, and
116
+ change the path of lexicon and LM accordingly.
fairseq/examples/hubert/config/decode/ax_sweep/ngram.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ common_eval:
4
+ results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}_ax/${dataset.gen_subset}
5
+
6
+ hydra:
7
+ sweeper:
8
+ ax_config:
9
+ max_trials: 60
10
+ early_stop:
11
+ minimize: true
12
+ max_epochs_without_improvement: 10
13
+ epsilon: 0.025
14
+ experiment:
15
+ name: ${dataset.gen_subset}
16
+ objective_name: wer
17
+ minimize: true
18
+ parameter_constraints: null
19
+ outcome_constraints: null
20
+ status_quo: null
21
+ client:
22
+ verbose_logging: false
23
+ random_seed: null
24
+ params:
25
+ decoding.decoder.lmweight:
26
+ type: range
27
+ bounds: [0.0, 8.0]
28
+ decoding.decoder.wordscore:
29
+ type: range
30
+ bounds: [-5.0, 5.0]
31
+ decoding.decoder.silweight:
32
+ type: range
33
+ bounds: [-10.0, 0.0]
fairseq/examples/hubert/config/decode/ax_sweep/transformer.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ common_eval:
4
+ results_path: ${decoding.exp_dir}/decode/${decoding.decoder.name}_ax/${dataset.gen_subset}
5
+
6
+ hydra:
7
+ sweeper:
8
+ ax_config:
9
+ max_trials: 60
10
+ early_stop:
11
+ minimize: true
12
+ max_epochs_without_improvement: 10
13
+ epsilon: 0.025
14
+ experiment:
15
+ name: ${dataset.gen_subset}
16
+ objective_name: wer
17
+ minimize: true
18
+ parameter_constraints: null
19
+ outcome_constraints: null
20
+ status_quo: null
21
+ client:
22
+ verbose_logging: false
23
+ random_seed: null
24
+ params:
25
+ decoding.decoder.lmweight:
26
+ type: range
27
+ bounds: [0.0, 4.0]
28
+ decoding.decoder.wordscore:
29
+ type: range
30
+ bounds: [-5.0, 5.0]
31
+ decoding.decoder.silweight:
32
+ type: range
33
+ bounds: [-8.0, 0.0]
fairseq/examples/hubert/config/decode/infer_fsqlm.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ defaults:
4
+ - model: null
5
+
6
+ hydra:
7
+ run:
8
+ dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
9
+ sweep:
10
+ dir: ${common_eval.results_path}
11
+ subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
12
+
13
+ task:
14
+ _name: hubert_pretraining
15
+ single_target: true
16
+ fine_tuning: true
17
+ data: ???
18
+ normalize: ???
19
+
20
+ decoding:
21
+ type: fairseqlm
22
+ lexicon: ???
23
+ lmpath: ???
24
+ beamthreshold: 25
25
+ beam: 500
26
+ lmweight: 2
27
+ wordscore: -1
28
+ silweight: 0
29
+ unique_wer_file: true
30
+ common_eval:
31
+ results_path: ???
32
+ path: ???
33
+ post_process: letter
34
+ dataset:
35
+ max_tokens: 1100000
36
+ gen_subset: ???
fairseq/examples/hubert/config/decode/infer_kenlm.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ defaults:
4
+ - model: null
5
+
6
+ hydra:
7
+ run:
8
+ dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
9
+ sweep:
10
+ dir: ${common_eval.results_path}
11
+ subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
12
+
13
+ task:
14
+ _name: hubert_pretraining
15
+ single_target: true
16
+ fine_tuning: true
17
+ data: ???
18
+ normalize: ???
19
+
20
+ decoding:
21
+ type: kenlm
22
+ lexicon: ???
23
+ lmpath: ???
24
+ beamthreshold: 100
25
+ beam: 500
26
+ lmweight: 2
27
+ wordscore: -1
28
+ silweight: 0
29
+ unique_wer_file: true
30
+ common_eval:
31
+ results_path: ???
32
+ path: ???
33
+ post_process: letter
34
+ dataset:
35
+ max_tokens: 1100000
36
+ gen_subset: ???
fairseq/examples/hubert/config/decode/infer_viterbi.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ defaults:
4
+ - model: null
5
+
6
+ hydra:
7
+ run:
8
+ dir: ${common_eval.results_path}/viterbi
9
+ sweep:
10
+ dir: ${common_eval.results_path}
11
+ subdir: viterbi
12
+
13
+ task:
14
+ _name: hubert_pretraining
15
+ single_target: true
16
+ fine_tuning: true
17
+ data: ???
18
+ normalize: ???
19
+
20
+ decoding:
21
+ type: viterbi
22
+ unique_wer_file: true
23
+ common_eval:
24
+ results_path: ???
25
+ path: ???
26
+ post_process: letter
27
+ dataset:
28
+ max_tokens: 1100000
29
+ gen_subset: ???
fairseq/examples/hubert/config/decode/run/submitit_slurm.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ hydra:
3
+ launcher:
4
+ cpus_per_task: ${distributed_training.distributed_world_size}
5
+ gpus_per_node: ${distributed_training.distributed_world_size}
6
+ tasks_per_node: ${hydra.launcher.gpus_per_node}
7
+ nodes: 1
8
+ mem_gb: 200
9
+ timeout_min: 4320
10
+ max_num_timeout: 50
11
+ name: ${hydra.job.config_name}
12
+ submitit_folder: ${hydra.sweep.dir}/submitit
13
+
14
+ distributed_training:
15
+ distributed_world_size: 1
16
+ distributed_no_spawn: true
17
+ distributed_port: 29761
fairseq/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ hydra:
3
+ launcher:
4
+ cpus_per_task: ${distributed_training.distributed_world_size}
5
+ gpus_per_node: ${distributed_training.distributed_world_size}
6
+ tasks_per_node: ${hydra.launcher.gpus_per_node}
7
+ nodes: 1
8
+ mem_gb: 200
9
+ timeout_min: 4320
10
+ max_num_timeout: 50
11
+ name: ${hydra.job.config_name}
12
+ submitit_folder: ${hydra.sweep.dir}/submitit
13
+
14
+ distributed_training:
15
+ distributed_world_size: 8
16
+ distributed_no_spawn: true
17
+ distributed_port: 29761
fairseq/examples/hubert/config/finetune/base_10h.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ tensorboard_logdir: tblog
8
+ seed: 1337
9
+
10
+ checkpoint:
11
+ save_interval: 5
12
+ keep_interval_updates: 1
13
+ no_epoch_checkpoints: true
14
+ best_checkpoint_metric: wer
15
+
16
+ distributed_training:
17
+ ddp_backend: c10d
18
+ find_unused_parameters: true
19
+ distributed_world_size: 1
20
+ distributed_port: 29671
21
+ nprocs_per_node: 8
22
+
23
+ task:
24
+ _name: hubert_pretraining
25
+ data: ???
26
+ fine_tuning: true
27
+ label_dir: ???
28
+ normalize: false # must be consistent with pre-training
29
+ labels: ["ltr"]
30
+ single_target: true
31
+
32
+ dataset:
33
+ num_workers: 0
34
+ max_tokens: 3200000
35
+ validate_after_updates: ${model.freeze_finetune_updates}
36
+ validate_interval: 5
37
+ train_subset: train
38
+ valid_subset: valid
39
+
40
+ criterion:
41
+ _name: ctc
42
+ zero_infinity: true
43
+
44
+ optimization:
45
+ max_update: 25000
46
+ lr: [2e-5]
47
+ sentence_avg: true
48
+ update_freq: [1]
49
+
50
+ optimizer:
51
+ _name: adam
52
+ adam_betas: (0.9,0.98)
53
+ adam_eps: 1e-08
54
+
55
+ lr_scheduler:
56
+ _name: tri_stage
57
+ warmup_steps: 8000
58
+ hold_steps: 0
59
+ decay_steps: 72000
60
+ final_lr_scale: 0.05
61
+
62
+ model:
63
+ _name: hubert_ctc
64
+ w2v_path: ???
65
+ apply_mask: true
66
+ mask_selection: static
67
+ mask_length: 10
68
+ mask_other: 0
69
+ mask_prob: 0.75
70
+ mask_channel_selection: static
71
+ mask_channel_length: 64
72
+ mask_channel_other: 0
73
+ mask_channel_prob: 0.5
74
+ layerdrop: 0.1
75
+ dropout: 0.0
76
+ activation_dropout: 0.1
77
+ attention_dropout: 0.0
78
+ feature_grad_mult: 0.0
79
+ freeze_finetune_updates: 10000
80
+
81
+ hydra:
82
+ job:
83
+ config:
84
+ override_dirname:
85
+ kv_sep: '-'
86
+ item_sep: '__'
87
+ exclude_keys:
88
+ - run
89
+ - task.data
90
+ - task.label_dir
91
+ - model.w2v_path
92
+ - dataset.train_subset
93
+ - dataset.valid_subset
94
+ - criterion.wer_kenlm_model
95
+ - criterion.wer_lexicon
96
+ run:
97
+ dir: ???
98
+ sweep:
99
+ dir: ???
100
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
fairseq/examples/hubert/config/finetune/ckpt/it1.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ task:
4
+ normalize: false
5
+
6
+ model:
7
+ w2v_path: /checkpoint/wnhsu/w2v/hubert_final/iter1/hubert.km.randcrop.pmw1_0.puw0_0.grpnorm.ml10.mp0_8.untie.mxsz250000.ufreq1.maxtok1400000.MU400k.s1337.ngpu32/checkpoint_last.pt
fairseq/examples/hubert/config/finetune/lm/ls_4gram.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ criterion:
4
+ wer_kenlm_model: /checkpoint/abdo/old_checkpoint02/datasets/librispeech/4-gram.bin
5
+ wer_lexicon: /checkpoint/abdo/old_checkpoint02/datasets/librispeech/10h/raw/lexicon_ltr.lst
6
+ wer_lm_weight: 2.0
7
+ wer_word_score: -1.0
fairseq/examples/hubert/config/finetune/run/submitit_reg.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ launcher:
5
+ cpus_per_task: 8
6
+ gpus_per_node: 8
7
+ tasks_per_node: ${hydra.launcher.gpus_per_node}
8
+ nodes: 1
9
+ comment: null
10
+ mem_gb: 384
11
+ timeout_min: 4320
12
+ max_num_timeout: 100
13
+ constraint: volta32gb
14
+ name: ${hydra.job.config_name}/${hydra.job.override_dirname}
15
+ submitit_folder: ${hydra.sweep.dir}/submitit/%j
16
+
17
+ distributed_training:
18
+ distributed_world_size: 8
19
+ distributed_port: 29671
20
+ nprocs_per_node: 8
fairseq/examples/hubert/config/pretrain/data/iter1.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ task:
4
+ label_dir: ???
5
+ labels: ["km"]
6
+
7
+ model:
8
+ label_rate: 100
fairseq/examples/hubert/config/pretrain/data/iter2.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ task:
4
+ label_dir: ???
5
+ labels: ["km"]
6
+
7
+ model:
8
+ label_rate: 50