import logging def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, exclude_list=[], topk=10): xq_s = [ f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] exclude_list_str = ','.join([f'\'{i}\'' for i in exclude_list]) _cond = (f"WHERE obj_id NOT IN ({exclude_list_str})" if len( exclude_list) > 0 else "") _subq_str = [] _img_score_subq = [] for _l, _xq in enumerate(xq_s): _img_score_subq.append( f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))") _subq_str.append(f""" SELECT img_id, img_url, img_w, img_h, 1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))) AS pred_logit, obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l FROM {OBJ_DB_NAME} JOIN {IMG_DB_NAME} ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id PREWHERE obj_id IN ( SELECT obj_id FROM ( SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME} ORDER BY dist DESC ) {_cond} LIMIT 10 ) """) _subq_str = ' UNION ALL '.join(_subq_str) _img_score_q = ','.join(_img_score_subq) _img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score" q_str = f""" SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id, groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, groupArray(pred_logit) AS logit, groupArray(l) as label, {_img_score_q} FROM ({_subq_str}) GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC """ xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()] return xc def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08): xq_s = [ f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] image_list = ','.join([f'\'{i}\'' for i in img_ids]) _thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else "" _subq_str = [] _img_score_subq = [] for _l, _xq in enumerate(xq_s): _img_score_subq.append( f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))") _subq_str.append(f""" SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, (1 / (1 + exp(-(arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))))) AS pred_logit, obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l FROM {OBJ_DB_NAME} JOIN {IMG_DB_NAME} ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id PREWHERE img_id IN ({image_list}) {_thresh} """) _subq_str = ' UNION ALL '.join(_subq_str) _img_score_q = ','.join(_img_score_subq) _img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score" q_str = f""" SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h, groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, groupArray(pred_logit) AS logit, groupArray(l) as label, {_img_score_q} FROM ({_subq_str}) GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC """ xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()] return xc def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10): xq_s = [ f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq] res = [] subq_str = [] _thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else "" for _l, _xq in enumerate(xq_s): subq_str.append( f""" SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit, obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME} JOIN {IMG_DB_NAME} ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id {_thresh} LIMIT 10 """) subq_str = " UNION ALL ".join(subq_str) q_str = f""" SELECT groupArray(img_url) AS img_url, groupArray(img_w) AS img_w, groupArray(img_h) AS img_h, groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h, l AS label, groupArray(dist) as d, groupArray(1 / (1 + exp(-dist))) AS logit FROM ( {subq_str} ) GROUP BY l """ xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()] return xc