Spaces:
Running
on
L4
Running
on
L4
import json | |
import logging | |
import math | |
import os | |
import time | |
from contextlib import suppress | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
try: | |
import wandb | |
except ImportError: | |
wandb = None | |
from open_clip import ClipLoss, gather_features | |
from .distributed import is_master | |
from .zero_shot import zero_shot_eval | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def unwrap_model(model): | |
if hasattr(model, "module"): | |
return model.module | |
else: | |
return model | |
def train_one_epoch( | |
model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None | |
): | |
device = torch.device(args.device) | |
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress | |
model.train() | |
loss = ClipLoss( | |
local_loss=args.local_loss, | |
gather_with_grad=args.gather_with_grad, | |
cache_labels=True, | |
rank=args.rank, | |
world_size=args.world_size, | |
use_horovod=args.horovod, | |
mlp_loss=args.clap_mlploss, | |
weight_loss_kappa=args.kappa, | |
) | |
dataloader, sampler = data["train"].dataloader, data["train"].sampler | |
if args.distributed and sampler is not None: | |
sampler.set_epoch(epoch) | |
num_batches_per_epoch = dataloader.num_batches | |
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) | |
# for toy dataset | |
if args.dataset_type == "toy": | |
dataloader.dataset.generate_queue() | |
loss_m = AverageMeter() | |
batch_time_m = AverageMeter() | |
data_time_m = AverageMeter() | |
end = time.time() | |
for i, batch in enumerate(dataloader): | |
# logging.info(f"batch {i} of {num_batches_per_epoch}") | |
step = num_batches_per_epoch * epoch + i | |
if isinstance(scheduler, dict): | |
for s in scheduler.values(): | |
s(step) | |
else: | |
scheduler(step) | |
audios = batch # contains mel_spec, wavform, and longer list | |
texts = batch["text"] | |
# audios = audios.to(device=device, non_blocking=True) | |
# texts = texts.to(device=device, non_blocking=True) | |
data_time_m.update(time.time() - end) | |
if isinstance(optimizer, dict): | |
for o_ in optimizer.values(): | |
o_.zero_grad() | |
else: | |
optimizer.zero_grad() | |
with autocast(): | |
( | |
audio_features, | |
text_features, | |
audio_features_mlp, | |
text_features_mlp, | |
logit_scale_a, | |
logit_scale_t, | |
) = model(audios, texts, device) | |
if args.clap_mlploss: | |
total_loss = loss( | |
audio_features=audio_features, | |
text_features=text_features, | |
logit_scale_a=logit_scale_a, | |
logit_scale_t=logit_scale_t, | |
audio_features_mlp=audio_features_mlp, | |
text_features_mlp=text_features_mlp, | |
) | |
else: | |
total_loss = loss( | |
audio_features=audio_features, | |
text_features=text_features, | |
logit_scale_a=logit_scale_a, | |
) | |
if isinstance(optimizer, dict): | |
if scaler is not None: | |
scaler.scale(total_loss).backward() | |
for o_ in optimizer.values(): | |
if args.horovod: | |
o_.synchronize() | |
scaler.unscale_(o_) | |
with o_.skip_synchronize(): | |
scaler.step(o_) | |
else: | |
scaler.step(o_) | |
scaler.update() | |
else: | |
total_loss.backward() | |
for o_ in optimizer.values(): | |
o_.step() | |
else: | |
if scaler is not None: | |
scaler.scale(total_loss).backward() | |
if args.horovod: | |
optimizer.synchronize() | |
scaler.unscale_(optimizer) | |
with optimizer.skip_synchronize(): | |
scaler.step(optimizer) | |
else: | |
scaler.step(optimizer) | |
scaler.update() | |
else: | |
total_loss.backward() | |
optimizer.step() | |
# Note: we clamp to 4.6052 = ln(100), as in the original paper. | |
with torch.no_grad(): | |
unwrap_model(model).logit_scale_a.clamp_(0, math.log(100)) | |
if args.clap_mlploss: | |
unwrap_model(model).logit_scale_t.clamp_(0, math.log(100)) | |
batch_time_m.update(time.time() - end) | |
end = time.time() | |
batch_count = i + 1 | |
if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): | |
if isinstance(audios, dict): | |
batch_size = len(audios["waveform"]) | |
else: | |
batch_size = len(audios) | |
num_samples = batch_count * batch_size * args.world_size | |
samples_per_epoch = dataloader.num_samples | |
percent_complete = 100.0 * batch_count / num_batches_per_epoch | |
# NOTE loss is coarsely sampled, just master node and per log update | |
loss_m.update(total_loss.item(), batch_size) | |
logit_scale_scalar_a = logit_scale_a.item() | |
logit_scale_scalar_t = logit_scale_t.item() | |
if isinstance(optimizer, dict): | |
if args.clap_mlploss: | |
logging.info( | |
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " | |
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " | |
f"Data (t): {data_time_m.avg:.3f} " | |
f"Batch (t): {batch_time_m.avg:.3f} " | |
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " | |
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" | |
f"Logit Scale Text: {logit_scale_scalar_t:.3f}" | |
) | |
log_data = { | |
"loss": loss_m.val, | |
"data_time": data_time_m.val, | |
"batch_time": batch_time_m.val, | |
"scale_audio": logit_scale_scalar_a, | |
"scale_text": logit_scale_scalar_t, | |
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], | |
} | |
else: | |
logging.info( | |
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " | |
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " | |
f"Data (t): {data_time_m.avg:.3f} " | |
f"Batch (t): {batch_time_m.avg:.3f} " | |
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " | |
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" | |
) | |
log_data = { | |
"loss": loss_m.val, | |
"data_time": data_time_m.val, | |
"batch_time": batch_time_m.val, | |
"scale_audio": logit_scale_scalar_a, | |
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], | |
} | |
else: | |
if args.clap_mlploss: | |
logging.info( | |
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " | |
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " | |
f"Data (t): {data_time_m.avg:.3f} " | |
f"Batch (t): {batch_time_m.avg:.3f} " | |
f"LR: {optimizer.param_groups[0]['lr']:5f} " | |
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" | |
f"Logit Scale Text: {logit_scale_scalar_t:.3f}" | |
) | |
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing | |
log_data = { | |
"loss": loss_m.val, | |
"data_time": data_time_m.val, | |
"batch_time": batch_time_m.val, | |
"scale_audio": logit_scale_scalar_a, | |
"scale_text": logit_scale_scalar_t, | |
"lr": optimizer.param_groups[0]["lr"], | |
} | |
else: | |
logging.info( | |
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " | |
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " | |
f"Data (t): {data_time_m.avg:.3f} " | |
f"Batch (t): {batch_time_m.avg:.3f} " | |
f"LR: {optimizer.param_groups[0]['lr']:5f} " | |
f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" | |
) | |
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing | |
log_data = { | |
"loss": loss_m.val, | |
"data_time": data_time_m.val, | |
"batch_time": batch_time_m.val, | |
"scale_audio": logit_scale_scalar_a, | |
"lr": optimizer.param_groups[0]["lr"], | |
} | |
for name, val in log_data.items(): | |
name = "train/" + name | |
if tb_writer is not None: | |
tb_writer.add_scalar(name, val, step) | |
if args.wandb: | |
assert wandb is not None, "Please install wandb." | |
wandb.log({name: val, "step": step}) | |
# resetting batch / data time meters per log window | |
batch_time_m.reset() | |
data_time_m.reset() | |
# end for | |
def evaluate(model, data, epoch, args, tb_writer=None): | |
metrics = {} | |
if not args.parallel_eval: | |
if not is_master(args): | |
return metrics | |
device = torch.device(args.device) | |
model.eval() | |
# CHANGE | |
# zero_shot_metrics = zero_shot_eval(model, data, epoch, args) | |
# metrics.update(zero_shot_metrics) | |
if is_master(args): | |
print("Evaluating...") | |
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress | |
if args.val_dataset_names == ["Clotho", "audiocaps"]: | |
# if only clotho and audiocaps are used, then we will use a different evaluation function. | |
# This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio. | |
if args.parallel_eval: | |
# (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps. | |
raise NotImplementedError( | |
"Parallel evaluation not supported for eval only Clotho and audiocaps." | |
) | |
val_metrics_per_dataset = evaluate_clotho_audiocaps( | |
model, data, epoch, args, autocast, device, tb_writer | |
) | |
for m in val_metrics_per_dataset.values(): | |
metrics.update(m) | |
if "epoch" not in metrics.keys(): | |
metrics.update({"epoch": epoch}) | |
metrics = select_top_metric_clotho_audiocaps( | |
metrics, val_metrics_per_dataset, args | |
) | |
elif "val" in data and ( | |
args.val_frequency | |
and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) | |
): | |
dataloader = data["val"].dataloader | |
num_samples = 0 | |
samples_per_val = dataloader.num_samples | |
# FIXME this does not scale past small eval datasets | |
# all_audio_features @ all_text_features will blow up memory and compute very quickly | |
eval_info = {} | |
if args.clap_mlploss: | |
eval_info["all"] = { | |
"cumulative_loss": 0.0, | |
"num_samples": 0, | |
"all_audio_features": [], | |
"all_text_features": [], | |
"all_audio_features_mlp": [], | |
"all_text_features_mlp": [], | |
} # cumulative_loss = 0.0 | |
else: | |
eval_info["all"] = { | |
"cumulative_loss": 0.0, | |
"num_samples": 0, | |
"all_audio_features": [], | |
"all_text_features": [], | |
} # cumu | |
# all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], [] | |
with torch.no_grad(): | |
for i, batch in enumerate(dataloader): | |
audios = batch # contains mel_spec, wavform, and longer list | |
texts = batch["text"] | |
# audios = audios.to(device=device, non_blocking=True) | |
all_names = list( | |
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) | |
) | |
for name in all_names: | |
if name not in eval_info.keys(): | |
if args.clap_mlploss: | |
eval_info[name] = { | |
"cumulative_loss": 0.0, | |
"num_samples": 0, | |
"all_audio_features": [], | |
"all_text_features": [], | |
"all_audio_features_mlp": [], | |
"all_text_features_mlp": [], | |
} | |
else: | |
eval_info[name] = { | |
"cumulative_loss": 0.0, | |
"num_samples": 0, | |
"all_audio_features": [], | |
"all_text_features": [], | |
} | |
with autocast(): | |
( | |
audio_features, | |
text_features, | |
audio_features_mlp, | |
text_features_mlp, | |
logit_scale_a, | |
logit_scale_t, | |
) = model(audios, texts, device) | |
if args.parallel_eval: | |
# multi-GPU eval | |
if args.clap_mlploss: | |
( | |
audio_features, | |
text_features, | |
audio_features_mlp, | |
text_features_mlp, | |
) = gather_features( | |
audio_features=audio_features, | |
text_features=text_features, | |
audio_features_mlp=audio_features_mlp, | |
text_features_mlp=text_features_mlp, | |
local_loss=False, | |
gather_with_grad=False, | |
rank=args.rank, | |
world_size=args.world_size, | |
use_horovod=args.horovod, | |
mlp_loss=args.clap_mlploss, | |
) | |
else: | |
(audio_features, text_features,) = gather_features( | |
audio_features=audio_features, | |
text_features=text_features, | |
local_loss=False, | |
gather_with_grad=False, | |
rank=args.rank, | |
world_size=args.world_size, | |
use_horovod=args.horovod, | |
mlp_loss=args.clap_mlploss, | |
) | |
if is_master(args): | |
num_samples += audio_features.shape[0] | |
for n in [*all_names, "all"]: | |
if n == "all": | |
eval_info[n]["all_audio_features"].append( | |
audio_features.cpu() | |
) | |
eval_info[n]["all_text_features"].append( | |
text_features.cpu() | |
) | |
if args.clap_mlploss: | |
eval_info[n]["all_audio_features_mlp"].append( | |
audio_features_mlp.cpu() | |
) | |
eval_info[n]["all_text_features_mlp"].append( | |
text_features_mlp.cpu() | |
) | |
else: | |
idx = np.where( | |
np.array( | |
[ | |
"-".join(b.split("/")[-3:-1]) | |
for b in batch["__url__"] | |
] | |
) | |
== n | |
)[0] | |
eval_info[n]["all_audio_features"].append( | |
audio_features.cpu().index_select( | |
0, torch.tensor(idx).long() | |
) | |
) | |
eval_info[n]["all_text_features"].append( | |
text_features.cpu().index_select( | |
0, torch.tensor(idx).long() | |
) | |
) | |
if args.clap_mlploss: | |
eval_info[n]["all_audio_features_mlp"].append( | |
audio_features_mlp.cpu().index_select( | |
0, torch.tensor(idx).long() | |
) | |
) | |
eval_info[n]["all_text_features_mlp"].append( | |
text_features_mlp.cpu().index_select( | |
0, torch.tensor(idx).long() | |
) | |
) | |
# print(f'eval step {i}') # (yusong): for debug | |
# cumulative_loss += total_loss * batch_size | |
# num_samples += batch_size | |
if is_master(args) and (i % 100) == 0: # and i != 0: | |
logging.info( | |
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" | |
) | |
if is_master(args): | |
val_metrics_per_dataset = {} | |
for n in eval_info.keys(): | |
if args.clap_mlploss: | |
metrics_single_dataset = get_metrics( | |
audio_features=torch.cat( | |
eval_info[n]["all_audio_features"] | |
), | |
text_features=torch.cat(eval_info[n]["all_text_features"]), | |
logit_scale_a=logit_scale_a.cpu(), | |
audio_features_mlp=torch.cat( | |
eval_info[n]["all_audio_features_mlp"] | |
), | |
text_features_mlp=torch.cat( | |
eval_info[n]["all_text_features_mlp"] | |
), | |
logit_scale_t=logit_scale_t.cpu(), | |
mlp_loss=args.clap_mlploss, | |
) | |
else: | |
metrics_single_dataset = get_metrics( | |
audio_features=torch.cat( | |
eval_info[n]["all_audio_features"] | |
), | |
text_features=torch.cat(eval_info[n]["all_text_features"]), | |
logit_scale_a=logit_scale_a.cpu(), | |
mlp_loss=args.clap_mlploss, | |
) | |
val_metrics_per_dataset[n] = { | |
n + "/" + k: v for k, v in metrics_single_dataset.items() | |
} | |
metrics.update(val_metrics_per_dataset[n]) | |
if "epoch" not in metrics.keys(): | |
metrics.update({"epoch": epoch}) | |
if is_master(args): | |
if not metrics: | |
return metrics | |
logging.info( | |
f"Eval Epoch: {epoch} " | |
+ "\n".join( | |
[ | |
"\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()]) | |
for m in val_metrics_per_dataset.values() | |
] | |
) | |
) | |
if args.save_logs: | |
for name, val in metrics.items(): | |
if tb_writer is not None: | |
tb_writer.add_scalar(f"val/{name}", val, epoch) | |
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: | |
f.write(json.dumps(metrics)) | |
f.write("\n") | |
if args.wandb: | |
assert wandb is not None, "Please install wandb." | |
for name, val in metrics.items(): | |
wandb.log({f"val/{name}": val, "epoch": epoch}) | |
return metrics | |
else: | |
return metrics | |
def get_metrics( | |
audio_features, | |
text_features, | |
logit_scale_a, | |
audio_features_mlp=None, | |
text_features_mlp=None, | |
logit_scale_t=None, | |
mlp_loss=False, | |
): | |
metrics = {} | |
if mlp_loss: | |
# Set up audio to text & text to audio similary matrice | |
a_logits_per_audio = ( | |
(logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu() | |
) | |
a_logits_per_text = a_logits_per_audio.t().detach().cpu() | |
t_logits_per_audio = ( | |
(logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu() | |
) | |
t_logits_per_text = t_logits_per_audio.t().detach().cpu() | |
labels = torch.arange(audio_features.shape[0]).long() | |
# Change the loss from two terms into four terms with 2x2 combined CE loss | |
total_loss = ( | |
F.cross_entropy(a_logits_per_audio, labels) | |
+ F.cross_entropy(a_logits_per_text, labels) | |
+ F.cross_entropy(t_logits_per_audio, labels) | |
+ F.cross_entropy(t_logits_per_text, labels) | |
) / 4 | |
metrics[f"cumulative_loss"] = total_loss.item() | |
metrics[f"num_samples"] = audio_features.shape[0] | |
logits = { | |
"audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2, | |
"text_to_audio": (a_logits_per_text + t_logits_per_text) / 2, | |
} | |
ground_truth = torch.arange(len(text_features)).view(-1, 1) | |
else: | |
# print("text_features", text_features) | |
# print("text_features.shape", text_features.shape) | |
logits_per_audio = ( | |
(logit_scale_a * audio_features @ text_features.t()).detach().cpu() | |
) | |
logits_per_text = logits_per_audio.t().detach().cpu() | |
labels = torch.arange(audio_features.shape[0]).long() | |
# Change the loss from two terms into four terms with 2x2 combined CE loss | |
total_loss = ( | |
F.cross_entropy(logits_per_audio, labels) | |
+ F.cross_entropy(logits_per_text, labels) | |
) / 2 | |
metrics[f"cumulative_loss"] = total_loss.item() | |
metrics[f"num_samples"] = audio_features.shape[0] | |
logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text} | |
ground_truth = torch.arange(len(text_features)).view(-1, 1) | |
for name, logit in logits.items(): | |
ranking = torch.argsort(logit, descending=True) | |
preds = torch.where(ranking == ground_truth)[ | |
1 | |
] # (yusong) this line is slow because it uses single thread | |
preds = preds.detach().cpu().numpy() | |
metrics[f"{name}_mean_rank"] = preds.mean() + 1 | |
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 | |
for k in [1, 5, 10]: | |
metrics[f"{name}_R@{k}"] = np.mean(preds < k) | |
# map@10 | |
metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) | |
return metrics | |
def evaluate_clotho_audiocaps( | |
model, data, epoch, args, autocast, device, tb_writer=None | |
): | |
""" | |
Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py. | |
1. for text-to-audio retrieval, do 5 times and average the results | |
2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text | |
3. for map@10 in audio-to-text retrieval: | |
3.1: sort the rank of 5 text | |
3.2: exclude the rank >=10 (0-index) | |
3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks). | |
(3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth. | |
(3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc. | |
""" | |
# TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now. | |
dataloader = data["val"].dataloader | |
with torch.no_grad(): | |
eval_info = {} | |
for i, batch in enumerate(dataloader): | |
audios = batch # contains mel_spec, wavform, and longer list | |
# each item in the list has 5 texts | |
if args.tmodel == "transformer": | |
from open_clip import tokenize | |
texts = [tokenize(t) for t in batch["full_text"]] | |
texts = torch.cat(texts) | |
else: | |
from .data import tokenizer | |
texts = [ | |
tokenizer(t) for t in batch["full_text"] | |
] # 5 texts for each audio | |
texts = { | |
k: torch.cat([t[k] for t in texts]) for k in texts[0].keys() | |
} # 5 x batch | |
# audios = audios.to(device=device, non_blocking=True) | |
all_names = list( | |
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) | |
) | |
for name in all_names: | |
if name not in eval_info.keys(): | |
# we will not use mlp outputs even if args.clap_mlploss=True | |
eval_info[name] = { | |
"cumulative_loss": 0.0, | |
"num_samples": 0, | |
"all_audio_features": [], | |
"all_text_features": [], | |
} | |
with autocast(): | |
audio_features = model(audios, None, device) | |
text_features = model(None, texts, device) | |
audio_features = F.normalize(audio_features, dim=-1) | |
text_features = F.normalize(text_features, dim=-1) | |
all_names = list( | |
set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) | |
) | |
for n in all_names: | |
idx = np.where( | |
np.array( | |
["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]] | |
) | |
== n | |
)[0] | |
eval_info[n]["all_audio_features"].append( | |
audio_features.cpu().index_select(0, torch.tensor(idx).long()) | |
) | |
# (yusong) please double-check. This is for selecting 5 text features at once. | |
# because idx is a list of indices in size of num_samples, | |
# and text_features is a tensor of size (5*num_samples, dim) | |
# so we need to select 5 consecutive indices at once for a single index in idx. | |
eval_info[n]["all_text_features"].append( | |
text_features.cpu() | |
.reshape([-1, 5, text_features.shape[1]]) | |
.index_select(0, torch.tensor(idx).long()) | |
.reshape([-1, text_features.shape[1]]) | |
) | |
val_metrics_all = {} | |
for n in eval_info.keys(): | |
logit_scale_a, logit_scale_t = model(None, None, device) | |
logit_scale_a = logit_scale_a.cpu() | |
audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0) | |
text_features = torch.cat(eval_info[n]["all_text_features"], dim=0) | |
logits_per_audio = ( | |
(logit_scale_a * audio_features @ text_features.t()).detach().cpu() | |
) | |
logits_per_text = logits_per_audio.t().detach().cpu() | |
# logits_per_audio shape: [num_samples, num_samples*5] | |
# logits_per_text shape: [num_samples*5, num_samples] | |
logging.info( | |
f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, " | |
f"logits_per_text shape: {logits_per_text.shape}" | |
) | |
metrics = {} | |
num_samples = audio_features.shape[0] | |
metrics[f"num_samples"] = num_samples | |
# (yusong) the following code is very important, please double-check: | |
# logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d] | |
# logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] | |
# Those two are retrieving one of the 5 text for each audio. | |
labels = torch.arange(audio_features.shape[0]).long() | |
audio_to_text_loss = [ | |
F.cross_entropy( | |
logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], | |
labels, | |
) | |
for d in range(5) | |
] | |
text_to_audio_loss = [ | |
F.cross_entropy( | |
logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], | |
labels, | |
) | |
for d in range(5) | |
] | |
total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2 | |
metrics[f"cumulative_loss"] = total_loss.item() | |
# text to audio: do 5 times | |
pred_text = [] | |
for d in range(5): | |
logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] | |
ground_truth = torch.arange(len(logit)).view(-1, 1) | |
ranking = torch.argsort( | |
logit, descending=True | |
) # [num_samples, num_samples] | |
preds = torch.where(ranking == ground_truth)[1] | |
pred_text.append(preds.detach().cpu().numpy()) | |
pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples] | |
metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1 | |
metrics[f"text_to_audio_median_rank"] = ( | |
np.floor(np.median(pred_text_concat)) + 1 | |
) | |
for k in [1, 5, 10]: | |
metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k) | |
# map@10 | |
metrics[f"text_to_audio_mAP@10"] = np.mean( | |
np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0) | |
) | |
# audio to text: take the best result | |
# for audio to text map 10, sort and assign descending ground truth. | |
# see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103 | |
# map@10 | |
map_all = [] | |
pred_audio_all = [] | |
for d in range(num_samples): | |
# logits_per_audio: [num_samples, num_samples*5] | |
logit_single = logits_per_audio[d, :] # [5*num_samples] | |
# Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4] | |
ranking = torch.argsort( | |
logit_single, descending=True | |
) # [5*num_samples] | |
# ranking: the index of first match, second match, ... | |
ground_truth = torch.arange(d * 5, d * 5 + 5)[None] | |
all_pred = torch.where( | |
torch.stack([ranking] * 5) == ground_truth.view(-1, 1) | |
)[1] | |
min_pred = torch.min(all_pred) | |
pred_audio_all.append(min_pred.detach().cpu().numpy()) | |
all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy() | |
# /5 because we have 5 text, so it means for the text rank >=10 we count as 0. | |
map_single = ( | |
np.sum( | |
(np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1)) | |
) | |
/ 5 | |
) | |
map_all.append(map_single) | |
metrics[f"audio_to_text_mAP@10"] = np.mean(map_all) | |
for k in [1, 5, 10]: | |
metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k) | |
val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()} | |
return val_metrics_all | |
def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset): | |
""" | |
Calculate performance for Clotho+AudioCaps for model selection. | |
""" | |
selection_performance_all = [] | |
for n in val_metrics_per_dataset.keys(): | |
selection_performance = ( | |
val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] | |
+ val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"] | |
) / 2 | |
selection_performance_all.append(selection_performance) | |
return np.mean(selection_performance_all) | |
def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args): | |
# val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value | |
# metrics: dict, key: metric name, value: metric value | |
# Hack: use args to save the top performance | |
if not hasattr(args, "top_selection_performance"): | |
selection_performance = calculate_selection_performance_clotho_audiocaps( | |
val_metrics_per_dataset | |
) | |
# TODO: write the if and else together | |
metric_update = {} | |
for n in val_metrics_per_dataset.keys(): | |
for k in val_metrics_per_dataset[n].keys(): | |
metric_update[ | |
k.split("/")[0] + "-top" + "/" + k.split("/")[1] | |
] = val_metrics_per_dataset[n][k] | |
metric_update["top_selection_performance"] = selection_performance | |
metric_update["top-selection-epoch"] = metrics["epoch"] | |
metrics.update(metric_update) | |
args.top_metric = metric_update | |
args.top_selection_performance = selection_performance | |
else: | |
selection_performance_new = calculate_selection_performance_clotho_audiocaps( | |
val_metrics_per_dataset | |
) | |
selection_performance_old = args.top_selection_performance | |
if selection_performance_new > selection_performance_old: | |
metric_update = {} | |
for n in val_metrics_per_dataset.keys(): | |
for k in val_metrics_per_dataset[n].keys(): | |
metric_update[ | |
k.split("/")[0] + "-top" + "/" + k.split("/")[1] | |
] = val_metrics_per_dataset[n][k] | |
metric_update["top_selection_performance"] = selection_performance_new | |
metric_update["top-selection-epoch"] = metrics["epoch"] | |
metrics.update(metric_update) | |
args.top_metric = metric_update | |
args.top_selection_performance = selection_performance_new | |
else: | |
metrics.update(args.top_metric) | |
return metrics | |