mpsk commited on
Commit
2cff12a
·
1 Parent(s): 511c1dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +322 -241
app.py CHANGED
@@ -8,18 +8,18 @@ import torch
8
  import logging
9
  from os import environ
10
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
11
-
12
  from myscaledb import Client
13
  from classifier import Classifier, prompt2vec, tune, SplitLayer
14
  from query_model import simple_query, topk_obj_query, rev_query
15
  from card_model import card, obj_card, style
16
  from box_utils import postprocess
17
 
18
- environ['TOKENIZERS_PARALLELISM'] = 'true'
19
 
20
  OBJ_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_objects"
21
  IMG_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_images"
22
- MODEL_ID = 'google/owlvit-base-patch32'
23
  DIMS = 512
24
 
25
  qtime = 0
@@ -34,9 +34,9 @@ def build_model(name="google/owlvit-base-patch32"):
34
  Returns:
35
  (model, processor): OwlViT model and its processor for both image and text
36
  """
37
- device = 'cpu'
38
  if torch.cuda.is_available():
39
- device = 'cuda'
40
  model = OwlViTForObjectDetection.from_pretrained(name).to(device)
41
  processor = OwlViTProcessor.from_pretrained(name)
42
  return model, processor
@@ -44,7 +44,7 @@ def build_model(name="google/owlvit-base-patch32"):
44
 
45
  @st.experimental_singleton(show_spinner=False)
46
  def init_owlvit():
47
- """ Initialize OwlViT Model
48
 
49
  Returns:
50
  model, processor
@@ -55,7 +55,7 @@ def init_owlvit():
55
 
56
  @st.experimental_singleton(show_spinner=False)
57
  def init_db():
58
- """ Initialize the Database Connection
59
 
60
  Returns:
61
  meta_field: Meta field that records if an image is viewed or not
@@ -63,15 +63,15 @@ def init_db():
63
  """
64
  meta = []
