Spaces:
Runtime error
Runtime error
Fangrui Liu
commited on
Commit
•
98667f6
1
Parent(s):
1439326
update compute gradient with sql
Browse files- app.py +12 -6
- box_utils.py +61 -32
- card_model.py +2 -1
- classifier.py +46 -37
- query_model.py +2 -2
app.py
CHANGED
@@ -192,7 +192,7 @@ def submit(meta):
|
|
192 |
zip(
|
193 |
*(
|
194 |
(
|
195 |
-
v[
|
196 |
st.session_state.text_prompts.index(st.session_state[f"label-{i}"]),
|
197 |
)
|
198 |
for i, v in matches.items()
|
@@ -329,7 +329,7 @@ try:
|
|
329 |
matches = st.session_state.matches
|
330 |
# initialize classifier
|
331 |
if "clf" not in st.session_state:
|
332 |
-
st.session_state.clf = Classifier(st.session_state.xq)
|
333 |
st.session_state.step = 0
|
334 |
if qtime > 0:
|
335 |
st.info(
|
@@ -344,11 +344,13 @@ try:
|
|
344 |
),
|
345 |
)
|
346 |
)
|
|
|
|
|
347 |
|
348 |
# export the model into executable ONNX
|
349 |
st.session_state.dnld_model = BytesIO()
|
350 |
torch.onnx.export(
|
351 |
-
torch.nn.Sequential(
|
352 |
torch.zeros([1, len(st.session_state.xq[0])]),
|
353 |
st.session_state.dnld_model,
|
354 |
input_names=["input"],
|
@@ -370,7 +372,9 @@ try:
|
|
370 |
with st.expander("Top-K Images"):
|
371 |
with st.container():
|
372 |
boxes_w_img, _ = postprocess(
|
373 |
-
o_matches, st.session_state.text_prompts,
|
|
|
|
|
374 |
)
|
375 |
boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True)
|
376 |
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
|
@@ -428,7 +432,9 @@ try:
|
|
428 |
|
429 |
# Post processing boxes regarding to their score, intersection
|
430 |
boxes_w_img, meta = postprocess(
|
431 |
-
matches, st.session_state.text_prompts, img_matches
|
|
|
|
|
432 |
)
|
433 |
|
434 |
# Sort the result according to their relavancy
|
@@ -452,7 +458,7 @@ try:
|
|
452 |
img_row[0].write(card(*args), unsafe_allow_html=True)
|
453 |
# crop objects out of the original image
|
454 |
for b in boxes:
|
455 |
-
_id, cx, cy, w, h, label, logit, is_selected
|
456 |
with img_row[1 + ind_b % 3].container():
|
457 |
st.write("{:s}: {:.4f}".format(label, logit))
|
458 |
# quite hacky: with streamlit components API
|
|
|
192 |
zip(
|
193 |
*(
|
194 |
(
|
195 |
+
v[0],
|
196 |
st.session_state.text_prompts.index(st.session_state[f"label-{i}"]),
|
197 |
)
|
198 |
for i, v in matches.items()
|
|
|
329 |
matches = st.session_state.matches
|
330 |
# initialize classifier
|
331 |
if "clf" not in st.session_state:
|
332 |
+
st.session_state.clf = Classifier(st.session_state.index, OBJ_DB_NAME, st.session_state.xq)
|
333 |
st.session_state.step = 0
|
334 |
if qtime > 0:
|
335 |
st.info(
|
|
|
344 |
),
|
345 |
)
|
346 |
)
|
347 |
+
lnprob = torch.nn.Linear(st.session_state.xq.shape[1], st.session_state.xq.shape[0], bias=False)
|
348 |
+
lnprob.weight = torch.nn.Parameter(st.session_state.clf.weight)
|
349 |
|
350 |
# export the model into executable ONNX
|
351 |
st.session_state.dnld_model = BytesIO()
|
352 |
torch.onnx.export(
|
353 |
+
torch.nn.Sequential(lnprob, SplitLayer()),
|
354 |
torch.zeros([1, len(st.session_state.xq[0])]),
|
355 |
st.session_state.dnld_model,
|
356 |
input_names=["input"],
|
|
|
372 |
with st.expander("Top-K Images"):
|
373 |
with st.container():
|
374 |
boxes_w_img, _ = postprocess(
|
375 |
+
o_matches, st.session_state.text_prompts, o_matches,
|
376 |
+
agnostic_ratio=1-0.6**(st.session_state.step+1),
|
377 |
+
class_ratio=1-0.2**(st.session_state.step+1)
|
378 |
)
|
379 |
boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True)
|
380 |
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
|
|
|
432 |
|
433 |
# Post processing boxes regarding to their score, intersection
|
434 |
boxes_w_img, meta = postprocess(
|
435 |
+
matches, st.session_state.text_prompts, img_matches,
|
436 |
+
agnostic_ratio=1-0.6**(st.session_state.step+1),
|
437 |
+
class_ratio=1-0.2**(st.session_state.step+1)
|
438 |
)
|
439 |
|
440 |
# Sort the result according to their relavancy
|
|
|
458 |
img_row[0].write(card(*args), unsafe_allow_html=True)
|
459 |
# crop objects out of the original image
|
460 |
for b in boxes:
|
461 |
+
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
|
462 |
with img_row[1 + ind_b % 3].container():
|
463 |
st.write("{:s}: {:.4f}".format(label, logit))
|
464 |
# quite hacky: with streamlit components API
|
box_utils.py
CHANGED
@@ -2,16 +2,14 @@ import numpy as np
|
|
2 |
|
3 |
|
4 |
def cxywh2xywh(cx, cy, w, h):
|
5 |
-
"""
|
6 |
-
"""
|
7 |
x = cx - w / 2
|
8 |
y = cy - h / 2
|
9 |
return x, y, w, h
|
10 |
|
11 |
|
12 |
def cxywh2ltrb(cx, cy, w, h):
|
13 |
-
"""CxCyWH format to LeftRightTopBottom format
|
14 |
-
"""
|
15 |
l = cx - w / 2
|
16 |
t = cy - h / 2
|
17 |
r = cx + w / 2
|
@@ -61,9 +59,16 @@ def nms(cx, cy, w, h, s, iou_thresh=0.3):
|
|
61 |
i = sort_ind[0]
|
62 |
res.append(i)
|
63 |
|
64 |
-
_iou = iou(
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
sel_ind = np.where(_iou <= iou_thresh)[0]
|
68 |
sort_ind = sort_ind[sel_ind + 1]
|
69 |
return res
|
@@ -77,43 +82,64 @@ def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7):
|
|
77 |
"""
|
78 |
ret = []
|
79 |
labelwise = {}
|
80 |
-
for
|
|
|
81 |
if label not in labelwise:
|
82 |
labelwise[label] = []
|
83 |
labelwise[label].append(logit)
|
84 |
labelwise = {l: max(s) for l, s in labelwise.items()}
|
85 |
agnostic = max([v for _, v in labelwise.items()])
|
86 |
for b in boxes:
|
87 |
-
_id, cx, cy, w, h, label, logit, is_selected
|
88 |
-
if logit > class_ratio * labelwise[label]
|
89 |
-
and logit > agnostic_ratio * agnostic:
|
90 |
ret.append(b)
|
91 |
return ret
|
92 |
|
93 |
|
94 |
-
def postprocess(matches, prompt_labels, img_matches=None):
|
95 |
meta = []
|
96 |
boxes_w_img = []
|
97 |
-
matches_ = {m[
|
98 |
if img_matches is not None:
|
99 |
-
img_matches_ = {m[
|
100 |
for k in matches_.keys():
|
101 |
m = matches_[k]
|
102 |
boxes = []
|
103 |
-
boxes += list(
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
if img_matches is not None and k in img_matches_:
|
110 |
img_m = img_matches_[k]
|
111 |
# and also those non-TopK hits and those non-topk are not anticipating training
|
112 |
-
boxes += [
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
else:
|
118 |
img_m = None
|
119 |
# update record metadata after query
|
@@ -121,16 +147,19 @@ def postprocess(matches, prompt_labels, img_matches=None):
|
|
121 |
meta.append(b[0])
|
122 |
|
123 |
# remove some non-significant boxes
|
124 |
-
boxes = filter_nonpos(
|
125 |
-
boxes, agnostic_ratio=0.4, class_ratio=0.7)
|
126 |
|
127 |
# doing non-maximum suppression
|
128 |
-
cx, cy, w, h, s = list(
|
129 |
-
|
|
|
130 |
ind = nms(cx, cy, w, h, s, 0.3)
|
131 |
boxes = [boxes[i] for i in ind]
|
132 |
if img_m is not None:
|
133 |
-
img_score =
|
|
|
|
|
134 |
boxes_w_img.append(
|
135 |
-
(m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes)
|
136 |
-
|
|
|
|
2 |
|
3 |
|
4 |
def cxywh2xywh(cx, cy, w, h):
|
5 |
+
"""CxCyWH format to XYWH format conversion"""
|
|
|
6 |
x = cx - w / 2
|
7 |
y = cy - h / 2
|
8 |
return x, y, w, h
|
9 |
|
10 |
|
11 |
def cxywh2ltrb(cx, cy, w, h):
|
12 |
+
"""CxCyWH format to LeftRightTopBottom format"""
|
|
|
13 |
l = cx - w / 2
|
14 |
t = cy - h / 2
|
15 |
r = cx + w / 2
|
|
|
59 |
i = sort_ind[0]
|
60 |
res.append(i)
|
61 |
|
62 |
+
_iou = iou(
|
63 |
+
(l[i], t[i], r[i], b[i], areas[i]),
|
64 |
+
(
|
65 |
+
l[sort_ind[1:]],
|
66 |
+
t[sort_ind[1:]],
|
67 |
+
r[sort_ind[1:]],
|
68 |
+
b[sort_ind[1:]],
|
69 |
+
areas[sort_ind[1:]],
|
70 |
+
),
|
71 |
+
)
|
72 |
sel_ind = np.where(_iou <= iou_thresh)[0]
|
73 |
sort_ind = sort_ind[sel_ind + 1]
|
74 |
return res
|
|
|
82 |
"""
|
83 |
ret = []
|
84 |
labelwise = {}
|
85 |
+
for b in boxes:
|
86 |
+
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
|
87 |
if label not in labelwise:
|
88 |
labelwise[label] = []
|
89 |
labelwise[label].append(logit)
|
90 |
labelwise = {l: max(s) for l, s in labelwise.items()}
|
91 |
agnostic = max([v for _, v in labelwise.items()])
|
92 |
for b in boxes:
|
93 |
+
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
|
94 |
+
if logit > class_ratio * labelwise[label] and logit > agnostic_ratio * agnostic:
|
|
|
95 |
ret.append(b)
|
96 |
return ret
|
97 |
|
98 |
|
99 |
+
def postprocess(matches, prompt_labels, img_matches=None, agnostic_ratio=0.4, class_ratio=0.7):
|
100 |
meta = []
|
101 |
boxes_w_img = []
|
102 |
+
matches_ = {m["img_id"]: m for m in matches}
|
103 |
if img_matches is not None:
|
104 |
+
img_matches_ = {m["img_id"]: m for m in img_matches}
|
105 |
for k in matches_.keys():
|
106 |
m = matches_[k]
|
107 |
boxes = []
|
108 |
+
boxes += list(
|
109 |
+
map(
|
110 |
+
list,
|
111 |
+
zip(
|
112 |
+
m["box_id"],
|
113 |
+
m["cx"],
|
114 |
+
m["cy"],
|
115 |
+
m["w"],
|
116 |
+
m["h"],
|
117 |
+
[prompt_labels[int(l)] for l in m["label"]],
|
118 |
+
m["logit"],
|
119 |
+
[1] * len(m["box_id"]),
|
120 |
+
),
|
121 |
+
)
|
122 |
+
)
|
123 |
if img_matches is not None and k in img_matches_:
|
124 |
img_m = img_matches_[k]
|
125 |
# and also those non-TopK hits and those non-topk are not anticipating training
|
126 |
+
boxes += [
|
127 |
+
i
|
128 |
+
for i in map(
|
129 |
+
list,
|
130 |
+
zip(
|
131 |
+
img_m["box_id"],
|
132 |
+
img_m["cx"],
|
133 |
+
img_m["cy"],
|
134 |
+
img_m["w"],
|
135 |
+
img_m["h"],
|
136 |
+
[prompt_labels[int(l)] for l in img_m["label"]],
|
137 |
+
img_m["logit"],
|
138 |
+
[0] * len(img_m["box_id"]),
|
139 |
+
),
|
140 |
+
)
|
141 |
+
if i[0] not in [b[0] for b in boxes]
|
142 |
+
]
|
143 |
else:
|
144 |
img_m = None
|
145 |
# update record metadata after query
|
|
|
147 |
meta.append(b[0])
|
148 |
|
149 |
# remove some non-significant boxes
|
150 |
+
boxes = filter_nonpos(boxes, agnostic_ratio=agnostic_ratio, class_ratio=class_ratio)
|
|
|
151 |
|
152 |
# doing non-maximum suppression
|
153 |
+
cx, cy, w, h, s = list(
|
154 |
+
map(lambda x: np.array(x), list(zip(*[(*b[1:5], b[6]) for b in boxes])))
|
155 |
+
)
|
156 |
ind = nms(cx, cy, w, h, s, 0.3)
|
157 |
boxes = [boxes[i] for i in ind]
|
158 |
if img_m is not None:
|
159 |
+
img_score = (
|
160 |
+
img_m["img_score"] if img_matches is not None else m["img_score"]
|
161 |
+
)
|
162 |
boxes_w_img.append(
|
163 |
+
(m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes)
|
164 |
+
)
|
165 |
+
return boxes_w_img, meta
|
card_model.py
CHANGED
@@ -47,7 +47,8 @@ def card(img_url, img_w, img_h, boxes):
|
|
47 |
"""
|
48 |
_boxes = ""
|
49 |
img_url = convert_img_url(img_url)
|
50 |
-
for
|
|
|
51 |
x, y, w, h = cxywh2xywh(cx, cy, w, h)
|
52 |
x = round(img_w * x)
|
53 |
y = round(img_h * y)
|
|
|
47 |
"""
|
48 |
_boxes = ""
|
49 |
img_url = convert_img_url(img_url)
|
50 |
+
for b in boxes:
|
51 |
+
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
|
52 |
x, y, w, h = cxywh2xywh(cx, cy, w, h)
|
53 |
x = round(img_w * x)
|
54 |
y = round(img_h * y)
|
classifier.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
|
3 |
|
4 |
-
def extract_text_feature(prompt, model, processor, device=
|
5 |
"""Extract text features
|
6 |
|
7 |
Args:
|
@@ -10,12 +10,11 @@ def extract_text_feature(prompt, model, processor, device='cpu'):
|
|
10 |
processor: OwlViT processor
|
11 |
device (str, optional): device to run. Defaults to 'cpu'.
|
12 |
"""
|
13 |
-
device =
|
14 |
if torch.cuda.is_available():
|
15 |
-
device =
|
16 |
with torch.no_grad():
|
17 |
-
input_ids = torch.as_tensor(processor(text=prompt)[
|
18 |
-
'input_ids']).to(device)
|
19 |
print(input_ids.device)
|
20 |
text_outputs = model.owlvit.text_model(
|
21 |
input_ids=input_ids,
|
@@ -32,7 +31,7 @@ def extract_text_feature(prompt, model, processor, device='cpu'):
|
|
32 |
|
33 |
|
34 |
def prompt2vec(prompt: str, model, processor):
|
35 |
-
"""
|
36 |
|
37 |
Args:
|
38 |
prompt (str): Text to be tokenized
|
@@ -49,7 +48,7 @@ def prompt2vec(prompt: str, model, processor):
|
|
49 |
|
50 |
|
51 |
def tune(clf, X, y, iters=2):
|
52 |
-
"""
|
53 |
|
54 |
Args:
|
55 |
X (numpy.ndarray): Input vectors (retreived vectors)
|
@@ -62,60 +61,70 @@ def tune(clf, X, y, iters=2):
|
|
62 |
# extract new vector
|
63 |
return clf.get_weights()
|
64 |
|
65 |
-
|
66 |
class Classifier:
|
67 |
"""Multi-Class Zero-shot Classifier
|
68 |
This Classifier provides proxy regarding to the user's reaction to the probed images.
|
69 |
The proxy will replace the original query vector generated by prompted vector and finally
|
70 |
give the user a satisfying retrieval result.
|
71 |
|
72 |
-
This can be commonly seen in a recommendation system. The classifier will recommend more
|
73 |
precise result as it accumulating user's activity.
|
74 |
-
|
75 |
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
|
76 |
and the last one takes the negative one.
|
77 |
"""
|
78 |
|
79 |
-
def __init__(self, xq: list):
|
80 |
init_weight = torch.Tensor(xq)
|
81 |
self.num_class = xq.shape[0]
|
82 |
-
DIMS = xq.shape[1]
|
83 |
-
# note that the bias is ignored, as we only focus on the inner product result
|
84 |
-
self.model = torch.nn.Linear(DIMS, self.num_class, bias=False)
|
85 |
# convert initial query `xq` to tensor parameter to init weights
|
86 |
-
self.
|
87 |
-
|
88 |
-
self.
|
89 |
-
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
|
90 |
|
91 |
def fit(self, X: list, y: list, iters: int = 5):
|
92 |
# convert X and y to tensor
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
y[non_ind] = 0
|
100 |
-
for i in range(iters):
|
101 |
# zero gradients
|
102 |
-
|
103 |
# Normalize the weight before inference
|
104 |
# This will constrain the gradient or you will have an explosion on query vector
|
105 |
-
self.
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
# update weights
|
113 |
-
|
|
|
114 |
|
115 |
def get_weights(self):
|
116 |
-
xq = self.
|
117 |
return xq
|
118 |
-
|
|
|
|
|
119 |
class SplitLayer(torch.nn.Module):
|
120 |
def forward(self, x):
|
121 |
return torch.split(x, 1, dim=-1)
|
|
|
1 |
import torch
|
2 |
|
3 |
|
4 |
+
def extract_text_feature(prompt, model, processor, device="cpu"):
|
5 |
"""Extract text features
|
6 |
|
7 |
Args:
|
|
|
10 |
processor: OwlViT processor
|
11 |
device (str, optional): device to run. Defaults to 'cpu'.
|
12 |
"""
|
13 |
+
device = "cpu"
|
14 |
if torch.cuda.is_available():
|
15 |
+
device = "cuda"
|
16 |
with torch.no_grad():
|
17 |
+
input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device)
|
|
|
18 |
print(input_ids.device)
|
19 |
text_outputs = model.owlvit.text_model(
|
20 |
input_ids=input_ids,
|
|
|
31 |
|
32 |
|
33 |
def prompt2vec(prompt: str, model, processor):
|
34 |
+
"""Convert prompt into a computational vector
|
35 |
|
36 |
Args:
|
37 |
prompt (str): Text to be tokenized
|
|
|
48 |
|
49 |
|
50 |
def tune(clf, X, y, iters=2):
|
51 |
+
"""Train the Zero-shot Classifier
|
52 |
|
53 |
Args:
|
54 |
X (numpy.ndarray): Input vectors (retreived vectors)
|
|
|
61 |
# extract new vector
|
62 |
return clf.get_weights()
|
63 |
|
|
|
64 |
class Classifier:
|
65 |
"""Multi-Class Zero-shot Classifier
|
66 |
This Classifier provides proxy regarding to the user's reaction to the probed images.
|
67 |
The proxy will replace the original query vector generated by prompted vector and finally
|
68 |
give the user a satisfying retrieval result.
|
69 |
|
70 |
+
This can be commonly seen in a recommendation system. The classifier will recommend more
|
71 |
precise result as it accumulating user's activity.
|
72 |
+
|
73 |
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
|
74 |
and the last one takes the negative one.
|
75 |
"""
|
76 |
|
77 |
+
def __init__(self, client, obj_db:str, xq: list):
|
78 |
init_weight = torch.Tensor(xq)
|
79 |
self.num_class = xq.shape[0]
|
80 |
+
self.DIMS = xq.shape[1]
|
|
|
|
|
81 |
# convert initial query `xq` to tensor parameter to init weights
|
82 |
+
self.weight = init_weight
|
83 |
+
self.client = client
|
84 |
+
self.obj_db = obj_db
|
|
|
85 |
|
86 |
def fit(self, X: list, y: list, iters: int = 5):
|
87 |
# convert X and y to tensor
|
88 |
+
xq_s = [
|
89 |
+
f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]"
|
90 |
+
for _xq in self.get_weights().tolist()
|
91 |
+
]
|
92 |
+
|
93 |
+
for _ in range(iters):
|
|
|
|
|
94 |
# zero gradients
|
95 |
+
grad = []
|
96 |
# Normalize the weight before inference
|
97 |
# This will constrain the gradient or you will have an explosion on query vector
|
98 |
+
self.weight.data /= torch.norm(
|
99 |
+
self.weight.data, p=2, dim=-1, keepdim=True
|
100 |
+
)
|
101 |
+
for n in range(self.num_class):
|
102 |
+
# select all training sample and create labels
|
103 |
+
labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]])))
|
104 |
+
|
105 |
+
# NOTE from @fangruil
|
106 |
+
# Use SQL to calculate the gradient
|
107 |
+
# For binary cross entropy we have
|
108 |
+
# g = (1/(1+\exp(-XW))-Y)^TX
|
109 |
+
# To simplify the query, we separated
|
110 |
+
# the calculation into class numbers
|
111 |
+
grad_q_str = f"""
|
112 |
+
SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
|
113 |
+
FROM (
|
114 |
+
SELECT groupArray(arrayPopBack(prelogit)) AS X,
|
115 |
+
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
|
116 |
+
FROM {self.obj_db} WHERE obj_id IN {objs})"""
|
117 |
+
grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad']))
|
118 |
# update weights
|
119 |
+
grad = torch.stack(grad, dim=0)
|
120 |
+
self.weight -= 0.1 * grad
|
121 |
|
122 |
def get_weights(self):
|
123 |
+
xq = self.weight.detach().numpy()
|
124 |
return xq
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
class SplitLayer(torch.nn.Module):
|
129 |
def forward(self, x):
|
130 |
return torch.split(x, 1, dim=-1)
|
query_model.py
CHANGED
@@ -32,7 +32,7 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
|
|
32 |
q_str = f"""
|
33 |
SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
|
34 |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
35 |
-
groupArray(pred_logit) AS logit, groupArray(l) as label,
|
36 |
{_img_score_q}
|
37 |
FROM
|
38 |
({_subq_str})
|
@@ -68,7 +68,7 @@ def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
|
|
68 |
q_str = f"""
|
69 |
SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
|
70 |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
71 |
-
groupArray(pred_logit) AS logit, groupArray(l) as label,
|
72 |
{_img_score_q}
|
73 |
FROM
|
74 |
({_subq_str})
|
|
|
32 |
q_str = f"""
|
33 |
SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
|
34 |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
35 |
+
groupArray(pred_logit) AS logit, groupArray(l) as label,
|
36 |
{_img_score_q}
|
37 |
FROM
|
38 |
({_subq_str})
|
|
|
68 |
q_str = f"""
|
69 |
SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
|
70 |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
71 |
+
groupArray(pred_logit) AS logit, groupArray(l) as label,
|
72 |
{_img_score_q}
|
73 |
FROM
|
74 |
({_subq_str})
|