ShiromiyaGamer commited on
Commit
339a64c
1 Parent(s): 682794f

Upload dump_w2v2_feature.py

Browse files
Files changed (1) hide show
  1. dump_w2v2_feature.py +94 -0
dump_w2v2_feature.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+
10
+ import fairseq
11
+ import soundfile as sf
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from feature_utils import get_path_iterator, dump_feature
16
+
17
+
18
+ logging.basicConfig(
19
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
22
+ stream=sys.stdout,
23
+ )
24
+ logger = logging.getLogger("dump_w2v2_feature")
25
+
26
+
27
+ class Wav2Vec2FeatureReader(object):
28
+ def __init__(self, ckpt_path, layer, max_chunk=1600000):
29
+ (
30
+ model,
31
+ cfg,
32
+ task,
33
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
34
+ self.model = model[0].eval().cuda()
35
+ self.task = task
36
+ self.layer = layer # assume this is 1-based like HuBERT
37
+ self.max_chunk = max_chunk
38
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
39
+ logger.info(f" max_chunk = {self.max_chunk}")
40
+ logger.info(f" model:\n{self.model}")
41
+
42
+ def read_audio(self, path, ref_len=None):
43
+ wav, sr = sf.read(path)
44
+ assert sr == self.task.cfg.sample_rate, sr
45
+ if wav.ndim == 2:
46
+ wav = wav.mean(-1)
47
+ assert wav.ndim == 1, wav.ndim
48
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
49
+ logging.warning(f"ref {ref_len} != read {len(wav)} ({path})")
50
+ return wav
51
+
52
+ def get_feats(self, path, ref_len=None):
53
+ x = self.read_audio(path, ref_len)
54
+ with torch.no_grad():
55
+ x = torch.from_numpy(x).float().cuda()
56
+ if self.task.cfg.normalize:
57
+ x = F.layer_norm(x, x.shape)
58
+ x = x.view(1, -1)
59
+
60
+ feat = []
61
+ for start in range(0, x.size(1), self.max_chunk):
62
+ x_chunk = x[:, start: start + self.max_chunk]
63
+ res = self.model.extract_features(
64
+ source=x_chunk,
65
+ padding_mask=None,
66
+ mask=False,
67
+ )
68
+ feat_chunk = res["x"]
69
+ feat.append(feat_chunk)
70
+ return torch.cat(feat, 1).squeeze(0)
71
+
72
+
73
+ def main(tsv_dir, split, ckpt_path, layer, nshard, rank, feat_dir, max_chunk):
74
+ reader = Wav2Vec2FeatureReader(ckpt_path, layer, max_chunk)
75
+ generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank)
76
+ dump_feature(reader, generator, num, split, nshard, rank, feat_dir)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ import argparse
81
+
82
+ parser = argparse.ArgumentParser()
83
+ parser.add_argument("tsv_dir")
84
+ parser.add_argument("split")
85
+ parser.add_argument("ckpt_path")
86
+ parser.add_argument("layer", type=int)
87
+ parser.add_argument("nshard", type=int)
88
+ parser.add_argument("rank", type=int)
89
+ parser.add_argument("feat_dir")
90
+ parser.add_argument("--max_chunk", type=int, default=1600000)
91
+ args = parser.parse_args()
92
+ logger.info(args)
93
+
94
+ main(**vars(args))