eliphatfs commited on
Commit
31070ee
1 Parent(s): 3fbe09c
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -42,7 +42,12 @@ model_b32 = load_openshape('openshape-pointbert-vitb32-rgb').cpu()
42
  model_l14 = load_openshape('openshape-pointbert-vitl14-rgb')
43
  model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
44
  torch.set_grad_enabled(False)
 
 
 
45
 
 
 
46
  from openshape.demo import misc_utils, classification, caption, sd_pc2img, retrieval
47
 
48
 
@@ -59,6 +64,67 @@ tab_cls, tab_img, tab_text, tab_pc, tab_sd, tab_cap = st.tabs([
59
  ])
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def demo_classification():
63
  load_data = misc_utils.input_3d_shape('cls')
64
  cats = st.text_input("Custom Categories (64 max, separated with comma)")
@@ -68,7 +134,7 @@ def demo_classification():
68
  return
69
  lvis_run = st.button("Run Classification on LVIS Categories")
70
  custom_run = st.button("Run Classification on Custom Categories")
71
- if lvis_run:
72
  pc = load_data(prog)
73
  col2 = misc_utils.render_pc(pc)
74
  prog.progress(0.5, "Running Classification")
@@ -92,31 +158,35 @@ def demo_classification():
92
  st.text(cat)
93
  st.caption("Similarity %.4f" % sim)
94
  prog.progress(1.0, "Idle")
 
 
95
 
96
 
97
  def demo_captioning():
98
  with st.form("capform"):
99
  load_data = misc_utils.input_3d_shape('cap')
100
- cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0)
101
- if st.form_submit_button("Generate a Caption"):
102
  pc = load_data(prog)
103
  col2 = misc_utils.render_pc(pc)
104
  prog.progress(0.5, "Running Generation")
105
  cap = caption.pc_caption(model_b32, pc, cond_scale)
106
  st.text(cap)
107
  prog.progress(1.0, "Idle")
 
 
108
 
109
 
110
  def demo_pc2img():
111
  with st.form("sdform"):
112
  load_data = misc_utils.input_3d_shape('sd')
113
- prompt = st.text_input("Prompt (Optional)")
114
  noise_scale = st.slider('Variation Level', 0, 5, 1)
115
  cfg_scale = st.slider('Guidance Scale', 0.0, 30.0, 10.0)
116
  steps = st.slider('Diffusion Steps', 8, 50, 25)
117
  width = 640 # st.slider('Width', 480, 640, step=32)
118
  height = 640 # st.slider('Height', 480, 640, step=32)
119
- if st.form_submit_button("Generate"):
120
  pc = load_data(prog)
121
  col2 = misc_utils.render_pc(pc)
122
  prog.progress(0.49, "Running Generation")
@@ -131,6 +201,8 @@ def demo_pc2img():
131
  with col2:
132
  st.image(img)
133
  prog.progress(1.0, "Idle")
 
 
134
 
135
 
136
  def retrieval_results(results):
@@ -155,35 +227,44 @@ def demo_retrieval():
155
  with st.form("rtextform"):
156
  k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rtext')
157
  text = st.text_input("Input Text")
 
158
  if st.form_submit_button("Run with Text"):
159
  prog.progress(0.49, "Computing Embeddings")
160
  device = clip_model.device
161
- tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
 
 
162
  enc = clip_model.get_text_features(**tn).float().cpu()
163
  prog.progress(0.7, "Running Retrieval")
164
  retrieval_results(retrieval.retrieve(enc, k))
165
  prog.progress(1.0, "Idle")
166
 
167
  with tab_img:
 
168
  with st.form("rimgform"):
169
  k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rimage')
170
- pic = st.file_uploader("Upload an Image")
171
  if st.form_submit_button("Run with Image"):
