pivot-prompt-demo / vip_utils.py
pivot-prompt's picture
Add application file
660daa9
raw
history blame contribute delete
No virus
3.82 kB
"""Utils for visual iterative prompting.
A number of utility functions for VIP.
"""
import re
import matplotlib.pyplot as plt
import numpy as np
import scipy.spatial.distance as distance
def min_dist(coord, coords):
if not coords:
return np.inf
xys = np.asarray([[coord.xy] for coord in coords])
return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min()
def coord_outside_image(coord, image, radius):
(height, image_width, _) = image.shape
x, y = coord.xy
x_outside = x > image_width - 2 * radius or x < 2 * radius
y_outside = y > height - 2 * radius or y < 2 * radius
return x_outside or y_outside
def is_invalid_coord(coord, coords, radius, image):
# invalid if too close to others or outside of the image
pos_overlaps = min_dist(coord, coords) < 1.5 * radius
return pos_overlaps or coord_outside_image(coord, image, radius)
def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40):
x, y = arm_coord
x += int(np.cos(angle) * mag)
y += int(np.sin(angle) * mag)
if is_circle:
x += int(np.cos(angle) * radius * np.sign(mag))
y += int(np.sin(angle) * radius * np.sign(mag))
return x, y
def coord_to_text_coord(coord, arm_coord, radius):
delta_coord = np.asarray(coord.xy) - arm_coord
if np.linalg.norm(delta_coord) == 0:
return arm_coord
return (
int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)),
)
def parse_response(response, answer_key='Arrow: ['):
values = []
if answer_key in response:
print('parse_response from answer_key')
arrow_response = response.split(answer_key)[-1].split(']')[0]
for val in map(int, re.findall(r'\d+', arrow_response)):
values.append(val)
else:
print('parse_response for all ints')
for val in map(int, re.findall(r'\d+', response)):
values.append(val)
return values
def compute_errors(action, true_action, verbose=False):
"""Compute errors between a predicted action and true action."""
l2_error = np.linalg.norm(action - true_action)
cos_sim = 1 - distance.cosine(action, true_action)
l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
z_error = np.abs(action[0] - true_action[0])
errors = {
'l2': l2_error,
'cos_sim': cos_sim,
'l2_xy_error': l2_xy_error,
'cos_xy_sim': cos_xy_sim,
'z_error': z_error,
}
if verbose:
print('action: \t', [f'{a:.3f}' for a in action])
print('true_action \t', [f'{a:.3f}' for a in true_action])
print(f'l2: \t\t{l2_error:.3f}')
print(f'l2_xy_error: \t{l2_xy_error:.3f}')
print(f'cos_sim: \t{cos_sim:.3f}')
print(f'cos_xy_sim: \t{cos_xy_sim:.3f}')
print(f'z_error: \t{z_error:.3f}')
return errors
def plot_errors(all_errors, error_types=None):
"""Plot errors across iterations."""
if error_types is None:
error_types = [
'l2',
'l2_xy_error',
'z_error',
'cos_sim',
'cos_xy_sim',
]
_, axs = plt.subplots(2, 3, figsize=(15, 8))
for i, error_type in enumerate(error_types): # go through each error type
all_iter_errors = {}
for error_by_iter in all_errors: # go through each call
for itr in error_by_iter: # go through each iteration
if itr in all_iter_errors: # add error to the iteration it happened
all_iter_errors[itr].append(error_by_iter[itr][error_type])
else:
all_iter_errors[itr] = [error_by_iter[itr][error_type]]
mean_iter_errors = [
np.mean(all_iter_errors[itr]) for itr in all_iter_errors
]
axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
axs[i // 3, i % 3].set_title(error_type)
plt.show()