lmzjms's picture
Upload 46 files
8121fee
raw
history blame
No virus
1.63 kB
import json
import random
import argparse
import numpy as np
from tqdm import tqdm
from h5py import File
import sklearn.metrics
random.seed(1)
parser = argparse.ArgumentParser()
parser.add_argument("train_feature", type=str)
parser.add_argument("train_corpus", type=str)
parser.add_argument("pred_feature", type=str)
parser.add_argument("output_json", type=str)
args = parser.parse_args()
train_embs = []
train_idx_to_audioid = []
with File(args.train_feature, "r") as store:
for audio_id, embedding in tqdm(store.items(), ascii=True):
train_embs.append(embedding[()])
train_idx_to_audioid.append(audio_id)
train_annotation = json.load(open(args.train_corpus, "r"))["audios"]
train_audioid_to_tokens = {}
for item in train_annotation:
audio_id = item["audio_id"]
train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]]
train_embs = np.stack(train_embs)
pred_data = []
pred_embs = []
pred_idx_to_audioids = []
with File(args.pred_feature, "r") as store:
for audio_id, embedding in tqdm(store.items(), ascii=True):
pred_embs.append(embedding[()])
pred_idx_to_audioids.append(audio_id)
pred_embs = np.stack(pred_embs)
similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs)
for idx, audio_id in enumerate(pred_idx_to_audioids):
train_idx = similarity[idx].argmax()
pred_data.append({
"filename": audio_id,
"tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]])
})
json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4)