172
- img = Image.open(pic)
173
- st.image(img)
174
- prog.progress(0.49, "Computing Embeddings")
175
- device = clip_model.device
176
- tn = clip_prep(images=[img], return_tensors="pt").to(device)
177
- enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
178
- prog.progress(0.7, "Running Retrieval")
179
- retrieval_results(retrieval.retrieve(enc, k))
180
- prog.progress(1.0, "Idle")
 
 
 
 
 
181
 
182
  with tab_pc:
183
  with st.form("rpcform"):
184
  k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rpc')
185
  load_data = misc_utils.input_3d_shape('retpc')
186
- if st.form_submit_button("Run with Shape"):
187
  pc = load_data(prog)
188
  col2 = misc_utils.render_pc(pc)
189
  prog.progress(0.49, "Computing Embeddings")
@@ -192,6 +273,8 @@ def demo_retrieval():
192
  prog.progress(0.7, "Running Retrieval")
193
  retrieval_results(retrieval.retrieve(enc, k))
194
  prog.progress(1.0, "Idle")
 
 
195
 
196
 
197
  try:
 
42
  model_l14 = load_openshape('openshape-pointbert-vitl14-rgb')
43
  model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
44
  torch.set_grad_enabled(False)
45
+ for kc, vc in st.session_state.get('state_queue', []):
46
+ st.session_state[kc] = vc
47
+ st.session_state.state_queue = []
48
 
49
+
50
+ import samples_index
51
  from openshape.demo import misc_utils, classification, caption, sd_pc2img, retrieval
52
 
53
 
 
64
  ])
65
 
66
 
67
+ def sq(kc, vc):
68
+ st.session_state.state_queue.append((kc, vc))
69
+
70
+
71
+ def reset_3d_shape_input(key):
72
+ objaid_key = key + "_objaid"
73
+ model_key = key + "_model"
74
+ npy_key = key + "_npy"
75
+ swap_key = key + "_swap"
76
+ sq(objaid_key, "")
77
+ sq(model_key, None)
78
+ sq(npy_key, None)
79
+ sq(swap_key, "Y is up (for most Objaverse shapes)")
80
+
81
+
82
+ def auto_submit(key):
83
+ if st.session_state.get(key):
84
+ st.session_state[key] = False
85
+ return True
86
+ return False
87
+
88
+
89
+ def queue_auto_submit(key):
90
+ st.session_state[key] = True
91
+ st.experimental_rerun()
92
+
93
+
94
+ img_example_counter = 0
95
+
96
+
97
+ def image_examples(samples, ncols, return_key=None):
98
+ global img_example_counter
99
+ trigger = False
100
+ with st.expander("Examples", True):
101
+ for i in range(len(samples) // ncols):
102
+ cols = st.columns(ncols)
103
+ for j in range(ncols):
104
+ idx = i * ncols + j
105
+ if idx >= len(samples):
106
+ continue
107
+ entry = samples[idx]
108
+ with cols[j]:
109
+ st.image(entry['dispi'])
110
+ img_example_counter += 1
111
+ with st.columns(5)[2]:
112
+ this_trigger = st.button('\+', key='imgexuse%d' % img_example_counter)
113
+ trigger = trigger or this_trigger
114
+ if this_trigger:
115
+ if return_key is None:
116
+ for k, v in entry.items():
117
+ if not k.startswith('disp'):
118
+ sq(k, v)
119
+ else:
120
+ trigger = entry[return_key]
121
+ return trigger
122
+
123
+
124
+ def text_examples(samples):
125
+ return st.selectbox("Or pick an example", samples)
126
+
127
+
128
  def demo_classification():
129
  load_data = misc_utils.input_3d_shape('cls')
130
  cats = st.text_input("Custom Categories (64 max, separated with comma)")
 
134
  return
135
  lvis_run = st.button("Run Classification on LVIS Categories")
136
  custom_run = st.button("Run Classification on Custom Categories")
137
+ if lvis_run or auto_submit("clsauto"):
138
  pc = load_data(prog)
139
  col2 = misc_utils.render_pc(pc)
140
  prog.progress(0.5, "Running Classification")
 
158
  st.text(cat)
159
  st.caption("Similarity %.4f" % sim)
160
  prog.progress(1.0, "Idle")
161
+ if image_examples(samples_index.classification, 3):
162
+ queue_auto_submit("clsauto")
163
 
164
 
165
  def demo_captioning():
166
  with st.form("capform"):
167
  load_data = misc_utils.input_3d_shape('cap')
168
+ cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='capcondscl')
169
+ if st.form_submit_button("Generate a Caption") or auto_submit("capauto"):
170
  pc = load_data(prog)
