pivot-prompt commited on
Commit
660daa9
1 Parent(s): cd8d52a

Add application file

Browse files
Files changed (10) hide show
  1. app.py +193 -0
  2. ims/aloha.png +0 -0
  3. ims/parking.jpg +0 -0
  4. ims/robot.png +0 -0
  5. ims/tools.png +0 -0
  6. requirements.txt +6 -0
  7. vip.py +397 -0
  8. vip_runner.py +153 -0
  9. vip_utils.py +122 -0
  10. vlms.py +33 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PIVOT Demo."""
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ from vip_runner import vip_runner
6
+ from vlms import GPT4V
7
+
8
+ # Adjust radius of annotations based on size of the image
9
+ radius_per_pixel = 0.05
10
+
11
+
12
+ def run_vip(
13
+ im,
14
+ query,
15
+ n_samples_init,
16
+ n_samples_opt,
17
+ n_iters,
18
+ n_parallel_trials,
19
+ openai_api_key,
20
+ progress=gr.Progress(track_tqdm=False),
21
+ ):
22
+
23
+ if not openai_api_key:
24
+ return [], 'Must provide OpenAI API Key'
25
+ if im is None:
26
+ return [], 'Must specify image'
27
+ if not query:
28
+ return [], 'Must specify description'
29
+
30
+ img_size = np.min(im.shape[:2])
31
+ print(int(img_size * radius_per_pixel))
32
+ # add some action spec
33
+ style = {
34
+ 'num_samples': 12,
35
+ 'circle_alpha': 0.6,
36
+ 'alpha': 0.8,
37
+ 'arrow_alpha': 0.0,
38
+ 'radius': int(img_size * radius_per_pixel),
39
+ 'thickness': 2,
40
+ 'fontsize': int(img_size * radius_per_pixel),
41
+ 'rgb_scale': 255,
42
+ 'focal_offset': 1, # camera distance / std of action in z
43
+ }
44
+
45
+ action_spec = {
46
+ 'loc': [0, 0, 0],
47
+ 'scale': [0.0, 100, 100],
48
+ 'min_scale': [0.0, 30, 30],
49
+ 'min': [0, -300.0, -300],
50
+ 'max': [0, 300, 300],
51
+ 'action_to_coord': 250,
52
+ 'robot': None,
53
+ }
54
+
55
+ vlm = GPT4V(openai_api_key=openai_api_key)
56
+ vip_gen = vip_runner(
57
+ vlm,
58
+ im,
59
+ query,
60
+ style,
61
+ action_spec,
62
+ n_samples_init=n_samples_init,
63
+ n_samples_opt=n_samples_opt,
64
+ n_iters=n_iters,
65
+ n_parallel_trials=n_parallel_trials,
66
+ )
67
+ for rst in vip_gen:
68
+ yield rst
69
+
70
+
71
+ examples = [
72
+ {
73
+ 'im_path': 'ims/aloha.png',
74
+ 'desc': 'a point between the fork and the cup',
75
+ },
76
+ {
77
+ 'im_path': 'ims/robot.png',
78
+ 'desc': 'the toy in the middle of the table',
79
+ },
80
+ {
81
+ 'im_path': 'ims/parking.jpg',
82
+ 'desc': 'a place to park if I am handicapped',
83
+ },
84
+ {
85
+ 'im_path': 'ims/tools.png',
86
+ 'desc': 'what should I use pull a nail'
87
+ },
88
+ ]
89
+
90
+
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("""
93
+ # PIVOT: Prompting with Iterative Visual Optimization
94
+ The demo below showcases a version of the PIVOT algorithm, which uses iterative visual prompts to optimize and guide the reasoning of Vision-Langauge-Models (VLMs).
95
+ Given an image and a description of an object or region,
96
+ PIVOT iteratively searches for the point in the image that best corresponds to the description.
97
+ This is done through visual prompting, where instead of reasoning with text, the VLM reasons over images annotated with sampled points,
98
+ in order to pick the best points.
99
+ In each iteration, we take the points previously selected by the VLM, resample new points around the their mean, and repeat the process.
100
+
101
+ To get started, you can use the provided example image and query pairs, or
102
+ upload your own images.
103
+ This demo uses GPT-4V, so it requires an OpenAI API key.
104
+
105
+ Hyperparameters to set:
106
+ * N Samples for Initialization - how many initial points are sampled for the first PIVOT iteration.
107
+ * N Samples for Optimiazation - how many points are sampled for subsequent iterations.
108
+ * N Iterations - how many optimization iterations to perform.
109
+ * N Ensemble Recursions - how many ensembles for recursive PIVOT.
110
+
111
+ Note that each iteration takes about ~10s, and each additional ensemble adds a multiple number of N Iterations.
112
+
113
+ After PIVOT finishes, the image gallery below will visualize PIVOT results throughout all the iterations.
114
+ There are two images for each iteration - the first one shows all the sampled points, and the second one shows which one PIVOT picked.
115
+ The Info textbox will show the final selected pixel coordinate that PIVOT converged to.
116
+
117
+ **To use the example images, right click on the image -> copy image, then click the clipboard icon in the Input Image box.**
118
+ """.strip())
119
+
120
+ gr.Markdown(
121
+ '## Example Images and Queries\n Drag images into the image box below (Try safari on Mac if dragging does not work)'
122
+ )
123
+ with gr.Row(equal_height=True):
124
+ for example in examples:
125
+ gr.Image(value=example['im_path'], type='numpy', label=example['desc'])
126
+
127
+ gr.Markdown('## New Query')
128
+ with gr.Row():
129
+ with gr.Column():
130
+ inp_im = gr.Image(
131
+ label='Input Image',
132
+ type='numpy',
133
+ show_label=True,
134
+ value=examples[0]['im_path'],
135
+ )
136
+ inp_query = gr.Textbox(
137
+ label='Description',
138
+ lines=1,
139
+ placeholder=examples[0]['desc'],
140
+ )
141
+
142
+ with gr.Column():
143
+ inp_openai_api_key = gr.Textbox(
144
+ label='OpenAI API Key (not saved)', lines=1
145
+ )
146
+ with gr.Group():
147
+ inp_n_samples_init = gr.Slider(
148
+ label='N Samples for Initialization',
149
+ minimum=10,
150
+ maximum=40,
151
+ value=25,
152
+ step=1,
153
+ )
154
+ inp_n_samples_opt = gr.Slider(
155
+ label='N Samples for Optimization',
156
+ minimum=3,
157
+ maximum=20,
158
+ value=10,
159
+ step=1,
160
+ )
161
+ inp_n_iters = gr.Slider(
162
+ label='N Iterations', minimum=1, maximum=5, value=3, step=1
163
+ )
164
+ inp_n_parallel_trials = gr.Slider(
165
+ label='N Parallel Trials', minimum=1, maximum=3, value=1, step=1
166
+ )
167
+ btn_run = gr.Button('Run')
168
+
169
+ with gr.Group():
170
+ out_ims = gr.Gallery(
171
+ label='Images with Sampled and Chosen Points',
172
+ columns=4,
173
+ rows=1,
174
+ interactive=False,
175
+ object_fit="contain", height="auto"
176
+ )
177
+ out_info = gr.Textbox(label='Info', lines=1)
178
+
179
+ btn_run.click(
180
+ run_vip,
181
+ inputs=[
182
+ inp_im,
183
+ inp_query,
184
+ inp_n_samples_init,
185
+ inp_n_samples_opt,
186
+ inp_n_iters,
187
+ inp_n_parallel_trials,
188
+ inp_openai_api_key,
189
+ ],
190
+ outputs=[out_ims, out_info],
191
+ )
192
+
193
+ demo.launch()
ims/aloha.png ADDED
ims/parking.jpg ADDED
ims/robot.png ADDED
ims/tools.png ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ matplotlib
3
+ opencv-python
4
+ openai
5
+ gradio
6
+ scipy
vip.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visual Iterative Prompting functions.
2
+
3
+ Code to implement visual iterative prompting, an approach for querying VLMs.
4
+ """
5
+
6
+ import copy
7
+ import dataclasses
8
+ import enum
9
+ import io
10
+ from typing import Optional, Tuple
11
+ import cv2
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ import scipy.stats
15
+ import vip_utils
16
+
17
+
18
+ @enum.unique
19
+ class SupportedEmbodiments(str, enum.Enum):
20
+ """Embodiments supported by VIP."""
21
+
22
+ HF_DEMO = 'hf_demo'
23
+
24
+
25
+ @dataclasses.dataclass()
26
+ class Coordinate:
27
+ """Coordinate with necessary information for visualizing annotation."""
28
+
29
+ # 2D image coordinates for the target annotation
30
+ xy: Tuple[int, int]
31
+ # Color and style of the coord.
32
+ color: Optional[float] = None
33
+ radius: Optional[int] = None
34
+
35
+
36
+ @dataclasses.dataclass()
37
+ class Sample:
38
+ """Single Sample mapping actions to Coordinates."""
39
+
40
+ # 2D or 3D action
41
+ action: np.ndarray
42
+ # Coordinates for the main annotation
43
+ coord: Coordinate
44
+ # Coordinates for the text label
45
+ text_coord: Coordinate
46
+ # Label to display in the text label
47
+ label: str
48
+
49
+
50
+ class VisualIterativePrompter:
51
+ """Visual Iterative Prompting class."""
52
+
53
+ def __init__(self, style, action_spec, embodiment):
54
+ self.embodiment = embodiment
55
+ self.style = style
56
+ self.action_spec = action_spec
57
+ self.fig_scale_size = None
58
+ # image preparer
59
+ # robot_to_image_canonical_coords
60
+
61
+ def action_to_coord(self, action, image, arm_xy, do_project=False):
62
+ """Converts candidate action to image coordinate."""
63
+ return self.navigation_action_to_coord(
64
+ action=action, image=image, center_xy=arm_xy, do_project=do_project
65
+ )
66
+
67
+ def navigation_action_to_coord(
68
+ self, action, image, center_xy, do_project=False
69
+ ):
70
+ """Converts a ZXY or XY action to an image coordinate.
71
+
72
+ Conversion is done based on style['focal_offset'] and action_spec['scale'].
73
+
74
+ Args:
75
+ action: z, y, x action in robot action space
76
+ image: image
77
+ center_xy: x, y in image space
78
+ do_project: whether or not to project actions sampled outside the image to
79
+ the edge of the image
80
+
81
+ Returns:
82
+ Dict coordinate with image x, y, arrow color, and circle radius.
83
+ """
84
+ if self.action_spec['scale'][0] == 0: # no z dimension
85
+ norm_action = [
86
+ (action[d] - self.action_spec['loc'][d])
87
+ / (2 * self.action_spec['scale'][d])
88
+ for d in range(1, 3)
89
+ ]
90
+ norm_action_y, norm_action_x = norm_action
91
+ norm_action_z = 0
92
+ else:
93
+ norm_action = [
94
+ (action[d] - self.action_spec['loc'][d])
95
+ / (2 * self.action_spec['scale'][d])
96
+ for d in range(3)
97
+ ]
98
+ norm_action_z, norm_action_y, norm_action_x = norm_action
99
+ focal_length = np.max([
100
+ 0.2, # positive focal lengths only
101
+ self.style['focal_offset']
102
+ / (self.style['focal_offset'] + norm_action_z),
103
+ ])
104
+ image_x = center_xy[0] - (
105
+ self.action_spec['action_to_coord'] * norm_action_x * focal_length
106
+ )
107
+ image_y = center_xy[1] - (
108
+ self.action_spec['action_to_coord'] * norm_action_y * focal_length
109
+ )
110
+ if (
111
+ vip_utils.coord_outside_image(
112
+ Coordinate(xy=(image_x, image_y)), image, self.style['radius']
113
+ )
114
+ and do_project
115
+ ):
116
+ # project the arrow to the edge of the image if too large
117
+ height, width, _ = image.shape
118
+ max_x = (
119
+ width - center_xy[0] - 2 * self.style['radius']
120
+ if norm_action_x < 0
121
+ else center_xy[0] - 2 * self.style['radius']
122
+ )
123
+ max_y = (
124
+ height - center_xy[1] - 2 * self.style['radius']
125
+ if norm_action_y < 0
126
+ else center_xy[1] - 2 * self.style['radius']
127
+ )
128
+ rescale_ratio = min(
129
+ np.abs([
130
+ max_x / (self.action_spec['action_to_coord'] * norm_action_x),
131
+ max_y / (self.action_spec['action_to_coord'] * norm_action_y),
132
+ ])
133
+ )
134
+ image_x = (
135
+ center_xy[0]
136
+ - self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio
137
+ )
138
+ image_y = (
139
+ center_xy[1]
140
+ - self.action_spec['action_to_coord'] * norm_action_y * rescale_ratio
141
+ )
142
+
143
+ return Coordinate(
144
+ xy=(int(image_x), int(image_y)),
145
+ color=0.1 * self.style['rgb_scale'],
146
+ radius=int(self.style['radius']),
147
+ )
148
+
149
+ def sample_actions(
150
+ self, image, arm_xy, loc, scale, true_action=None, max_itrs=1000
151
+ ):
152
+ """Sample actions from distribution.
153
+
154
+ Args:
155
+ image: image
156
+ arm_xy: x, y in image space of arm
157
+ loc: action distribution mean to sample from
158
+ scale: action distribution variance to sample from
159
+ true_action: action taken in demonstration if available
160
+ max_itrs: number of tries to get a valid sample
161
+
162
+ Returns:
163
+ samples: Samples with associated actions, coords, text_coords, labels.
164
+ """
165
+ image = copy.deepcopy(image)
166
+
167
+ samples = []
168
+ actions = []
169
+ coords = []
170
+ text_coords = []
171
+ labels = []
172
+
173
+ # Keep track of oracle action if available.
174
+ true_label = None
175
+ if true_action is not None:
176
+ actions.append(true_action)
177
+ coord = self.action_to_coord(true_action, image, arm_xy)
178
+ coords.append(coord)
179
+ text_coords.append(
180
+ vip_utils.coord_to_text_coord(coords[-1], arm_xy, coord.radius)
181
+ )
182
+ true_label = np.random.randint(self.style['num_samples'])
183
+ # labels.append(str(true_label) + '*')
184
+ labels.append(str(true_label))
185
+
186
+ # Generate all action samples.
187
+ for i in range(self.style['num_samples']):
188
+ if i == true_label:
189
+ continue
190
+ itrs = 0
191
+
192
+ # Generate action scaled appropriately.
193
+ action = np.clip(
194
+ np.random.normal(loc, scale),
195
+ self.action_spec['min'],
196
+ self.action_spec['max'],
197
+ )
198
+
199
+ # Convert sampled action to image coordinates.
200
+ coord = self.action_to_coord(action, image, arm_xy)
201
+
202
+ # Resample action if it results in invalid image annotation.
203
+ adjusted_scale = np.array(scale)
204
+ while (
205
+ vip_utils.is_invalid_coord(
206
+ coord, coords, self.style['radius'] * 1.5, image
207
+ )
208
+ or vip_utils.coord_outside_image(coord, image, self.style['radius'])
209
+ ) and itrs < max_itrs:
210
+ action = np.clip(
211
+ np.random.normal(loc, adjusted_scale),
212
+ self.action_spec['min'],
213
+ self.action_spec['max'],
214
+ )
215
+ coord = self.action_to_coord(action, image, arm_xy)
216
+ itrs += 1
217
+ # increase sampling range slightly if not finding a good sample
218
+ adjusted_scale *= 1.1
219
+ if itrs == max_itrs:
220
+ # If the final iteration results in invalid annotation, just clip
221
+ # to edge of image.
222
+ coord = self.action_to_coord(action, image, arm_xy, do_project=True)
223
+
224
+ # Compute image coordinates of text labels.
225
+ radius = coord.radius
226
+ text_coord = Coordinate(
227
+ xy=vip_utils.coord_to_text_coord(coord, arm_xy, radius)
228
+ )
229
+
230
+ actions.append(action)
231
+ coords.append(coord)
232
+ text_coords.append(text_coord)
233
+ labels.append(str(i))
234
+
235
+ for i in range(len(actions)):
236
+ sample = Sample(
237
+ action=actions[i],
238
+ coord=coords[i],
239
+ text_coord=text_coords[i],
240
+ label=str(i),
241
+ )
242
+ samples.append(sample)
243
+ return samples
244
+
245
+ def add_arrow_overlay_plt(self, image, samples, arm_xy):
246
+ """Add arrows and circles to the image.
247
+
248
+ Args:
249
+ image: image
250
+ samples: Samples to visualize.
251
+ arm_xy: x, y image coordinates for EEF center.
252
+ log_image: Boolean for whether to save to CNS.
253
+
254
+ Returns:
255
+ image: image with visual prompts.
256
+ """
257
+ # Add transparent arrows and circles
258
+ overlay = image.copy()
259
+ (original_image_height, original_image_width, _) = image.shape
260
+
261
+ white = (
262
+ self.style['rgb_scale'],
263
+ self.style['rgb_scale'],
264
+ self.style['rgb_scale'],
265
+ )
266
+
267
+ # Add arrows.
268
+ for sample in samples:
269
+ color = sample.coord.color
270
+ cv2.arrowedLine(
271
+ overlay, arm_xy, sample.coord.xy, color, self.style['thickness']
272
+ )
273
+ image = cv2.addWeighted(
274
+ overlay,
275
+ self.style['arrow_alpha'],
276
+ image,
277
+ 1 - self.style['arrow_alpha'],
278
+ 0,
279
+ )
280
+
281
+ overlay = image.copy()
282
+ # Add circles.
283
+ for sample in samples:
284
+ color = sample.coord.color
285
+ radius = sample.coord.radius
286
+ cv2.circle(
287
+ overlay,
288
+ sample.text_coord.xy,
289
+ radius,
290
+ color,
291
+ self.style['thickness'] + 1,
292
+ )
293
+ cv2.circle(overlay, sample.text_coord.xy, radius, white, -1)
294
+ image = cv2.addWeighted(
295
+ overlay,
296
+ self.style['circle_alpha'],
297
+ image,
298
+ 1 - self.style['circle_alpha'],
299
+ 0,
300
+ )
301
+
302
+ dpi = plt.rcParams['figure.dpi']
303
+ if self.fig_scale_size is None:
304
+ # test saving a figure to decide size for text figure
305
+ fig_size = (original_image_width / dpi, original_image_height / dpi)
306
+ plt.subplots(1, figsize=fig_size)
307
+ plt.imshow(image, cmap='binary')
308
+ plt.axis('off')
309
+ fig = plt.gcf()
310
+ fig.tight_layout(pad=0)
311
+ buf = io.BytesIO()
312
+ plt.savefig(buf, format='png')
313
+ plt.close()
314
+ buf.seek(0)
315
+ test_image = cv2.imdecode(
316
+ np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR
317
+ )
318
+ self.fig_scale_size = original_image_width / test_image.shape[1]
319
+
320
+ # Add text to figure.
321
+ fig_size = (
322
+ self.fig_scale_size * original_image_width / dpi,
323
+ self.fig_scale_size * original_image_height / dpi,
324
+ )
325
+ plt.subplots(1, figsize=fig_size)
326
+ plt.imshow(image, cmap='binary')
327
+ for sample in samples:
328
+ plt.text(
329
+ sample.text_coord.xy[0],
330
+ sample.text_coord.xy[1],
331
+ sample.label,
332
+ ha='center',
333
+ va='center',
334
+ color='k',
335
+ fontsize=self.style['fontsize'],
336
+ )
337
+
338
+ # Compile image.
339
+ plt.axis('off')
340
+ fig = plt.gcf()
341
+ fig.tight_layout(pad=0)
342
+ buf = io.BytesIO()
343
+ plt.savefig(buf, format='png')
344
+ plt.close()
345
+ image = cv2.imdecode(
346
+ np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR
347
+ )
348
+
349
+ image = cv2.resize(image, (original_image_width, original_image_height))
350
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
351
+
352
+ return image
353
+
354
+ def fit(self, values, samples):
355
+ """Fit a loc and scale to selected actions.
356
+
357
+ Args:
358
+ values: list of selected labels
359
+ samples: list of all Samples
360
+
361
+ Returns:
362
+ loc: mean of selected distribution
363
+ scale: variance of selected distribution
364
+ """
365
+ actions = [sample.action for sample in samples]
366
+ labels = [sample.label for sample in samples]
367
+
368
+ if not values: # revert to initial distribution
369
+ print('GPT failed to return integer arrows')
370
+ loc = self.action_spec['loc']
371
+ scale = self.action_spec['scale']
372
+ elif len(values) == 1: # single response, add a distribution over it
373
+ index = np.where([label == str(values[-1]) for label in labels])[0][0]
374
+ action = actions[index]
375
+ print('action', action)
376
+ loc = action
377
+ scale = self.action_spec['min_scale']
378
+ else: # fit distribution
379
+ selected_actions = []
380
+ for value in values:
381
+ idx = np.where([label == str(value) for label in labels])[0][0]
382
+ selected_actions.append(actions[idx])
383
+ print('selected_actions', selected_actions)
384
+
385
+ loc_scale = [
386
+ scipy.stats.norm.fit([action[d] for action in selected_actions])
387
+ for d in range(3)
388
+ ]
389
+ loc = [loc_scale[d][0] for d in range(3)]
390
+ scale = np.clip(
391
+ [loc_scale[d][1] for d in range(3)],
392
+ self.action_spec['min_scale'],
393
+ None,
394
+ )
395
+ print('loc', loc, '\nscale', scale)
396
+
397
+ return loc, scale
vip_runner.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VIP."""
2
+
3
+ import json
4
+ import re
5
+
6
+ import cv2
7
+ from tqdm import trange
8
+ import numpy as np
9
+ import vip
10
+
11
+
12
+ def make_prompt(description, top_n=3):
13
+ return f"""
14
+ INSTRUCTIONS:
15
+ You are tasked to locate an object, region, or point in space in the given annotated image according to a description.
16
+ The image is annoated with numbered circles.
17
+ Choose the top {top_n} circles that have the most overlap with and/or is closest to what the description is describing in the image.
18
+ You are a five-time world champion in this game.
19
+ Give a one sentence analysis of why you chose those points.
20
+ Provide your answer at the end in a valid JSON of this format:
21
+
22
+ {{"points": []}}
23
+
24
+ DESCRIPTION: {description}
25
+ IMAGE:
26
+ """.strip()
27
+
28
+
29
+ def extract_json(response, key):
30
+ json_part = re.search(r"\{.*\}", response, re.DOTALL)
31
+ parsed_json = {}
32
+ if json_part:
33
+ json_data = json_part.group()
34
+ # Parse the JSON data
35
+ parsed_json = json.loads(json_data)
36
+ else:
37
+ print("No JSON data found ******\n", response)
38
+ return parsed_json[key]
39
+
40
+
41
+ def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
42
+ """Perform one selection pass given samples."""
43
+ image_circles_np = prompter.add_arrow_overlay_plt(
44
+ image=im, samples=samples, arm_xy=arm_coord, log_image=False
45
+ )
46
+
47
+ _, encoded_image_circles = cv2.imencode(".png", image_circles_np)
48
+
49
+ prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles]
50
+ response = vlm.query(prompt_seq)
51
+
52
+ try:
53
+ arrow_ids = extract_json(response, "points")
54
+ except Exception as e:
55
+ print(e)
56
+ arrow_ids = []
57
+ return arrow_ids, image_circles_np
58
+
59
+
60
+ def vip_runner(
61
+ vlm,
62
+ im,
63
+ desc,
64
+ style,
65
+ action_spec,
66
+ n_samples_init=25,
67
+ n_samples_opt=10,
68
+ n_iters=3,
69
+ n_parallel_trials=1,
70
+ ):
71
+ """VIP."""
72
+
73
+ prompter = vip.VisualIterativePrompter(
74
+ style, action_spec, vip.SupportedEmbodiments.HF_DEMO
75
+ )
76
+
77
+ output_ims = []
78
+ arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2))
79
+
80
+ new_samples = []
81
+ center_mean = action_spec["loc"]
82
+ for i in range(n_parallel_trials):
83
+ center_mean = action_spec["loc"]
84
+ center_std = action_spec["scale"]
85
+ for itr in trange(n_iters):
86
+ if itr == 0:
87
+ style["num_samples"] = n_samples_init
88
+ else:
89
+ style["num_samples"] = n_samples_opt
90
+ samples = prompter.sample_actions(im, arm_coord, center_mean, center_std)
91
+ arrow_ids, image_circles_np = vip_perform_selection(
92
+ prompter, vlm, im, desc, arm_coord, samples, top_n=3
93
+ )
94
+
95
+ # plot sampled circles as red
96
+ selected_samples = []
97
+ for selected_id in arrow_ids:
98
+ sample = samples[selected_id]
99
+ sample.coord.color = (255, 0, 0)
100
+ selected_samples.append(sample)
101
+ image_circles_marked_np = prompter.add_arrow_overlay_plt(
102
+ image_circles_np, selected_samples, arm_coord
103
+ )
104
+ output_ims.append(image_circles_marked_np)
105
+ yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} iteration {itr+1}/{n_iters}. Still working..."
106
+
107
+ # if at last iteration, pick one answer out of the selected ones
108
+ if itr == n_iters - 1:
109
+ arrow_ids, _ = vip_perform_selection(
110
+ prompter, vlm, im, desc, arm_coord, selected_samples, top_n=1
111
+ )
112
+
113
+ selected_samples = []
114
+ for selected_id in arrow_ids:
115
+ sample = samples[selected_id]
116
+ sample.coord.color = (255, 0, 0)
117
+ selected_samples.append(sample)
118
+ image_circles_marked_np = prompter.add_arrow_overlay_plt(
119
+ im, selected_samples, arm_coord
120
+ )
121
+ output_ims.append(image_circles_marked_np)
122
+ new_samples += selected_samples
123
+ yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} last iteration. Still working..."
124
+ center_mean, center_std = prompter.fit(arrow_ids, samples)
125
+
126
+ if n_parallel_trials > 1:
127
+ # adjust sample label to avoid duplications
128
+ for sample_id in range(len(new_samples)):
129
+ new_samples[sample_id].label = str(sample_id)
130
+ arrow_ids, _ = vip_perform_selection(
131
+ prompter, vlm, im, desc, arm_coord, new_samples, top_n=1
132
+ )
133
+
134
+ selected_samples = []
135
+ for selected_id in arrow_ids:
136
+ sample = new_samples[selected_id]
137
+ sample.coord.color = (255, 0, 0)
138
+ selected_samples.append(sample)
139
+ image_circles_marked_np = prompter.add_arrow_overlay_plt(
140
+ im, selected_samples, arm_coord
141
+ )
142
+ output_ims.append(image_circles_marked_np)
143
+ center_mean, _ = prompter.fit(arrow_ids, new_samples)
144
+
145
+ if output_ims:
146
+ yield (
147
+ output_ims,
148
+ (
149
+ "Final selected coordinate:"
150
+ f" {np.round(prompter.action_to_coord(center_mean, im, arm_coord).xy, decimals=0)}"
151
+ ),
152
+ )
153
+ return [], "Unable to understand query"
vip_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for visual iterative prompting.
2
+
3
+ A number of utility functions for VIP.
4
+ """
5
+
6
+ import re
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import scipy.spatial.distance as distance
11
+
12
+
13
+ def min_dist(coord, coords):
14
+ if not coords:
15
+ return np.inf
16
+ xys = np.asarray([[coord.xy] for coord in coords])
17
+ return np.linalg.norm(xys - np.asarray(coord.xy), axis=-1).min()
18
+
19
+
20
+ def coord_outside_image(coord, image, radius):
21
+ (height, image_width, _) = image.shape
22
+ x, y = coord.xy
23
+ x_outside = x > image_width - 2 * radius or x < 2 * radius
24
+ y_outside = y > height - 2 * radius or y < 2 * radius
25
+ return x_outside or y_outside
26
+
27
+
28
+ def is_invalid_coord(coord, coords, radius, image):
29
+ # invalid if too close to others or outside of the image
30
+ pos_overlaps = min_dist(coord, coords) < 1.5 * radius
31
+ return pos_overlaps or coord_outside_image(coord, image, radius)
32
+
33
+
34
+ def angle_mag_2_x_y(angle, mag, arm_coord, is_circle=False, radius=40):
35
+ x, y = arm_coord
36
+ x += int(np.cos(angle) * mag)
37
+ y += int(np.sin(angle) * mag)
38
+ if is_circle:
39
+ x += int(np.cos(angle) * radius * np.sign(mag))
40
+ y += int(np.sin(angle) * radius * np.sign(mag))
41
+ return x, y
42
+
43
+
44
+ def coord_to_text_coord(coord, arm_coord, radius):
45
+ delta_coord = np.asarray(coord.xy) - arm_coord
46
+ if np.linalg.norm(delta_coord) == 0:
47
+ return arm_coord
48
+ return (
49
+ int(coord.xy[0] + radius * delta_coord[0] / np.linalg.norm(delta_coord)),
50
+ int(coord.xy[1] + radius * delta_coord[1] / np.linalg.norm(delta_coord)),
51
+ )
52
+
53
+
54
+ def parse_response(response, answer_key='Arrow: ['):
55
+ values = []
56
+ if answer_key in response:
57
+ print('parse_response from answer_key')
58
+ arrow_response = response.split(answer_key)[-1].split(']')[0]
59
+ for val in map(int, re.findall(r'\d+', arrow_response)):
60
+ values.append(val)
61
+ else:
62
+ print('parse_response for all ints')
63
+ for val in map(int, re.findall(r'\d+', response)):
64
+ values.append(val)
65
+ return values
66
+
67
+
68
+ def compute_errors(action, true_action, verbose=False):
69
+ """Compute errors between a predicted action and true action."""
70
+ l2_error = np.linalg.norm(action - true_action)
71
+ cos_sim = 1 - distance.cosine(action, true_action)
72
+ l2_xy_error = np.linalg.norm(action[-2:] - true_action[-2:])
73
+ cos_xy_sim = 1 - distance.cosine(action[-2:], true_action[-2:])
74
+ z_error = np.abs(action[0] - true_action[0])
75
+ errors = {
76
+ 'l2': l2_error,
77
+ 'cos_sim': cos_sim,
78
+ 'l2_xy_error': l2_xy_error,
79
+ 'cos_xy_sim': cos_xy_sim,
80
+ 'z_error': z_error,
81
+ }
82
+
83
+ if verbose:
84
+ print('action: \t', [f'{a:.3f}' for a in action])
85
+ print('true_action \t', [f'{a:.3f}' for a in true_action])
86
+ print(f'l2: \t\t{l2_error:.3f}')
87
+ print(f'l2_xy_error: \t{l2_xy_error:.3f}')
88
+ print(f'cos_sim: \t{cos_sim:.3f}')
89
+ print(f'cos_xy_sim: \t{cos_xy_sim:.3f}')
90
+ print(f'z_error: \t{z_error:.3f}')
91
+
92
+ return errors
93
+
94
+
95
+ def plot_errors(all_errors, error_types=None):
96
+ """Plot errors across iterations."""
97
+ if error_types is None:
98
+ error_types = [
99
+ 'l2',
100
+ 'l2_xy_error',
101
+ 'z_error',
102
+ 'cos_sim',
103
+ 'cos_xy_sim',
104
+ ]
105
+
106
+ _, axs = plt.subplots(2, 3, figsize=(15, 8))
107
+ for i, error_type in enumerate(error_types): # go through each error type
108
+ all_iter_errors = {}
109
+ for error_by_iter in all_errors: # go through each call
110
+ for itr in error_by_iter: # go through each iteration
111
+ if itr in all_iter_errors: # add error to the iteration it happened
112
+ all_iter_errors[itr].append(error_by_iter[itr][error_type])
113
+ else:
114
+ all_iter_errors[itr] = [error_by_iter[itr][error_type]]
115
+
116
+ mean_iter_errors = [
117
+ np.mean(all_iter_errors[itr]) for itr in all_iter_errors
118
+ ]
119
+
120
+ axs[i // 3, i % 3].plot(all_iter_errors.keys(), mean_iter_errors)
121
+ axs[i // 3, i % 3].set_title(error_type)
122
+ plt.show()
vlms.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VLM Helper Functions."""
2
+ import base64
3
+ import numpy as np
4
+ from openai import OpenAI
5
+
6
+
7
+ class GPT4V:
8
+ """GPT4V VLM."""
9
+
10
+ def __init__(self, openai_api_key):
11
+ self.client = OpenAI(api_key=openai_api_key)
12
+
13
+ def query(self, prompt_seq, temperature=0, max_tokens=512):
14
+ """Queries GPT-4V."""
15
+ content = []
16
+ for elem in prompt_seq:
17
+ if isinstance(elem, str):
18
+ content.append({'type': 'text', 'text': elem})
19
+ elif isinstance(elem, np.ndarray):
20
+ base64_image_str = base64.b64encode(elem).decode('utf-8')
21
+ image_url = f'data:image/jpeg;base64,{base64_image_str}'
22
+ content.append({'type': 'image_url', 'image_url': {'url': image_url}})
23
+
24
+ messages = [{'role': 'user', 'content': content}]
25
+
26
+ response = self.client.chat.completions.create(
27
+ model='gpt-4-vision-preview',
28
+ messages=messages,
29
+ temperature=temperature,
30
+ max_tokens=max_tokens
31
+ )
32
+
33
+ return response.choices[0].message.content