eliphatfs commited on
Commit
a886fec
Β·
1 Parent(s): ff2e0a9

Merged super demo.

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +169 -10
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: OpenShape Classification Demo
3
  emoji: πŸŒ–
4
  colorFrom: red
5
  colorTo: purple
 
1
  ---
2
+ title: OpenShape Demo
3
  emoji: πŸŒ–
4
  colorFrom: red
5
  colorTo: purple
app.py CHANGED
@@ -1,30 +1,67 @@
 
1
  import streamlit as st
2
  from huggingface_hub import HfFolder, snapshot_download
3
- HfFolder().save_token(st.secrets['etoken'])
4
- snapshot_download("OpenShape/openshape-demo-support", local_dir='.')
 
 
 
 
 
 
 
5
 
6
 
7
  import numpy
 
8
  import openshape
9
- from openshape.demo import misc_utils, classification
10
-
11
 
12
  @st.cache_resource
13
  def load_openshape(name):
14
  return openshape.load_pc_encoder(name)
15
 
16
 
 
 
 
 
 
 
 
 
 
17
  f32 = numpy.float32
18
- # clip_model, clip_prep = load_openclip()
19
- model_g14 = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb')
 
 
 
 
 
20
 
 
21
 
22
  st.title("OpenShape Demo")
23
- load_data = misc_utils.input_3d_shape()
24
  prog = st.progress(0.0, "Idle")
 
 
 
 
 
 
 
 
25
 
26
 
27
- try:
 
 
 
 
 
 
28
  if st.button("Run Classification on LVIS Categories"):
29
  pc = load_data(prog)
30
  col2 = misc_utils.render_pc(pc)
@@ -35,5 +72,127 @@ try:
35
  st.text(cat)
36
  st.caption("Similarity %.4f" % sim)
37
  prog.progress(1.0, "Idle")
38
- except Exception as exc:
39
- st.error(repr(exc))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
  import streamlit as st
3
  from huggingface_hub import HfFolder, snapshot_download
4
+
5
+
6
+ @st.cache_data
7
+ def load_support():
8
+ HfFolder().save_token(st.secrets['etoken'])
9
+ sys.path.append(snapshot_download("OpenShape/openshape-demo-support"))
10
+
11
+
12
+ # load_support()
13
 
14
 
15
  import numpy
16
+ import torch
17
  import openshape
18
+ import transformers
19
+ from PIL import Image
20
 
21
  @st.cache_resource
22
  def load_openshape(name):
23
  return openshape.load_pc_encoder(name)
24
 
25
 
26
+ @st.cache_resource
27
+ def load_openclip():
28
+ return transformers.CLIPModel.from_pretrained(
29
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
30
+ low_cpu_mem_usage=True, torch_dtype=half,
31
+ offload_state_dict=True
32
+ ), transformers.CLIPProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
33
+
34
+
35
  f32 = numpy.float32
36
+ half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
37
+ # clip_model, clip_prep = None, None
38
+ clip_model, clip_prep = load_openclip()
39
+ model_b32 = load_openshape('openshape-pointbert-vitb32-rgb').cpu()
40
+ model_l14 = load_openshape('openshape-pointbert-vitl14-rgb')
41
+ model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
42
+ torch.set_grad_enabled(False)
43
 
44
+ from openshape.demo import misc_utils, classification, caption, sd_pc2img, retrieval
45
 
46
  st.title("OpenShape Demo")
 
47
  prog = st.progress(0.0, "Idle")
48
+ tab_cls, tab_text, tab_img, tab_pc, tab_sd, tab_cap = st.tabs([
49
+ "Classification",
50
+ "Retrieval from Text",
51
+ "Retrieval from Image",
52
+ "Retrieval from 3D Shape",
53
+ "Image Generation",
54
+ "Captioning",
55
+ ])
56
 
57
 
58
+ def demo_classification():
59
+ load_data = misc_utils.input_3d_shape('cls')
60
+ cats = st.text_input("Custom Categories (64 max, separated with comma)")
61
+ cats = [a.strip() for a in cats.split(',')]
62
+ if len(cats) > 64:
63
+ st.error('Maximum 64 custom categories supported in the demo')
64
+ return
65
  if st.button("Run Classification on LVIS Categories"):
66
  pc = load_data(prog)
67
  col2 = misc_utils.render_pc(pc)
 
72
  st.text(cat)
73
  st.caption("Similarity %.4f" % sim)
74
  prog.progress(1.0, "Idle")
