relation-query

#3
by mpsk - opened
Files changed (5) hide show
  1. .gitignore +0 -2
  2. app.py +7 -8
  3. classifier.py +6 -9
  4. query_model.py +10 -10
  5. requirements.txt +1 -3
.gitignore DELETED
@@ -1,2 +0,0 @@
1
- .streamlit/
2
- __pycache__
 
 
 
app.py CHANGED
@@ -9,8 +9,7 @@ import logging
9
  from os import environ
10
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
11
  from bot import Bot, Message
12
- from parse import parse
13
- from clickhouse_connect import get_client
14
  from classifier import Classifier, prompt2vec, tune, SplitLayer
15
  from query_model import simple_query, topk_obj_query, rev_query
16
  from card_model import card, obj_card, style
@@ -63,11 +62,11 @@ def init_db():
63
  client: Database connection object
64
  """
65
  meta = []
66
- r = parse("{http_pre}://{host}:{port}", st.secrets["DB_URL"])
67
- client = get_client(
68
- host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"],
69
- interface=r['http_pre'],
70
  )
 
 
71
  return meta, client
72
 
73
 
@@ -118,7 +117,7 @@ def query(xq, exclude_list=None):
118
  IMG_DB_NAME,
119
  OBJ_DB_NAME,
120
  exclude_list=exclude_list,
121
- topk=10,
122
  )
123
  img_ids = [r["img_id"] for r in matches]
124
  if "topk_img_id" not in st.session_state:
@@ -141,7 +140,7 @@ def query(xq, exclude_list=None):
141
  IMG_DB_NAME,
142
  OBJ_DB_NAME,
143
  thresh=-1,
144
- topk=10,
145
  )
146
  status_bar[0].write("Retrieving Non-TopK in Another TopK Images...")
147
  pbar.progress(75)
 
9
  from os import environ
10
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
11
  from bot import Bot, Message
12
+ from myscaledb import Client
 
13
  from classifier import Classifier, prompt2vec, tune, SplitLayer
14
  from query_model import simple_query, topk_obj_query, rev_query
15
  from card_model import card, obj_card, style
 
62
  client: Database connection object
63
  """
64
  meta = []
65
+ client = Client(
66
+ url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]
 
 
67
  )
68
+ # We can check if the connection is alive
69
+ assert client.is_alive()
70
  return meta, client
71
 
72
 
 
117
  IMG_DB_NAME,
118
  OBJ_DB_NAME,
119
  exclude_list=exclude_list,
120
+ topk=5000,
121
  )
122
  img_ids = [r["img_id"] for r in matches]
123
  if "topk_img_id" not in st.session_state:
 
140
  IMG_DB_NAME,
141
  OBJ_DB_NAME,
142
  thresh=-1,
143
+ topk=5000,
144
  )
145
  status_bar[0].write("Retrieving Non-TopK in Another TopK Images...")
146
  pbar.progress(75)
classifier.py CHANGED
@@ -95,8 +95,8 @@ class Classifier:
95
  grad = []
96
  # Normalize the weight before inference
97
  # This will constrain the gradient or you will have an explosion on query vector
98
- self.weight /= torch.norm(
99
- self.weight, p=2, dim=-1, keepdim=True
100
  )
101
  for n in range(self.num_class):
102
  # select all training sample and create labels
@@ -109,25 +109,22 @@ class Classifier:
109
  # To simplify the query, we separated
110
  # the calculation into class numbers
111
  grad_q_str = f"""
112
- SELECT avgForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
113
  FROM (
114
  SELECT groupArray(arrayPopBack(prelogit)) AS X,
115
  groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
116
  FROM {self.obj_db} WHERE obj_id IN {objs})"""
117
- grad_ = [r['grad'] for r in self.client.query(grad_q_str).named_results()][0]
118
- grad.append(torch.as_tensor(grad_))
119
  # update weights
120
  grad = torch.stack(grad, dim=0)
121
- self.weight -= 0.01 * grad
122
- self.weight /= torch.norm(
123
- self.weight, p=2, dim=-1, keepdim=True
124
- )
125
 
126
  def get_weights(self):
127
  xq = self.weight.detach().numpy()
128
  return xq
129
 
130
 
 
131
  class SplitLayer(torch.nn.Module):
132
  def forward(self, x):
133
  return torch.split(x, 1, dim=-1)
 
95
  grad = []
96
  # Normalize the weight before inference
97
  # This will constrain the gradient or you will have an explosion on query vector
98
+ self.weight.data /= torch.norm(
99
+ self.weight.data, p=2, dim=-1, keepdim=True
100
  )
101
  for n in range(self.num_class):
102
  # select all training sample and create labels
 
109
  # To simplify the query, we separated
110
  # the calculation into class numbers
111
  grad_q_str = f"""
112
+ SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad
113
  FROM (
114
  SELECT groupArray(arrayPopBack(prelogit)) AS X,
115
  groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT
116
  FROM {self.obj_db} WHERE obj_id IN {objs})"""
117
+ grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad']))
 
118
  # update weights
119
  grad = torch.stack(grad, dim=0)
120
+ self.weight -= 0.1 * grad
 
 
 
121
 
122
  def get_weights(self):
123
  xq = self.weight.detach().numpy()
124
  return xq
125
 
126
 
127
+
128
  class SplitLayer(torch.nn.Module):
129
  def forward(self, x):
130
  return torch.split(x, 1, dim=-1)
query_model.py CHANGED
@@ -19,11 +19,11 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
19
  FROM {OBJ_DB_NAME}
20
  JOIN {IMG_DB_NAME}
21
  ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
22
- WHERE obj_id IN (
23
  SELECT obj_id FROM (
24
- SELECT obj_id, distance(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME}
25
- ORDER BY dist DESC LIMIT 5000
26
- ) {_cond} LIMIT {topk}
27
  )
28
  """)
29
  _subq_str = ' UNION ALL '.join(_subq_str)
@@ -38,7 +38,7 @@ def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
38
  ({_subq_str})
39
  GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
40
  """
41
- xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
42
  return xc
43
 
44
 
@@ -74,7 +74,7 @@ def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
74
  ({_subq_str})
75
  GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
76
  """
77
- xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
78
  return xc
79
 
80
 
@@ -88,11 +88,11 @@ def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
88
  subq_str.append(
89
  f"""
90
  SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit,
91
- obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance(prelogit, {_xq}) AS dist
92
  FROM {OBJ_DB_NAME}
93
  JOIN {IMG_DB_NAME}
94
  ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
95
- {_thresh} ORDER BY dist DESC LIMIT {topk}
96
  """)
97
  subq_str = " UNION ALL ".join(subq_str)
98
  q_str = f"""
@@ -104,5 +104,5 @@ def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
104
  )
105
  GROUP BY l
106
  """
107
- xc = [{k: v for k, v in r.items()} for r in client.query(q_str).named_results()]
108
- return xc
 
19
  FROM {OBJ_DB_NAME}
20
  JOIN {IMG_DB_NAME}
21
  ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
22
+ PREWHERE obj_id IN (
23
  SELECT obj_id FROM (
24
+ SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME}
25
+ ORDER BY dist DESC
26
+ ) {_cond} LIMIT 10
27
  )
28
  """)
29
  _subq_str = ' UNION ALL '.join(_subq_str)
 
38
  ({_subq_str})
39
  GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
40
  """
41
+ xc = client.fetch(q_str)
42
  return xc
43
 
44
 
 
74
  ({_subq_str})
75
  GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
76
  """
77
+ xc = client.fetch(q_str)
78
  return xc
79
 
80
 
 
88
  subq_str.append(
89
  f"""
90
  SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit,
91
+ obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist
92
  FROM {OBJ_DB_NAME}
93
  JOIN {IMG_DB_NAME}
94
  ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
95
+ {_thresh} LIMIT 10
96
  """)
97
  subq_str = " UNION ALL ".join(subq_str)
98
  q_str = f"""
 
104
  )
105
  GROUP BY l
106
  """
107
+ res = client.fetch(q_str)
108
+ return res
requirements.txt CHANGED
@@ -1,9 +1,7 @@
1
  transformers
2
  tqdm
3
- clickhouse-connect
4
- parse
5
  streamlit
6
- altair < 5
7
  numpy
8
  torch
9
  onnx
 
1
  transformers
2
  tqdm
3
+ myscaledb-client==1.1.7
 
4
  streamlit
 
5
  numpy
6
  torch
7
  onnx