171
  col2 = misc_utils.render_pc(pc)
172
  prog.progress(0.5, "Running Generation")
173
  cap = caption.pc_caption(model_b32, pc, cond_scale)
174
  st.text(cap)
175
  prog.progress(1.0, "Idle")
176
+ if image_examples(samples_index.cap, 3):
177
+ queue_auto_submit("capauto")
178
 
179
 
180
  def demo_pc2img():
181
  with st.form("sdform"):
182
  load_data = misc_utils.input_3d_shape('sd')
183
+ prompt = st.text_input("Prompt (Optional)", key='sdtprompt')
184
  noise_scale = st.slider('Variation Level', 0, 5, 1)
185
  cfg_scale = st.slider('Guidance Scale', 0.0, 30.0, 10.0)
186
  steps = st.slider('Diffusion Steps', 8, 50, 25)
187
  width = 640 # st.slider('Width', 480, 640, step=32)
188
  height = 640 # st.slider('Height', 480, 640, step=32)
189
+ if st.form_submit_button("Generate") or auto_submit("sdauto"):
190
  pc = load_data(prog)
191
  col2 = misc_utils.render_pc(pc)
192
  prog.progress(0.49, "Running Generation")
 
201
  with col2:
202
  st.image(img)
203
  prog.progress(1.0, "Idle")
204
+ if image_examples(samples_index.sd, 3):
205
+ queue_auto_submit("sdauto")
206
 
207
 
208
  def retrieval_results(results):
 
227
  with st.form("rtextform"):
228
  k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rtext')
229
  text = st.text_input("Input Text")
230
+ picked_sample = text_examples(samples_index.retrieval_texts)
231
  if st.form_submit_button("Run with Text"):
232
  prog.progress(0.49, "Computing Embeddings")
233
  device = clip_model.device
234
+ tn = clip_prep(
235
+ text=[text or picked_sample], return_tensors='pt', truncation=True, max_length=76
236
+ ).to(device)
237
  enc = clip_model.get_text_features(**tn).float().cpu()
238
  prog.progress(0.7, "Running Retrieval")
239
  retrieval_results(retrieval.retrieve(enc, k))
240
  prog.progress(1.0, "Idle")
241
 
242
  with tab_img:
243
+ submit = False
244
  with st.form("rimgform"):
245
  k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rimage')
246
+ pic = st.file_uploader("Upload an Image", key='rimageinput')
247
  if st.form_submit_button("Run with Image"):
248
+ submit = True
249
+ sample_got = image_examples(samples_index.iret, 4, 'rimageinput')
250
+ if sample_got:
251
+ pic = sample_got
252
+ if sample_got or submit:
253
+ img = Image.open(pic)
254
+ st.image(img)
255
+ prog.progress(0.49, "Computing Embeddings")
256
+ device = clip_model.device
257
+ tn = clip_prep(images=[img], return_tensors="pt").to(device)
258
+ enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
259
+ prog.progress(0.7, "Running Retrieval")
260
+ retrieval_results(retrieval.retrieve(enc, k))
261
+ prog.progress(1.0, "Idle")
262
 
263
  with tab_pc:
264
  with st.form("rpcform"):
265
  k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rpc')
266
  load_data = misc_utils.input_3d_shape('retpc')
267
+ if st.form_submit_button("Run with Shape") or auto_submit('rpcauto'):
268
  pc = load_data(prog)
