yuanze1024 commited on
Commit
04ea559
1 Parent(s): 55e7aed

update app

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. app.py +31 -21
Dockerfile CHANGED
@@ -22,7 +22,7 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
22
  # note that you may need to modify the TORCH_CUDA_ARCH_LIST in the setup.py file
23
  ENV TORCH_CUDA_ARCH_LIST="8.6"
24
 
25
- # Install Pointnet2_PyTorch
26
  RUN git clone https://github.com/yuanze1024/Pointnet2_PyTorch.git && cd Pointnet2_PyTorch/pointnet2_ops_lib && pip install .
27
 
28
  COPY --chown=user:user . /code
 
22
  # note that you may need to modify the TORCH_CUDA_ARCH_LIST in the setup.py file
23
  ENV TORCH_CUDA_ARCH_LIST="8.6"
24
 
25
+ # Install Pointnet2_PyTorch, pip install git+ won't work for unknown reason
26
  RUN git clone https://github.com/yuanze1024/Pointnet2_PyTorch.git && cd Pointnet2_PyTorch/pointnet2_ops_lib && pip install .
27
 
28
  COPY --chown=user:user . /code
app.py CHANGED
@@ -89,13 +89,12 @@ def retrieve_3D_models(textual_query, top_k, modality_list):
89
  indices = _retrieve_3D_models(textual_query, top_k, modality_list)
90
  return [get_image_and_id(index) for index in indices]
91
 
92
- def get_sub_dataset(sub_dataset_id):
93
  """
94
  get sub-dataset by sub_dataset_id [1, 1000]
95
 
96
  Returns:
97
  caption: str
98
- difficulty: str
99
  images: list of tuple (PIL.Image, str)
100
  """
101
  rel = relation[sub_dataset_id - 1]
@@ -111,18 +110,23 @@ def get_sub_dataset(sub_dataset_id):
111
  return new_image
112
 
113
  results = []
114
- for gt_id in GT_ids:
115
- image, source_id = get_image_and_id(source_to_id[gt_id])
116
- results.append((handle_image(image, True), source_id))
117
- for neg_id in negative_ids:
118
- image, source_id = get_image_and_id(source_to_id[neg_id])
119
- results.append((handle_image(image, False), source_id))
 
 
 
 
 
120
 
121
- return caption, difficulty, results
122
 
123
- def feel_lucky():
124
  sub_dataset_id = random.randint(1, 1000)
125
- return sub_dataset_id, *get_sub_dataset(sub_dataset_id)
126
 
127
  def launch():
128
  with gr.Blocks() as demo: # https://sketchfab.com/3d-models/fd30f87848c9454c9225eccc39726787
@@ -131,14 +135,17 @@ def launch():
131
  with gr.Tab("Retrieval Visualization"):
132
  with gr.Row():
133
  md2 = gr.Markdown(r"""### Visualization for Text-Based-3D Model Retrieval
134
- We build a visualization demo to demonstrate the text-based-3D model retrievals. Due to the memory limitation of HF Space, we only support the [Uni3D](https://github.com/baaivision/Uni3D) which has shown an excellent performance in our benchmark.
 
 
135
 
136
  **Note**:
137
 
138
- The *Modality List* refers to the features ensembled by the retrieval methods. According to our experiment results, basically the more modalities, the better performance the methods gets.""")
 
 
139
  with gr.Row():
140
- textual_query = gr.Textbox(label="Textual Query", autofocus=True,
141
- placeholder="A chair with a wooden frame and a cushioned seat")
142
  modality_list = gr.CheckboxGroup(label="Modality List", value=[],
143
  choices=["text", "front", "back", "left", "right", "above",
144
  "below", "diag_above", "diag_below", "3D"])
@@ -173,19 +180,22 @@ Here is a visualization of the dataset.
173
 
174
  **Note:**
175
 
176
- The *Query* is used in this sub-dataset. The *Difficulty* is a coarse label for the textual query, which is divided into **easy**, **medium**, and **hard**, basically submit to the rule in our paper.
177
- The color surrounding the 3D model indicates whether it is a good fit for the textual query. A **<span style="color:#00FF00">green</span>** color suggests a Ground Truth, while a **<span style="color:#FF0000">red</span>** color indicates a mismatch.""")
 
178
  with gr.Row():
179
  lucky = gr.Button("I'm Feeling Lucky !", scale=1, variant='primary')
180
- query_id = gr.Number(label="Sub-dataset ID", scale=1, minimum=1, maximum=1000, step=1, interactive=True)
 
181
  query = gr.Textbox(label="Textual Query", scale=3, interactive=False)
182
- difficulty = gr.Textbox(label="Query Difficulty", scale=1, interactive=False)
183
  # model3d = gr.Model3D(interactive=False, scale=1)
184
  with gr.Row():
185
  output2 = gr.Gallery(format="webp", label="3D Models in Sub-dataset", columns=5, type="pil", interactive=False)
186
 
187
- lucky.click(feel_lucky, outputs=[query_id, query, difficulty, output2])
188
- query_id.submit(get_sub_dataset, query_id, [query, difficulty, output2])
 
189
 
190
  demo.queue(max_size=10)
191
  demo.launch(server_name='0.0.0.0')
 
89
  indices = _retrieve_3D_models(textual_query, top_k, modality_list)
90
  return [get_image_and_id(index) for index in indices]
91
 
92
+ def get_sub_dataset(sub_dataset_id, sorted=False):
93
  """
