adityas commited on
Commit
9798f42
β€’
1 Parent(s): c883b62

add demo app code

Browse files
Files changed (4) hide show
  1. README.md +9 -5
  2. app.py +288 -0
  3. system_template.txt +32 -0
  4. user_template.txt +2 -0
README.md CHANGED
@@ -1,13 +1,17 @@
1
  ---
2
  title: OCTO
3
- emoji: πŸ“š
4
- colorFrom: green
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: OCTO
3
+ emoji: πŸ™
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.14.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
+ # OCTO+: A Suite for Automatic Open-Vocabulary Object Placement in Mixed Reality
14
+
15
+ [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/octo-pearl/octo-pearl/blob/main/demo.ipynb) [![Homepage](https://img.shields.io/badge/🌐-Homepage-blue)](https://octo-pearl.github.io/) [![arXiv](https://img.shields.io/badge/πŸ“–-arXiv-b31b1b)](https://octo-pearl.github.io/)
16
+
17
+ This repo contains the code and data for the paper "[OCTO+: A Suite for Automatic Open-Vocabulary Object Placement in Mixed Reality](https://octo-pearl.github.io/)".
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ if not os.path.isdir("weights"):
4
+ os.mkdir("weights")
5
+
6
+ os.system("python -m pip install --upgrade pip")
7
+ os.system(
8
+ "wget https://raw.githubusercontent.com/asharma381/cs291I/main/backend/original_images/000749.png"
9
+ )
10
+ os.system(
11
+ "wget -q -O weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
12
+ )
13
+ os.system(
14
+ "wget -q -O weights/ram_plus_swin_large_14m.pth https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"
15
+ )
16
+ os.system(
17
+ "wget -q -O weights/groundingdino_swint_ogc.pth https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
18
+ )
19
+ os.system("pip install git+https://github.com/xinyu1205/recognize-anything.git")
20
+ os.system("pip install git+https://github.com/IDEA-Research/GroundingDINO.git")
21
+ os.system("pip install git+https://github.com/facebookresearch/segment-anything.git")
22
+ os.system("pip install openai==0.27.4")
23
+ os.system("pip install tenacity")
24
+
25
+
26
+ from typing import List, Tuple
27
+
28
+ import cv2
29
+ import gradio as gr
30
+ import groundingdino.config.GroundingDINO_SwinT_OGC
31
+ import numpy as np
32
+ import openai
33
+ import torch
34
+ from groundingdino.util.inference import Model
35
+ from PIL import Image, ImageDraw
36
+ from ram import get_transform
37
+ from ram import inference_ram as inference
38
+ from ram.models import ram_plus
39
+ from scipy.spatial.distance import cdist
40
+ from segment_anything import SamPredictor, sam_model_registry
41
+ from supervision import Detections
42
+ from tenacity import retry, wait_fixed
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ ram_model = None
46
+ ram_threshold_multiplier = 1
47
+ gdino_model = None
48
+ sam_model = None
49
+ sam_predictor = None
50
+
51
+ print("CUDA Available:", torch.cuda.is_available())
52
+
53
+
54
+ def get_tags_ram(
55
+ image: Image.Image, threshold_multiplier=0.8, weights_folder="weights"
56
+ ) -> List[str]:
57
+ global ram_model, ram_threshold_multiplier
58
+ if ram_model is None:
59
+ print("Loading RAM++ Model...")
60
+ ram_model = ram_plus(
61
+ pretrained=f"{weights_folder}/ram_plus_swin_large_14m.pth",
62
+ vit="swin_l",
63
+ image_size=384,
64
+ )
65
+ ram_model.eval()
66
+ ram_model = ram_model.to(device)
67
+
68
+ ram_model.class_threshold *= threshold_multiplier / ram_threshold_multiplier
69
+ ram_threshold_multiplier = threshold_multiplier
70
+ transform = get_transform()
71
+
72
+ image = transform(image).unsqueeze(0).to(device)
73
+ res = inference(image, ram_model)
74
+ return [s.strip() for s in res[0].split("|")]
75
+
76
+
77
+ def get_gdino_result(
78
+ image: Image.Image,
79
+ classes: List[str],
80
+ box_threshold: float = 0.25,
81
+ weights_folder="weights",
82
+ ) -> Tuple[Detections, List[str]]:
83
+ global gdino_model
84
+
85
+ if gdino_model is None:
86
+ print("Loading GroundingDINO Model...")
87
+ config_path = groundingdino.config.GroundingDINO_SwinT_OGC.__file__
88
+ gdino_model = Model(
89
+ model_config_path=config_path,
90
+ model_checkpoint_path=f"{weights_folder}/groundingdino_swint_ogc.pth",
91
+ device=device,
92
+ )
93
+
94
+ detections, phrases = gdino_model.predict_with_caption(
95
+ image=np.array(image),
96
+ caption=", ".join(classes),
97
+ box_threshold=box_threshold,
98
+ text_threshold=0.25,
99
+ )
100
+
101
+ return detections, phrases
102
+
103
+
104
+ def get_sam_model(weights_folder="weights"):
105
+ global sam_model
106
+ if sam_model is None:
107
+ sam_checkpoint = f"{weights_folder}/sam_vit_h_4b8939.pth"
108
+ sam_model = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
109
+ sam_model.to(device=device)
110
+ return sam_model
111
+
112
+
113
+ def filter_tags_gdino(image: Image.Image, tags: List[str]) -> List[str]:
114
+ detections, phrases = get_gdino_result(image, tags)
115
+ filtered_tags = []
116
+ for tag in tags:
117
+ for (
118
+ phrase,
119
+ area,
120
+ ) in zip(phrases, detections.area):
121
+ if area < 0.9 * image.size[0] * image.size[1] and tag in phrase:
122
+ filtered_tags.append(tag)
123
+ break
124
+ return filtered_tags
125
+
126
+
127
+ def read_file_to_string(file_path: str) -> str:
128
+ content = ""
129
+
130
+ try:
131
+ with open(file_path, "r", encoding="utf8") as file:
132
+ content = file.read()
133
+ except FileNotFoundError:
134
+ print(f"The file {file_path} was not found.")
135
+ except Exception as e:
136
+ print(f"An error occurred while reading {file_path}: {e}")
137
+
138
+ return content
139
+
140
+
141
+ @retry(wait=wait_fixed(2))
142
+ def completion_with_backoff(**kwargs):
143
+ return openai.ChatCompletion.create(**kwargs)
144
+
145
+
146
+ def gpt4(
147
+ usr_prompt: str, sys_prompt: str = "", api_key: str = "", model: str = "gpt-4"
148
+ ) -> str:
149
+ openai.api_key = api_key
150
+
151
+ message = [
152
+ {"role": "system", "content": sys_prompt},
153
+ {"role": "user", "content": usr_prompt},
154
+ ]
155
+
156
+ response = completion_with_backoff(
157
+ model=model,
158
+ messages=message,
159
+ temperature=0.2,
160
+ max_tokens=1000,
161
+ frequency_penalty=0.0,
162
+ )
163
+
164
+ return response["choices"][0]["message"]["content"]
165
+
166
+
167
+ def select_best_tag(
168
+ filtered_tags: List[str], object_to_place: str, api_key: str = ""
169
+ ) -> str:
170
+ user_template = read_file_to_string("user_template.txt").format(object=object_to_place)
171
+ user_prompt = user_template + "\n".join(filtered_tags)
172
+ system_prompt = read_file_to_string("system_template.txt")
173
+ return gpt4(user_prompt, system_prompt, api_key=api_key)
174
+
175
+
176
+ def get_location_gsam(
177
+ image: Image.Image, prompt: str, weights_folder="weights"
178
+ ) -> Tuple[int, int]:
179
+ global sam_predictor
180
+
181
+ BOX_TRESHOLD = 0.25
182
+ RESIZE_RATIO = 3
183
+
184
+ detections, phrases = get_gdino_result(
185
+ image=image,
186
+ classes=[prompt],
187
+ box_threshold=BOX_TRESHOLD,
188
+ )
189
+
190
+ while len(detections.xyxy) == 0:
191
+ BOX_TRESHOLD -= 0.02
192
+ detections, phrases = get_gdino_result(
193
+ image=image,
194
+ classes=[prompt],
195
+ box_threshold=BOX_TRESHOLD,
196
+ )
197
+
198
+ sam_model = get_sam_model(weights_folder)
199
+
200
+ if sam_predictor is None:
201
+ print("Loading SAM Model...")
202
+ sam_predictor = SamPredictor(sam_model)
203
+
204
+ sam_predictor.set_image(np.array(image))
205
+ result_masks = []
206
+ for box in detections.xyxy:
207
+ masks, scores, logits = sam_predictor.predict(box=box, multimask_output=True)
208
+ index = np.argmax(scores)
209
+ result_masks.append(masks[index])
210
+ detections.mask = np.array(result_masks)
211
+
212
+ combined_mask = detections.mask[0]
213
+ for mask in detections.mask[1:]:
214
+ combined_mask += mask
215
+ combined_mask[combined_mask > 1] = 1
216
+ mask = cv2.resize(
217
+ combined_mask.astype("uint8"),
218
+ (
219
+ combined_mask.shape[1] // RESIZE_RATIO,
220
+ combined_mask.shape[0] // RESIZE_RATIO,
221
+ ),
222
+ )
223
+
224
+ mask_2_pad = np.pad(mask, pad_width=2, mode="constant", constant_values=0)
225
+ mask_1_pad = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
226
+
227
+ windows = np.lib.stride_tricks.sliding_window_view(mask_2_pad, (3, 3))
228
+ windows_all_zero = (windows == 0).all(axis=(2, 3))
229
+
230
+ result = np.where(windows_all_zero, 2, mask_1_pad)
231
+ mask_0_coordinates = np.argwhere(result == 0)
232
+ mask_1_coordinates = np.argwhere(result == 1)
233
+ distances = cdist(mask_1_coordinates, mask_0_coordinates, "euclidean")
234
+ max_min_distance_index = np.argmax(np.min(distances, axis=1))
235
+ y, x = mask_1_coordinates[max_min_distance_index]
236
+
237
+ return int(x) * RESIZE_RATIO, int(y) * RESIZE_RATIO
238
+
239
+
240
+ def run_octo_pipeline(input_image, object, api_key):
241
+ print("Inside run_octo_pipeline with input_image=", input_image, "object=", object)
242
+
243
+ print("Loading Image...")
244
+ image = input_image.convert("RGB")
245
+
246
+ print("Stage 1...")
247
+ tags = get_tags_ram(image, threshold_multiplier=0.8)
248
+ print("RAM++ Tags", tags)
249
+ filtered_tags = filter_tags_gdino(image, tags)
250
+ print("Filtered Tags", filtered_tags)
251
+
252
+ print("Stage 2...")
253
+ selected_tag = select_best_tag(filtered_tags, object, api_key=api_key)
254
+ print("GPT-4 Selected Tag", selected_tag)
255
+
256
+ print("Stage 3...")
257
+ x, y = get_location_gsam(image, selected_tag)
258
+ print("G-SAM Location", "(" + str(x) + "," + str(y) + ")")
259
+
260
+ draw = ImageDraw.Draw(image)
261
+ radius = 10
262
+ bbox = (x - radius, y - radius, x + radius, y + radius)
263
+ draw.ellipse(bbox, fill="red")
264
+ return [image]
265
+
266
+
267
+ block = gr.Blocks()
268
+
269
+ with block:
270
+ with gr.Row():
271
+ with gr.Column():
272
+ input_image = gr.Image(type="pil", value="000749.png")
273
+ object = gr.Textbox(label="Object", placeholder="Enter an object")
274
+ api_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter OpenAI API Key")
275
+
276
+ with gr.Column():
277
+ gallery = gr.Gallery(
278
+ label="Output",
279
+ show_label=False,
280
+ elem_id="gallery",
281
+ preview=True,
282
+ object_fit="scale-down",
283
+ )
284
+
285
+ iface = gr.Interface(
286
+ fn=run_octo_pipeline, inputs=[input_image, object, api_key], outputs=gallery
287
+ )
288
+ iface.launch()
system_template.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an expert in determining where objects should be placed in a scene. You will be given a list of objects in a scene, and the name of a new object to be placed in the scene. Your task is to select the most natural location for the new object to be placed out of the options provided. Write one of your answers and write it exactly character for character as it appears in the list of possible answers. Provide a one-word response. Here are some examples.
2
+ Question: Where would be the most natural location for a banana to be placed?
3
+ Possible Answers:
4
+ floor
5
+ table
6
+ computer
7
+ sink
8
+ couch
9
+
10
+ table
11
+
12
+ Question: Where would be the most natural location for a marker to be placed?
13
+ Possible Answers:
14
+ bed
15
+ counter
16
+ computer
17
+ sink
18
+ desk
19
+ couch
20
+
21
+ desk
22
+
23
+ Question: Where would be the most natural location for a suitcase to be placed?
24
+ Possible Answers:
25
+ desk
26
+ floor
27
+ table
28
+ sink
29
+ computer
30
+ couch
31
+
32
+ floor
user_template.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Question: Where would be the most natural location for a {object} to be placed?
2
+ Possible Answers: