lmzjms's picture
Upload 1162 files
0b32ad6 verified
import os
import torch
import random
import argparse
import numpy as np
from s3prl import hub
SAMPLE_RATE = 16000
BATCH_SIZE = 8
parser = argparse.ArgumentParser()
parser.add_argument("--upstream", "-u", required=True)
parser.add_argument("--output", "-o", required=True)
parser.add_argument("--device", default="cuda")
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
wavs = [
torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device)
for _ in range(BATCH_SIZE)
]
upstream = getattr(hub, args.upstream)().to(args.device)
upstream.eval()
with torch.no_grad():
hidden = upstream(wavs)["last_hidden_state"].detach().cpu()
torch.save(hidden, args.output)