Spaces:
Runtime error
Runtime error
import functools | |
import itertools | |
import logging | |
from tqdm import tqdm | |
from PIL import Image | |
from multiprocessing import Pool | |
import multiprocessing as mp | |
from argparse import ArgumentParser | |
import numpy as np | |
import torch | |
import torchvision | |
from decord import VideoReader, cpu | |
import transformers | |
from tasks.eval.model_utils import load_pllava, pllava_answer | |
from tasks.eval.eval_utils import conv_templates | |
from tasks.eval.mvbench import ( | |
MVBenchDataset, | |
check_ans, | |
save_results, | |
load_results, | |
) | |
logging.basicConfig() | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
RESOLUTION = 672 # | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--pretrained_model_name_or_path", | |
type=str, | |
required=True, | |
default='llava-hf/llava-1.5-7b-hf' | |
) | |
parser.add_argument( | |
"--save_path", | |
type=str, | |
required=True, | |
default='"./test_results/test_llava_mvbench"' | |
) | |
parser.add_argument( | |
"--num_frames", | |
type=int, | |
required=True, | |
default=4, | |
) | |
parser.add_argument( | |
"--use_lora", | |
action='store_true' | |
) | |
parser.add_argument( | |
"--lora_alpha", | |
type=int, | |
required=False, | |
default=32, | |
) | |
parser.add_argument( | |
"--weight_dir", | |
type=str, | |
required=False, | |
default=None, | |
) | |
parser.add_argument( | |
"--conv_mode", | |
type=str, | |
required=False, | |
default='eval_mvbench', | |
) | |
parser.add_argument( | |
"--pooling_shape", | |
type=str, | |
required=False, | |
default=None, | |
) | |
args = parser.parse_args() | |
return args | |
def load_model_and_dataset(rank, world_size, pretrained_model_name_or_path, num_frames, use_lora, lora_alpha, weight_dir, pooling_shape=(16,12,12)): | |
# remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes. | |
model, processor = load_pllava(pretrained_model_name_or_path, num_frames=num_frames, use_lora=use_lora, weight_dir=weight_dir, lora_alpha=lora_alpha, pooling_shape=pooling_shape) | |
logger.info('done loading llava') | |
# position embedding | |
model = model.to(torch.device(rank)) | |
model = model.eval() | |
dataset = MVBenchDataset(num_segments=num_frames) | |
dataset.set_rank_and_world_size(rank, world_size) | |
return model, processor, dataset | |
def infer_mvbench( | |
model, | |
processor, | |
data_sample, | |
conv_mode, | |
pre_query_prompt=None, # add in the head of question | |
post_query_prompt=None, # add in the end of question | |
answer_prompt=None, # add in the begining of answer | |
return_prompt=None, # add in the begining of return message | |
print_res=False, | |
): | |
video_list = data_sample["video_pils"] | |
conv = conv_templates[conv_mode].copy() | |
conv.user_query(data_sample['question'], pre_query_prompt, post_query_prompt, is_mm=True) | |
if answer_prompt is not None: | |
conv.assistant_response(answer_prompt) | |
llm_message, conv = pllava_answer( | |
conv=conv, | |
model=model, | |
processor=processor, | |
img_list=video_list, | |
max_new_tokens=100, | |
do_sample=False, | |
print_res=print_res | |
) | |
if answer_prompt is not None: | |
llm_message = ''.join(llm_message.split(answer_prompt)[1:]) | |
if return_prompt is not None: | |
llm_message = return_prompt + llm_message | |
return llm_message | |
def single_test(model, processor, vid_path, num_frames=4, conv_mode="plain"): | |
def get_index(num_frames, num_segments): | |
seg_size = float(num_frames - 1) / num_segments | |
start = int(seg_size / 2) | |
offsets = np.array([ | |
start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
]) | |
return offsets | |
def load_video(video_path, num_segments=8, return_msg=False, num_frames=4, resolution=336): | |
transforms = torchvision.transforms.Resize(size=resolution) | |
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
num_frames = len(vr) | |
frame_indices = get_index(num_frames, num_segments) | |
images_group = list() | |
for frame_index in frame_indices: | |
img = Image.fromarray(vr[frame_index].asnumpy()) | |
images_group.append(transforms(img)) | |
if return_msg: | |
fps = float(vr.get_avg_fps()) | |
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) | |
# " " should be added in the start and end | |
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." | |
return images_group, msg | |
else: | |
return images_group | |
if num_frames != 0: | |
vid, msg = load_video(vid_path, num_segments=num_frames, return_msg=True, resolution=RESOLUTION) | |
else: | |
vid, msg = None, 'num_frames is 0, not inputing image' | |
img_list = vid | |
conv = conv_templates[conv_mode].copy() | |
conv.user_query("Describe the video in details.", is_mm=True) | |
llm_response, conv = pllava_answer(conv=conv, model=model, processor=processor, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True) | |
def run(rank, args, world_size): | |
if rank != 0: | |
transformers.utils.logging.set_verbosity_error() | |
logger.setLevel(transformers.logging.ERROR) | |
print_res = False | |
conv_mode= args.conv_mode | |
pre_query_prompt = None | |
post_query_prompt = "\nOnly give the best option." | |
if args.pooling_shape is not None: | |
pooling_shape=tuple([int(x) for x in args.pooling_shape.split("-")]) | |
logger.info(f'loading model and constructing dataset to gpu {rank}...') | |
model, processor, dataset = load_model_and_dataset(rank, | |
world_size, | |
pretrained_model_name_or_path=args.pretrained_model_name_or_path, | |
num_frames=args.num_frames, | |
use_lora=args.use_lora, | |
lora_alpha=args.lora_alpha, | |
weight_dir=args.weight_dir, | |
pooling_shape=pooling_shape) | |
logger.info(f'done model and dataset...') | |
logger.info('constructing dataset...') | |
logger.info('single test...') | |
vid_path = "./example/yoga.mp4" | |
# vid_path = "./example/jesse_dance.mp4" | |
if rank == 0: | |
single_test(model, | |
processor, | |
vid_path, | |
num_frames=args.num_frames, | |
conv_mode=args.conv_mode) | |
logger.info('single test done...') | |
tbar = tqdm(total=len(dataset)) | |
correct = 0 | |
total = 0 | |
result_list = [] | |
acc_dict = {} | |
done_count = 0 | |
for example in dataset: | |
task_type = example['task_type'] | |
if task_type not in acc_dict: | |
acc_dict[task_type] = [0, 0] # correct, total | |
acc_dict[task_type][1] += 1 | |
total += 1 | |
pred = infer_mvbench( | |
model, | |
processor, | |
example, | |
conv_mode=conv_mode, | |
pre_query_prompt=pre_query_prompt, | |
post_query_prompt=post_query_prompt, | |
answer_prompt="Best option:(", | |
return_prompt='(', | |
print_res=print_res, | |
) | |
gt = example['answer'] | |
result_list.append({ | |
'pred': pred, | |
'gt': gt, | |
'task_type': task_type, | |
'video_path': example['video_path'], | |
'question': example['question'], | |
}) | |
if check_ans(pred=pred, gt=gt): | |
acc_dict[task_type][0] += 1 | |
correct += 1 | |
if rank == 0: | |
tbar.update(len(result_list) - done_count, ) | |
tbar.set_description_str( | |
f"One Chunk--Task Type: {task_type}, Chunk Part Acc: {acc_dict[task_type][0] / acc_dict[task_type][1] * 100 :.2f}%;" | |
f" Chunk Total Acc: {correct / total * 100 :.2f}%" | |
) | |
done_count = len(result_list) | |
return result_list | |
def main(): | |
multiprocess=True | |
mp.set_start_method('spawn') | |
args = parse_args() | |
save_path = args.save_path | |
json_data = load_results(save_path) | |
if json_data is None: | |
if multiprocess: | |
logger.info(f'started benchmarking, saving to: {save_path}') | |
n_gpus = torch.cuda.device_count() | |
# assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" | |
world_size = n_gpus | |
with Pool(world_size) as pool: | |
func = functools.partial(run, args=args, world_size=world_size) | |
result_lists = pool.map(func, range(world_size)) | |
logger.info('finished running') | |
result_list = [ res for res in itertools.chain(*result_lists)] | |
else: | |
result_list = run(0, world_size=1, args=args) # debug | |
else: | |
logger.info(f'loaded results from {save_path}') | |
result_list = json_data | |
save_results(result_list, save_path) | |
if __name__ == "__main__": | |
main() |