Spaces:
Sleeping
Sleeping
"""Visual Iterative Prompting functions. | |
Code to implement visual iterative prompting, an approach for querying VLMs. | |
""" | |
import copy | |
import dataclasses | |
import enum | |
import io | |
from typing import Optional, Tuple | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy.stats | |
import vip_utils | |
class SupportedEmbodiments(str, enum.Enum): | |
"""Embodiments supported by VIP.""" | |
HF_DEMO = 'hf_demo' | |
class Coordinate: | |
"""Coordinate with necessary information for visualizing annotation.""" | |
# 2D image coordinates for the target annotation | |
xy: Tuple[int, int] | |
# Color and style of the coord. | |
color: Optional[float] = None | |
radius: Optional[int] = None | |
class Sample: | |
"""Single Sample mapping actions to Coordinates.""" | |
# 2D or 3D action | |
action: np.ndarray | |
# Coordinates for the main annotation | |
coord: Coordinate | |
# Coordinates for the text label | |
text_coord: Coordinate | |
# Label to display in the text label | |
label: str | |
class VisualIterativePrompter: | |
"""Visual Iterative Prompting class.""" | |
def __init__(self, style, action_spec, embodiment): | |
self.embodiment = embodiment | |
self.style = style | |
self.action_spec = action_spec | |
self.fig_scale_size = None | |
# image preparer | |
# robot_to_image_canonical_coords | |
def action_to_coord(self, action, image, arm_xy, do_project=False): | |
"""Converts candidate action to image coordinate.""" | |
return self.navigation_action_to_coord( | |
action=action, image=image, center_xy=arm_xy, do_project=do_project | |
) | |
def navigation_action_to_coord( | |
self, action, image, center_xy, do_project=False | |
): | |
"""Converts a ZXY or XY action to an image coordinate. | |
Conversion is done based on style['focal_offset'] and action_spec['scale']. | |
Args: | |
action: z, y, x action in robot action space | |
image: image | |
center_xy: x, y in image space | |
do_project: whether or not to project actions sampled outside the image to | |
the edge of the image | |
Returns: | |
Dict coordinate with image x, y, arrow color, and circle radius. | |
""" | |
if self.action_spec['scale'][0] == 0: # no z dimension | |
norm_action = [ | |
(action[d] - self.action_spec['loc'][d]) | |
/ (2 * self.action_spec['scale'][d]) | |
for d in range(1, 3) | |
] | |
norm_action_y, norm_action_x = norm_action | |
norm_action_z = 0 | |
else: | |
norm_action = [ | |
(action[d] - self.action_spec['loc'][d]) | |
/ (2 * self.action_spec['scale'][d]) | |
for d in range(3) | |
] | |
norm_action_z, norm_action_y, norm_action_x = norm_action | |
focal_length = np.max([ | |
0.2, # positive focal lengths only | |
self.style['focal_offset'] | |
/ (self.style['focal_offset'] + norm_action_z), | |
]) | |
image_x = center_xy[0] - ( | |
self.action_spec['action_to_coord'] * norm_action_x * focal_length | |
) | |
image_y = center_xy[1] - ( | |
self.action_spec['action_to_coord'] * norm_action_y * focal_length | |
) | |
if ( | |
vip_utils.coord_outside_image( | |
Coordinate(xy=(image_x, image_y)), image, self.style['radius'] | |
) | |
and do_project | |
): | |
# project the arrow to the edge of the image if too large | |
height, width, _ = image.shape | |
max_x = ( | |
width - center_xy[0] - 2 * self.style['radius'] | |
if norm_action_x < 0 | |
else center_xy[0] - 2 * self.style['radius'] | |
) | |
max_y = ( | |
height - center_xy[1] - 2 * self.style['radius'] | |
if norm_action_y < 0 | |
else center_xy[1] - 2 * self.style['radius'] | |
) | |
rescale_ratio = min( | |
np.abs([ | |
max_x / (self.action_spec['action_to_coord'] * norm_action_x), | |
max_y / (self.action_spec['action_to_coord'] * norm_action_y), | |
]) | |
) | |
image_x = ( | |
center_xy[0] | |
- self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio | |
) | |
image_y = ( | |
center_xy[1] | |
- self.action_spec['action_to_coord'] * norm_action_y * rescale_ratio | |
) | |
return Coordinate( | |
xy=(int(image_x), int(image_y)), | |
color=0.1 * self.style['rgb_scale'], | |
radius=int(self.style['radius']), | |
) | |
def sample_actions( | |
self, image, arm_xy, loc, scale, true_action=None, max_itrs=1000 | |
): | |
"""Sample actions from distribution. | |
Args: | |
image: image | |
arm_xy: x, y in image space of arm | |
loc: action distribution mean to sample from | |
scale: action distribution variance to sample from | |
true_action: action taken in demonstration if available | |
max_itrs: number of tries to get a valid sample | |
Returns: | |
samples: Samples with associated actions, coords, text_coords, labels. | |
""" | |
image = copy.deepcopy(image) | |
samples = [] | |
actions = [] | |
coords = [] | |
text_coords = [] | |
labels = [] | |
# Keep track of oracle action if available. | |
true_label = None | |
if true_action is not None: | |
actions.append(true_action) | |
coord = self.action_to_coord(true_action, image, arm_xy) | |
coords.append(coord) | |
text_coords.append( | |
vip_utils.coord_to_text_coord(coords[-1], arm_xy, coord.radius) | |
) | |
true_label = np.random.randint(self.style['num_samples']) | |
# labels.append(str(true_label) + '*') | |
labels.append(str(true_label)) | |
# Generate all action samples. | |
for i in range(self.style['num_samples']): | |
if i == true_label: | |
continue | |
itrs = 0 | |
# Generate action scaled appropriately. | |
action = np.clip( | |
np.random.normal(loc, scale), | |
self.action_spec['min'], | |
self.action_spec['max'], | |
) | |
# Convert sampled action to image coordinates. | |
coord = self.action_to_coord(action, image, arm_xy) | |
# Resample action if it results in invalid image annotation. | |
adjusted_scale = np.array(scale) | |
while ( | |
vip_utils.is_invalid_coord( | |
coord, coords, self.style['radius'] * 1.5, image | |
) | |
or vip_utils.coord_outside_image(coord, image, self.style['radius']) | |
) and itrs < max_itrs: | |
action = np.clip( | |
np.random.normal(loc, adjusted_scale), | |
self.action_spec['min'], | |
self.action_spec['max'], | |
) | |
coord = self.action_to_coord(action, image, arm_xy) | |
itrs += 1 | |
# increase sampling range slightly if not finding a good sample | |
adjusted_scale *= 1.1 | |
if itrs == max_itrs: | |
# If the final iteration results in invalid annotation, just clip | |
# to edge of image. | |
coord = self.action_to_coord(action, image, arm_xy, do_project=True) | |
# Compute image coordinates of text labels. | |
radius = coord.radius | |
text_coord = Coordinate( | |
xy=vip_utils.coord_to_text_coord(coord, arm_xy, radius) | |
) | |
actions.append(action) | |
coords.append(coord) | |
text_coords.append(text_coord) | |
labels.append(str(i)) | |
for i in range(len(actions)): | |
sample = Sample( | |
action=actions[i], | |
coord=coords[i], | |
text_coord=text_coords[i], | |
label=str(i), | |
) | |
samples.append(sample) | |
return samples | |
def add_arrow_overlay_plt(self, image, samples, arm_xy): | |
"""Add arrows and circles to the image. | |
Args: | |
image: image | |
samples: Samples to visualize. | |
arm_xy: x, y image coordinates for EEF center. | |
log_image: Boolean for whether to save to CNS. | |
Returns: | |
image: image with visual prompts. | |
""" | |
# Add transparent arrows and circles | |
overlay = image.copy() | |
(original_image_height, original_image_width, _) = image.shape | |
white = ( | |
self.style['rgb_scale'], | |
self.style['rgb_scale'], | |
self.style['rgb_scale'], | |
) | |
# Add arrows. | |
for sample in samples: | |
color = sample.coord.color | |
cv2.arrowedLine( | |
overlay, arm_xy, sample.coord.xy, color, self.style['thickness'] | |
) | |
image = cv2.addWeighted( | |
overlay, | |
self.style['arrow_alpha'], | |
image, | |
1 - self.style['arrow_alpha'], | |
0, | |
) | |
overlay = image.copy() | |
# Add circles. | |
for sample in samples: | |
color = sample.coord.color | |
radius = sample.coord.radius | |
cv2.circle( | |
overlay, | |
sample.text_coord.xy, | |
radius, | |
color, | |
self.style['thickness'] + 1, | |
) | |
cv2.circle(overlay, sample.text_coord.xy, radius, white, -1) | |
image = cv2.addWeighted( | |
overlay, | |
self.style['circle_alpha'], | |
image, | |
1 - self.style['circle_alpha'], | |
0, | |
) | |
dpi = plt.rcParams['figure.dpi'] | |
if self.fig_scale_size is None: | |
# test saving a figure to decide size for text figure | |
fig_size = (original_image_width / dpi, original_image_height / dpi) | |
plt.subplots(1, figsize=fig_size) | |
plt.imshow(image, cmap='binary') | |
plt.axis('off') | |
fig = plt.gcf() | |
fig.tight_layout(pad=0) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
test_image = cv2.imdecode( | |
np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR | |
) | |
self.fig_scale_size = original_image_width / test_image.shape[1] | |
# Add text to figure. | |
fig_size = ( | |
self.fig_scale_size * original_image_width / dpi, | |
self.fig_scale_size * original_image_height / dpi, | |
) | |
plt.subplots(1, figsize=fig_size) | |
plt.imshow(image, cmap='binary') | |
for sample in samples: | |
plt.text( | |
sample.text_coord.xy[0], | |
sample.text_coord.xy[1], | |
sample.label, | |
ha='center', | |
va='center', | |
color='k', | |
fontsize=self.style['fontsize'], | |
) | |
# Compile image. | |
plt.axis('off') | |
fig = plt.gcf() | |
fig.tight_layout(pad=0) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
image = cv2.imdecode( | |
np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR | |
) | |
image = cv2.resize(image, (original_image_width, original_image_height)) | |
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
return image | |
def fit(self, values, samples): | |
"""Fit a loc and scale to selected actions. | |
Args: | |
values: list of selected labels | |
samples: list of all Samples | |
Returns: | |
loc: mean of selected distribution | |
scale: variance of selected distribution | |
""" | |
actions = [sample.action for sample in samples] | |
labels = [sample.label for sample in samples] | |
if not values: # revert to initial distribution | |
print('GPT failed to return integer arrows') | |
loc = self.action_spec['loc'] | |
scale = self.action_spec['scale'] | |
elif len(values) == 1: # single response, add a distribution over it | |
index = np.where([label == str(values[-1]) for label in labels])[0][0] | |
action = actions[index] | |
print('action', action) | |
loc = action | |
scale = self.action_spec['min_scale'] | |
else: # fit distribution | |
selected_actions = [] | |
for value in values: | |
idx = np.where([label == str(value) for label in labels])[0][0] | |
selected_actions.append(actions[idx]) | |
print('selected_actions', selected_actions) | |
loc_scale = [ | |
scipy.stats.norm.fit([action[d] for action in selected_actions]) | |
for d in range(3) | |
] | |
loc = [loc_scale[d][0] for d in range(3)] | |
scale = np.clip( | |
[loc_scale[d][1] for d in range(3)], | |
self.action_spec['min_scale'], | |
None, | |
) | |
print('loc', loc, '\nscale', scale) | |
return loc, scale | |