Spaces:
Runtime error
Runtime error
import logging | |
def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, | |
exclude_list=[], topk=10): | |
xq_s = [ | |
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] | |
exclude_list_str = ','.join([f'\'{i}\'' for i in exclude_list]) | |
_cond = (f"WHERE obj_id NOT IN ({exclude_list_str})" if len( | |
exclude_list) > 0 else "") | |
_subq_str = [] | |
_img_score_subq = [] | |
for _l, _xq in enumerate(xq_s): | |
_img_score_subq.append( | |
f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))") | |
_subq_str.append(f""" | |
SELECT img_id, img_url, img_w, img_h, 1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))) AS pred_logit, | |
obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l | |
FROM {OBJ_DB_NAME} | |
JOIN {IMG_DB_NAME} | |
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id | |
PREWHERE obj_id IN ( | |
SELECT obj_id FROM ( | |
SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME} | |
ORDER BY dist DESC | |
) {_cond} LIMIT 10 | |
) | |
""") | |
_subq_str = ' UNION ALL '.join(_subq_str) | |
_img_score_q = ','.join(_img_score_subq) | |
_img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score" | |
q_str = f""" | |
SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id, | |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, | |
groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb, | |
{_img_score_q} | |
FROM | |
({_subq_str}) | |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC | |
""" | |
xc = client.fetch(q_str) | |
return xc | |
def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08): | |
xq_s = [ | |
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] | |
image_list = ','.join([f'\'{i}\'' for i in img_ids]) | |
_thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else "" | |
_subq_str = [] | |
_img_score_subq = [] | |
for _l, _xq in enumerate(xq_s): | |
_img_score_subq.append( | |
f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))") | |
_subq_str.append(f""" | |
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, | |
(1 / (1 + exp(-(arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))))) AS pred_logit, | |
obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l | |
FROM {OBJ_DB_NAME} | |
JOIN {IMG_DB_NAME} | |
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id | |
PREWHERE img_id IN ({image_list}) | |
{_thresh} | |
""") | |
_subq_str = ' UNION ALL '.join(_subq_str) | |
_img_score_q = ','.join(_img_score_subq) | |
_img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score" | |
q_str = f""" | |
SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h, | |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, | |
groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb, | |
{_img_score_q} | |
FROM | |
({_subq_str}) | |
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC | |
""" | |
xc = client.fetch(q_str) | |
return xc | |
def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10): | |
xq_s = [ | |
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] | |
res = [] | |
subq_str = [] | |
_thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else "" | |
for _l, _xq in enumerate(xq_s): | |
subq_str.append( | |
f""" | |
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit, | |
obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist | |
FROM {OBJ_DB_NAME} | |
JOIN {IMG_DB_NAME} | |
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id | |
{_thresh} LIMIT 10 | |
""") | |
subq_str = " UNION ALL ".join(subq_str) | |
q_str = f""" | |
SELECT groupArray(img_url) AS img_url, groupArray(img_w) AS img_w, groupArray(img_h) AS img_h, | |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, | |
l AS label, groupArray(dist) as d, | |
groupArray(1 / (1 + exp(-dist))) AS logit FROM ( | |
{subq_str} | |
) | |
GROUP BY l | |
""" | |
res = client.fetch(q_str) | |
return res | |