|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import gc |
|
import os |
|
import os.path as osp |
|
import random |
|
import numpy as np |
|
import tqdm |
|
import torch |
|
|
|
from collections import namedtuple |
|
|
|
import faiss |
|
|
|
import fairseq |
|
import soundfile as sf |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
description="compute kmeans codebook from kaldi-computed feats" |
|
) |
|
|
|
parser.add_argument('data', help='location of tsv files') |
|
parser.add_argument('--save-dir', help='where to save the output', required=True) |
|
parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) |
|
parser.add_argument('--sample-pct', '-r', type=float, help='percentage of timesteps to sample', default=0) |
|
parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) |
|
parser.add_argument('--faiss-specs', '-f', type=str, |
|
help='faiss index specs; separated by space ' |
|
'format is: PCAx_NORM_CLUSx_SPHERICAL -> ' |
|
'PCAx if exists first apply PCA ' |
|
'NORM if exists, normalize the vector by L2 norm ' |
|
'CLUSx must exist, cluster to x clusters ' |
|
'SPEHRICAL if exists, apply spherical kmeans', |
|
default='l2') |
|
|
|
|
|
return parser |
|
|
|
|
|
faiss_spec = namedtuple("faiss_spec", ["pca", "norm", "n_clus", "sphere", "spec_str"]) |
|
|
|
|
|
def parse_faiss_specs(specs_str): |
|
specs = [] |
|
for ss in specs_str.split(): |
|
comps = ss.split("_") |
|
pca = 0 |
|
norm = False |
|
n_clus = 0 |
|
sphere = False |
|
for c in comps: |
|
if c.startswith("PCA"): |
|
pca = int(c[3:]) |
|
elif c == "NORM": |
|
norm = True |
|
elif c.startswith("CLUS"): |
|
n_clus = int(c[4:]) |
|
elif c == "SPHERICAL": |
|
sphere = True |
|
assert n_clus > 0 |
|
specs.append( |
|
faiss_spec(pca=pca, norm=norm, n_clus=n_clus, sphere=sphere, spec_str=ss) |
|
) |
|
return specs |
|
|
|
|
|
class Wav2VecFeatureReader(object): |
|
def __init__(self, cp_file, layer): |
|
state = fairseq.checkpoint_utils.load_checkpoint_to_cpu(cp_file) |
|
|
|
self.layer = layer |
|
|
|
if "cfg" in state: |
|
w2v_args = state["cfg"] |
|
task = fairseq.tasks.setup_task(w2v_args.task) |
|
model = task.build_model(w2v_args.model) |
|
else: |
|
w2v_args = state["args"] |
|
task = fairseq.tasks.setup_task(w2v_args) |
|
model = task.build_model(w2v_args) |
|
model.load_state_dict(state["model"], strict=True) |
|
model.eval() |
|
model.cuda() |
|
self.model = model |
|
|
|
def read_audio(self, fname): |
|
"""Load an audio file and return PCM along with the sample rate""" |
|
wav, sr = sf.read(fname) |
|
assert sr == 16e3 |
|
|
|
return wav |
|
|
|
def get_feats(self, loc): |
|
x = self.read_audio(loc) |
|
with torch.no_grad(): |
|
source = torch.from_numpy(x).view(1, -1).float().cuda() |
|
res = self.model( |
|
source=source, mask=False, features_only=True, layer=self.layer |
|
) |
|
return res["layer_results"][self.layer][0].squeeze(1) |
|
|
|
|
|
def get_iterator(args): |
|
with open(args.data, "r") as fp: |
|
lines = fp.read().split("\n") |
|
root = lines.pop(0).strip() |
|
files = [osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0] |
|
|
|
if getattr(args, "sample_pct", 0) > 0: |
|
files = random.sample(files, int(args.sample_pct * len(files))) |
|
num = len(files) |
|
reader = Wav2VecFeatureReader(args.checkpoint, args.layer) |
|
|
|
def iterate(): |
|
for fname in files: |
|
feats = reader.get_feats(fname) |
|
yield feats.cpu().numpy() |
|
|
|
return iterate, num |
|
|
|
|
|
def main(): |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
|
|
faiss_specs = parse_faiss_specs(args.faiss_specs) |
|
print("Faiss Specs:", faiss_specs) |
|
|
|
feat_path = osp.join(args.save_dir, "features") |
|
if osp.exists(feat_path + ".npy"): |
|
feats = np.load(feat_path + ".npy") |
|
else: |
|
generator, num = get_iterator(args) |
|
iterator = generator() |
|
|
|
feats = [] |
|
for f in tqdm.tqdm(iterator, total=num): |
|
feats.append(f) |
|
|
|
del iterator |
|
del generator |
|
|
|
feats = np.concatenate(feats) |
|
|
|
print(feats.shape) |
|
|
|
os.makedirs(args.save_dir, exist_ok=True) |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
reload = False |
|
for spec in faiss_specs: |
|
print("Processing spec", spec) |
|
|
|
if reload: |
|
print("Reloading...") |
|
del feats |
|
gc.collect() |
|
feats = np.load(feat_path + ".npy") |
|
|
|
save_path = osp.join(args.save_dir, spec.spec_str) |
|
os.makedirs(save_path, exist_ok=True) |
|
d = feats.shape[-1] |
|
x = feats |
|
if spec.pca > 0: |
|
print("Computing PCA") |
|
pca = faiss.PCAMatrix(d, spec.pca) |
|
pca.train(x) |
|
d = spec.pca |
|
b = faiss.vector_to_array(pca.b) |
|
A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in) |
|
np.save(osp.join(save_path, "pca_A"), A.T) |
|
np.save(osp.join(save_path, "pca_b"), b) |
|
print("Applying PCA") |
|
x = pca.apply_py(x) |
|
|
|
if spec.norm: |
|
reload = spec.pca <= 0 |
|
print("Normalizing") |
|
faiss.normalize_L2(x) |
|
|
|
print("Computing kmeans") |
|
kmeans = faiss.Kmeans( |
|
d, |
|
spec.n_clus, |
|
niter=50, |
|
verbose=True, |
|
spherical=spec.sphere, |
|
max_points_per_centroid=feats.shape[0], |
|
gpu=True, |
|
nredo=3, |
|
) |
|
kmeans.train(x) |
|
np.save(osp.join(save_path, "centroids"), kmeans.centroids) |
|
del kmeans |
|
del x |
|
gc.collect() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|