pivot-iterative-visual-optimization commited on
Commit
5c80958
·
verified ·
1 Parent(s): afdb372

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +182 -0
  2. requirements.txt +6 -0
  3. vip.py +462 -0
  4. vip_runner.py +163 -0
  5. vip_utils.py +130 -0
  6. vlms.py +33 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visual Iterative Prompting Demo."""
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ from vip_runner import vip_runner
6
+ from vlms import GPT4V
7
+
8
+ # Adjust radius of annotations based on size of the image
9
+ radius_per_pixel = 0.05
10
+
11
+
12
+ def run_vip(
13
+ im,
14
+ query,
15
+ n_samples_init,
16
+ n_samples_opt,
17
+ n_iters,
18
+ n_recurssion,
19
+ openai_api_key,
20
+ progress=gr.Progress(track_tqdm=True),
21
+ ):
22
+
23
+ if not openai_api_key:
24
+ return [], 'Must provide OpenAI API Key'
25
+ if im is None:
26
+ return [], 'Must specify image'
27
+ if not query:
28
+ return [], 'Must specify description'
29
+
30
+ img_size = np.min(im.shape[:2])
31
+ print(int(img_size * radius_per_pixel))
32
+ # add some action spec
33
+ style = {
34
+ 'num_samples': 12,
35
+ 'circle_alpha': 0.6,
36
+ 'alpha': 0.8,
37
+ 'arrow_alpha': 0.0,
38
+ 'radius': int(img_size * radius_per_pixel),
39
+ 'thickness': 2,
40
+ 'fontsize': int(img_size * radius_per_pixel),
41
+ 'rgb_scale': 255,
42
+ 'focal_offset': 1, # camera distance / std of action in z
43
+ }
44
+
45
+ action_spec = {
46
+ 'loc': [0, 0, 0],
47
+ 'scale': [0.0, 100, 100],
48
+ 'min_scale': [0.0, 30, 30],
49
+ 'min': [0, -300.0, -300],
50
+ 'max': [0, 300, 300],
51
+ 'action_to_coord': 250,
52
+ 'robot': 'meta',
53
+ }
54
+
55
+ vlm = GPT4V(openai_api_key=openai_api_key)
56
+ ims, center, _ = vip_runner(
57
+ vlm,
58
+ im,
59
+ query,
60
+ style,
61
+ action_spec,
62
+ n_samples_init=n_samples_init,
63
+ n_samples_opt=n_samples_opt,
64
+ n_iters=n_iters,
65
+ recursion_level=n_recurssion,
66
+ )
67
+ return ims, f'Final selected coordinate: {np.round(center, decimals=0)}'
68
+
69
+
70
+ examples = [
71
+ {
72
+ 'im_path': 'ims/aloha.png',
73
+ 'desc': 'a point between the fork and the cup',
74
+ },
75
+ {
76
+ 'im_path': 'ims/robot.png',
77
+ 'desc': 'the toy in the middle of the table',
78
+ },
79
+ {
80
+ 'im_path': 'ims/parking.jpg',
81
+ 'desc': 'a place to park if I am handicapped',
82
+ },
83
+ {
84
+ 'im_path': 'ims/tools.png',
85
+ 'desc': 'what should I use pull a nail'
86
+ },
87
+ ]
88
+
89
+
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("""
92
+ # Visual Iterative Prompting Demo
93
+ The demo below showcases the Visual Iterative Prompting (VIP) algorithm.
94
+ Given an image and a description of an object or region,
95
+ VIP leverages a Vision-Language Model (VLM) to iteratively search for the point in the image that best corresponds to the description.
96
+ This is done through visual prompting, where instead of reasoning with text, the VLM reasons over images annotated with sampled points,
97
+ in order to pick the best points.
98
+ In each iteration, we take the points previously selected by the VLM, resample new points around the their mean, and repeat the process.
99
+
100
+ To get started, you can use the provided example image and query pairs, or
101
+ upload your own images.
102
+ This demo uses GPT-4V, so it requires an OpenAI API key.
103
+
104
+ To use the provided example images, you can right click on the image -> copy image, then click the clipboard icon in the Input Image box.
105
+
106
+ Hyperparameters to set:
107
+ * N Samples for Initialization - how many initial points are sampled for the first VIP iteration.
108
+ * N Samples for Optimiazation - how many points are sampled for subsequent iterations.
109
+ * N Iterations - how many optimization iterations to perform.
110
+ * N Ensemble Recursions - how many ensembles for recursive VIP.
111
+
112
+ Note that each iteration takes about ~10s, and each additional ensemble adds a multiple number of N Iterations.
113
+
114
+ After VIP finishes, the image gallery below will visualize VIP results throughout all the iterations.
115
+ There are two images for each iteration - the first one shows all the sampled points, and the second one shows which one VIP picked.
116
+ The Info textbox will show the final selected pixel coordinate that VIP converged to.
117
+ """.strip())
118
+
119
+ gr.Markdown(
120
+ '## Example Images and Queries\n Drag images into the image box below'
121
+ )
122
+ with gr.Row(equal_height=True):
123
+ for example in examples:
124
+ gr.Image(value=example['im_path'], label=example['desc'])
125
+
126
+ gr.Markdown('## New Query')
127
+ with gr.Row():
128
+ with gr.Column():
129
+ inp_im = gr.Image(label='Input Image', type='numpy', show_label=True)
130
+ inp_query = gr.Textbox(label='Description', lines=1)
131
+
132
+ with gr.Column():
133
+ inp_openai_api_key = gr.Textbox(
134
+ label='OpenAI API Key (not saved)', lines=1
135
+ )
136
+ with gr.Group():
137
+ inp_n_samples_init = gr.Slider(
138
+ label='N Samples for Initialization',
139
+ minimum=10,
140
+ maximum=40,
141
+ value=25,
142
+ step=1,
143
+ )
144
+ inp_n_samples_opt = gr.Slider(
145
+ label='N Samples for Optimization',
146
+ minimum=3,
147
+ maximum=20,
148
+ value=10,
149
+ step=1,
150
+ )
151
+ inp_n_iters = gr.Slider(
152
+ label='N Iterations', minimum=1, maximum=5, value=3, step=1
153
+ )
154
+ inp_n_recurssions = gr.Slider(
155
+ label='N Ensemble Recursions', minimum=0, maximum=3, value=0, step=1
156
+ )
157
+ btn_run = gr.Button('Run')
158
+
159
+ with gr.Group():
160
+ out_ims = gr.Gallery(
161
+ label='Images with Sampled and Chosen Points',
162
+ columns=4,
163
+ rows=1,
164
+ interactive=False,
165
+ )
166
+ out_info = gr.Textbox(label='Info', lines=1)
167
+
168
+ btn_run.click(
169
+ run_vip,
170
+ inputs=[
171
+ inp_im,
172
+ inp_query,
173
+ inp_n_samples_init,
174
+ inp_n_samples_opt,
175
+ inp_n_iters,
176
+ inp_n_recurssions,
177
+ inp_openai_api_key,
178
+ ],
179
+ outputs=[out_ims, out_info],
180
+ )
181
+
182
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ matplotlib
3
+ opencv-python
4
+ openai
5
+ gradio
6
+ scipy
vip.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=line-too-long
2
+ """Visual Iterative Prompting functions.
3
+
4
+ Copied from experimental/users/ichter/vip/vip.py
5
+
6
+ Code to implement visual iterative prompting, an approach for querying VLMs.
7
+ See go/visual-iterative-prompting for more information.
8
+
9
+ These are used within Colabs such as:
10
+ *
11
+ https://colab.corp.google.com/drive/1GnO-1urDCETWo3M3PpQKQ8TqT1Ql_jiS#scrollTo=5dUSoiz6Hplv
12
+ *
13
+ https://colab.corp.google.com/drive/14AYsa4W68NnsaREFTUX7lTkSxpD5eHCO#scrollTo=qA2A_oTcGTzN
14
+ *
15
+ https://colab.corp.google.com/drive/11H-WtHNYzBkr_lQpaa4ASeYy0HD29EXe#scrollTo=HapF0UIxdJM6
16
+ """
17
+
18
+ import copy
19
+ import dataclasses
20
+ import enum
21
+ import io
22
+ from typing import Optional, Tuple
23
+ import cv2
24
+ import matplotlib.pyplot as plt
25
+ import numpy as np
26
+ import scipy.stats
27
+ import vip_utils
28
+
29
+
30
+ @enum.unique
31
+ class SupportedEmbodiments(str, enum.Enum):
32
+ """Embodiments supported by VIP."""
33
+
34
+ META_MANIPULATION = 'meta_manipulation'
35
+ ALOHA_MANIPULATION = 'aloha_manipulation'
36
+ META_NAVIGATION = 'meta_navigation'
37
+
38
+
39
+ @dataclasses.dataclass()
40
+ class Coordinate:
41
+ """Coordinate with necessary information for visualizing annotation."""
42
+
43
+ # 2D image coordinates for the target annotation
44
+ xy: Tuple[int, int]
45
+ # Color and style of the coord.
46
+ color: Optional[float] = None
47
+ radius: Optional[int] = None
48
+
49
+
50
+ @dataclasses.dataclass()
51
+ class Sample:
52
+ """Single Sample mapping actions to Coordinates."""
53
+
54
+ # 2D or 3D action
55
+ action: np.ndarray
56
+ # Coordinates for the main annotation
57
+ coord: Coordinate
58
+ # Coordinates for the text label
59
+ text_coord: Coordinate
60
+ # Label to display in the text label
61
+ label: str
62
+
63
+
64
+ class VisualIterativePrompter:
65
+ """Visual Iterative Prompting class."""
66
+
67
+ def __init__(self, style, action_spec, embodiment):
68
+ self.embodiment = embodiment
69
+ self.style = style
70
+ self.action_spec = action_spec
71
+ self.fig_scale_size = None
72
+ # image preparer
73
+ # robot_to_image_canonical_coords
74
+
75
+ def action_to_coord(self, action, image, arm_xy, do_project=False):
76
+ """Converts candidate action to image coordinate."""
77
+ if (self.embodiment == SupportedEmbodiments.META_MANIPULATION or
78
+ self.embodiment == SupportedEmbodiments.ALOHA_MANIPULATION):
79
+ return self.manipulation_action_to_coord(
80
+ action=action, image=image, arm_xy=arm_xy, do_project=do_project
81
+ )
82
+ elif self.embodiment == SupportedEmbodiments.META_NAVIGATION:
83
+ return self.navigation_action_to_coord(
84
+ action=action, image=image, center_xy=arm_xy, do_project=do_project
85
+ )
86
+ else:
87
+ raise NotImplementedError('Embodiment not supported.')
88
+
89
+ def manipulation_action_to_coord(
90
+ self, action, image, arm_xy, do_project=False
91
+ ):
92
+ """Converts a ZXY or XY action to an image coordinate.
93
+
94
+ Conversion is done based on style['focal_offset'] and action_spec['scale'].
95
+
96
+ Args:
97
+ action: z, y, x action in robot action space
98
+ image: image
99
+ arm_xy: x, y in image space
100
+ do_project: whether or not to project actions sampled outside the image to
101
+ the edge of the image
102
+
103
+ Returns:
104
+ Dict coordinate with image x, y, arrow color, and circle radius.
105
+ """
106
+ # TODO(tedxiao): Refactor into common utiliy fns, add embodiment specific
107
+ # logic.
108
+ if self.action_spec['scale'][0] == 0: # no z dimension
109
+ norm_action = [(action[d] - self.action_spec['loc'][d]) /
110
+ (2 * self.action_spec['scale'][d]) for d in range(1, 3)]
111
+ norm_action_y, norm_action_x = norm_action
112
+ norm_action_z = 0
113
+ else:
114
+ norm_action = [(action[d] - self.action_spec['loc'][d]) /
115
+ (2 * self.action_spec['scale'][d]) for d in range(3)]
116
+ norm_action_z, norm_action_y, norm_action_x = norm_action
117
+ focal_length = np.max(
118
+ [0.2, # positive focal lengths only
119
+ self.style['focal_offset'] / (self.style['focal_offset'] + norm_action_z)])
120
+ image_x = arm_xy[0] - (
121
+ self.action_spec['action_to_coord'] * norm_action_x * focal_length
122
+ )
123
+ image_y = arm_xy[1] - (
124
+ self.action_spec['action_to_coord'] * norm_action_y * focal_length
125
+ )
126
+ if vip_utils.coord_outside_image(
127
+ coord=Coordinate(xy=(int(image_x), int(image_y))),
128
+ image=image,
129
+ radius=self.style['radius']) and do_project:
130
+ # project the arrow to the edge of the image if too large
131
+ height, width, _ = image.shape
132
+ max_x = (
133
+ width - arm_xy[0] - 2 * self.style['radius']
134
+ if norm_action_x < 0
135
+ else arm_xy[0] - 2 * self.style['radius']
136
+ )
137
+ max_y = (
138
+ height - arm_xy[1] - 2 * self.style['radius']
139
+ if norm_action_y < 0
140
+ else arm_xy[1] - 2 * self.style['radius']
141
+ )
142
+ rescale_ratio = min(np.abs([
143
+ max_x / (self.action_spec['action_to_coord'] * norm_action_x),
144
+ max_y / (self.action_spec['action_to_coord'] * norm_action_y)]))
145
+ image_x = (
146
+ arm_xy[0]
147
+ - self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio
148
+ )
149
+ image_y = (
150
+ arm_xy[1]
151
+ - self.action_spec['action_to_coord'] * norm_action_y * rescale_ratio
152
+ )
153
+
154
+ # blue is out of the page, red is into the page
155
+ red_z = self.style['rgb_scale'] * ((norm_action[0] + 1) / 2)
156
+ blue_z = self.style['rgb_scale'] * (1 - (norm_action[0] + 1) / 2)
157
+ color_z = np.clip(
158
+ (red_z, 0, blue_z),
159
+ 0, self.style['rgb_scale'])
160
+ radius_z = int(np.clip((0.75 - norm_action_z / 4) * self.style['radius'],
161
+ 0.5 * self.style['radius'], self.style['radius']))
162
+ return Coordinate(
163
+ xy=(int(image_x), int(image_y)),
164
+ color=color_z,
165
+ radius=radius_z,
166
+ )
167
+
168
+ def navigation_action_to_coord(
169
+ self, action, image, center_xy, do_project=False
170
+ ):
171
+ """Converts a ZXY or XY action to an image coordinate.
172
+
173
+ Conversion is done based on style['focal_offset'] and action_spec['scale'].
174
+
175
+ Args:
176
+ action: z, y, x action in robot action space
177
+ image: image
178
+ center_xy: x, y in image space
179
+ do_project: whether or not to project actions sampled outside the image to
180
+ the edge of the image
181
+
182
+ Returns:
183
+ Dict coordinate with image x, y, arrow color, and circle radius.
184
+ """
185
+ # TODO(tedxiao): Refactor into common utiliy fns, add embodiment specific
186
+ # logic.
187
+ if self.action_spec['scale'][0] == 0: # no z dimension
188
+ norm_action = [(action[d] - self.action_spec['loc'][d]) /
189
+ (2 * self.action_spec['scale'][d]) for d in range(1, 3)]
190
+ norm_action_y, norm_action_x = norm_action
191
+ norm_action_z = 0
192
+ else:
193
+ norm_action = [(action[d] - self.action_spec['loc'][d]) /
194
+ (2 * self.action_spec['scale'][d]) for d in range(3)]
195
+ norm_action_z, norm_action_y, norm_action_x = norm_action
196
+ focal_length = np.max(
197
+ [0.2, # positive focal lengths only
198
+ self.style['focal_offset'] / (self.style['focal_offset'] + norm_action_z)])
199
+ image_x = center_xy[0] - (
200
+ self.action_spec['action_to_coord'] * norm_action_x * focal_length
201
+ )
202
+ image_y = center_xy[1] - (
203
+ self.action_spec['action_to_coord'] * norm_action_y * focal_length
204
+ )
205
+ if (
206
+ vip_utils.coord_outside_image(
207
+ Coordinate(xy=(image_x, image_y)), image, self.style['radius']
208
+ )
209
+ and do_project
210
+ ):
211
+ # project the arrow to the edge of the image if too large
212
+ height, width, _ = image.shape
213
+ max_x = (
214
+ width - center_xy[0] - 2 * self.style['radius']
215
+ if norm_action_x < 0
216
+ else center_xy[0] - 2 * self.style['radius']
217
+ )
218
+ max_y = (
219
+ height - center_xy[1] - 2 * self.style['radius']
220
+ if norm_action_y < 0
221
+ else center_xy[1] - 2 * self.style['radius']
222
+ )
223
+ rescale_ratio = min(np.abs([
224
+ max_x / (self.action_spec['action_to_coord'] * norm_action_x),
225
+ max_y / (self.action_spec['action_to_coord'] * norm_action_y)]))
226
+ image_x = (
227
+ center_xy[0]
228
+ - self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio
229
+ )
230
+ image_y = (
231
+ center_xy[1]
232
+ - self.action_spec['action_to_coord'] * norm_action_y * rescale_ratio
233
+ )
234
+
235
+ return Coordinate(
236
+ xy=(int(image_x), int(image_y)),
237
+ color=0.1 * self.style['rgb_scale'],
238
+ radius=int(self.style['radius']),
239
+ )
240
+
241
+ def sample_actions(
242
+ self, image, arm_xy, loc, scale, true_action=None, max_itrs=1000
243
+ ):
244
+ """Sample actions from distribution.
245
+
246
+ Args:
247
+ image: image
248
+ arm_xy: x, y in image space of arm
249
+ loc: action distribution mean to sample from
250
+ scale: action distribution variance to sample from
251
+ true_action: action taken in demonstration if available
252
+ max_itrs: number of tries to get a valid sample
253
+
254
+ Returns:
255
+ samples: Samples with associated actions, coords, text_coords, labels.
256
+ """
257
+ image = copy.deepcopy(image)
258
+
259
+ samples = []
260
+ actions = []
261
+ coords = []
262
+ text_coords = []
263
+ labels = []
264
+
265
+ # Keep track of oracle action if available.
266
+ true_label = None
267
+ if true_action is not None:
268
+ actions.append(true_action)
269
+ coord = self.action_to_coord(true_action, image, arm_xy)
270
+ coords.append(coord)
271
+ text_coords.append(
272
+ vip_utils.coord_to_text_coord(coords[-1], arm_xy, coord.radius)
273
+ )
274
+ true_label = np.random.randint(self.style['num_samples'])
275
+ # labels.append(str(true_label) + '*')
276
+ labels.append(str(true_label))
277
+
278
+ # Generate all action samples.
279
+ for i in range(self.style['num_samples']):
280
+ if i == true_label:
281
+ continue
282
+ itrs = 0
283
+
284
+ # Generate action scaled appropriately.
285
+ action = np.clip(np.random.normal(loc, scale),
286
+ self.action_spec['min'], self.action_spec['max'])
287
+
288
+ # Convert sampled action to image coordinates.
289
+ coord = self.action_to_coord(action, image, arm_xy)
290
+
291
+ # Resample action if it results in invalid image annotation.
292
+ adjusted_scale = np.array(scale)
293
+ while ((vip_utils.is_invalid_coord(coord, coords, self.style['radius']*1.5, image)
294
+ or vip_utils.coord_outside_image(coord, image, self.style['radius']))
295
+ and itrs < max_itrs):
296
+ action = np.clip(np.random.normal(loc, adjusted_scale),
297
+ self.action_spec['min'], self.action_spec['max'])
298
+ coord = self.action_to_coord(action, image, arm_xy)
299
+ itrs += 1
300
+ # increase sampling range slightly if not finding a good sample
301
+ adjusted_scale *= 1.1
302
+ if itrs == max_itrs:
303
+ # If the final iteration results in invalid annotation, just clip
304
+ # to edge of image.
305
+ coord = self.action_to_coord(action, image, arm_xy, do_project=True)
306
+
307
+ # Compute image coordinates of text labels.
308
+ radius = coord.radius
309
+ text_coord = Coordinate(
310
+ xy=vip_utils.coord_to_text_coord(coord, arm_xy, radius)
311
+ )
312
+
313
+ actions.append(action)
314
+ coords.append(coord)
315
+ text_coords.append(text_coord)
316
+ labels.append(str(i))
317
+
318
+ for i in range(len(actions)):
319
+ sample = Sample(
320
+ action=actions[i],
321
+ coord=coords[i],
322
+ text_coord=text_coords[i],
323
+ label=str(i),
324
+ )
325
+ samples.append(sample)
326
+ return samples
327
+
328
+ def add_arrow_overlay_plt(self, image, samples, arm_xy, log_image=False):
329
+ """Add arrows and circles to the image.
330
+
331
+ Args:
332
+ image: image
333
+ samples: Samples to visualize.
334
+ arm_xy: x, y image coordinates for EEF center.
335
+ log_image: Boolean for whether to save to CNS.
336
+
337
+ Returns:
338
+ image: image with visual prompts.
339
+ """
340
+ # Add transparent arrows and circles
341
+ overlay = image.copy()
342
+ (original_image_height, original_image_width, _) = image.shape
343
+
344
+ white = (
345
+ self.style['rgb_scale'],
346
+ self.style['rgb_scale'],
347
+ self.style['rgb_scale'],
348
+ )
349
+
350
+ # Add arrows.
351
+ for sample in samples:
352
+ color = sample.coord.color
353
+ cv2.arrowedLine(
354
+ overlay, arm_xy, sample.coord.xy, color, self.style['thickness']
355
+ )
356
+ image = cv2.addWeighted(overlay, self.style['arrow_alpha'],
357
+ image, 1 - self.style['arrow_alpha'], 0)
358
+
359
+ overlay = image.copy()
360
+ # Add circles.
361
+ for sample in samples:
362
+ color = sample.coord.color
363
+ radius = sample.coord.radius
364
+ cv2.circle(
365
+ overlay,
366
+ sample.text_coord.xy,
367
+ radius,
368
+ color,
369
+ self.style['thickness'] + 1,
370
+ )
371
+ cv2.circle(overlay, sample.text_coord.xy, radius, white, -1)
372
+ image = cv2.addWeighted(overlay, self.style['circle_alpha'],
373
+ image, 1 - self.style['circle_alpha'], 0)
374
+
375
+ dpi = plt.rcParams['figure.dpi']
376
+ if self.fig_scale_size is None:
377
+ # test saving a figure to decide size for text figure
378
+ fig_size = (original_image_width / dpi, original_image_height / dpi)
379
+ plt.subplots(1, figsize=fig_size)
380
+ plt.imshow(image, cmap='binary')
381
+ plt.axis('off')
382
+ fig = plt.gcf()
383
+ fig.tight_layout(pad=0)
384
+ buf = io.BytesIO()
385
+ plt.savefig(buf, format='png')
386
+ plt.close()
387
+ buf.seek(0)
388
+ test_image = cv2.imdecode(
389
+ np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR)
390
+ self.fig_scale_size = original_image_width / test_image.shape[1]
391
+
392
+ # Add text to figure.
393
+ fig_size = (self.fig_scale_size * original_image_width / dpi,
394
+ self.fig_scale_size * original_image_height / dpi)
395
+ plt.subplots(1, figsize=fig_size)
396
+ plt.imshow(image, cmap='binary')
397
+ for sample in samples:
398
+ plt.text(
399
+ sample.text_coord.xy[0],
400
+ sample.text_coord.xy[1],
401
+ sample.label,
402
+ ha='center',
403
+ va='center',
404
+ color='k',
405
+ fontsize=self.style['fontsize'],
406
+ )
407
+
408
+ # Compile image.
409
+ plt.axis('off')
410
+ fig = plt.gcf()
411
+ fig.tight_layout(pad=0)
412
+ buf = io.BytesIO()
413
+ plt.savefig(buf, format='png')
414
+ plt.close()
415
+ image = cv2.imdecode(np.frombuffer(buf.getvalue(), dtype=np.uint8),
416
+ cv2.IMREAD_COLOR)
417
+
418
+ image = cv2.resize(image, (original_image_width, original_image_height))
419
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
420
+
421
+ # Optionally log images to CNS.
422
+ if log_image:
423
+ raise NotImplementedError('TODO: log image too CNS')
424
+ return image
425
+
426
+ def fit(self, values, samples):
427
+ """Fit a loc and scale to selected actions.
428
+
429
+ Args:
430
+ values: list of selected labels
431
+ samples: list of all Samples
432
+
433
+ Returns:
434
+ loc: mean of selected distribution
435
+ scale: variance of selected distribution
436
+ """
437
+ actions = [sample.action for sample in samples]
438
+ labels = [sample.label for sample in samples]
439
+
440
+ if not values: # revert to initial distribution
441
+ print('GPT failed to return integer arrows')
442
+ loc = self.action_spec['loc']
443
+ scale = self.action_spec['scale']
444
+ elif len(values) == 1: # single response, add a distribution over it
445
+ index = np.where([label == str(values[-1]) for label in labels])[0][0]
446
+ action = actions[index]
447
+ print('action', action)
448
+ loc = action
449
+ scale = self.action_spec["min_scale"]
450
+ else: # fit distribution
451
+ selected_actions = []
452
+ for value in values:
453
+ idx = np.where([label == str(value) for label in labels])[0][0]
454
+ selected_actions.append(actions[idx])
455
+ print('selected_actions', selected_actions)
456
+
457
+ loc_scale = [scipy.stats.norm.fit([action[d] for action in selected_actions]) for d in range(3)]
458
+ loc = [loc_scale[d][0] for d in range(3)]
459
+ scale = np.clip([loc_scale[d][1] for d in range(3)], self.action_spec['min_scale'], None)
460
+ print('loc', loc, '\nscale', scale)
461
+
462
+ return loc, scale
vip_runner.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VIP."""
2
+
3
+ import json
4
+ import re
5
+
6
+ import cv2
7
+ from tqdm import trange
8
+ import vip
9
+
10
+
11
+ def make_prompt(description, top_n=3):
12
+ return f"""
13
+ INSTRUCTIONS:
14
+ You are tasked to locate an object, region, or point in space in the given annotated image according to a description.
15
+ The image is annoated with numbered circles.
16
+ Choose the top {top_n} circles that have the most overlap with and/or is closest to what the description is describing in the image.
17
+ You are a five-time world champion in this game.
18
+ Give a one sentence analysis of why you chose those points.
19
+ Provide your answer at the end in a valid JSON of this format:
20
+
21
+ {{"points": []}}
22
+
23
+ DESCRIPTION: {description}
24
+ IMAGE:
25
+ """.strip()
26
+
27
+
28
+ def extract_json(response, key):
29
+ json_part = re.search(r"\{.*\}", response, re.DOTALL)
30
+ parsed_json = {}
31
+ if json_part:
32
+ json_data = json_part.group()
33
+ # Parse the JSON data
34
+ parsed_json = json.loads(json_data)
35
+ else:
36
+ print("No JSON data found ******\n", response)
37
+ return parsed_json[key]
38
+
39
+
40
+ def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
41
+ """Perform one selection pass given samples."""
42
+ image_circles_np = prompter.add_arrow_overlay_plt(
43
+ image=im, samples=samples, arm_xy=arm_coord, log_image=False
44
+ )
45
+
46
+ _, encoded_image_circles = cv2.imencode(".png", image_circles_np)
47
+
48
+ prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles]
49
+ response = vlm.query(prompt_seq)
50
+
51
+ arrow_ids = extract_json(response, "points")
52
+ return arrow_ids, image_circles_np
53
+
54
+
55
+ def vip_runner(
56
+ vlm,
57
+ im,
58
+ desc,
59
+ style,
60
+ action_spec,
61
+ n_samples_init=25,
62
+ n_samples_opt=10,
63
+ n_iters=3,
64
+ recursion_level=0,
65
+ ):
66
+ """VIP."""
67
+
68
+ prompter = vip.VisualIterativePrompter(
69
+ style, action_spec, vip.SupportedEmbodiments.META_NAVIGATION
70
+ )
71
+
72
+ output_ims = []
73
+ arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2))
74
+
75
+ if recursion_level == 0:
76
+ center_mean = action_spec["loc"]
77
+ center_std = action_spec["scale"]
78
+ selected_samples = []
79
+ for itr in trange(n_iters):
80
+ if itr == 0:
81
+ style["num_samples"] = n_samples_init
82
+ else:
83
+ style["num_samples"] = n_samples_opt
84
+ samples = prompter.sample_actions(im, arm_coord, center_mean, center_std)
85
+ arrow_ids, image_circles_np = vip_perform_selection(
86
+ prompter, vlm, im, desc, arm_coord, samples, top_n=3
87
+ )
88
+
89
+ # plot sampled circles as red
90
+ selected_samples = []
91
+ for selected_id in arrow_ids:
92
+ sample = samples[selected_id]
93
+ sample.coord.color = (255, 0, 0)
94
+ selected_samples.append(sample)
95
+ image_circles_marked_np = prompter.add_arrow_overlay_plt(
96
+ image_circles_np, selected_samples, arm_coord
97
+ )
98
+ output_ims.append(image_circles_marked_np)
99
+
100
+ # if at last iteration, pick one answer out of the selected ones
101
+ if itr == n_iters - 1:
102
+ arrow_ids, _ = vip_perform_selection(
103
+ prompter, vlm, im, desc, arm_coord, selected_samples, top_n=1
104
+ )
105
+
106
+ selected_samples = []
107
+ for selected_id in arrow_ids:
108
+ sample = samples[selected_id]
109
+ sample.coord.color = (255, 0, 0)
110
+ selected_samples.append(sample)
111
+ image_circles_marked_np = prompter.add_arrow_overlay_plt(
112
+ im, selected_samples, arm_coord
113
+ )
114
+ output_ims.append(image_circles_marked_np)
115
+ center_mean, center_std = prompter.fit(arrow_ids, samples)
116
+
117
+ if output_ims:
118
+ return (
119
+ output_ims,
120
+ prompter.action_to_coord(center_mean, im, arm_coord).xy,
121
+ selected_samples,
122
+ )
123
+ else:
124
+ new_samples = []
125
+ for i in range(3):
126
+ out_ims, _, cur_samples = vip_runner(
127
+ vlm=vlm,
128
+ im=im,
129
+ desc=desc,
130
+ style=style,
131
+ action_spec=action_spec,
132
+ n_samples_init=n_samples_init,
133
+ n_samples_opt=n_samples_opt,
134
+ n_iters=n_iters,
135
+ recursion_level=recursion_level - 1,
136
+ )
137
+ output_ims += out_ims
138
+ new_samples += cur_samples
139
+ # adjust sample label to avoid duplications
140
+ for sample_id in range(len(new_samples)):
141
+ new_samples[sample_id].label = str(sample_id)
142
+ arrow_ids, _ = vip_perform_selection(
143
+ prompter, vlm, im, desc, arm_coord, new_samples, top_n=1
144
+ )
145
+
146
+ selected_samples = []
147
+ for selected_id in arrow_ids:
148
+ sample = new_samples[selected_id]
149
+ sample.coord.color = (255, 0, 0)
150
+ selected_samples.append(sample)
151
+ image_circles_marked_np = prompter.add_arrow_overlay_plt(
152
+ im, selected_samples, arm_coord
153
+ )
154
+ output_ims.append(image_circles_marked_np)
155
+ center_mean, _ = prompter.fit(arrow_ids, new_samples)
156
+
157
+ if output_ims:
158
+ return (
159
+ output_ims,
160
+ prompter.action_to_coord(center_mean, im, arm_coord).xy,
161
+ selected_samples,
162
+ )
163
+ return [], "Unable to understand query"
vip_utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=line-too-long
2
+ """Utils for visual iterative prompting.
3
+
4
+ A number of utility functions for VIP.
5
+ """
6
+
7
+ import copy
8
+ import re
9
+
10
+ import numpy as np
11
+ import scipy.spatial.distance as distance
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ def min_dist(coord, coords):
16
+ if not coords:
17
+ return np.inf
18
+ xys = np.asarray([[coord.xy] for coord in coords])
19
+ return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min()
20
+
21
+
22
+ def coord_outside_image(coord, image, radius):
23
+ (height, image_width, _) = image.shape
24
+ x, y = coord.xy
25
+ x_outside = x > image_width - 2 * radius or x < 2 * radius
26
+ y_outside = y > height - 2 * radius or y < 2 * radius
27
+ return x_outside or y_outside
28
+
29
+
30
+ def is_invalid_coord(coord, coords, radius, image):
31
+ # invalid if too close to others or outside of the image
32
+ pos_overlaps = min_dist(coord, coords) < 1.5 * radius
33
+ return pos_overlaps or coord_outside_image(coord, image, radius)
34
+
35
+
36
+ def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40):
37
+ x, y = arm_coord
38
+ x += int(np.cos(angle) * mag)
39
+ y += int(np.sin(angle) * mag)
40
+ if is_circle:
41
+ x += int(np.cos(angle) * radius * np.sign(mag))
42
+ y += int(np.sin(angle) * radius * np.sign(mag))
43
+ return x, y
44
+
45
+
46
+ def coord_to_text_coord(coord, arm_coord, radius):
47
+ delta_coord = np.asarray(coord.xy) - arm_coord
48
+ if np.linalg.norm(delta_coord) == 0:
49
+ return arm_coord
50
+ return (
51
+ int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
52
+ int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)))
53
+
54
+
55
+ def prep_aloha_frames(real_frame):
56
+ """Prepare collage of ALOHA view frames."""
57
+ markup_frame = copy.deepcopy(real_frame)
58
+ top_frame = copy.deepcopy(markup_frame[
59
+ :int(markup_frame.shape[0] / 2), :int(markup_frame.shape[1] / 2)])
60
+ side_frame = copy.deepcopy(markup_frame[
61
+ int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)])
62
+ right_frame = copy.deepcopy(markup_frame[
63
+ int(markup_frame.shape[0] / 2):, int(markup_frame.shape[1] / 2):])
64
+ left_frame = copy.deepcopy(markup_frame[
65
+ :int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):])
66
+ markup_frame[int(markup_frame.shape[0] / 2):, :int(markup_frame.shape[1] / 2)] = left_frame
67
+ markup_frame[:int(markup_frame.shape[0] / 2), int(markup_frame.shape[1] / 2):] = side_frame
68
+ return markup_frame, right_frame, left_frame
69
+
70
+
71
+ def parse_response(response, answer_key='Arrow: ['):
72
+ values = []
73
+ if answer_key in response:
74
+ print('parse_response from answer_key')
75
+ arrow_response = response.split(answer_key)[-1].split(']')[0]
76
+ for val in map(int, re.findall(r'\d+', arrow_response)):
77
+ values.append(val)
78
+ else:
79
+ print('parse_response for all ints')
80
+ for val in map(int, re.findall(r'\d+', response)):
81
+ values.append(val)
82
+ return values
83
+
84
+
85
+ # TODO(ichter): normalize values by std
86
+ def compute_errors(action, true_action, verbose=False):
87
+ """Compute errors between a predicted action and true action."""
88
+ l2_error = np.linalg.norm(action - true_action)
89
+ cos_sim = 1 - distance.cosine(action, true_action)
90
+ l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
91
+ cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
92
+ z_error = np.abs(action[0] - true_action[0])
93
+ errors = {'l2': l2_error,
94
+ 'cos_sim': cos_sim,
95
+ 'l2_xy_error': l2_xy_error,
96
+ 'cos_xy_sim': cos_xy_sim,
97
+ 'z_error': z_error}
98
+
99
+ if verbose:
100
+ print('action: \t', [f'{a:.3f}' for a in action])
101
+ print('true_action \t', [f'{a:.3f}' for a in true_action])
102
+ print(f'l2: \t\t{l2_error:.3f}')
103
+ print(f'l2_xy_error: \t{l2_xy_error:.3f}')
104
+ print(f'cos_sim: \t{cos_sim:.3f}')
105
+ print(f'cos_xy_sim: \t{cos_xy_sim:.3f}')
106
+ print(f'z_error: \t{z_error:.3f}')
107
+
108
+ return errors
109
+
110
+
111
+ def plot_errors(all_errors, error_types=None):
112
+ """Plot errors across iterations."""
113
+ if error_types is None:
114
+ error_types = ['l2', 'l2_xy_error', 'z_error', 'cos_sim', 'cos_xy_sim',]
115
+
116
+ _, axs = plt.subplots(2, 3, figsize=(15, 8))
117
+ for i, error_type in enumerate(error_types): # go through each error type
118
+ all_iter_errors = {}
119
+ for error_by_iter in all_errors: # go through each call
120
+ for itr in error_by_iter: # go through each iteration
121
+ if itr in all_iter_errors: # add error to the iteration it happened
122
+ all_iter_errors[itr].append(error_by_iter[itr][error_type])
123
+ else:
124
+ all_iter_errors[itr] = [error_by_iter[itr][error_type]]
125
+
126
+ mean_iter_errors = [np.mean(all_iter_errors[itr]) for itr in all_iter_errors]
127
+
128
+ axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
129
+ axs[i // 3, i % 3].set_title(error_type)
130
+ plt.show()
vlms.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VLM Helper Functions."""
2
+ import base64
3
+ import numpy as np
4
+ from openai import OpenAI
5
+
6
+
7
+ class GPT4V:
8
+ """GPT4V VLM."""
9
+
10
+ def __init__(self, openai_api_key):
11
+ self.client = OpenAI(api_key=openai_api_key)
12
+
13
+ def query(self, prompt_seq, temperature=0, max_tokens=512):
14
+ """Queries GPT-4V."""
15
+ content = []
16
+ for elem in prompt_seq:
17
+ if isinstance(elem, str):
18
+ content.append({'type': 'text', 'text': elem})
19
+ elif isinstance(elem, np.ndarray):
20
+ base64_image_str = base64.b64encode(elem).decode('utf-8')
21
+ image_url = f'data:image/jpeg;base64,{base64_image_str}'
22
+ content.append({'type': 'image_url', 'image_url': {'url': image_url}})
23
+
24
+ messages = [{'role': 'user', 'content': content}]
25
+
26
+ response = self.client.chat.completions.create(
27
+ model='gpt-4-vision-preview',
28
+ messages=messages,
29
+ temperature=temperature,
30
+ max_tokens=max_tokens
31
+ )
32
+
33
+ return response.choices[0].message.content