Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import os | |
| import os.path as osp | |
| import math | |
| import numpy as np | |
| import tqdm | |
| import torch | |
| import torch.nn.functional as F | |
| from shutil import copyfile | |
| from npy_append_array import NpyAppendArray | |
| def get_parser(): | |
| parser = argparse.ArgumentParser( | |
| description="mean pools representations by compressing uniform splits of the data" | |
| ) | |
| # fmt: off | |
| parser.add_argument('source', help='directory with features') | |
| parser.add_argument('--split', help='which split to read', required=True) | |
| parser.add_argument('--save-dir', help='where to save the output', required=True) | |
| parser.add_argument('--subsample-rate', type=float, default=0.5, help='size to subsample data to') | |
| parser.add_argument('--remove-extra', action='store_true', help='if true, removes extra states that cant be pooled, otherwise pads with 0s') | |
| # fmt: on | |
| return parser | |
| def main(): | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| source_path = osp.join(args.source, args.split) | |
| print(f"data path: {source_path}") | |
| features = np.load(source_path + ".npy", mmap_mode="r") | |
| os.makedirs(args.save_dir, exist_ok=True) | |
| save_path = osp.join(args.save_dir, args.split) | |
| copyfile(source_path + ".tsv", save_path + ".tsv") | |
| if os.path.exists(source_path + ".phn"): | |
| copyfile(source_path + ".phn", save_path + ".phn") | |
| if os.path.exists(source_path + ".wrd"): | |
| copyfile(source_path + ".wrd", save_path + ".wrd") | |
| if os.path.exists(osp.join(args.source, "dict.phn.txt")): | |
| copyfile( | |
| osp.join(args.source, "dict.phn.txt"), | |
| osp.join(args.save_dir, "dict.phn.txt"), | |
| ) | |
| if osp.exists(save_path + ".npy"): | |
| os.remove(save_path + ".npy") | |
| npaa = NpyAppendArray(save_path + ".npy") | |
| with open(source_path + ".lengths", "r") as lf: | |
| lengths = lf.readlines() | |
| fsz = features.shape[-1] | |
| start = 0 | |
| with torch.no_grad(): | |
| with open(save_path + ".lengths", "w") as lengths_out: | |
| for length in tqdm.tqdm(lengths): | |
| length = int(length) | |
| end = start + length | |
| feats = features[start:end] | |
| start += length | |
| x = torch.from_numpy(feats).cuda() | |
| target_num = math.ceil(length * args.subsample_rate) | |
| rem = length % target_num | |
| if rem > 0: | |
| if args.remove_extra: | |
| to_rem = target_num - rem | |
| target_num -= 1 | |
| x = x[:-to_rem] | |
| else: | |
| to_add = target_num - rem | |
| x = F.pad(x, [0, 0, 0, to_add]) | |
| x[-to_add:] = x[-to_add - 1] | |
| x = x.view(target_num, -1, fsz) | |
| x = x.mean(dim=-2) | |
| print(target_num, file=lengths_out) | |
| npaa.append(x.cpu().numpy()) | |
| if __name__ == "__main__": | |
| main() | |