Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import os | |
import os.path as osp | |
import numpy as np | |
import tqdm | |
import torch | |
import sys | |
import faiss | |
import torch.nn.functional as F | |
from wav2vec_cluster_faiss import parse_faiss_specs, Wav2VecFeatureReader | |
def get_parser(): | |
parser = argparse.ArgumentParser(description="apply clusters") | |
# fmt: off | |
parser.add_argument('data', help='location of tsv files') | |
parser.add_argument('--split', help='split to process', required=True) | |
parser.add_argument('--labels', help='split to process', default="phn") | |
parser.add_argument('--path', help='path to pca and centroids', required=True) | |
parser.add_argument('--checkpoint', type=str, help='checkpoint for wav2vec model (if using wav2vec features)', required=True) | |
parser.add_argument('--layer', '-l', type=int, help='which layer to read', default=14) | |
parser.add_argument('--max-tsz', type=int, help='batch kmeans up to this much', default=14) | |
# fmt: on | |
return parser | |
def get_iterator(args): | |
label_path = osp.join(args.data, f"{args.split}.{args.labels}") | |
if osp.exists(label_path): | |
lp = open(label_path, "r") | |
else: | |
lp = None | |
with open(osp.join(args.data, f"{args.split}.tsv"), "r") as fp: | |
lines = fp.read().split("\n") | |
root = lines.pop(0).strip() | |
files = [line.rstrip() for line in lines if len(line) > 0] | |
if lp is not None: | |
lbls = [line.rstrip() for line in lp] | |
else: | |
lbls = [None] * len(files) | |
num = len(files) | |
reader = Wav2VecFeatureReader(args.checkpoint, args.layer) | |
def iterate(): | |
for fname, lbl in zip(files, lbls): | |
file = osp.join(root, fname.split("\t")[0]) | |
feats = reader.get_feats(file) | |
yield feats.data, fname, lbl | |
return iterate, num, root | |
def main(): | |
parser = get_parser() | |
args = parser.parse_args() | |
spec = osp.basename(args.path) | |
try: | |
faiss_spec = parse_faiss_specs(spec.rstrip("/"))[0] | |
except: | |
print(spec) | |
raise | |
print("Faiss Spec:", faiss_spec, file=sys.stderr) | |
if faiss_spec.pca: | |
A = torch.from_numpy(np.load(osp.join(args.path, "pca_A.npy"))).cuda() | |
b = torch.from_numpy(np.load(osp.join(args.path, "pca_b.npy"))).cuda() | |
print("Loaded PCA", file=sys.stderr) | |
centroids = np.load(osp.join(args.path, "centroids.npy")) | |
print("Loaded centroids", centroids.shape, file=sys.stderr) | |
res = faiss.StandardGpuResources() | |
index_flat = ( | |
faiss.IndexFlatL2(centroids.shape[1]) | |
if not faiss_spec.sphere | |
else faiss.IndexFlatIP(centroids.shape[1]) | |
) | |
faiss_index = faiss.index_cpu_to_gpu(res, 0, index_flat) | |
faiss_index.add(centroids) | |
generator, num, root = get_iterator(args) | |
iterator = generator() | |
had_labels = False | |
label_path = osp.join(args.path, f"{args.split}.{args.labels}") | |
with torch.no_grad(): | |
with open(osp.join(args.path, f"{args.split}.src"), "w") as fp, open( | |
osp.join(args.path, f"{args.split}.tsv"), "w" | |
) as pp, open(label_path, "w") as lp: | |
print(root, file=pp) | |
for f, fname, lbl in tqdm.tqdm(iterator, total=num): | |
if faiss_spec.pca: | |
f = torch.mm(f, A) + b | |
if faiss_spec.norm: | |
f = F.normalize(f, p=2, dim=-1) | |
f = f.cpu().numpy() | |
_, z = faiss_index.search(f, 1) | |
print(" ".join(str(x.item()) for x in z), file=fp) | |
print(fname, file=pp) | |
if lbl is not None: | |
print(lbl, file=lp) | |
had_labels = True | |
if not had_labels: | |
os.remove(label_path) | |
if __name__ == "__main__": | |
main() | |