Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import json | |
from tqdm import tqdm | |
from collections import defaultdict | |
from copy import deepcopy | |
from PIL import Image | |
import numpy as np | |
import torch | |
from torch.nn import CrossEntropyLoss | |
from LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | |
from LLaVA.llava.conversation import conv_templates, SeparatorStyle | |
from LLaVA.llava.model.builder import load_pretrained_model | |
from LLaVA.llava.utils import disable_torch_init | |
from LLaVA.llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_object_token | |
from visual_search import parse_args, VSM, visual_search | |
def normalize_bbox(bbox, image_width, image_height): | |
normalized_bbox = [bbox[0]/image_width, bbox[1]/image_height, (bbox[0]+bbox[2])/image_width, (bbox[1]+bbox[3])/image_height] | |
normalized_bbox = [np.clip(_, 0, 1) for _ in normalized_bbox] | |
return normalized_bbox | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img, 0, 0 | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result, 0, (width - height) // 2 | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result, (height - width) // 2, 0 | |
class VQA_LLM: | |
def __init__(self, args): | |
disable_torch_init() | |
model_path = args.vqa_model_path | |
model_name = get_model_name_from_path(model_path) | |
model_name += 'llava' | |
model_base = None | |
device_map = "auto" | |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, model_name) | |
self.conv_type = args.conv_type | |
def get_patch(self, bbox, image_width, image_height, patch_size=224, patch_scale=None): | |
object_width = int(np.ceil(bbox[2])) | |
object_height = int(np.ceil(bbox[3])) | |
object_center_x = int(bbox[0] + bbox[2]/2) | |
object_center_y = int(bbox[1] + bbox[3]/2) | |
if patch_scale is None: | |
patch_width = max(object_width, patch_size) | |
patch_height = max(object_height, patch_size) | |
else: | |
patch_width = int(object_width*patch_scale) | |
patch_height = int(object_height*patch_scale) | |
left = max(0, object_center_x-patch_width//2) | |
right = min(left+patch_width, image_width) | |
top = max(0, object_center_y-patch_height//2) | |
bottom = min(top+patch_height, image_height) | |
return [left, top, right, bottom] | |
def get_object_crop(self, image, bbox, patch_scale): | |
resized_bbox = self.get_patch(bbox, image.width, image.height, patch_scale=patch_scale) | |
object_crop = image.crop((resized_bbox[0], resized_bbox[1], resized_bbox[2], resized_bbox[3])) | |
object_crop = object_crop.resize((self.image_processor.crop_size['width'],self.image_processor.crop_size['height'])) | |
object_crop = self.image_processor.preprocess(object_crop, return_tensors='pt')['pixel_values'][0] | |
return object_crop | |
def free_form_inference(self, image, question, temperature=0, top_p=None, num_beams=1, max_new_tokens=200, object_crops=None, images_long=None, objects_long=None): | |
conv = conv_templates[self.conv_type].copy() | |
qs = DEFAULT_IMAGE_TOKEN + '\n' + question | |
conv.append_message(conv.roles[0], qs) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
keywords = [stop_str] | |
input_ids = tokenizer_image_object_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | |
image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) | |
output_ids = self.model.generate( | |
input_ids, | |
images=image_tensor.unsqueeze(0).half().cuda(), | |
object_features=object_crops.half().cuda() if object_crops is not None else None, | |
images_long = images_long, | |
objects_long = objects_long, | |
do_sample= True if temperature > 0 else False, | |
num_beams=num_beams, | |
temperature=temperature, | |
top_p = top_p, | |
max_new_tokens=max_new_tokens, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria]) | |
input_token_len = input_ids.shape[1] | |
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | |
if n_diff_input_output > 0: | |
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') | |
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | |
outputs = outputs.strip() | |
if outputs.endswith(stop_str): | |
outputs = outputs[:-len(stop_str)] | |
outputs = outputs.strip() | |
return outputs | |
def multiple_choices_inference(self, image, question, options, object_crops=None, images_long=None, objects_long=None): | |
conv = conv_templates[self.conv_type].copy() | |
qs = DEFAULT_IMAGE_TOKEN + '\n' + question | |
conv.append_message(conv.roles[0], qs) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
question_input_ids = tokenizer_image_object_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | |
image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] | |
output_question = self.model( | |
question_input_ids, | |
use_cache=True, | |
images=image_tensor.unsqueeze(0).half().cuda(), | |
object_features=object_crops.half().cuda() if object_crops is not None else None, | |
images_long = images_long, | |
objects_long = objects_long) | |
question_logits = output_question.logits | |
question_past_key_values = output_question.past_key_values | |
loss_list = [] | |
for option in options: | |
conv = conv_templates[self.conv_type].copy() | |
conv.append_message(conv.roles[0], qs) | |
conv.append_message(conv.roles[1], option) | |
full_prompt = conv.get_prompt() | |
full_input_ids = tokenizer_image_object_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | |
option_answer_input_ids = full_input_ids[:, question_input_ids.shape[1]:] | |
output_option = self.model(input_ids=option_answer_input_ids, | |
use_cache=True, | |
attention_mask=torch.ones(1, question_logits.shape[1]+option_answer_input_ids.shape[1], device=full_input_ids.device), | |
past_key_values=question_past_key_values) | |
logits = torch.cat([question_logits[:, -1:], output_option.logits[:, :-1]], 1) | |
loss_fct = CrossEntropyLoss() | |
logits = logits.view(-1, self.model.config.vocab_size) | |
labels = option_answer_input_ids.view(-1) | |
loss = loss_fct(logits, labels) | |
loss_list.append(loss) | |
option_chosen = torch.stack(loss_list).argmin() | |
return option_chosen.cpu().item() | |
def eval_model(args): | |
# init VQA LLM | |
vqa_llm = VQA_LLM(args) | |
# init VSM | |
vsm_args = parse_args({}) | |
vsm_args.version = args.vsm_model_path | |
vsm = VSM(vsm_args) | |
results = {} | |
per_type_acc = defaultdict(list) | |
all_acc = [] | |
missing_objects_msg = "Sorry, I can not answer the question. Some visual information about the following objects is missing or unclear:" | |
focus_msg = "Additional visual information to focus on: " | |
for test_type in ['direct_attributes', 'relative_position']: | |
results[test_type] = [] | |
folder = os.path.join(args.benchmark_folder, test_type) | |
image_files = list(filter(lambda file: '.json' not in file, os.listdir(folder))) | |
for image_file in tqdm(image_files): | |
result_single_sample = {} | |
image_path = os.path.join(folder, image_file) | |
annotation_path = image_path.split('.')[0] + '.json' | |
image = Image.open(image_path).convert('RGB') | |
annotation = json.load(open(annotation_path)) | |
image, _, _ = expand2square(image, tuple(int(x*255) for x in vqa_llm.image_processor.image_mean)) | |
question = annotation['question'] | |
# generate free-form response to check whether visual search needs to be activated | |
prediction = vqa_llm.free_form_inference(image, question) | |
missing_objects = [] | |
if missing_objects_msg in prediction: | |
missing_objects = prediction.split(missing_objects_msg)[-1] | |
if missing_objects.endswith('.'): | |
missing_objects = missing_objects[:-1] | |
missing_objects = missing_objects.split(',') | |
missing_objects = [missing_object.strip() for missing_object in missing_objects] | |
search_result = [] | |
if len(missing_objects) > 0: | |
# visual search | |
for object_name in missing_objects: | |
image = Image.open(image_path).convert('RGB') | |
smallest_size = max(int(np.ceil(min(image.width, image.height)/args.minimum_size_scale)), args.minimum_size) | |
final_step, path_length, search_successful, all_valid_boxes = visual_search(vsm, image, object_name, target_bbox=None, smallest_size=smallest_size) | |
if all_valid_boxes is not None: | |
# might exist multiple target instances | |
for search_bbox in all_valid_boxes: | |
search_final_patch = final_step['bbox'] | |
search_bbox[0] += search_final_patch[0] | |
search_bbox[1] += search_final_patch[1] | |
search_result.append({'bbox':search_bbox.tolist(),'name':object_name}) | |
else: | |
search_bbox = final_step['detection_result'] | |
search_final_patch = final_step['bbox'] | |
search_bbox[0] += search_final_patch[0] | |
search_bbox[1] += search_final_patch[1] | |
search_result.append({'bbox':search_bbox.tolist(),'name':object_name}) | |
# predict the multiple-choice option | |
options = annotation['options'] | |
image = Image.open(image_path).convert('RGB') | |
if len(missing_objects) > 0: | |
object_names = [_['name'] for _ in search_result] | |
bboxs = deepcopy([_['bbox'] for _ in search_result]) | |
if len(object_names) <= 2: | |
images_long = [False] | |
objects_long = [True]*len(object_names) | |
else: | |
images_long = [False] | |
objects_long = [False]*len(object_names) | |
object_crops = [] | |
for bbox in bboxs: | |
object_crop = vqa_llm.get_object_crop(image, bbox, patch_scale=1.2) | |
object_crops.append(object_crop) | |
object_crops = torch.stack(object_crops, 0) | |
image, left, top = expand2square(image, tuple(int(x*255) for x in vqa_llm.image_processor.image_mean)) | |
bbox_list = [] | |
for bbox in bboxs: | |
bbox[0] += left | |
bbox[1] += top | |
bbox_list.append(bbox) | |
bbox_list = [normalize_bbox(bbox, image.width, image.height) for bbox in bbox_list] | |
cur_focus_msg = focus_msg | |
for i, (object_name, bbox) in enumerate(zip(object_names, bbox_list)): | |
cur_focus_msg = cur_focus_msg + "{} <object> at location [{:.3f},{:.3f},{:.3f},{:.3f}]".format(object_name, bbox[0], bbox[1], bbox[2], bbox[3]) | |
if i != len(bbox_list)-1: | |
cur_focus_msg = cur_focus_msg+"; " | |
else: | |
cur_focus_msg = cur_focus_msg +'.' | |
question_with_focus = cur_focus_msg+"\n"+question | |
option_chosen = vqa_llm.multiple_choices_inference(image, question_with_focus, options, object_crops, images_long=images_long, objects_long=objects_long) | |
else: | |
option_chosen = vqa_llm.multiple_choices_inference(image, question, options) | |
correct = 1 if option_chosen==0 else 0 | |
per_type_acc[test_type].append(correct) | |
all_acc.append(correct) | |
result_single_sample['question'] = question | |
result_single_sample['options'] = options | |
result_single_sample['image'] = image_file | |
result_single_sample['prediction_freeform'] = prediction | |
result_single_sample['missing_objects'] = missing_objects | |
result_single_sample['search_result'] = search_result | |
result_single_sample['option_chosen'] = option_chosen | |
result_single_sample['correct'] = correct | |
results[test_type].append(result_single_sample) | |
print(test_type, np.mean(per_type_acc[test_type])) | |
print(np.mean(all_acc)) | |
with open(args.output_path, 'w') as f: | |
json.dump(results, f, indent=4) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--vqa-model-path", type=str, default="craigwu/seal_vqa_7b") | |
parser.add_argument("--vqa-model-base", type=str, default=None) | |
parser.add_argument("--conv_type", default="v1", type=str,) | |
parser.add_argument("--benchmark-folder", type=str, default="vstar_bench") | |
parser.add_argument("--vsm-model-path", type=str, default="craigwu/seal_vsm_7b") | |
parser.add_argument("--output-path", type=str, default="eval_result.json") | |
parser.add_argument("--minimum_size_scale", default=4.0, type=float, help="minimum sub-image scale for the termination of search") | |
parser.add_argument("--minimum_size", default=224, type=int, help="minimum sub-image size for the termination of search") | |
args = parser.parse_args() | |
eval_model(args) |