Spaces:
Running
on
Zero
Running
on
Zero
import pathlib | |
import sys | |
import os | |
directory = pathlib.Path(os.getcwd()) | |
sys.path.append(str(directory)) | |
import torch | |
import numpy as np | |
from wav_evaluation.models.CLAPWrapper import CLAPWrapper | |
import torch.nn.functional as F | |
import argparse | |
import csv | |
from tqdm import tqdm | |
from torch.utils.data import Dataset,DataLoader | |
import pandas as pd | |
import json | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--csv_path',type=str,default='') | |
parser.add_argument('--wavsdir',type=str) | |
parser.add_argument('--mean',type=bool,default=True) | |
parser.add_argument('--ckpt_path', default="useful_ckpts/CLAP") | |
args = parser.parse_args() | |
return args | |
def add_audio_path(df): | |
df['audio_path'] = df.apply(lambda x:x['mel_path'].replace('.npy','.wav'),axis=1) | |
return df | |
def build_csv_from_wavs(root_dir): | |
with open('ldm/data/audiocaps_fn2cap.json','r') as f: | |
fn2cap = json.load(f) | |
wavs_root = os.path.join(root_dir,'fake_class') | |
wavfiles = os.listdir(wavs_root) | |
wavfiles = list(filter(lambda x:x.endswith('.wav') and x[-6:-4]!='gt',wavfiles)) | |
print(len(wavfiles)) | |
dict_list = [] | |
for wavfile in wavfiles: | |
tmpd = {'audio_path':os.path.join(wavs_root,wavfile)} | |
key = wavfile.rsplit('_sample')[0] + wavfile.rsplit('_sample')[1][:2] | |
tmpd['caption'] = fn2cap[key] | |
dict_list.append(tmpd) | |
df = pd.DataFrame.from_dict(dict_list) | |
csv_path = f'{os.path.basename(root_dir)}.csv' | |
csv_path = os.path.join(wavs_root,csv_path) | |
df.to_csv(csv_path,sep='\t',index=False) | |
return csv_path | |
def cal_score_by_csv(csv_path,clap_model): # audiocaps val的gt音频的clap_score计算为0.479077 | |
df = pd.read_csv(csv_path,sep='\t') | |
clap_scores = [] | |
if not ('audio_path' in df.columns): | |
df = add_audio_path(df) | |
caption_list,audio_list = [],[] | |
with torch.no_grad(): | |
for idx,t in enumerate(tqdm(df.itertuples()),start=1): | |
# text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding | |
# audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True) | |
# score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) | |
# clap_scores.append(score.cpu().numpy()) | |
caption_list.append(getattr(t,'caption')) | |
audio_list.append(getattr(t,'audio_path')) | |
if idx % 20 == 0: | |
text_embeddings = clap_model.get_text_embeddings(caption_list)# 经过了norm的embedding | |
audio_embeddings = clap_model.get_audio_embeddings(audio_list, resample=True)# 这一步比较耗时,读取音频并重采样到44100 | |
score_mat = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) | |
score = score_mat.diagonal() | |
clap_scores.append(score.cpu().numpy()) | |
# print(caption_list) | |
# print(audio_list) | |
# print(score) | |
audio_list = [] | |
caption_list = [] | |
# print("mean:",np.mean(np.array(clap_scores).flatten())) | |
return np.mean(np.array(clap_scores).flatten()) | |
def add_clap_score_to_csv(csv_path,clap_model): | |
df = pd.read_csv(csv_path,sep='\t') | |
clap_scores_dict = {} | |
with torch.no_grad(): | |
for idx,t in enumerate(tqdm(df.itertuples()),start=1): | |
text_embeddings = clap_model.get_text_embeddings([getattr(t,'caption')])# 经过了norm的embedding | |
audio_embeddings = clap_model.get_audio_embeddings([getattr(t,'audio_path')], resample=True) | |
score = clap_model.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False) | |
clap_scores_dict[idx] = score.cpu().numpy() | |
df['clap_score'] = clap_scores_dict | |
df.to_csv(csv_path[:-4]+'_clap.csv',sep='\t',index=False) | |
if __name__ == '__main__': | |
args = parse_args() | |
if args.csv_path: | |
csv_path = args.csv_path | |
else: | |
csv_path = os.path.join(args.wavsdir,'fake_class/result.csv') | |
if not os.path.exists(csv_path): | |
print("result csv not exist,build for it") | |
csv_path = build_csv_from_wavs(args.wavsdir) | |
clap_model = CLAPWrapper(os.path.join(args.ckpt_path,'CLAP_weights_2022.pth'),os.path.join(args.ckpt_path,'config.yml'), use_cuda=True) | |
clap_score = cal_score_by_csv(csv_path,clap_model) | |
out = args.wavsdir if args.wavsdir else args.csv_path | |
print(f"clap_score for {out} is:{clap_score}") | |
print(f"clap_score for {out} is:{clap_score}") | |
print(f"clap_score for {out} is:{clap_score}") | |
# os.remove(csv_path) |