sql-compute-grad (#2)
Browse files- update compute gradient with sql (98667f6d3f61d4fe705a49978307db02586eeb86)
- app.py +12 -6
- box_utils.py +61 -32
- card_model.py +2 -1
- classifier.py +46 -37
- query_model.py +2 -2
app.py
CHANGED
|
@@ -192,7 +192,7 @@ def submit(meta):
|
|
| 192 |
zip(
|
| 193 |
*(
|
| 194 |
(
|
| 195 |
-
v[
|
| 196 |
st.session_state.text_prompts.index(st.session_state[f"label-{i}"]),
|
| 197 |
)
|
| 198 |
for i, v in matches.items()
|
|
@@ -329,7 +329,7 @@ try:
|
|
| 329 |
matches = st.session_state.matches
|
| 330 |
# initialize classifier
|
| 331 |
if "clf" not in st.session_state:
|
| 332 |
-
st.session_state.clf = Classifier(st.session_state.xq)
|
| 333 |
st.session_state.step = 0
|
| 334 |
if qtime > 0:
|
| 335 |
st.info(
|
|
@@ -344,11 +344,13 @@ try:
|
|
| 344 |
),
|
| 345 |
)
|
| 346 |
)
|
|
|
|
|
|
|
| 347 |
|
| 348 |
# export the model into executable ONNX
|
| 349 |
st.session_state.dnld_model = BytesIO()
|
| 350 |
torch.onnx.export(
|
| 351 |
-
torch.nn.Sequential(
|
| 352 |
torch.zeros([1, len(st.session_state.xq[0])]),
|
| 353 |
st.session_state.dnld_model,
|
| 354 |
input_names=["input"],
|
|
@@ -370,7 +372,9 @@ try:
|
|
| 370 |
with st.expander("Top-K Images"):
|
| 371 |
with st.container():
|
| 372 |
boxes_w_img, _ = postprocess(
|
| 373 |
-
o_matches, st.session_state.text_prompts,
|
|
|
|
|
|
|
| 374 |
)
|
| 375 |
boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True)
|
| 376 |
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
|
|
@@ -428,7 +432,9 @@ try:
|
|
| 428 |
|
| 429 |
# Post processing boxes regarding to their score, intersection
|
| 430 |
boxes_w_img, meta = postprocess(
|
| 431 |
-
matches, st.session_state.text_prompts, img_matches
|
|
|
|
|
|
|
| 432 |
)
|
| 433 |
|
| 434 |
# Sort the result according to their relavancy
|
|
@@ -452,7 +458,7 @@ try:
|
|
| 452 |
img_row[0].write(card(*args), unsafe_allow_html=True)
|
| 453 |
# crop objects out of the original image
|
| 454 |
for b in boxes:
|
| 455 |
-
_id, cx, cy, w, h, label, logit, is_selected
|
| 456 |
with img_row[1 + ind_b % 3].container():
|
| 457 |
st.write("{:s}: {:.4f}".format(label, logit))
|
| 458 |
# quite hacky: with streamlit components API
|
|
|
|
| 192 |
zip(
|
| 193 |
*(
|
| 194 |
(
|
| 195 |
+
v[0],
|
| 196 |
st.session_state.text_prompts.index(st.session_state[f"label-{i}"]),
|
| 197 |
)
|
| 198 |
for i, v in matches.items()
|
|
|
|
| 329 |
matches = st.session_state.matches
|
| 330 |
# initialize classifier
|
| 331 |
if "clf" not in st.session_state:
|
| 332 |
+
st.session_state.clf = Classifier(st.session_state.index, OBJ_DB_NAME, st.session_state.xq)
|
| 333 |
st.session_state.step = 0
|
| 334 |
if qtime > 0:
|
| 335 |
st.info(
|
|
|
|
| 344 |
),
|
| 345 |
)
|
| 346 |
)
|
| 347 |
+
lnprob = torch.nn.Linear(st.session_state.xq.shape[1], st.session_state.xq.shape[0], bias=False)
|
| 348 |
+
lnprob.weight = torch.nn.Parameter(st.session_state.clf.weight)
|
| 349 |
|
| 350 |
# export the model into executable ONNX
|
| 351 |
st.session_state.dnld_model = BytesIO()
|
| 352 |
torch.onnx.export(
|
| 353 |
+
torch.nn.Sequential(lnprob, SplitLayer()),
|
| 354 |
torch.zeros([1, len(st.session_state.xq[0])]),
|
| 355 |
st.session_state.dnld_model,
|
| 356 |
input_names=["input"],
|
|
|
|
| 372 |
with st.expander("Top-K Images"):
|
| 373 |
with st.container():
|
| 374 |
boxes_w_img, _ = postprocess(
|
| 375 |
+
o_matches, st.session_state.text_prompts, o_matches,
|
| 376 |
+
agnostic_ratio=1-0.6**(st.session_state.step+1),
|
| 377 |
+
class_ratio=1-0.2**(st.session_state.step+1)
|
| 378 |
)
|
| 379 |
boxes_w_img = sorted(boxes_w_img, key=lambda x: x[4], reverse=True)
|
| 380 |
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
|
|
|
|
| 432 |
|
| 433 |
# Post processing boxes regarding to their score, intersection
|
| 434 |
boxes_w_img, meta = postprocess(
|
| 435 |
+
matches, st.session_state.text_prompts, img_matches,
|
| 436 |
+
agnostic_ratio=1-0.6**(st.session_state.step+1),
|
| 437 |
+
class_ratio=1-0.2**(st.session_state.step+1)
|
| 438 |
)
|
| 439 |
|
| 440 |
# Sort the result according to their relavancy
|
|
|
|
| 458 |
img_row[0].write(card(*args), unsafe_allow_html=True)
|
| 459 |
# crop objects out of the original image
|
| 460 |
for b in boxes:
|
| 461 |
+
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
|
| 462 |
with img_row[1 + ind_b % 3].container():
|
| 463 |
st.write("{:s}: {:.4f}".format(label, logit))
|
| 464 |
# quite hacky: with streamlit components API
|
box_utils.py
CHANGED
|
@@ -2,16 +2,14 @@ import numpy as np
|
|
| 2 |
|
| 3 |
|
| 4 |
def cxywh2xywh(cx, cy, w, h):
|
| 5 |
-
"""
|
| 6 |
-
"""
|
| 7 |
x = cx - w / 2
|
| 8 |
y = cy - h / 2
|
| 9 |
return x, y, w, h
|
| 10 |
|
| 11 |
|
| 12 |
def cxywh2ltrb(cx, cy, w, h):
|
| 13 |
-
"""CxCyWH format to LeftRightTopBottom format
|
| 14 |
-
"""
|
| 15 |
l = cx - w / 2
|
| 16 |
t = cy - h / 2
|
| 17 |
r = cx + w / 2
|
|
@@ -61,9 +59,16 @@ def nms(cx, cy, w, h, s, iou_thresh=0.3):
|
|
| 61 |
i = sort_ind[0]
|
| 62 |
res.append(i)
|
| 63 |
|
| 64 |
-
_iou = iou(
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
sel_ind = np.where(_iou <= iou_thresh)[0]
|
| 68 |
sort_ind = sort_ind[sel_ind + 1]
|
| 69 |
return res
|
|
@@ -77,43 +82,64 @@ def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7):
|
|
| 77 |
"""
|
| 78 |
ret = []
|
| 79 |
labelwise = {}
|
| 80 |
-
for
|
|
|
|
| 81 |
if label not in labelwise:
|
| 82 |
labelwise[label] = []
|
| 83 |
labelwise[label].append(logit)
|
| 84 |
labelwise = {l: max(s) for l, s in labelwise.items()}
|
| 85 |
agnostic = max([v for _, v in labelwise.items()])
|
| 86 |
for b in boxes:
|
| 87 |
-
_id, cx, cy, w, h, label, logit, is_selected
|
| 88 |
-
if logit > class_ratio * labelwise[label]
|
| 89 |
-
and logit > agnostic_ratio * agnostic:
|
| 90 |
ret.append(b)
|
| 91 |
return ret
|
| 92 |
|
| 93 |
|
| 94 |
-
def postprocess(matches, prompt_labels, img_matches=None):
|
| 95 |
meta = []
|
| 96 |
boxes_w_img = []
|
| 97 |
-
matches_ = {m[
|
| 98 |
if img_matches is not None:
|
| 99 |
-
img_matches_ = {m[
|
| 100 |
for k in matches_.keys():
|
| 101 |
m = matches_[k]
|
| 102 |
boxes = []
|
| 103 |
-
boxes += list(
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
if img_matches is not None and k in img_matches_:
|
| 110 |
img_m = img_matches_[k]
|
| 111 |
# and also those non-TopK hits and those non-topk are not anticipating training
|
| 112 |
-
boxes += [
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
else:
|
| 118 |
img_m = None
|
| 119 |
# update record metadata after query
|
|
@@ -121,16 +147,19 @@ def postprocess(matches, prompt_labels, img_matches=None):
|
|
| 121 |
meta.append(b[0])
|
| 122 |
|
| 123 |
# remove some non-significant boxes
|
| 124 |
-
boxes = filter_nonpos(
|
| 125 |
-
boxes, agnostic_ratio=0.4, class_ratio=0.7)
|
| 126 |
|
| 127 |
# doing non-maximum suppression
|
| 128 |
-
cx, cy, w, h, s = list(
|
| 129 |
-
|
|
|
|
| 130 |
ind = nms(cx, cy, w, h, s, 0.3)
|
| 131 |
boxes = [boxes[i] for i in ind]
|
| 132 |
if img_m is not None:
|
| 133 |
-
img_score =
|
|
|
|
|
|
|
| 134 |
boxes_w_img.append(
|
| 135 |
-
(m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes)
|
| 136 |
-
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def cxywh2xywh(cx, cy, w, h):
|
| 5 |
+
"""CxCyWH format to XYWH format conversion"""
|
|
|
|
| 6 |
x = cx - w / 2
|
| 7 |
y = cy - h / 2
|
| 8 |
return x, y, w, h
|
| 9 |
|
| 10 |
|
| 11 |
def cxywh2ltrb(cx, cy, w, h):
|
| 12 |
+
"""CxCyWH format to LeftRightTopBottom format"""
|
|
|
|
| 13 |
l = cx - w / 2
|
| 14 |
t = cy - h / 2
|
| 15 |
r = cx + w / 2
|
|
|
|
| 59 |
i = sort_ind[0]
|
| 60 |
res.append(i)
|
| 61 |
|
| 62 |
+
_iou = iou(
|
| 63 |
+
(l[i], t[i], r[i], b[i], areas[i]),
|
| 64 |
+
(
|
| 65 |
+
l[sort_ind[1:]],
|
| 66 |
+
t[sort_ind[1:]],
|
| 67 |
+
r[sort_ind[1:]],
|
| 68 |
+
b[sort_ind[1:]],
|
| 69 |
+
areas[sort_ind[1:]],
|
| 70 |
+
),
|
| 71 |
+
)
|
| 72 |
sel_ind = np.where(_iou <= iou_thresh)[0]
|
| 73 |
sort_ind = sort_ind[sel_ind + 1]
|
| 74 |
return res
|
|
|
|
| 82 |
"""
|
| 83 |
ret = []
|
| 84 |
labelwise = {}
|
| 85 |
+
for b in boxes:
|
| 86 |
+
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
|
| 87 |
if label not in labelwise:
|
| 88 |
labelwise[label] = []
|
| 89 |
labelwise[label].append(logit)
|
| 90 |
labelwise = {l: max(s) for l, s in labelwise.items()}
|
| 91 |
agnostic = max([v for _, v in labelwise.items()])
|
| 92 |
for b in boxes:
|
| 93 |
+
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
|
| 94 |
+
if logit > class_ratio * labelwise[label] and logit > agnostic_ratio * agnostic:
|
|
|
|
| 95 |
ret.append(b)
|
| 96 |
return ret
|
| 97 |
|
| 98 |
|
| 99 |
+
def postprocess(matches, prompt_labels, img_matches=None, agnostic_ratio=0.4, class_ratio=0.7):
|
| 100 |
meta = []
|
| 101 |
boxes_w_img = []
|
| 102 |
+
matches_ = {m["img_id"]: m for m in matches}
|
| 103 |
if img_matches is not None:
|
| 104 |
+
img_matches_ = {m["img_id"]: m for m in img_matches}
|
| 105 |
for k in matches_.keys():
|
| 106 |
m = matches_[k]
|
| 107 |
boxes = []
|
| 108 |
+
boxes += list(
|
| 109 |
+
map(
|
| 110 |
+
list,
|
| 111 |
+
zip(
|
| 112 |
+
m["box_id"],
|
| 113 |
+
m["cx"],
|
| 114 |
+
m["cy"],
|
| 115 |
+
m["w"],
|
| 116 |
+
m["h"],
|
| 117 |
+
[prompt_labels[int(l)] for l in m["label"]],
|
| 118 |
+
m["logit"],
|
| 119 |
+
[1] * len(m["box_id"]),
|
| 120 |
+
),
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
if img_matches is not None and k in img_matches_:
|
| 124 |
img_m = img_matches_[k]
|
| 125 |
# and also those non-TopK hits and those non-topk are not anticipating training
|
| 126 |
+
boxes += [
|
| 127 |
+
i
|
| 128 |
+
for i in map(
|
| 129 |
+
list,
|
| 130 |
+
zip(
|
| 131 |
+
img_m["box_id"],
|
| 132 |
+
img_m["cx"],
|
| 133 |
+
img_m["cy"],
|
| 134 |
+
img_m["w"],
|
| 135 |
+
img_m["h"],
|
| 136 |
+
[prompt_labels[int(l)] for l in img_m["label"]],
|
| 137 |
+
img_m["logit"],
|
| 138 |
+
[0] * len(img_m["box_id"]),
|
| 139 |
+
),
|
| 140 |
+
)
|
| 141 |
+
if i[0] not in [b[0] for b in boxes]
|
| 142 |
+
]
|
| 143 |
else:
|
| 144 |
img_m = None
|
| 145 |
# update record metadata after query
|
|
|
|
| 147 |
meta.append(b[0])
|
| 148 |
|
| 149 |
# remove some non-significant boxes
|
| 150 |
+
boxes = filter_nonpos(boxes, agnostic_ratio=agnostic_ratio, class_ratio=class_ratio)
|
|
|
|
| 151 |
|
| 152 |
# doing non-maximum suppression
|
| 153 |
+
cx, cy, w, h, s = list(
|
| 154 |
+
map(lambda x: np.array(x), list(zip(*[(*b[1:5], b[6]) for b in boxes])))
|
| 155 |
+
)
|
| 156 |
ind = nms(cx, cy, w, h, s, 0.3)
|
| 157 |
boxes = [boxes[i] for i in ind]
|
| 158 |
if img_m is not None:
|
| 159 |
+
img_score = (
|
| 160 |
+
img_m["img_score"] if img_matches is not None else m["img_score"]
|
| 161 |
+
)
|
| 162 |
boxes_w_img.append(
|
| 163 |
+
(m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes)
|
| 164 |
+
)
|
| 165 |
+
return boxes_w_img, meta
|
card_model.py
CHANGED
|
@@ -47,7 +47,8 @@ def card(img_url, img_w, img_h, boxes):
|
|
| 47 |
"""
|
| 48 |
_boxes = ""
|
| 49 |
img_url = convert_img_url(img_url)
|
| 50 |
-
for
|
|
|
|
| 51 |
x, y, w, h = cxywh2xywh(cx, cy, w, h)
|
| 52 |
x = round(img_w * x)
|
| 53 |
y = round(img_h * y)
|
|
|
|
| 47 |
"""
|
| 48 |
_boxes = ""
|
| 49 |
img_url = convert_img_url(img_url)
|
| 50 |
+
for b in boxes:
|
| 51 |
+
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
|
| 52 |
x, y, w, h = cxywh2xywh(cx, cy, w, h)
|
| 53 |
x = round(img_w * x)
|
| 54 |
y = round(img_h * y)
|
classifier.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
|
| 4 |
-
def extract_text_feature(prompt, model, processor, device=
|
| 5 |
"""Extract text features
|
| 6 |
|
| 7 |
Args:
|
|
@@ -10,12 +10,11 @@ def extract_text_feature(prompt, model, processor, device='cpu'):
|
|
| 10 |
processor: OwlViT processor
|
| 11 |
device (str, optional): device to run. Defaults to 'cpu'.
|
| 12 |
"""
|
| 13 |
-
device =
|
| 14 |
if torch.cuda.is_available():
|
| 15 |
-
device =
|
| 16 |
with torch.no_grad():
|
| 17 |
-
input_ids = torch.as_tensor(processor(text=prompt)[
|
| 18 |
-
'input_ids']).to(device)
|
| 19 |
print(input_ids.device)
|
| 20 |
text_outputs = model.owlvit.text_model(
|
| 21 |
input_ids=input_ids,
|
|
@@ -32,7 +31,7 @@ def extract_text_feature(prompt, model, processor, device='cpu'):
|
|
| 32 |
|
| 33 |
|
| 34 |
def prompt2vec(prompt: str, model, processor):
|
| 35 |
-
"""
|
| 36 |
|
| 37 |
Args:
|
| 38 |
prompt (str): Text to be tokenized
|
|
@@ -49,7 +48,7 @@ def prompt2vec(prompt: str, model, processor):
|
|
| 49 |
|
| 50 |
|
| 51 |
def tune(clf, X, y, iters=2):
|
| 52 |
-
"""
|
| 53 |
|
| 54 |
Args:
|
| 55 |
X (numpy.ndarray): Input vectors (retreived vectors)
|
|
@@ -62,60 +61,70 @@ def tune(clf, X, y, iters=2):
|
|
| 62 |
# extract new vector
|
| 63 |
return clf.get_weights()
|
| 64 |
|
| 65 |
-
|
| 66 |
class Classifier:
|
| 67 |
"""Multi-Class Zero-shot Classifier
|
| 68 |
This Classifier provides proxy regarding to the user's reaction to the probed images.
|
| 69 |
The proxy will replace the original query vector generated by prompted vector and finally
|
| 70 |
give the user a satisfying retrieval result.
|
| 71 |
|
| 72 |
-
This can be commonly seen in a recommendation system. The classifier will recommend more
|
| 73 |
precise result as it accumulating user's activity.
|
| 74 |
-
|
| 75 |
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
|
| 76 |
and the last one takes the negative one.
|
| 77 |
"""
|
| 78 |
|
| 79 |
-
def __init__(self, xq: list):
|
| 80 |
init_weight = torch.Tensor(xq)
|
| 81 |
self.num_class = xq.shape[0]
|
| 82 |
-
DIMS = xq.shape[1]
|
| 83 |
-
# note that the bias is ignored, as we only focus on the inner product result
|
| 84 |
-
self.model = torch.nn.Linear(DIMS, self.num_class, bias=False)
|
| 85 |
# convert initial query `xq` to tensor parameter to init weights
|
| 86 |
-
self.
|
| 87 |
-
|
| 88 |
-
self.
|
| 89 |
-
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
|
| 90 |
|
| 91 |
def fit(self, X: list, y: list, iters: int = 5):
|
| 92 |
# convert X and y to tensor
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
y[non_ind] = 0
|
| 100 |
-
for i in range(iters):
|
| 101 |
# zero gradients
|
| 102 |
-
|
| 103 |
# Normalize the weight before inference
|
| 104 |
# This will constrain the gradient or you will have an explosion on query vector
|
| 105 |
-
self.
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
# update weights
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
def get_weights(self):
|
| 116 |
-
xq = self.
|
| 117 |
return xq
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
class SplitLayer(torch.nn.Module):
|
| 120 |
def forward(self, x):
|
| 121 |
return torch.split(x, 1, dim=-1)
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
|
| 4 |
+
def extract_text_feature(prompt, model, processor, device="cpu"):
|
| 5 |
"""Extract text features
|
| 6 |
|
| 7 |
Args:
|
|
|
|
| 10 |
processor: OwlViT processor
|
| 11 |
device (str, optional): device to run. Defaults to 'cpu'.
|
| 12 |
"""
|
| 13 |
+
device = "cpu"
|
| 14 |
if torch.cuda.is_available():
|
| 15 |
+
device = "cuda"
|
| 16 |
with torch.no_grad():
|
| 17 |
+
input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device)
|
|
|
|
| 18 |
print(input_ids.device)
|
| 19 |
text_outputs = model.owlvit.text_model(
|
| 20 |
input_ids=input_ids,
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def prompt2vec(prompt: str, model, processor):
|
| 34 |
+
"""Convert prompt into a computational vector
|
| 35 |
|
| 36 |
Args:
|
| 37 |
prompt (str): Text to be tokenized
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def tune(clf, X, y, iters=2):
|
| 51 |
+
"""Train the Zero-shot Classifier
|
| 52 |
|
| 53 |
Args:
|
| 54 |
X (numpy.ndarray): Input vectors (retreived vectors)
|
|
|
|
| 61 |
# extract new vector
|
| 62 |
return clf.get_weights()
|
| 63 |
|
|
|
|
| 64 |
class Classifier:
|
| 65 |
"""Multi-Class Zero-shot Classifier
|
| 66 |
This Classifier provides proxy regarding to the user's reaction to the probed images.
|
| 67 |
The proxy will replace the original query vector generated by prompted vector and finally
|
| 68 |
give the user a satisfying retrieval result.
|
| 69 |
|
| 70 |
+
This can be commonly seen in a recommendation system. The classifier will recommend more
|
| 71 |
precise result as it accumulating user's activity.
|
| 72 |
+
|
| 73 |
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
|
| 74 |
and the last one takes the negative one.
|
| 75 |
"""
|
| 76 |
|
| 77 |
+
def __init__(self, client, obj_db:str, xq: list):
|
| 78 |
init_weight = torch.Tensor(xq)
|
| 79 |
self.num_class = xq.shape[0]
|
| 80 |
+
self.DIMS = xq.shape[1]
|
|
|
|
|
|
|
| 81 |
# convert initial query `xq` to tensor parameter to init weights
|
| 82 |
+
self.weight = init_weight
|
| 83 |
+
self.client = client
|
| 84 |
+
self.obj_db = obj_db
|
|
|
|
| 85 |
|
| 86 |
def fit(self, X: list, y: list, iters: int = 5):
|
| 87 |
# convert X and y to tensor
|
| 88 |
+
xq_s = [
|
| 89 |
+
f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]"
|
| 90 |
+
for _xq in self.get_weights().tolist()
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
for _ in range(iters):
|
|
|
|
|
|
|
| 94 |
# zero gradients
|
| 95 |
+
grad = []
|
| 96 |
# Normalize the weight before inference
|
| 97 |
# This will constrain the gradient or you will have an explosion on query vector
|
| 98 |
+
self.weight.data /= torch.norm(
|
| 99 |
+
self.weight.data, p=2, dim=-1, keepdim=True
|
| 100 |
+
)
|
| 101 |
+
for n in range(self.num_class):
|
| 102 |
+
# select all training sample and create labels
|
| 103 |
+
labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]])))
|
| 104 |
+
|
| 105 |
+
# NOTE from @fangruil
|
| 106 |
+
# Use SQL to calculate the gradient
|
| 107 |
+
# For binary cross entropy we have
|
| 108 |
+
# g = (1/(1+\exp(-XW))-Y)^TX
|
| 109 |
+
# To simplify the query, we separated
|
| 110 |
+
# the calculation into class numbers
|
| 111 |
+
grad_q_str = f"""
|
| 112 |
+
SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
|
| 113 |
+
FROM (
|
| 114 |
+
SELECT groupArray(arrayPopBack(prelogit)) AS X,
|
| 115 |
+
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
|
| 116 |
+
FROM {self.obj_db} WHERE obj_id IN {objs})"""
|
| 117 |
+
grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad']))
|
| 118 |
# update weights
|
| 119 |
+
grad = torch.stack(grad, dim=0)
|
| 120 |
+
self.weight -= 0.1 * grad
|
| 121 |
|
| 122 |
def get_weights(self):
|
| 123 |
+
xq = self.weight.detach().numpy()
|
| 124 |
return xq
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
class SplitLayer(torch.nn.Module):
|
| 129 |
def forward(self, x):
|
| 130 |
return torch.split(x, 1, dim=-1)
|
query_model.py
CHANGED
|
@@ -32,7 +32,7 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
|
|
| 32 |
q_str = f"""
|
| 33 |
SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
|
| 34 |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
| 35 |
-
groupArray(pred_logit) AS logit, groupArray(l) as label,
|
| 36 |
{_img_score_q}
|
| 37 |
FROM
|
| 38 |
({_subq_str})
|
|
@@ -68,7 +68,7 @@ def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
|
|
| 68 |
q_str = f"""
|
| 69 |
SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
|
| 70 |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
| 71 |
-
groupArray(pred_logit) AS logit, groupArray(l) as label,
|
| 72 |
{_img_score_q}
|
| 73 |
FROM
|
| 74 |
({_subq_str})
|
|
|
|
| 32 |
q_str = f"""
|
| 33 |
SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
|
| 34 |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
| 35 |
+
groupArray(pred_logit) AS logit, groupArray(l) as label,
|
| 36 |
{_img_score_q}
|
| 37 |
FROM
|
| 38 |
({_subq_str})
|
|
|
|
| 68 |
q_str = f"""
|
| 69 |
SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
|
| 70 |
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
| 71 |
+
groupArray(pred_logit) AS logit, groupArray(l) as label,
|
| 72 |
{_img_score_q}
|
| 73 |
FROM
|
| 74 |
({_subq_str})
|