Spaces:
Running
on
L4
Running
on
L4
import argparse | |
import functools | |
import glob | |
import os | |
import random | |
import string | |
import json | |
import sys | |
sys.path.append('../') | |
from tqdm import tqdm | |
import yaml | |
from collections import defaultdict | |
import io | |
import warnings | |
import subprocess | |
import pickle | |
import numpy as np | |
import torch | |
from data.data import get_audiotext_dataloader | |
from src.factory import create_model_and_transforms | |
from train.train_utils import Dict2Class, get_autocast, get_cast_dtype | |
def inference_this( | |
args, data_config, clap_config, model_config, test_dataset_name, tmp_file, | |
temperature=1.0, num_beams=3, ckpt=-1, end_batch_idx=-2, verbose=False, | |
): | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning | |
model, tokenizer = create_model_and_transforms( | |
**model_config, | |
clap_config=clap_config, | |
use_local_files=args.offline, | |
gradient_checkpointing=args.gradient_checkpointing, | |
freeze_lm_embeddings=args.freeze_lm_embeddings, | |
) | |
device_id = 0 | |
model = model.to(device_id) | |
model.eval() | |
if ckpt == -1: | |
checkpoint_list = glob.glob(f"{args.expdir}/{args.run_name}/checkpoint_*.pt") | |
resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1] | |
else: | |
resume_from_checkpoint = f"{args.expdir}/{args.run_name}/checkpoint_{ckpt}.pt" | |
checkpoint = torch.load(resume_from_checkpoint, map_location="cpu") | |
msd = checkpoint["model_state_dict"] | |
msd = {k.replace("module.", ""): v for k, v in msd.items()} | |
x,y = model.load_state_dict(msd, False) | |
print(x) | |
print(y) | |
autocast = get_autocast( | |
args.precision, cache_enabled=(not args.fsdp) | |
) | |
cast_dtype = get_cast_dtype(args.precision) | |
# model = model.to(dtype=cast_dtype) | |
if test_dataset_name in data_config["valid_dataset_config"]: | |
data_config["valid_dataset_config"] = {test_dataset_name: data_config["valid_dataset_config"][test_dataset_name]} | |
else: | |
data_config["valid_dataset_config"] = {test_dataset_name: True} | |
all_test_AudioTextDataInfo = get_audiotext_dataloader(data_config, clap_config, tokenizer, args.batch_size, split='test') | |
assert test_dataset_name in list(all_test_AudioTextDataInfo.keys()), "{} not a test set".format(test_dataset_name) | |
dataloader = all_test_AudioTextDataInfo[test_dataset_name].dataloader | |
deduplicate_tasks = ["Clotho-v2-AudioCaptioning", "audiocaps-AudioCaptioning", "MACS-AudioCaptioning", "LP-MusicCaps-MSD-AudioCaptioning", "LP-MusicCaps-MC-AudioCaptioning"] | |
if any([test_dataset_name.startswith(x) for x in deduplicate_tasks]): | |
deduplicate = True | |
else: | |
deduplicate = False | |
if os.path.exists(tmp_file): | |
with open(tmp_file, 'rb') as pickle_file: | |
tmp_data = pickle.load(pickle_file) | |
results_dic = tmp_data['results_dic'] | |
results = tmp_data['results'] | |
finished_batches = tmp_data['finished_batches'] | |
print('reading tmp data from {}: {} batches already computed'.format(tmp_file, finished_batches+1)) | |
else: | |
tmp_data = {} | |
results_dic = {} # for deduplicate | |
results = [] # for non-deduplicate | |
finished_batches = -1 | |
print('no tmp data found; will store tmp data to {}'.format(tmp_file)) | |
# print(len(dataloader)) | |
# print('---------------------') | |
from itertools import islice | |
for batch_idx, batch in tqdm(enumerate(islice(dataloader, finished_batches, None), start=finished_batches)): | |
# for batch_idx, batch in tqdm(enumerate(dataloader)): | |
if end_batch_idx > 0 and batch_idx == end_batch_idx: | |
break | |
if batch_idx <= finished_batches: | |
continue | |
audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
input_ids = batch["input_ids"].to(device_id, non_blocking=True) | |
filenames = batch["filenames"] | |
# print(input_ids) | |
media_token_id = tokenizer.encode("<audio>")[-1] | |
sep_token_id = tokenizer.sep_token_id | |
for idx in range(input_ids.shape[0]): | |
filename = filenames[idx] | |
if type(filename) is list: | |
# interleaved data | |
filename = filename[-1] | |
input_id = input_ids[idx] | |
for sep_location in range(len(input_id)-1, -1, -1): | |
# find last <SEP> | |
if input_id[sep_location] == sep_token_id: | |
break | |
# print(tokenizer.decode(input_id)) | |
prompt = input_id[:sep_location+1] | |
prompt_decoded = tokenizer.decode(prompt).replace(tokenizer.sep_token, '') | |
ground_truth_decoded = tokenizer.decode(input_id).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') | |
if not (deduplicate and (filename, prompt_decoded) in results_dic): | |
# print(prompt) | |
# print(prompt_decoded) | |
output = model.generate( | |
audio_x=audio_clips[idx].unsqueeze(0), | |
audio_x_mask=audio_embed_mask[idx].unsqueeze(0), | |
lang_x=prompt.unsqueeze(0), | |
eos_token_id=tokenizer.eos_token_id, | |
max_new_tokens=256, | |
temperature=temperature, | |
)[0] | |
output_decoded = tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') | |
# print(ground_truth_decoded) | |
# print('------') | |
# print(output_decoded) | |
if deduplicate: | |
if (filename, prompt_decoded) in results_dic: | |
results_dic[(filename, prompt_decoded)]['ground_truth'].append(ground_truth_decoded) | |
else: | |
results_dic[(filename, prompt_decoded)] = { | |
'ground_truth': [ground_truth_decoded], | |
'output': output_decoded | |
} | |
else: | |
results.append((filename, prompt_decoded, ground_truth_decoded, output_decoded)) | |
tmp_data['results_dic'] = results_dic | |
tmp_data['results'] = results | |
tmp_data['finished_batches'] = batch_idx | |
with open(tmp_file, 'wb') as pickle_file: | |
pickle.dump(tmp_data, pickle_file) | |
if deduplicate: | |
for (filename, prompt) in results_dic: | |
ground_truth = '|'.join(results_dic[(filename, prompt)]['ground_truth']) | |
output = results_dic[(filename, prompt)]['output'] | |
results.append((filename, prompt, ground_truth, output)) | |
# if verbose: | |
# for filename, prompt, ground_truth, output in results: | |
# print('-'*30) | |
# print('filename:', filename) | |
# print('prompt:', prompt) | |
# print('ground_truth:', ground_truth) | |
# print('output:', output) | |
return results | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-c', '--config', type=str, default='../config/config.yaml', help='yaml config path') | |
parser.add_argument('-t', '--task', type=str, help='which task to inference') | |
parser.add_argument('-temp', '--temperature', type=float, default=1.0, help='temperature') | |
parser.add_argument('-nb', '--num_beams', type=int, default=1, help='num beams for beam search') | |
parser.add_argument('--ckpt', type=int, default=-1, help='checkpoint idx, -1 means latest') | |
parsed_args = parser.parse_args() | |
print(parsed_args) | |
test_dataset_name = parsed_args.task | |
output_file = os.path.join( | |
'../outputs/', | |
parsed_args.task.replace('/', '-'), | |
'{}-ckpt{}-{}.log'.format( | |
parsed_args.config.split('/')[-1][:-5], | |
parsed_args.ckpt, | |
"sft" | |
) | |
) | |
tmp_file = output_file.replace('.log', '.tmp.pickle') | |
print('output file:', output_file) | |
print('no previous log file; generating samples') | |
config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader) | |
# print(config) | |
# print('----------------------') | |
data_config = config['data_config'] | |
model_config = config['model_config'] | |
print(model_config) | |
clap_config = config['clap_config'] | |
clap_config = config['clap_config'] | |
mert_config = config['mert_config'] | |
args = Dict2Class(config['train_config']) | |
results = inference_this( | |
args, data_config, clap_config, model_config, test_dataset_name, | |
temperature=float(parsed_args.temperature), | |
num_beams=int(parsed_args.num_beams), | |
ckpt=parsed_args.ckpt, | |
verbose=True, | |
tmp_file=tmp_file, | |
) | |
if __name__ == "__main__": | |
main() |