|
"""VIP.""" |
|
|
|
import json |
|
import re |
|
|
|
import cv2 |
|
from tqdm import trange |
|
import numpy as np |
|
import vip |
|
|
|
|
|
def make_prompt(description, top_n=3): |
|
return f""" |
|
INSTRUCTIONS: |
|
You are tasked to locate an object, region, or point in space in the given annotated image according to a description. |
|
The image is annoated with numbered circles. |
|
Choose the top {top_n} circles that have the most overlap with and/or is closest to what the description is describing in the image. |
|
You are a five-time world champion in this game. |
|
Give a one sentence analysis of why you chose those points. |
|
Provide your answer at the end in a valid JSON of this format: |
|
|
|
{{"points": []}} |
|
|
|
DESCRIPTION: {description} |
|
IMAGE: |
|
""".strip() |
|
|
|
|
|
def extract_json(response, key): |
|
json_part = re.search(r"\{.*\}", response, re.DOTALL) |
|
parsed_json = {} |
|
if json_part: |
|
json_data = json_part.group() |
|
|
|
parsed_json = json.loads(json_data) |
|
else: |
|
print("No JSON data found ******\n", response) |
|
return parsed_json[key] |
|
|
|
|
|
def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n): |
|
"""Perform one selection pass given samples.""" |
|
image_circles_np = prompter.add_arrow_overlay_plt( |
|
image=im, samples=samples, arm_xy=arm_coord |
|
) |
|
|
|
_, encoded_image_circles = cv2.imencode(".png", image_circles_np) |
|
|
|
prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles] |
|
response = vlm.query(prompt_seq) |
|
|
|
try: |
|
arrow_ids = extract_json(response, "points") |
|
except Exception as e: |
|
print(e) |
|
arrow_ids = [] |
|
return arrow_ids, image_circles_np |
|
|
|
|
|
def vip_runner( |
|
vlm, |
|
im, |
|
desc, |
|
style, |
|
action_spec, |
|
n_samples_init=25, |
|
n_samples_opt=10, |
|
n_iters=3, |
|
n_parallel_trials=1, |
|
): |
|
"""VIP.""" |
|
|
|
prompter = vip.VisualIterativePrompter( |
|
style, action_spec, vip.SupportedEmbodiments.HF_DEMO |
|
) |
|
|
|
output_ims = [] |
|
arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2)) |
|
|
|
new_samples = [] |
|
center_mean = action_spec["loc"] |
|
for i in range(n_parallel_trials): |
|
center_mean = action_spec["loc"] |
|
center_std = action_spec["scale"] |
|
for itr in trange(n_iters): |
|
if itr == 0: |
|
style["num_samples"] = n_samples_init |
|
else: |
|
style["num_samples"] = n_samples_opt |
|
samples = prompter.sample_actions(im, arm_coord, center_mean, center_std) |
|
arrow_ids, image_circles_np = vip_perform_selection( |
|
prompter, vlm, im, desc, arm_coord, samples, top_n=3 |
|
) |
|
|
|
|
|
selected_samples = [] |
|
for selected_id in arrow_ids: |
|
sample = samples[selected_id] |
|
sample.coord.color = (255, 0, 0) |
|
selected_samples.append(sample) |
|
image_circles_marked_np = prompter.add_arrow_overlay_plt( |
|
image_circles_np, selected_samples, arm_coord |
|
) |
|
output_ims.append(image_circles_marked_np) |
|
yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} iteration {itr+1}/{n_iters}. Still working..." |
|
|
|
|
|
if itr == n_iters - 1: |
|
arrow_ids, _ = vip_perform_selection( |
|
prompter, vlm, im, desc, arm_coord, selected_samples, top_n=1 |
|
) |
|
|
|
selected_samples = [] |
|
for selected_id in arrow_ids: |
|
sample = samples[selected_id] |
|
sample.coord.color = (255, 0, 0) |
|
selected_samples.append(sample) |
|
image_circles_marked_np = prompter.add_arrow_overlay_plt( |
|
im, selected_samples, arm_coord |
|
) |
|
output_ims.append(image_circles_marked_np) |
|
new_samples += selected_samples |
|
yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} last iteration. Still working..." |
|
center_mean, center_std = prompter.fit(arrow_ids, samples) |
|
|
|
if n_parallel_trials > 1: |
|
|
|
for sample_id in range(len(new_samples)): |
|
new_samples[sample_id].label = str(sample_id) |
|
arrow_ids, _ = vip_perform_selection( |
|
prompter, vlm, im, desc, arm_coord, new_samples, top_n=1 |
|
) |
|
|
|
selected_samples = [] |
|
for selected_id in arrow_ids: |
|
sample = new_samples[selected_id] |
|
sample.coord.color = (255, 0, 0) |
|
selected_samples.append(sample) |
|
image_circles_marked_np = prompter.add_arrow_overlay_plt( |
|
im, selected_samples, arm_coord |
|
) |
|
output_ims.append(image_circles_marked_np) |
|
center_mean, _ = prompter.fit(arrow_ids, new_samples) |
|
|
|
if output_ims: |
|
yield ( |
|
output_ims, |
|
( |
|
"Final selected coordinate:" |
|
f" {np.round(prompter.action_to_coord(center_mean, im, arm_coord).xy, decimals=0)}" |
|
), |
|
) |
|
return [], "Unable to understand query" |
|
|