65
  client = Client(
66
- url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
 
67
  # We can check if the connection is alive
68
  assert client.is_alive()
69
  return meta, client
70
 
71
 
72
  def refresh_index():
73
- """ Clean the session
74
- """
75
  del st.session_state["meta"]
76
  st.session_state.meta = []
77
  st.session_state.query_num = 0
@@ -80,16 +80,16 @@ def refresh_index():
80
  init_db.clear()
81
  # refresh session states
82
  st.session_state.meta, st.session_state.index = init_db()
83
- if 'clf' in st.session_state:
84
  del st.session_state.clf
85
- if 'xq' in st.session_state:
86
  del st.session_state.xq
87
- if 'topk_img_id' in st.session_state:
88
  del st.session_state.topk_img_id
89
 
90
 
91
  def query(xq, exclude_list=None):
92
- """ Query matched w.r.t a given vector
93
 
94
  In this part, we will retrieve A LOT OF data from the server,
95
  including TopK boxes and their embeddings, the counterpart of non-TopK boxes in TopK images.
@@ -98,7 +98,7 @@ def query(xq, exclude_list=None):
98
  xq (numpy.ndarray or list of floats): Query vector
99
 
100
  Returns:
101
- matches: list of Records object. Keys referrring to selected columns group by images.
102
  Exclude the user's viewlist.
103
  img_matches: list of Records object. Containing other non-TopK but hit objects among TopK images.
104
  side_matches: list of Records object. Containing REAL TopK objects disregard the user's view history
@@ -112,27 +112,47 @@ def query(xq, exclude_list=None):
112
  while attempt < 3:
113
  try:
114
  matches = topk_obj_query(
115
- st.session_state.index, xq, IMG_DB_NAME, OBJ_DB_NAME,
116
- exclude_list=exclude_list, topk=5000)
117
- img_ids = [r['img_id'] for r in matches]
118
- if 'topk_img_id' not in st.session_state:
 
 
 
 
 
119
  st.session_state.topk_img_id = img_ids
120
  status_bar[0].write("Retrieving TopK Images...")
121
  pbar.progress(25)
122
  o_matches = rev_query(
123
- st.session_state.index, xq, st.session_state.topk_img_id,
124
- IMG_DB_NAME, OBJ_DB_NAME, thresh=0.1)
 
 
 
 
 
125
  status_bar[0].write("Retrieving TopKs Objects...")
126
  pbar.progress(50)
127
- side_matches = simple_query(st.session_state.index, xq, IMG_DB_NAME, OBJ_DB_NAME,
128
- thresh=-1, topk=5000)
129
- status_bar[0].write(
130
- "Retrieving Non-TopK in Another TopK Images...")
 
 
 
 
 
131
  pbar.progress(75)
132
  if len(img_ids) > 0:
133
  img_matches = rev_query(
134
- st.session_state.index, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME,
135
- thresh=0.1)
 
 
 
 
 
136
  else:
137
  img_matches = []
138
  status_bar[0].write("DONE!")
@@ -163,22 +183,31 @@ def init_random_query():
163
 
164
 
165
  def submit(meta):
166
- """ Tune the model w.r.t given score from user.
167
- """
168
  # Only updating the meta if the train button is pressed
169
  st.session_state.meta.extend(meta)
170
  st.session_state.step += 1
171
  matches = st.session_state.matched_boxes
172
- X, y = list(zip(*((v[-1],
173
- st.session_state.text_prompts.index(
174
- st.session_state[f"label-{i}"])) for i, v in matches.items())))
175
- st.session_state.xq = tune(st.session_state.clf,
176
- X, y, iters=int(st.session_state.iters))
177
- st.session_state.matches, \
178
- st.session_state.img_matches, \
179
- st.session_state.side_matches, \
180
- st.session_state.o_matches = query(
181
- st.session_state.xq, st.session_state.meta)
 
 
 
 
 
 
 
 
 
 
182
 
183
 
184
  # st.set_page_config(layout="wide")
@@ -186,210 +215,262 @@ def submit(meta):
186
  # Boxes are drawn in SVGs.
187
  st.write(style(), unsafe_allow_html=True)
188
 
189
- with st.spinner("Connecting DB..."):
190
- st.session_state.meta, st.session_state.index = init_db()
191
-
192
- with st.spinner("Loading Models..."):
193
- # Initialize model
194
- model, tokenizer = init_owlvit()
195
-
196
- # If its a fresh start... (query not set)
197
- if 'xq' not in st.session_state:
198
- with st.container():
199
- st.title('Object Detection Safari')
200
- start = [st.empty() for _ in range(8)]
201
- start[0].info("""
202
- We extracted boxes from **287,104** images in COCO Dataset, including its train / val / test /
203
- unlabeled images, collecting **165,371,904 boxes** which are then filtered with common prompts.
204
- You can search with almost any words or phrases you can think of. Please enjoy your journey of
205
- an adventure to COCO.
206
- """)
207
- prompt = start[1].text_input(
208
- "Prompt:", value="", placeholder="Examples: football, billboard, stop sign, watermark ...",)
209
- with start[2].container():
210
- st.write(
211
- 'You can search with multiple keywords. Plese separate with commas but with no space.')
212
- st.write('For example: `cat,dog,tree`')
213
- st.markdown('''
214
- <p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>
215
- ''',
216
- unsafe_allow_html=True)
217
-
218
- upld_model = start[4].file_uploader(
219
- "Or you can upload your previous run!", type='onnx')
220
- upld_btn = start[5].button(
221
- "Use Loaded Weights", disabled=upld_model is None, on_click=refresh_index)
222
-
223
- with start[3]:
224
- col = st.columns(8)
225
- has_no_prompt = (len(prompt) == 0 and upld_model is None)
226
- prompt_xq = col[6].button("Prompt", disabled=len(
227
- prompt) == 0, on_click=refresh_index)
228
- random_xq = col[7].button(
229
- "Random", disabled=not has_no_prompt, on_click=refresh_index)
230
- matches = []
231
- img_matches = []
232
- if random_xq:
233
- xq = init_random_query()
234
- st.session_state.xq = xq
235
- prompt = 'unknown'
236
- st.session_state.text_prompts = prompt.split(',') + ['none']
237
- _ = [elem.empty() for elem in start]
238
- t0 = time()
239
- st.session_state.matches, \
240
- st.session_state.img_matches, \
241
- st.session_state.side_matches, \
242
- st.session_state.o_matches = query(
243
- st.session_state.xq, st.session_state.meta)
244
- t1 = time()
245
- qtime = (t1-t0) * 1000
246
- elif prompt_xq or upld_btn:
247
- if upld_model is not None:
248
- import onnx
249
- from onnx import numpy_helper
250
- _model = onnx.load(upld_model)
251
- st.session_state.text_prompts = [
252
- node.name for node in _model.graph.output] + ['none']
253
- weights = _model.graph.initializer
254
- xq = numpy_helper.to_array(weights[0]).T
255
- assert xq.shape[0] == len(
256
- st.session_state.text_prompts)-1 and xq.shape[1] == DIMS
257
- st.session_state.xq = xq
258
- _ = [elem.empty() for elem in start]
259
- else:
260
- logging.info(f"Input prompt is {prompt}")
261
- st.session_state.text_prompts = prompt.split(',') + ['none']
262
- input_ids, xq = prompt2vec(
263
- st.session_state.text_prompts[:-1], model, tokenizer)
264
  st.session_state.xq = xq
 
 
265
  _ = [elem.empty() for elem in start]
266
- t0 = time()
267
- st.session_state.matches, \
268
- st.session_state.img_matches, \
269
- st.session_state.side_matches, \
270
- st.session_state.o_matches = query(
271
- st.session_state.xq, st.session_state.meta)
272
- t1 = time()
273
- qtime = (t1-t0) * 1000
274
-
275
- # If its not a fresh start (query is set)
276
- if 'xq' in st.session_state:
277
- o_matches = st.session_state.o_matches
278
- side_matches = st.session_state.side_matches
279
- img_matches = st.session_state.img_matches
280
- matches = st.session_state.matches
281
- # initialize classifier
282
- if 'clf' not in st.session_state:
283
- st.session_state.clf = Classifier(st.session_state.xq)
284
- st.session_state.step = 0
285
- if qtime > 0:
286
- st.info("Query done in {0:.2f} ms and returned {1:d} images with {2:d} boxes".format(
287
- qtime, len(matches), sum([len(m["box_id"]) + len(im["box_id"]) for m, im in zip(matches, img_matches)])))
288
-
289
- # export the model into executable ONNX
290
- st.session_state.dnld_model = BytesIO()
291
- torch.onnx.export(torch.nn.Sequential(st.session_state.clf.model, SplitLayer()),
292
- torch.zeros([1, len(st.session_state.xq[0])]),
293
- st.session_state.dnld_model,
294
- input_names=['input'],
295
- output_names=st.session_state.text_prompts[:-1])
296
-
297
- dnld_nam = st.text_input('Download Name:',
298
- f'{("_".join([i.replace(" ", "-") for i in st.session_state.text_prompts[:-1]]) if "text_prompts" in st.session_state else "model")}.onnx',
299
- max_chars=50)
300
- dnld_btn = st.download_button('Download your classifier!',
301
- st.session_state.dnld_model,
302
- dnld_nam)
303
- # build up a sidebar to display REAL TopK in DB
304
- # this will change during user's finetune. But sometime it would lead to bad results
305
- side_bar_len = min(240 // len(st.session_state.text_prompts), 120)
306
- with st.sidebar:
307
- with st.expander("Top-K Images"):
308
- with st.container():
309
- boxes_w_img, _ = postprocess(o_matches, st.session_state.text_prompts,
310
- o_matches)
311
- boxes_w_img = sorted(
312
- boxes_w_img, key=lambda x: x[4], reverse=True)
313
- for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
314
- args = img_url, img_w, img_h, boxes
315
- st.write(card(*args), unsafe_allow_html=True)
316
-
317
- with st.expander("Top-K Objects", expanded=True):
318
- side_cols = st.columns(
319
- len(st.session_state.text_prompts[:-1]))
320
- for _cols, m in zip(side_cols, side_matches):
321
- with _cols.container():
322
- for cx, cy, w, h, logit, img_url, img_w, img_h \
323
- in zip(m['cx'], m['cy'], m['w'], m['h'], m['logit'],
324
- m['img_url'], m['img_w'], m['img_h']):
325
- st.write("{:s}: {:.4f}".format(
326
- st.session_state.text_prompts[m['label']], logit))
327
- _html = obj_card(
328
- img_url, img_w, img_h, cx, cy, w, h, dst_len=side_bar_len)
329
- components.html(
330
- _html, side_bar_len, side_bar_len)
331
- with st.container():
332
- # Here let the user interact with batch labeling
333
- with st.form("batch", clear_on_submit=False):
334
- col = st.columns([1, 9])
335
-
336
- # If there is nothing to show about
337
- if len(matches) <= 0:
338
- st.warning(
339
- 'Oops! We didn\'t find anything relevant to your query! Pleas try another one :/')
340
- else:
341
- st.session_state.iters = st.slider(
342
- "Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2)
343
- # No matter what happened the user wants a way back
344
- col[1].form_submit_button(
345
- "Choose a new prompt", on_click=refresh_index)
346
-
347
- # If there are things to show
348
- if len(matches) > 0:
 
 
 
 
 
 
 
 
349
  with st.container():
350
- prompt_labels = st.session_state.text_prompts
351
-
352
- # Post processing boxes regarding to their score, intersection
353
- boxes_w_img, meta = postprocess(matches, st.session_state.text_prompts,
354
- img_matches)
355
-
356
- # Sort the result according to their relavancy
357
- boxes_w_img = sorted(
358
- boxes_w_img, key=lambda x: x[4], reverse=True)
359
-
360
- st.session_state.matched_boxes = {}
361
- # For each images in the retrieved images, DISPLAY
362
  for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
363
-
364
- # prepare inputs for training
365
- st.session_state.matched_boxes.update(
366
- {b[0]: b for b in boxes})
367
  args = img_url, img_w, img_h, boxes
368
-
369
- # display boxes
370
- with st.expander("{:s}: {:.4f}".format(img_id, img_score), expanded=True):
371
- ind_b = 0
372
- # 4 columns: (img, obj, obj, obj)
373
- img_row = st.columns([4, 2, 2, 2])
374
- img_row[0].write(
375
- card(*args), unsafe_allow_html=True)
376
- # crop objects out of the original image
377
- for b in boxes:
378
- _id, cx, cy, w, h, label, logit, is_selected, _ = b
379
- with img_row[1 + ind_b % 3].container():
380
- st.write(
381
- "{:s}: {:.4f}".format(label, logit))
382
- # quite hacky: with streamlit components API
383
- _html = \
384
- obj_card(img_url, img_w, img_h,
385
- *b[1:5], dst_len=120)
386
- components.html(_html, 120, 120)
387
- # the user will choose the right label of the given object
388
- st.selectbox(
389
- "Class",
390
- prompt_labels,
391
- index=prompt_labels.index(label),
392
- key=f"label-{_id}")
393
- ind_b += 1
394
- col[0].form_submit_button(
395
- "Train!", on_click=lambda: submit(meta))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import logging
9
  from os import environ
10
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
11
+ from bot import Bot, Message
12
  from myscaledb import Client
13
  from classifier import Classifier, prompt2vec, tune, SplitLayer
14
  from query_model import simple_query, topk_obj_query, rev_query
15
  from card_model import card, obj_card, style
16
  from box_utils import postprocess
17
 
18
+ environ["TOKENIZERS_PARALLELISM"] = "true"
19
 
20
  OBJ_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_objects"
21
  IMG_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_images"
22
+ MODEL_ID = "google/owlvit-base-patch32"
23
  DIMS = 512
24
 
25
  qtime = 0
 
34
  Returns:
35
  (model, processor): OwlViT model and its processor for both image and text
36
  """
37
+ device = "cpu"
38
  if torch.cuda.is_available():
39
+ device = "cuda"
40
  model = OwlViTForObjectDetection.from_pretrained(name).to(device)
41
  processor = OwlViTProcessor.from_pretrained(name)
42
  return model, processor
 
44
 
45
  @st.experimental_singleton(show_spinner=False)
46
  def init_owlvit():
47
+ """Initialize OwlViT Model
48
 
49
  Returns:
50
  model, processor
 
55
 
56
  @st.experimental_singleton(show_spinner=False)
57
  def init_db():
58
+ """Initialize the Database Connection
59
 
60
  Returns:
61
  meta_field: Meta field that records if an image is viewed or not
 
63
  """
64
  meta = []
65
  client = Client(
66
+ url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]
67
+ )
68
  # We can check if the connection is alive
69
  assert client.is_alive()
70
  return meta, client
71
 
72
 
73
  def refresh_index():
74
+ """Clean the session"""
 
75
  del st.session_state["meta"]
76
  st.session_state.meta = []
77
  st.session_state.query_num = 0
 
80
  init_db.clear()
81
  # refresh session states
82
  st.session_state.meta, st.session_state.index = init_db()
83
+ if "clf" in st.session_state:
84
  del st.session_state.clf
85
+ if "xq" in st.session_state:
86
  del st.session_state.xq
87
+ if "topk_img_id" in st.session_state:
88
  del st.session_state.topk_img_id
89
 
90
 
91
  def query(xq, exclude_list=None):
92
+ """Query matched w.r.t a given vector
93
 
94
  In this part, we will retrieve A LOT OF data from the server,
95
  including TopK boxes and their embeddings, the counterpart of non-TopK boxes in TopK images.
 
98
  xq (numpy.ndarray or list of floats): Query vector
99
 
100
  Returns:
101
+ matches: list of Records object. Keys referrring to selected columns group by images.
102
  Exclude the user's viewlist.
103
  img_matches: list of Records object. Containing other non-TopK but hit objects among TopK images.
104
  side_matches: list of Records object. Containing REAL TopK objects disregard the user's view history
 
112
  while attempt < 3:
113
  try:
114
  matches = topk_obj_query(
115
+ st.session_state.index,
116
+ xq,
117
+ IMG_DB_NAME,
118
+ OBJ_DB_NAME,
119
+ exclude_list=exclude_list,
120
+ topk=5000,
121
+ )
122
+ img_ids = [r["img_id"] for r in matches]
123
+ if "topk_img_id" not in st.session_state:
124
  st.session_state.topk_img_id = img_ids
125
  status_bar[0].write("Retrieving TopK Images...")
126
  pbar.progress(25)
127
  o_matches = rev_query(
128
+ st.session_state.index,
129
+ xq,
130
+ st.session_state.topk_img_id,
131
+ IMG_DB_NAME,
132
+ OBJ_DB_NAME,
133
+ thresh=0.1,
134
+ )
135
  status_bar[0].write("Retrieving TopKs Objects...")
136
  pbar.progress(50)
137
+ side_matches = simple_query(
138
+ st.session_state.index,
139
+ xq,
140
+ IMG_DB_NAME,
141
+ OBJ_DB_NAME,
142
+ thresh=-1,
143
+ topk=5000,
144
+ )
145
+ status_bar[0].write("Retrieving Non-TopK in Another TopK Images...")
146
  pbar.progress(75)
147
  if len(img_ids) > 0:
148
  img_matches = rev_query(
149
+ st.session_state.index,
150
+ xq,
151
+ img_ids,
152
+ IMG_DB_NAME,
153
+ OBJ_DB_NAME,
154
+ thresh=0.1,
155
+ )
156
  else:
157
  img_matches = []
158
  status_bar[0].write("DONE!")
 
183
 
184
 
185
  def submit(meta):
186
+ """Tune the model w.r.t given score from user."""
 
187
  # Only updating the meta if the train button is pressed
188
  st.session_state.meta.extend(meta)
189
  st.session_state.step += 1
190
  matches = st.session_state.matched_boxes
191
+ X, y = list(
192
+ zip(
193
+ *(
194
+ (
195
+ v[-1],
196
+ st.session_state.text_prompts.index(st.session_state[f"label-{i}"]),
197
+ )
198
+ for i, v in matches.items()
199
+ )
200
+ )
201
+ )
202
+ st.session_state.xq = tune(
203
+ st.session_state.clf, X, y, iters=int(st.session_state.iters)
204
+ )
205
+ (
206
+ st.session_state.matches,
207
+ st.session_state.img_matches,
208
+ st.session_state.side_matches,
209
+ st.session_state.o_matches,
210
+ ) = query(st.session_state.xq, st.session_state.meta)
211
 
212
 
213
  # st.set_page_config(layout="wide")
 
215
  # Boxes are drawn in SVGs.
216
  st.write(style(), unsafe_allow_html=True)
217
 
218
+ bot = Bot(app_name="HF OwlViT", enabled=True, bot_key=st.secrets['BOT_KEY'])
219
+ try:
220
+ with st.spinner("Connecting DB..."):
221
+ st.session_state.meta, st.session_state.index = init_db()
222
+
223
+ with st.spinner("Loading Models..."):
224
+ # Initialize model
225
+ model, tokenizer = init_owlvit()
226
+ # If its a fresh start... (query not set)
227
+ if "xq" not in st.session_state:
228
+ with st.container():
229
+ st.title("Object Detection Safari")
230
+ start = [st.empty() for _ in range(8)]
231
+ start[0].info(
232
+ """
233
+ We extracted boxes from **287,104** images in COCO Dataset, including its train / val / test /
234
+ unlabeled images, collecting **165,371,904 boxes** which are then filtered with common prompts.
235
+ You can search with almost any words or phrases you can think of. Please enjoy your journey of
236
+ an adventure to COCO.
237
+ """
238
+ )
239
+ prompt = start[1].text_input(
240
+ "Prompt:",
241
+ value="",
242
+ placeholder="Examples: football, billboard, stop sign, watermark ...",
243
+ )
244
+ with start[2].container():
245
+ st.write(
246
+ "You can search with multiple keywords. Plese separate with commas but with no space."
247
+ )
248
+ st.write("For example: `cat,dog,tree`")
249
+ st.markdown(
250
+ """
251
+ <p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>
252
+ """,
253
+ unsafe_allow_html=True,
254
+ )
255
+
256
+ upld_model = start[4].file_uploader(
257
+ "Or you can upload your previous run!", type="onnx"
258
+ )
259
+ upld_btn = start[5].button(
260
+ "Use Loaded Weights", disabled=upld_model is None, on_click=refresh_index
261
+ )
262
+
263
+ with start[3]:
264
+ col = st.columns(8)
265
+ has_no_prompt = len(prompt) == 0 and upld_model is None
266
+ prompt_xq = col[6].button(
267
+ "Prompt", disabled=len(prompt) == 0, on_click=refresh_index
268
+ )
269
+ random_xq = col[7].button(
270
+ "Random", disabled=not has_no_prompt, on_click=refresh_index
271
+ )
272
+ matches = []
273
+ img_matches = []
274
+ if random_xq:
275
+ xq = init_random_query()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  st.session_state.xq = xq
277
+ prompt = "unknown"
278
+ st.session_state.text_prompts = prompt.split(",") + ["none"]
279
  _ = [elem.empty() for elem in start]
280
+ t0 = time()
281
+ (
282
+ st.session_state.matches,
283
+ st.session_state.img_matches,
284
+ st.session_state.side_matches,
285
+ st.session_state.o_matches,
286
+ ) = query(st.session_state.xq, st.session_state.meta)
287
+ t1 = time()
288
+ qtime = (t1 - t0) * 1000
289
+ elif prompt_xq or upld_btn:
290
+ if upld_model is not None:
291
+ import onnx
292
+ from onnx import numpy_helper
293
+
294
+ _model = onnx.load(upld_model)
295
+ st.session_state.text_prompts = [
296
+ node.name for node in _model.graph.output
297
+ ] + ["none"]
298
+ weights = _model.graph.initializer
299
+ xq = numpy_helper.to_array(weights[0]).T
300
+ assert (
301
+ xq.shape[0] == len(st.session_state.text_prompts) - 1
302
+ and xq.shape[1] == DIMS
303
+ )
304
+ st.session_state.xq = xq
305
+ _ = [elem.empty() for elem in start]
306
+ else:
307
+ logging.info(f"Input prompt is {prompt}")
308
+ st.session_state.text_prompts = prompt.split(",") + ["none"]
309
+ input_ids, xq = prompt2vec(
310
+ st.session_state.text_prompts[:-1], model, tokenizer
311
+ )
312
+ st.session_state.xq = xq
313
+ _ = [elem.empty() for elem in start]
314
+ t0 = time()
315
+ (
316
+ st.session_state.matches,
317
+ st.session_state.img_matches,
318
+ st.session_state.side_matches,
319
+ st.session_state.o_matches,
320
+ ) = query(st.session_state.xq, st.session_state.meta)
321
+ t1 = time()
322
+ qtime = (t1 - t0) * 1000
323
+
324
+ # If its not a fresh start (query is set)
325
+ if "xq" in st.session_state:
326
+ o_matches = st.session_state.o_matches
327
+ side_matches = st.session_state.side_matches
328
+ img_matches = st.session_state.img_matches
329
+ matches = st.session_state.matches
330
+ # initialize classifier
331
+ if "clf" not in st.session_state:
332
+ st.session_state.clf = Classifier(st.session_state.xq)
333
+ st.session_state.step = 0
334
+ if qtime > 0:
335
+ st.info(
336
+ "Query done in {0:.2f} ms and returned {1:d} images with {2:d} boxes".format(
337
+ qtime,
338
+ len(matches),
339
+ sum(
340
+ [
341
+ len(m["box_id"]) + len(im["box_id"])
342
+ for m, im in zip(matches, img_matches)
343
+ ]
344
+ ),
345
+ )
346
+ )
347
+
348
+ # export the model into executable ONNX
349
+ st.session_state.dnld_model = BytesIO()
350
+ torch.onnx.export(
351
+ torch.nn.Sequential(st.session_state.clf.model, SplitLayer()),
352
+ torch.zeros([1, len(st.session_state.xq[0])]),
353
+ st.session_state.dnld_model,
354
+ input_names=["input"],
355
+ output_names=st.session_state.text_prompts[:-1],
356
+ )
357
+
358
+ dnld_nam = st.text_input(
359
+ "Download Name:",
360
+ f'{("_".join([i.replace(" ", "-") for i in st.session_state.text_prompts[:-1]]) if "text_prompts" in st.session_state else "model")}.onnx',
361
+ max_chars=50,
362
+ )
363
+ dnld_btn = st.download_button(
364
+ "Download your classifier!", st.session_state.dnld_model, dnld_nam
365
+ )
366
+ # build up a sidebar to display REAL TopK in DB
367
+ # this will change during user's finetune. But sometime it would lead to bad results
368
+ side_bar_len = min(240 // len(st.session_state.text_prompts), 120)
369
+ with st.sidebar:
370
+ with st.expander("Top-K Images"):
371
  with st.container():
372
+ boxes_w_img, _ = postprocess(
373
+ o_matches, st.session_state.text_prompts, None
374
+ )
375
+ boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True)
 
 
 
 
 
 
 
 
376
  for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
 
 
 
 
377
  args = img_url, img_w, img_h, boxes
378
+ st.write(card(*args), unsafe_allow_html=True)
379
+
380
+ with st.expander("Top-K Objects", expanded=True):
381
+ side_cols = st.columns(len(st.session_state.text_prompts[:-1]))
382
+ for _cols, m in zip(side_cols, side_matches):
383
+ with _cols.container():
384
+ for cx, cy, w, h, logit, img_url, img_w, img_h in zip(
385
+ m["cx"],
386
+ m["cy"],
387
+ m["w"],
388
+ m["h"],
389
+ m["logit"],
390
+ m["img_url"],
391
+ m["img_w"],
392
+ m["img_h"],
393
+ ):
394
+ st.write(
395
+ "{:s}: {:.4f}".format(
396
+ st.session_state.text_prompts[m["label"]], logit
397
+ )
398
+ )
399
+ _html = obj_card(
400
+ img_url, img_w, img_h, cx, cy, w, h, dst_len=side_bar_len
401
+ )
402
+ components.html(_html, side_bar_len, side_bar_len)
403
+ with st.container():
404
+ # Here let the user interact with batch labeling
405
+ with st.form("batch", clear_on_submit=False):
406
+ col = st.columns([1, 9])
407
+
408
+ # If there is nothing to show about
409
+ if len(matches) <= 0:
410
+ st.warning(
411
+ "Oops! We didn't find anything relevant to your query! Pleas try another one :/"
412
+ )
413
+ else:
414
+ st.session_state.iters = st.slider(
415
+ "Number of Iterations to Update",
416
+ min_value=0,
417
+ max_value=10,
418
+ step=1,
419
+ value=2,
420
+ )
421
+ # No matter what happened the user wants a way back
422
+ col[1].form_submit_button("Choose a new prompt", on_click=refresh_index)
423
+
424
+ # If there are things to show
425
+ if len(matches) > 0:
426
+ with st.container():
427
+ prompt_labels = st.session_state.text_prompts
428
+
429
+ # Post processing boxes regarding to their score, intersection
430
+ boxes_w_img, meta = postprocess(
431
+ matches, st.session_state.text_prompts, img_matches
432
+ )
433
+
434
+ # Sort the result according to their relavancy
435
+ boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True)
436
+
437
+ st.session_state.matched_boxes = {}
438
+ # For each images in the retrieved images, DISPLAY
439
+ for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
440
+
441
+ # prepare inputs for training
442
+ st.session_state.matched_boxes.update({b[0]: b for b in boxes})
443
+ args = img_url, img_w, img_h, boxes
444
+
445
+ # display boxes
446
+ with st.expander(
447
+ "{:s}: {:.4f}".format(img_id, img_score), expanded=True
448
+ ):
449
+ ind_b = 0
450
+ # 4 columns: (img, obj, obj, obj)
451
+ img_row = st.columns([4, 2, 2, 2])
452
+ img_row[0].write(card(*args), unsafe_allow_html=True)
453
+ # crop objects out of the original image
454
+ for b in boxes:
455
+ _id, cx, cy, w, h, label, logit, is_selected, _ = b
456
+ with img_row[1 + ind_b % 3].container():
457
+ st.write("{:s}: {:.4f}".format(label, logit))
458
+ # quite hacky: with streamlit components API
459
+ _html = obj_card(
460
+ img_url, img_w, img_h, *b[1:5], dst_len=120
461
+ )
462
+ components.html(_html, 120, 120)
463
+ # the user will choose the right label of the given object
464
+ st.selectbox(
465
+ "Class",
466
+ prompt_labels,
467
+ index=prompt_labels.index(label),
468
+ key=f"label-{_id}",
469
+ )
470
+ ind_b += 1
471
+ col[0].form_submit_button("Train!", on_click=lambda: submit(meta))
472
+ except Exception as e:
473
+ msg = Message()
474
+ msg.content = str(e.with_traceback(None))
475
+ msg.type_hint = str(type(e).__name__)
476
+ bot.incident(msg)