object-detection-safari / query_model.py
Fangrui Liu
init repo
3f1124e
raw
history blame
4.96 kB
import logging
def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
exclude_list=[], topk=10):
xq_s = [
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
exclude_list_str = ','.join([f'\'{i}\'' for i in exclude_list])
_cond = (f"WHERE obj_id NOT IN ({exclude_list_str})" if len(
exclude_list) > 0 else "")
_subq_str = []
_img_score_subq = []
for _l, _xq in enumerate(xq_s):
_img_score_subq.append(
f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))")
_subq_str.append(f"""
SELECT img_id, img_url, img_w, img_h, 1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))) AS pred_logit,
obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l
FROM {OBJ_DB_NAME}
JOIN {IMG_DB_NAME}
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
PREWHERE obj_id IN (
SELECT obj_id FROM (
SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME}
ORDER BY dist DESC
) {_cond} LIMIT 10
)
""")
_subq_str = ' UNION ALL '.join(_subq_str)
_img_score_q = ','.join(_img_score_subq)
_img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score"
q_str = f"""
SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
{_img_score_q}
FROM
({_subq_str})
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
"""
xc = client.fetch(q_str)
return xc
def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
xq_s = [
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
image_list = ','.join([f'\'{i}\'' for i in img_ids])
_thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else ""
_subq_str = []
_img_score_subq = []
for _l, _xq in enumerate(xq_s):
_img_score_subq.append(
f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))")
_subq_str.append(f"""
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h,
(1 / (1 + exp(-(arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))))) AS pred_logit,
obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l
FROM {OBJ_DB_NAME}
JOIN {IMG_DB_NAME}
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
PREWHERE img_id IN ({image_list})
{_thresh}
""")
_subq_str = ' UNION ALL '.join(_subq_str)
_img_score_q = ','.join(_img_score_subq)
_img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score"
q_str = f"""
SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
{_img_score_q}
FROM
({_subq_str})
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
"""
xc = client.fetch(q_str)
return xc
def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
xq_s = [
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
res = []
subq_str = []
_thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else ""
for _l, _xq in enumerate(xq_s):
subq_str.append(
f"""
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit,
obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist
FROM {OBJ_DB_NAME}
JOIN {IMG_DB_NAME}
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
{_thresh} LIMIT 10
""")
subq_str = " UNION ALL ".join(subq_str)
q_str = f"""
SELECT groupArray(img_url) AS img_url, groupArray(img_w) AS img_w, groupArray(img_h) AS img_h,
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
l AS label, groupArray(dist) as d,
groupArray(1 / (1 + exp(-dist))) AS logit FROM (
{subq_str}
)
GROUP BY l
"""
res = client.fetch(q_str)
return res