berkaygkv54's picture
first push
19759e2
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from open_clip import create_model
from open_clip import tokenize
import glob
import json
import librosa
from tqdm import tqdm
import numpy as np
import os
from laion_clap.training.params import parse_args
def get_output_from_single_audio(audio, text, model, device):
# audio_embedding = model.audio_infer(audio, hopsize=5 * 48000, key="embedding", device=device)['embedding']
# if audio_embedding.ndim > 1:
# audio_embedding = audio_embedding.mean(dim=0, keepdim=True)
# else:
# audio_embedding = audio_embedding.unsqueeze(0)
audio_features = model(audio, None, device)
audio_features = F.normalize(audio_features, dim=-1)
text_features = model(None, text, device=device)
text_features = F.normalize(text_features, dim=-1)
# CHANGE: before normalize or after
audio_features_mlp = model.audio_transform(audio_features)
text_features_mlp = model.text_transform(text_features)
return audio_features, text_features, audio_features_mlp, text_features_mlp, model.logit_scale_a.exp(), model.logit_scale_t.exp()
def get_metrics(text_to_audio_logits):
metrics = {}
# repeat ground truth 5 times because Clotho has 5 text for 1 audio
ground_truth = torch.repeat_interleave(torch.arange(len(text_features) // 5), 5).view(-1, 1)
ranking = torch.argsort(text_to_audio_logits, descending=True)
preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread
preds = preds.detach().cpu().numpy()
metrics[f"mean_rank"] = preds.mean() + 1
metrics[f"median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"R@{k}"] = np.mean(preds < k)
# map@10
metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
return metrics
if __name__ == '__main__':
args = parse_args()
model_path = args.pretrained
clotho_test_preprocessed_dir = "/fsx/yusong/clotho_test_set/test"
cudnn.benchmark = True
cudnn.deterministic = False
audio_features_ensemble_all = []
text_features_ensemble_all = []
audio_features_mlp_ensemble_all = []
text_features_mlp_ensemble_all = []
logit_scale_a_ensemble_all = []
logit_scale_t_ensemble_all = []
device = torch.device('cuda')
model, clap_model_cfg = create_model(
args.amodel,
args.tmodel,
args.pretrained,
precision=args.precision,
device=device,
jit=args.torchscript,
force_quick_gelu=args.force_quick_gelu,
openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
skip_params=False
)
# load model
checkpoint = torch.load(model_path, map_location=device)
if "epoch" in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith(
"module"
):
sd = {k[len("module."):]: v for k, v in sd.items()}
model.load_state_dict(sd)
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
for param in model.parameters():
param.requires_grad = False
# take every 5th file because clotho has 5 texts for 1 audio
test_file_list = sorted(glob.glob(f"{clotho_test_preprocessed_dir}/*.flac"))
audio_features_all = []
text_features_all = []
audio_features_mlp_all = []
text_features_mlp_all = []
logit_scale_a_all = []
logit_scale_t_all = []
with torch.no_grad():
for file_path in tqdm(test_file_list):
json_path = file_path.replace(".flac", ".json")
with open(json_path, "r") as f:
json_data = json.load(f)
audio, sr = librosa.load(file_path, sr=48000, mono=True)
audio = torch.from_numpy(audio).to(device)
audio = {'waveform': audio.unsqueeze(0), 'sample_rate': sr}
text = json_data["text"]
if args.tmodel == "transformer":
from open_clip import tokenize
text = tokenize(text)
else:
from laion_clap.training.data import tokenizer
text = tokenizer(text, tmodel=args.tmodel) # 5 texts for each audio
audio_features, text_features, audio_features_mlp, text_features_mlp, logit_scale_a, logit_scale_t = \
get_output_from_single_audio(audio, text, model, device)
audio_features_all.append(audio_features.detach().cpu())
text_features_all.append(text_features.detach().cpu())
audio_features_mlp_all.append(audio_features_mlp.detach().cpu())
text_features_mlp_all.append(text_features_mlp.detach().cpu())
logit_scale_a_all.append(logit_scale_a.detach().cpu())
logit_scale_t_all.append(logit_scale_t.detach().cpu())
audio_features = torch.cat(audio_features_all)
text_features = torch.cat(text_features_all)
logit_scale_a = logit_scale_a_all[0]
logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_audio.t().detach().cpu()
metrics = get_metrics(
logits_per_text
)
print(metrics)