269
  col2 = misc_utils.render_pc(pc)
270
  prog.progress(0.49, "Computing Embeddings")
 
273
  prog.progress(0.7, "Running Retrieval")
274
  retrieval_results(retrieval.retrieve(enc, k))
275
  prog.progress(1.0, "Idle")
276
+ if image_examples(samples_index.pret, 3):
277
+ queue_auto_submit("rpcauto")
278
 
279
 
280
  try:
samples/retrieval-img/img4.jpg DELETED
Binary file (48.3 kB)
 
samples/retrieval-img/img6.jpg CHANGED
samples/retrieval-img/img7.jpg CHANGED
samples/retrieval-img/img8.jpg CHANGED
samples/retrieval-img/img9.jpg ADDED
samples/retrieval-text.txt DELETED
@@ -1,16 +0,0 @@
1
- shark
2
- swordfish
3
- dolphin
4
- goldfish
5
- high heels
6
- boots
7
- slippers
8
- sneakers
9
- tiki mug
10
- viking mug
11
- animal-shaped mug
12
- travel mug
13
- white conical mug
14
- green cubic mug
15
- blue spherical mug
16
- orange cylinder mug
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
samples/sd-text.txt DELETED
@@ -1,3 +0,0 @@
1
- b8db8dc5caad4fa5842a9ed6dbd2e9d6,falcon
2
- ff2875fb1a5b4771805a5fd35c8fe7bb,in the woods
3
- tpvzmLUXAURQ7ZxccJIBZvcIDlr,above the fields
 
 
 
 
samples_index.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ cap_base = 'samples/caption'
5
+ cap = [
6
+ dict(cap_objaid=os.path.splitext(x)[0], dispi=os.path.join(cap_base, x))
7
+ for x in sorted(os.listdir(cap_base))
8
+ ]
9
+
10
+ cls_base = 'samples/classification'
11
+ classification = [
12
+ dict(cls_objaid=os.path.splitext(x)[0], dispi=os.path.join(cls_base, x))
13
+ for x in sorted(os.listdir(cls_base))
14
+ ]
15
+
16
+ sd_base = 'samples/sd'
17
+ sd_texts = {
18
+ 'b8db8dc5caad4fa5842a9ed6dbd2e9d6': 'falcon',
19
+ 'ff2875fb1a5b4771805a5fd35c8fe7bb': 'in the woods',
20
+ 'tpvzmLUXAURQ7ZxccJIBZvcIDlr': 'above the fields'
21
+ }
22
+ sd = [
23
+ dict(
24
+ sd_objaid=os.path.splitext(x)[0],
25
+ dispi=os.path.join(sd_base, x),
26
+ sdtprompt=sd_texts.get(os.path.splitext(x)[0], '')
27
+ )
28
+ for x in sorted(os.listdir(sd_base))
29
+ ]
30
+
31
+ retrieval_texts = """
32
+ shark
33
+ swordfish
34
+ dolphin
35
+ goldfish
36
+ high heels
37
+ boots
38
+ slippers
39
+ sneakers
40
+ tiki mug
41
+ viking mug
42
+ animal-shaped mug
43
+ travel mug
44
+ white conical mug
45
+ green cubic mug
46
+ blue spherical mug
47
+ orange cylinder mug
48
+ """.splitlines()
49
+ retrieval_texts = [x.strip() for x in retrieval_texts if x.strip()]
50
+
51
+ pret_base = 'samples/retrieval-pc'
52
+ pret = [
53
+ dict(retpc_objaid=os.path.splitext(x)[0], dispi=os.path.join(pret_base, x))
54
+ for x in sorted(os.listdir(pret_base))
55
+ ]
56
+
57
+ iret_base = 'samples/retrieval-img'
58
+ iret = [
59
+ dict(rimageinput=os.path.join(iret_base, x), dispi=os.path.join(iret_base, x))
60
+ for x in sorted(os.listdir(iret_base))
61
+ ]