Spaces:
Sleeping
Sleeping
import os.path | |
import glob | |
import random | |
import numpy as np | |
import logging | |
import wandb | |
import torch | |
import torch.nn.functional as F | |
import torch.backends.cudnn as cudnn | |
from clap_module import create_model | |
from clap_module import tokenize | |
from training.logger import setup_logging | |
from training.data import get_data | |
from training.train import evaluate | |
from clap_module.utils import get_tar_path_from_dataset_name, dataset_split | |
from training.params import parse_args | |
def find_params_value(file, key): | |
# find value of params in params_file | |
with open(file, 'r') as f: | |
for line in f: | |
if key + ': ' in line: | |
return line.split(': ')[1].strip() | |
return None | |
def evaluate_zeroshot(model, data, start_epoch, args, writer): | |
dataloader = data["val"].dataloader | |
metrics = {} | |
device = torch.device(args.device) | |
model.eval() | |
metrics.update({"epoch": start_epoch}) | |
all_audio_features = [] | |
all_class_labels = [] | |
with torch.no_grad(): | |
for i, batch in enumerate(dataloader): | |
audios = batch # contains mel_spec, wavform, and longer list | |
audio_features = model(audios, None, device) | |
audio_features = F.normalize(audio_features, dim=-1) | |
all_audio_features.append(audio_features.detach().cpu()) | |
all_class_labels.append(torch.argmax(batch["class_label"], 1).long()) | |
all_audio_features = torch.cat(all_audio_features, dim=0) | |
all_class_labels = torch.cat(all_class_labels, dim=0) | |
metrics["num_samples"] = all_audio_features.shape[0] | |
# get text features | |
all_texts = ["This is a sound of " + t for t in args.class_index_dict.keys()] | |
# (yusong): a hack, can make it better | |
if args.tmodel == "transformer": | |
from clap_module.tokenizer import tokenize | |
all_texts = tokenize(all_texts) | |
else: | |
from training.data import tokenizer | |
all_texts = tokenizer(all_texts) | |
all_text_features = model(None, all_texts, device) | |
all_text_features = F.normalize(all_text_features, dim=-1).detach().cpu() | |
# compute similarity | |
logit_scale_a, logit_scale_t = model(None, None, device) | |
logit_scale_a = logit_scale_a.cpu() | |
logits_per_audio = (logit_scale_a * all_audio_features @ all_text_features.t()).detach().cpu() | |
logits_per_text = logits_per_audio.t().detach().cpu() | |
ground_truth = all_class_labels.view(-1, 1) | |
logit = logits_per_audio | |
ranking = torch.argsort(logit, descending=True) | |
preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread | |
preds = preds.detach().cpu().numpy() | |
metrics[f"{args.datasetnames[0]}_mean_rank"] = preds.mean() + 1 | |
metrics[f"{args.datasetnames[0]}_median_rank"] = np.floor(np.median(preds)) + 1 | |
for k in [1, 5, 10]: | |
metrics[f"{args.datasetnames[0]}_R@{k}"] = np.mean(preds < k) | |
# map@10 | |
metrics[f"{args.datasetnames[0]}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) | |
logging.info( | |
f"Eval Epoch: {start_epoch} " | |
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) | |
) | |
if args.wandb: | |
assert wandb is not None, "Please install wandb." | |
for name, val in metrics.items(): | |
wandb.log({f"val/{name}": val, "epoch": start_epoch}) | |
if __name__ == '__main__': | |
# (yusong) repeated run might have different metric results. | |
# This is because we randomly select crop 10s for each audio. | |
args = parse_args() | |
if os.path.isdir(args.pretrained): | |
log_dir = os.path.dirname(args.pretrained) | |
else: | |
log_dir = os.path.dirname(os.path.dirname(args.pretrained)) | |
args.log_level = logging.DEBUG if args.debug else logging.INFO | |
log_path = os.path.join(log_dir, 'out.log') | |
setup_logging(log_path, args.log_level) | |
params_file = os.path.join(log_dir, 'params.txt') | |
seed = 3407 | |
random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
cudnn.benchmark = True | |
cudnn.deterministic = False | |
pretrained = 'openai' | |
amodel = find_params_value(params_file, 'amodel') | |
tmodel = find_params_value(params_file, 'tmodel') | |
if amodel is None or tmodel is None: | |
raise ValueError('model type not found in params file') | |
# set up dummy values for args | |
args.parallel_eval = False | |
args.rank = 0 | |
args.local_rank = 0 | |
args.world_size = 1 | |
args.val_frequency = 1 | |
args.epochs = 1 | |
args.precision = 'fp32' | |
args.save_logs = True | |
args.wandb = args.report_to == 'wandb' | |
args.class_index_dict = None | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
args.device = device | |
if args.remotedata: | |
for dataset_name in args.datasetnames: | |
for split in dataset_split[dataset_name]: | |
if not os.path.exists(f"./json_files/{dataset_name}/{split}"): | |
os.makedirs(f"./json_files/{dataset_name}/{split}") | |
os.system( | |
f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" | |
) | |
if args.datasetinfos is None: | |
args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] | |
if args.dataset_type == "webdataset": | |
args.train_data = get_tar_path_from_dataset_name( | |
args.datasetnames, | |
args.datasetinfos, | |
islocal=not args.remotedata, | |
proportion=args.dataset_proportion, | |
dataset_path=args.datasetpath, | |
) | |
args.val_data = get_tar_path_from_dataset_name( | |
args.datasetnames, | |
["valid", "test", "eval"], | |
islocal=not args.remotedata, | |
proportion=1, | |
dataset_path=args.datasetpath, | |
) | |
model, model_cfg = create_model( | |
amodel, | |
tmodel, | |
pretrained, | |
precision='fp32', | |
device=device, | |
jit=False, | |
force_quick_gelu=False, | |
openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), | |
skip_params=False, | |
enable_fusion=args.enable_fusion, | |
fusion_type=args.fusion_type | |
) # a hack to get model_cfg | |
data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data | |
writer = None # if use tensorboard, initalize writer here | |
if args.wandb: | |
assert wandb is not None, "Please install wandb." | |
# # find the line with "wandb_notes" and get the value | |
# wandb_notes = find_params_value(params_file, 'wandb_notes') | |
# if wandb_notes is None: | |
# print(f'wandb_notes not found in params file: {params_file}, set to timestamp.') | |
# wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}' | |
# wandb_notes = wandb_notes + '-eval-retrieval' | |
wandb_notes = args.wandb_notes | |
logging.debug("Starting wandb.") | |
args.train_sz = data["train"].dataloader.num_samples | |
if args.val_data is not None: | |
args.val_sz = data["val"].dataloader.num_samples | |
# you will have to configure this for your project! | |
if args.wandb_id is not None: | |
wandb.init( | |
project="clap", | |
id=args.wandb_id, | |
resume=True | |
) | |
else: | |
wandb.init( | |
project="clap", | |
notes=wandb_notes, | |
name=wandb_notes, | |
tags=[], | |
config=vars(args), | |
) | |
logging.debug("Finished loading wandb.") | |
if os.path.isdir(args.pretrained): | |
all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime) | |
else: | |
all_model_checkpoints = [args.pretrained] | |
for model_path in all_model_checkpoints: | |
args.checkpoint_path = os.path.dirname(model_path) | |
model, model_cfg = create_model( | |
amodel, | |
tmodel, | |
pretrained, | |
precision='fp32', | |
device=device, | |
jit=False, | |
force_quick_gelu=False, | |
openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), | |
skip_params=False, | |
enable_fusion=args.enable_fusion, | |
fusion_type=args.fusion_type | |
) | |
# load model | |
checkpoint = torch.load(model_path, map_location=device) | |
if "epoch" in checkpoint: | |
# resuming a train checkpoint w/ epoch and optimizer state | |
start_epoch = checkpoint["epoch"] | |
sd = checkpoint["state_dict"] | |
if next(iter(sd.items()))[0].startswith( | |
"module" | |
): | |
sd = {k[len("module."):]: v for k, v in sd.items()} | |
model.load_state_dict(sd) | |
logging.info( | |
f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})" | |
) | |
else: | |
# loading a bare (model only) checkpoint for fine-tune or evaluation | |
model.load_state_dict(checkpoint) | |
start_epoch = 0 | |
model.to(device) | |
model.eval() | |
for param in model.parameters(): | |
param.requires_grad = False | |
evaluate_zeroshot(model, data, start_epoch, args, writer) | |