75
+ if st.button("Run Classification on Custom Categories"):
76
+ pc = load_data(prog)
77
+ col2 = misc_utils.render_pc(pc)
78
+ prog.progress(0.5, "Computing Category Embeddings")
79
+ device = clip_model.device
80
+ tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76).to(device)
81
+ feats = clip_model.get_text_features(**tn).float().cpu()
82
+ prog.progress(0.5, "Running Classification")
83
+ pred = classification.pred_custom_sims(model_g14, pc, cats, feats)
84
+ with col2:
85
+ for i, (cat, sim) in zip(range(5), pred.items()):
86
+ st.text(cat)
87
+ st.caption("Similarity %.4f" % sim)
88
+ prog.progress(1.0, "Idle")
89
+
90
+
91
+ def demo_captioning():
92
+ load_data = misc_utils.input_3d_shape('cap')
93
+ cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0)
94
+ if st.button("Generate a Caption"):
95
+ pc = load_data(prog)
96
+ col2 = misc_utils.render_pc(pc)
97
+ prog.progress(0.5, "Running Generation")
98
+ cap = caption.pc_caption(model_b32, pc, cond_scale)
99
+ st.text(cap)
100
+ prog.progress(1.0, "Idle")
101
+
102
+
103
+ def demo_pc2img():
104
+ load_data = misc_utils.input_3d_shape('sd')
105
+ prompt = st.text_input("Prompt (Optional)")
106
+ noise_scale = st.slider('Variation Level', 0, 5, 1)
107
+ cfg_scale = st.slider('Guidance Scale', 0.0, 30.0, 10.0)
108
+ steps = st.slider('Diffusion Steps', 8, 50, 25)
109
+ width = 640 # st.slider('Width', 480, 640, step=32)
110
+ height = 640 # st.slider('Height', 480, 640, step=32)
111
+ if st.button("Generate"):
112
+ pc = load_data(prog)
113
+ col2 = misc_utils.render_pc(pc)
114
+ prog.progress(0.49, "Running Generation")
115
+ if torch.cuda.is_available():
116
+ clip_model.cpu()
117
+ img = sd_pc2img.pc_to_image(
118
+ model_l14, pc, prompt, noise_scale, width, height, cfg_scale, steps,
119
+ lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i)
120
+ )
121
+ if torch.cuda.is_available():
122
+ clip_model.cuda()
123
+ with col2:
124
+ st.image(img)
125
+ prog.progress(1.0, "Idle")
126
+
127
+
128
+ def retrieval_results(results):
129
+ for i in range(len(results) // 4):
130
+ cols = st.columns(4)
131
+ for j in range(4):
132
+ idx = i * 4 + j
133
+ if idx >= len(results):
134
+ continue
135
+ entry = results[idx]
136
+ with cols[j]:
137
+ ext_link = f"https://objaverse.allenai.org/explore/?query={entry['u']}"
138
+ st.image(entry['img'])
139
+ # st.markdown(f"[![thumbnail {entry['desc'].replace('\n', ' ')}]({entry['img']})]({ext_link})")
140
+ # st.text(entry['name'])
141
+ quote_name = entry['name'].replace('[', '\\[').replace(']', '\\]').replace('\n', ' ')
142
+ st.markdown(f"[{quote_name}]({ext_link})")
143
+
144
+
145
+ def demo_retrieval():
146
+ with tab_text:
147
+ k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rtext')
148
+ text = st.text_input("Input Text")
149
+ if st.button("Run with Text"):
150
+ prog.progress(0.49, "Computing Embeddings")
151
+ device = clip_model.device
152
+ tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
153
+ enc = clip_model.get_text_features(**tn).float().cpu()
154
+ prog.progress(0.7, "Running Retrieval")
155
+ retrieval_results(retrieval.retrieve(enc, k))
156
+ prog.progress(1.0, "Idle")
157
+
158
+ with tab_img:
159
+ k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rimage')
160
+ pic = st.file_uploader("Upload an Image")
161
+ if st.button("Run with Image"):
162
+ img = Image.open(pic)
163
+ st.image(img)
164
+ prog.progress(0.49, "Computing Embeddings")
165
+ device = clip_model.device
166
+ tn = clip_prep(images=[img], return_tensors="pt").to(device)
167
+ enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
168
+ prog.progress(0.7, "Running Retrieval")
169
+ retrieval_results(retrieval.retrieve(enc, k))
170
+ prog.progress(1.0, "Idle")
171
+
172
+ with tab_pc:
173
+ k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rpc')
174
+ load_data = misc_utils.input_3d_shape('retpc')
175
+ if st.button("Run with Shape"):
176
+ pc = load_data(prog)
177
+ col2 = misc_utils.render_pc(pc)
178
+ prog.progress(0.49, "Computing Embeddings")
179
+ ref_dev = next(model_g14.parameters()).device
180
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
181
+ prog.progress(0.7, "Running Retrieval")
182
+ retrieval_results(retrieval.retrieve(enc, k))
183
+ prog.progress(1.0, "Idle")
184
+
185
+
186
+ try:
187
+ if torch.cuda.is_available():
188
+ clip_model.cuda()
189
+ with tab_cls:
190
+ demo_classification()
191
+ with tab_cap:
192
+ demo_captioning()
193
+ with tab_sd:
194
+ demo_pc2img()
195
+ demo_retrieval()
196
+ except Exception:
197
+ import traceback
198
+ st.error(traceback.format_exc().replace("\n", " \n"))