| import os | |
| import torch | |
| import random | |
| import argparse | |
| import numpy as np | |
| from s3prl import hub | |
| SAMPLE_RATE = 16000 | |
| BATCH_SIZE = 8 | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--upstream", "-u", required=True) | |
| parser.add_argument("--output", "-o", required=True) | |
| parser.add_argument("--device", default="cuda") | |
| parser.add_argument("--seed", type=int, default=0) | |
| args = parser.parse_args() | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(args.seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| wavs = [ | |
| torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device) | |
| for _ in range(BATCH_SIZE) | |
| ] | |
| upstream = getattr(hub, args.upstream)().to(args.device) | |
| upstream.eval() | |
| with torch.no_grad(): | |
| hidden = upstream(wavs)["last_hidden_state"].detach().cpu() | |
| torch.save(hidden, args.output) | |