94
  get sub-dataset by sub_dataset_id [1, 1000]
95
 
96
  Returns:
97
  caption: str
 
98
  images: list of tuple (PIL.Image, str)
99
  """
100
  rel = relation[sub_dataset_id - 1]
 
110
  return new_image
111
 
112
  results = []
113
+ if not sorted:
114
+ for ind in target_ids:
115
+ image, source_id = get_image_and_id(source_to_id[ind])
116
+ results.append((handle_image(image, True if ind in GT_ids else False), source_id))
117
+ else:
118
+ for gt_id in GT_ids:
119
+ image, source_id = get_image_and_id(source_to_id[gt_id])
120
+ results.append((handle_image(image, True), source_id))
121
+ for neg_id in negative_ids:
122
+ image, source_id = get_image_and_id(source_to_id[neg_id])
123
+ results.append((handle_image(image, False), source_id))
124
 
125
+ return caption, results
126
 
127
+ def feel_lucky(is_sorted):
128
  sub_dataset_id = random.randint(1, 1000)
129
+ return sub_dataset_id, *get_sub_dataset(sub_dataset_id, is_sorted)
130
 
131
  def launch():
132
  with gr.Blocks() as demo: # https://sketchfab.com/3d-models/fd30f87848c9454c9225eccc39726787
 
135
  with gr.Tab("Retrieval Visualization"):
136
  with gr.Row():
137
  md2 = gr.Markdown(r"""### Visualization for Text-Based-3D Model Retrieval
138
+ We build a visualization demo to demonstrate the text-based-3D model retrievals. Due to the memory limitation of HF Space,
139
+ we only support the [Uni3D](https://github.com/baaivision/Uni3D) which has shown an excellent performance in our benchmark.
140
+ What's more, **we only search in a subset of Objaverse, which contains 89K 3D models**.
141
 
142
  **Note**:
143
 
144
+ The *Modality List* refers to the features ensembled by the retrieval methods. According to our experiment results, basically the more modalities, the better performance the methods gets.
145
+
146
+ Also, you may want to ckeck the 3D model in a 3D model viewer, in that case, you can visit [Objaverse](https://objaverse.allenai.org/explore) for exploration.""")
147
  with gr.Row():
148
+ textual_query = gr.Textbox(label="Textual Query", autofocus=True, value="Super Mario")
 
149
  modality_list = gr.CheckboxGroup(label="Modality List", value=[],
150
  choices=["text", "front", "back", "left", "right", "above",
151
  "below", "diag_above", "diag_below", "3D"])
 
180
 
181
  **Note:**
182
 
183
+ The *Query* is used in this sub-dataset. The *Sorted* will put the Ground Truths in the front of the results.
184
+ The color surrounding the 3D model indicates whether it is a good fit for the textual query.
185
+ A **<span style="color:#00FF00">green</span>** color suggests a Ground Truth, while a **<span style="color:#FF0000">red</span>** color indicates a mismatch.""")
186
  with gr.Row():
187
  lucky = gr.Button("I'm Feeling Lucky !", scale=1, variant='primary')
188
+ query_id = gr.Number(label="Sub-dataset ID", scale=1, minimum=1, maximum=1000, step=1, interactive=True, value=986)
189
+ is_sorted = gr.Checkbox(value=False, label="", scale=1, info="Sorted")
190
  query = gr.Textbox(label="Textual Query", scale=3, interactive=False)
191
+ # difficulty = gr.Textbox(label="Query Difficulty", scale=1, interactive=False)
192
  # model3d = gr.Model3D(interactive=False, scale=1)
193
  with gr.Row():
194
  output2 = gr.Gallery(format="webp", label="3D Models in Sub-dataset", columns=5, type="pil", interactive=False)
195
 
196
+ lucky.click(feel_lucky, inputs=is_sorted, outputs=[query_id, query, output2])
197
+ query_id.submit(get_sub_dataset, [query_id, is_sorted], [query, output2])
198
+ is_sorted.change(get_sub_dataset, [query_id, is_sorted], [query, output2])
199
 
200
  demo.queue(max_size=10)
201
  demo.launch(server_name='0.0.0.0')