vstar / visual_search.py
Penghao Wu
init
b11ae09
import argparse
import os
import sys
import json
import tqdm
import copy
from queue import PriorityQueue
import functools
import spacy
nlp = spacy.load("en_core_web_sm")
import cv2
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, CLIPImageProcessor
from transformers import OwlViTProcessor
from VisualSearch.model.VSM import VSMForCausalLM
from VisualSearch.model.llava import conversation as conversation_lib
from VisualSearch.model.llava.mm_utils import tokenizer_image_token
from VisualSearch.utils.utils import expand2square
from VisualSearch.utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
def parse_args(args):
parser = argparse.ArgumentParser(description="Visual Search Evaluation")
parser.add_argument("--version", default="craigwu/seal_vsm_7b")
parser.add_argument("--benchmark-folder", default="vstar_bench", type=str)
parser.add_argument("--visualization", action="store_true", default=False)
parser.add_argument("--output_path", default="", type=str)
parser.add_argument("--confidence_low", default=0.3, type=float)
parser.add_argument("--confidence_high", default=0.5, type=float)
parser.add_argument("--target_cue_threshold", default=6.0, type=float)
parser.add_argument("--target_cue_threshold_decay", default=0.7, type=float)
parser.add_argument("--target_cue_threshold_minimum", default=3.0, type=float)
parser.add_argument("--minimum_size_scale", default=4.0, type=float)
parser.add_argument("--minimum_size", default=224, type=int)
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument(
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
)
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
parser.add_argument(
"--conv_type",
default="llava_v1",
type=str,
choices=["llava_v1", "llava_llama_2"],
)
return parser.parse_args(args)
def tranverse(token):
children = [_ for _ in token.children]
if len(children) == 0:
return token.i, token.i
left_i = token.i
right_i = token.i
for child in children:
child_left_i, child_right_i = tranverse(child)
left_i = min(left_i, child_left_i)
right_i = max(right_i, child_right_i)
return left_i, right_i
def get_noun_chunks(token):
left_children = []
right_children = []
for child in token.children:
if child.i < token.i:
left_children.append(child)
else:
right_children.append(child)
start_token_i = token.i
for left_child in left_children[::-1]:
if left_child.dep_ in ['amod', 'compound', 'poss']:
start_token_i, _ = tranverse(left_child)
else:
break
end_token_i = token.i
for right_child in right_children:
if right_child.dep_ in ['relcl', 'prep']:
_, end_token_i = tranverse(right_child)
else:
break
return start_token_i, end_token_i
def filter_chunk_list(chunks):
def overlap(min1, max1, min2, max2):
return min(max1, max2) - max(min1, min2)
chunks = sorted(chunks, key=lambda chunk: chunk[1]-chunk[0], reverse=True)
filtered_chunks = []
for chunk in chunks:
flag=True
for exist_chunk in filtered_chunks:
if overlap(exist_chunk[0], exist_chunk[1], chunk[0], chunk[1]) >= 0:
flag = False
break
if flag:
filtered_chunks.append(chunk)
return sorted(filtered_chunks, key=lambda chunk: chunk[0])
def extract_noun_chunks(expression):
doc = nlp(expression)
cur_chunks = []
for token in doc:
if token.pos_ not in ["NOUN", "PRON"]:
continue
cur_chunks.append(get_noun_chunks(token))
cur_chunks = filter_chunk_list(cur_chunks)
cur_chunks = [doc[chunk[0]:chunk[1]+1].text for chunk in cur_chunks]
return cur_chunks
def preprocess(
x,
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
img_size=1024,
) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - pixel_mean) / pixel_std
# Pad
h, w = x.shape[-2:]
padh = img_size - h
padw = img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
class VSM:
def __init__(self, args):
kwargs = {}
kwargs['torch_dtype'] = torch.bfloat16
kwargs['device_map'] = 'cuda'
kwargs['is_eval'] = True
vsm_tokenizer = AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
)
vsm_tokenizer.pad_token = vsm_tokenizer.unk_token
loc_token_idx = vsm_tokenizer("[LOC]", add_special_tokens=False).input_ids[0]
vsm_model = VSMForCausalLM.from_pretrained(
args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, loc_token_idx=loc_token_idx, **kwargs
)
vsm_model.get_model().initialize_vision_modules(vsm_model.get_model().config)
vision_tower = vsm_model.get_model().get_vision_tower().cuda().to(dtype=torch.bfloat16)
vsm_image_processor = vision_tower.image_processor
vsm_model.eval()
clip_image_processor = CLIPImageProcessor.from_pretrained(vsm_model.config.vision_tower)
transform = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
self.model = vsm_model
self.vsm_tokenizer = vsm_tokenizer
self.vsm_image_processor = vsm_image_processor
self.clip_image_processor = clip_image_processor
self.transform = transform
self.conv_type = args.conv_type
self.use_mm_start_end = args.use_mm_start_end
@torch.inference_mode()
def inference(self, image, question, mode='segmentation'):
conv = conversation_lib.conv_templates[self.conv_type].copy()
conv.messages = []
prompt = DEFAULT_IMAGE_TOKEN + "\n" + question
if self.use_mm_start_end:
replace_token = ( DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN)
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], "")
prompt = conv.get_prompt()
background_color = tuple(int(x*255) for x in self.clip_image_processor.image_mean)
image_clip = self.clip_image_processor.preprocess(expand2square(image, background_color), return_tensors="pt")["pixel_values"][0].unsqueeze(0).cuda()
image_clip = image_clip.bfloat16()
image = np.array(image)
original_size_list = [image.shape[:2]]
image = self.transform(images=image, return_tensors="pt")['pixel_values'].cuda()
resize_list = [image.shape[:2]]
image = image.bfloat16()
input_ids = tokenizer_image_token(prompt, self.vsm_tokenizer, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).cuda()
output_ids, pred_masks, det_result = self.model.inference(
image_clip,
image,
input_ids,
resize_list,
original_size_list,
max_new_tokens=100,
tokenizer=self.vsm_tokenizer,
mode = mode
)
if mode == 'segmentation':
pred_mask = pred_masks[0]
pred_mask = torch.clamp(pred_mask, min=0)
return pred_mask[-1]
elif mode == 'vqa':
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
text_output = self.vsm_tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
text_output = text_output.replace("\n", "").replace(" ", " ").strip()
return text_output
elif mode == 'detection':
pred_mask = pred_masks[0]
pred_mask = torch.clamp(pred_mask, min=0)
return det_result['pred_boxes'][0].cpu(), det_result['pred_logits'][0].sigmoid().cpu(), pred_mask[-1]
def refine_bbox(bbox, image_width, image_height):
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = min(bbox[2], image_width-bbox[0])
bbox[3] = min(bbox[3], image_height-bbox[1])
return bbox
def split_4subpatches(current_patch_bbox):
hw_ratio = current_patch_bbox[3] / current_patch_bbox[2]
if hw_ratio >= 2:
return 1, 4
elif hw_ratio <= 0.5:
return 4, 1
else:
return 2, 2
def get_sub_patches(current_patch_bbox, num_of_width_patches, num_of_height_patches):
width_stride = int(current_patch_bbox[2]//num_of_width_patches)
height_stride = int(current_patch_bbox[3]/num_of_height_patches)
sub_patches = []
for j in range(num_of_height_patches):
for i in range(num_of_width_patches):
sub_patch_width = current_patch_bbox[2] - i*width_stride if i == num_of_width_patches-1 else width_stride
sub_patch_height = current_patch_bbox[3] - j*height_stride if j == num_of_height_patches-1 else height_stride
sub_patch = [current_patch_bbox[0]+i*width_stride, current_patch_bbox[1]+j*height_stride, sub_patch_width, sub_patch_height]
sub_patches.append(sub_patch)
return sub_patches, width_stride, height_stride
def get_subpatch_scores(score_heatmap, current_patch_bbox, sub_patches):
total_sum = (score_heatmap/(current_patch_bbox[2]*current_patch_bbox[3])).sum()
sub_scores = []
for sub_patch in sub_patches:
bbox = [(sub_patch[0]-current_patch_bbox[0]), sub_patch[1]-current_patch_bbox[1], sub_patch[2], sub_patch[3]]
score = (score_heatmap[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]/(current_patch_bbox[2]*current_patch_bbox[3])).sum()
if total_sum > 0:
score /= total_sum
else:
score *= 0
sub_scores.append(score)
return sub_scores
def normalize_score(score_heatmap):
max_score = score_heatmap.max()
min_score = score_heatmap.min()
if max_score != min_score:
score_heatmap = (score_heatmap - min_score) / (max_score - min_score)
else:
score_heatmap = score_heatmap * 0
return score_heatmap
def iou(bbox1, bbox2):
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[0]+bbox1[2], bbox2[0]+bbox2[2])
y2 = min(bbox1[1]+bbox1[3],bbox2[1]+bbox2[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
return inter_area/(bbox1[2]*bbox1[3]+bbox2[2]*bbox2[3]-inter_area)
BOX_COLOR = (255, 0, 0) # Red
TEXT_COLOR = (255, 255, 255) # White
import cv2
from matplotlib import pyplot as plt
def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=2):
"""Visualizes a single bounding box on the image"""
x_min, y_min, w, h = bbox
x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h)
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1)
cv2.putText(
img,
text=class_name,
org=(x_min, y_min - int(0.3 * text_height)),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.5,
color=TEXT_COLOR,
lineType=cv2.LINE_AA,
)
return img
def show_heatmap_on_image(img: np.ndarray,
mask: np.ndarray,
use_rgb: bool = False,
colormap: int = cv2.COLORMAP_JET,
image_weight: float = 0.5) -> np.ndarray:
mask = np.clip(mask, 0, 1)
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
if use_rgb:
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
if np.max(img) > 1:
raise Exception(
"The input image should np.float32 in the range [0, 1]")
if image_weight < 0 or image_weight > 1:
raise Exception(
f"image_weight should be in the range [0, 1].\
Got: {image_weight}")
cam = (1 - image_weight) * heatmap + image_weight * img
cam = cam / np.max(cam)
return np.uint8(255 * cam)
def vis_heatmap(image, heatmap, use_rgb=False):
max_v = np.max(heatmap)
min_v = np.min(heatmap)
if max_v != min_v:
heatmap = (heatmap - min_v) / (max_v - min_v)
heatmap_image = show_heatmap_on_image(image.astype(float)/255., heatmap, use_rgb=use_rgb)
return heatmap_image
def visualize_search_path(image, search_path, search_length, target_bbox, label, save_path):
context_cue_list = []
whole_image = image
os.makedirs(save_path, exist_ok=True)
whole_image.save(os.path.join(save_path, 'whole_image.jpg'))
whole_image = np.array(whole_image)
if target_bbox is not None:
whole_image = visualize_bbox(whole_image.copy(), target_bbox, class_name="gt: "+label, color=(255,0,0))
for step_i, node in enumerate(search_path):
if step_i + 1 > search_length:
break
current_patch_box = node['bbox']
if 'detection_result' in node:
final_patch_image = image.crop((current_patch_box[0],current_patch_box[1],current_patch_box[0]+current_patch_box[2], current_patch_box[1]+current_patch_box[3]))
final_patch_image.save(os.path.join(save_path, 'final_patch_image.jpg'))
final_search_result = visualize_bbox(np.array(final_patch_image), node['detection_result'], class_name='search result', color=(255,0,0))
final_search_result = cv2.cvtColor(final_search_result, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_path, 'search_result.jpg'), final_search_result)
cur_whole_image = visualize_bbox(whole_image.copy(), current_patch_box, class_name="step-{}".format(step_i+1), color=(0,0,255))
# if step_i != len(search_path)-1:
# next_patch_box = search_path[step_i+1]['bbox']
# cur_whole_image = visualize_bbox(cur_whole_image, next_patch_box, class_name="next-step", color=(0,255,0))
cur_whole_image = cv2.cvtColor(cur_whole_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_path, 'step_{}.jpg'.format(step_i+1)), cur_whole_image)
cur_patch_image = image.crop((current_patch_box[0],current_patch_box[1],current_patch_box[0]+current_patch_box[2], current_patch_box[1]+current_patch_box[3]))
if 'context_cue' in node:
context_cue = node['context_cue']
context_cue_list.append('step{}: {}'.format(step_i+1, context_cue)+'\n')
if 'final_heatmap' in node:
score_map = node['final_heatmap']
score_map = vis_heatmap(np.array(cur_patch_image), score_map, use_rgb=True)
score_map = cv2.cvtColor(score_map, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_path, 'step_{}_heatmap.jpg'.format(step_i+1)), score_map)
with open(os.path.join(save_path, 'context_cue.txt'),"w") as f:
f.writelines(context_cue_list)
@functools.total_ordering
class Prioritize:
def __init__(self, priority, item):
self.priority = priority
self.item = item
def __eq__(self, other):
return self.priority == other.priority
def __lt__(self, other):
return self.priority < other.priority
def visual_search_queue(vsm, image, target_object_name, current_patch, search_path, queue, smallest_size=224, confidence_high=0.5, target_cue_threshold=6.0, target_cue_threshold_decay=0.7, target_cue_threshold_minimum=3.0):
current_patch_bbox = current_patch['bbox']
current_patch_scale_level = current_patch['scale_level']
image_patch = image.crop((int(current_patch_bbox[0]), int(current_patch_bbox[1]), int(current_patch_bbox[0]+current_patch_bbox[2]), int(current_patch_bbox[1]+current_patch_bbox[3])))
# whehter we can detect the target object on the current image patch
question = "Please locate the {} in this image.".format(target_object_name)
pred_bboxes, pred_logits, target_cue_heatmap = vsm.inference(copy.deepcopy(image_patch), question, mode='detection')
if len(pred_logits) > 0:
top_index = pred_logits.view(-1).argmax()
top_logit = pred_logits.view(-1).max()
final_bbox = pred_bboxes[top_index].view(4)
final_bbox = final_bbox * torch.Tensor([image_patch.width, image_patch.height, image_patch.width, image_patch.height])
final_bbox[:2] -= final_bbox[2:] / 2
if top_logit > confidence_high:
search_path[-1]['detection_result'] = final_bbox
# only return multiple detected instances on the whole image
if len(search_path) == 1:
all_valid_boxes = pred_bboxes[pred_logits.view(-1)>0.5].view(-1, 4)
all_valid_boxes = all_valid_boxes * torch.Tensor([[image_patch.width, image_patch.height, image_patch.width, image_patch.height]])
all_valid_boxes[:, :2] -= all_valid_boxes[:, 2:] / 2
return True, search_path, all_valid_boxes
return True, search_path, None
else:
search_path[-1]['temp_detection_result'] = (top_logit, final_bbox)
### current patch is already the smallest unit
if min(current_patch_bbox[2], current_patch_bbox[3]) <= smallest_size:
return False, search_path, None
target_cue_heatmap = target_cue_heatmap.view(current_patch_bbox[3], current_patch_bbox[2], 1)
score_max = target_cue_heatmap.max().item()
# check whether the target cue is prominent
threshold = max(target_cue_threshold_minimum, target_cue_threshold*(target_cue_threshold_decay)**(current_patch_scale_level-1))
if score_max > threshold:
target_cue_heatmap = normalize_score(target_cue_heatmap)
final_heatmap = target_cue_heatmap
else:
question = "According to the common sense knowledge and possible visual cues, what is the most likely location of the {} in the image?".format(target_object_name)
vqa_results = vsm.inference(copy.deepcopy(image_patch), question, mode='vqa')
possible_location_phrase = vqa_results.split('most likely to appear')[-1].strip()
if possible_location_phrase.endswith('.'):
possible_location_phrase = possible_location_phrase[:-1]
possible_location_phrase = possible_location_phrase.split(target_object_name)[-1]
noun_chunks = extract_noun_chunks(possible_location_phrase)
if len(noun_chunks) == 1:
possible_location_phrase = noun_chunks[0]
else:
possible_location_phrase = "region {}".format(possible_location_phrase)
question = "Please locate the {} in this image.".format(possible_location_phrase)
context_cue_heatmap = vsm.inference(copy.deepcopy(image_patch), question, mode='segmentation').view(current_patch_bbox[3], current_patch_bbox[2], 1)
context_cue_heatmap = normalize_score(context_cue_heatmap)
final_heatmap = context_cue_heatmap
current_patch_index = len(search_path)-1
if score_max <= threshold:
search_path[current_patch_index]['context_cue'] = vqa_results + "#" + possible_location_phrase
search_path[current_patch_index]['final_heatmap'] = final_heatmap.cpu().numpy()
### split the current patch into 4 sub-patches
basic_sub_patches, sub_patch_width, sub_patch_height = get_sub_patches(current_patch_bbox, *split_4subpatches(current_patch_bbox))
tmp_patch = current_patch
basic_sub_scores = [0]*len(basic_sub_patches)
while True:
tmp_score_heatmap = tmp_patch['final_heatmap']
tmp_sub_scores = get_subpatch_scores(tmp_score_heatmap, tmp_patch['bbox'], basic_sub_patches)
basic_sub_scores = [basic_sub_scores[patch_i]+tmp_sub_scores[patch_i]/(4**tmp_patch['scale_level']) for patch_i in range(len(basic_sub_scores))]
if tmp_patch['parent_index'] == -1:
break
else:
tmp_patch = search_path[tmp_patch['parent_index']]
sub_patches = basic_sub_patches
sub_scores = basic_sub_scores
for sub_patch, sub_score in zip(sub_patches, sub_scores):
new_patch_info = dict()
new_patch_info['bbox'] = sub_patch
new_patch_info['scale_level'] = current_patch_scale_level + 1
new_patch_info['score'] = sub_score
new_patch_info['parent_index'] = current_patch_index
queue.put(Prioritize(-new_patch_info['score'], new_patch_info))
while(not queue.empty()):
patch_chosen = queue.get().item
search_path.append(patch_chosen)
success, search_path, all_valid_boxes = visual_search_queue(vsm, image, target_object_name, patch_chosen, search_path, queue, smallest_size=smallest_size, confidence_high=confidence_high, target_cue_threshold=target_cue_threshold, target_cue_threshold_decay=target_cue_threshold_decay, target_cue_threshold_minimum=target_cue_threshold_minimum)
if success:
return success, search_path, all_valid_boxes
return False, search_path, None
def visual_search(vsm, image, target_object_name, target_bbox, smallest_size, confidence_high=0.5, confidence_low=0.3, target_cue_threshold=6.0, target_cue_threshold_decay=0.7, target_cue_threshold_minimum=3.0, visualize=False, save_path=None):
if visualize:
assert save_path is not None
init_patch = dict()
init_patch['bbox'] = [0,0,image.width,image.height]
init_patch['scale_level'] = 1
init_patch['score'] = None
init_patch['parent_index'] = -1
search_path = [init_patch]
queue = PriorityQueue()
search_successful, search_path, all_valid_boxes = visual_search_queue(vsm, image, target_object_name, init_patch, search_path, queue, smallest_size=smallest_size, confidence_high=confidence_high, target_cue_threshold=target_cue_threshold, target_cue_threshold_decay=target_cue_threshold_decay, target_cue_threshold_minimum=target_cue_threshold_minimum)
path_length = len(search_path)
final_step = search_path[-1]
if not search_successful:
# if no target is found with confidence passing confidence_high, select the target with the highest confidence during search and compare its confidence with confidence_low
max_logit = 0
final_step = None
path_length = 0
for i, search_step in enumerate(search_path):
if 'temp_detection_result' in search_step:
if search_step['temp_detection_result'][0] > max_logit:
max_logit = search_step['temp_detection_result'][0]
final_step = search_step
path_length = i+1
final_step['detection_result'] = final_step['temp_detection_result'][1]
if max_logit >= confidence_low:
search_successful = True
if visualize:
vis_path_length = path_length if search_successful else len(search_path)
visualize_search_path(image, search_path, vis_path_length, target_bbox, target_object_name, save_path)
del queue
return final_step, path_length, search_successful, all_valid_boxes
def main(args):
args = parse_args(args)
vsm = VSM(args)
benchmark_folder = args.benchmark_folder
acc_list = []
search_path_length_list = []
for test_type in ['direct_attributes', 'relative_position']:
folder = os.path.join(benchmark_folder, test_type)
output_folder = None
if args.visualization:
output_folder = os.path.join(args.output_path, test_type)
os.makedirs(output_folder, exist_ok=True)
image_files = filter(lambda file: '.json' not in file, os.listdir(folder))
for image_file in tqdm.tqdm(image_files):
image_path = os.path.join(folder, image_file)
annotation_path = image_path.split('.')[0] + '.json'
annotation = json.load(open(annotation_path))
bboxs = annotation['bbox']
object_names = annotation['target_object']
for i, (gt_bbox, object_name) in enumerate(zip(bboxs, object_names)):
image = Image.open(image_path).convert('RGB')
smallest_size = max(int(np.ceil(min(image.width, image.height)/args.minimum_size_scale)), args.minimum_size)
if args.visualization:
vis_path = os.path.join(output_folder, "{}_{}".format(image_file.split('.')[0],i))
else:
vis_path = None
final_step, path_length, search_successful, all_valid_boxes = visual_search(vsm, image, object_name, target_bbox=gt_bbox, smallest_size=smallest_size, confidence_high=args.confidence_high, confidence_low=args.confidence_low, target_cue_threshold=args.target_cue_threshold, target_cue_threshold_decay=args.target_cue_threshold_decay, target_cue_threshold_minimum=args.target_cue_threshold_minimum, save_path=vis_path, visualize=args.visualization)
if search_successful:
search_bbox = final_step['detection_result']
search_final_patch = final_step['bbox']
search_bbox[0] += search_final_patch[0]
search_bbox[1] += search_final_patch[1]
iou_i = iou(search_bbox, gt_bbox).item()
det_acc = 1.0 if iou_i > 0.5 else 0.0
acc_list.append(det_acc)
search_path_length_list.append(path_length)
else:
acc_list.append(0)
search_path_length_list.append(0)
print('Avg search path length:', np.mean([search_path_length_list[i] for i in range(len(search_path_length_list)) if acc_list[i]]))
print('Top 1 Acc:', np.mean(acc_list))
if __name__ == "__main__":
main(sys.argv[1:])