Fangrui Liu commited on
Commit
5ebcc54
β€’
1 Parent(s): eb05b74
Files changed (1) hide show
  1. app.py +398 -0
app.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from turtle import onclick
3
+ import streamlit as st
4
+ import numpy as np
5
+ import base64
6
+ from io import BytesIO
7
+ from multilingual_clip import pt_multilingual_clip
8
+ from transformers import CLIPTokenizerFast, AutoTokenizer
9
+ import torch
10
+ import logging
11
+ from os import environ
12
+ environ['TOKENIZERS_PARALLELISM'] = 'true'
13
+
14
+ from myscaledb import Client
15
+
16
+ DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
17
+ MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
18
+ DIMS = 512
19
+ # Ignore some bad links (broken in the dataset already)
20
+ BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8', 'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
21
+
22
+ @st.experimental_singleton(show_spinner=False)
23
+ def init_clip():
24
+ """ Initialize CLIP Model
25
+
26
+ Returns:
27
+ Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
28
+ """
29
+ clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID)
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
31
+ return tokenizer, clip
32
+
33
+ @st.experimental_singleton(show_spinner=False)
34
+ def init_db():
35
+ """ Initialize the Database Connection
36
+
37
+ Returns:
38
+ meta_field: Meta field that records if an image is viewed or not
39
+ client: Database connection object
40
+ """
41
+ client = Client(url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
42
+ # We can check if the connection is alive
43
+ assert client.is_alive()
44
+ meta_field = {}
45
+ return meta_field, client
46
+
47
+ @st.experimental_singleton(show_spinner=False)
48
+ def init_query_num():
49
+ print("init query_num")
50
+ return 0
51
+
52
+ def query(xq, top_k=10):
53
+ """ Query TopK matched w.r.t a given vector
54
+
55
+ Args:
56
+ xq (numpy.ndarray or list of floats): Query vector
57
+ top_k (int, optional): Number of matched vectors. Defaults to 10.
58
+
59
+ Returns:
60
+ matches: list of Records object. Keys referrring to selected columns
61
+ """
62
+ attempt = 0
63
+ xq = xq / np.linalg.norm(xq)
64
+ while attempt < 3:
65
+ try:
66
+ xq_s = f"[{', '.join([str(float(fnum)) for fnum in list(xq)])}]"
67
+
68
+ print('Excluded pre:', st.session_state.meta)
69
+ if len(st.session_state.meta) > 0:
70
+ exclude_list = ','.join([f'\'{i}\'' for i, v in st.session_state.meta.items() if v >= 1])
71
+ print("Excluded:", exclude_list)
72
+ # Using PREWHERE allows you to do column filter before vector search
73
+ xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
74
+ distance('topK={top_k}')(vector, {xq_s}) AS dist\
75
+ FROM {DB_NAME} PREWHERE id NOT IN ({exclude_list})")
76
+ else:
77
+ xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
78
+ distance('topK={top_k}')(vector, {xq_s}) AS dist\
79
+ FROM {DB_NAME}")
80
+ # real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
81
+ # 1 - arraySum(arrayMap((x, y) -> x * y, {xq_s}, vector)) AS dist\
82
+ # FROM {DB_NAME} ORDER BY dist LIMIT {top_k}")
83
+ # FIXME: This is causing freezing on DB
84
+ real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
85
+ distance('topK={top_k}')(vector, {xq_s}) AS dist\
86
+ FROM {DB_NAME}")
87
+ top_k = real_xc
88
+ xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or \
89
+ st.session_state.meta[xi['id']] < 1]
90
+ logging.info(f'{len(xc)} records returned, {[_i["id"] for _i in xc]}')
91
+ matches = xc
92
+ break
93
+ except Exception as e:
94
+ # force reload if we have trouble on connections or something else
95
+ logging.warning(str(e))
96
+ _, st.session_state.index = init_db()
97
+ attempt += 1
98
+ matches = []
99
+ if len(matches) == 0:
100
+ logging.error(f"No matches found for '{DB_NAME}'")
101
+ return matches, top_k
102
+
103
+ @st.experimental_singleton(show_spinner=False)
104
+ def init_random_query():
105
+ xq = np.random.rand(DIMS).tolist()
106
+ return xq, xq.copy()
107
+
108
+ class Classifier:
109
+ """ Zero-shot Classifier
110
+ This Classifier provides proxy regarding to the user's reaction to the probed images.
111
+ The proxy will replace the original query vector generated by prompted vector and finally
112
+ give the user a satisfying retrieval result.
113
+
114
+ This can be commonly seen in a recommendation system. The classifier will recommend more
115
+ precise result as it accumulating user's activity.
116
+ """
117
+ def __init__(self, xq: list):
118
+ # initialize model with DIMS input size and 1 output
119
+ # note that the bias is ignored, as we only focus on the inner product result
120
+ self.model = torch.nn.Linear(DIMS, 1, bias=False)
121
+ # convert initial query `xq` to tensor parameter to init weights
122
+ init_weight = torch.Tensor(xq).reshape(1, -1)
123
+ self.model.weight = torch.nn.Parameter(init_weight)
124
+ # init loss and optimizer
125
+ self.loss = torch.nn.BCEWithLogitsLoss()
126
+ self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
127
+
128
+ def fit(self, X: list, y: list, iters: int = 5):
129
+ # convert X and y to tensor
130
+ X = torch.Tensor(X)
131
+ y = torch.Tensor(y).reshape(-1, 1)
132
+ for i in range(iters):
133
+ # zero gradients
134
+ self.optimizer.zero_grad()
135
+ # Normalize the weight before inference
136
+ # This will constrain the gradient or you will have an explosion on query vector
137
+ self.model.weight.data = self.model.weight.data / torch.norm(self.model.weight.data, p=2, dim=-1)
138
+ # forward pass
139
+ out = self.model(X)
140
+ # compute loss
141
+ loss = self.loss(out, y)
142
+ # backward pass
143
+ loss.backward()
144
+ # update weights
145
+ self.optimizer.step()
146
+
147
+ def get_weights(self):
148
+ xq = self.model.weight.detach().numpy()[0].tolist()
149
+ return xq
150
+
151
+ def prompt2vec(prompt: str):
152
+ """ Convert prompt into a computational vector
153
+
154
+ Args:
155
+ prompt (str): Text to be tokenized
156
+
157
+ Returns:
158
+ xq: vector from the tokenizer, representing the original prompt
159
+ """
160
+ # inputs = tokenizer(prompt, return_tensors='pt')
161
+ # out = clip.get_text_features(**inputs)
162
+ out = clip.forward(prompt, tokenizer)
163
+ xq = out.squeeze(0).cpu().detach().numpy().tolist()
164
+ return xq
165
+
166
+ def pil_to_bytes(img):
167
+ """ Convert a Pillow image into base64
168
+
169
+ Args:
170
+ img (PIL.Image): Pillow (PIL) Image
171
+
172
+ Returns:
173
+ img_bin: image in base64 format
174
+ """
175
+ with BytesIO() as buf:
176
+ img.save(buf, format='jpeg')
177
+ img_bin = buf.getvalue()
178
+ img_bin = base64.b64encode(img_bin).decode('utf-8')
179
+ return img_bin
180
+
181
+ def card(i, url):
182
+ return f'<img id="img{i}" src="{url}" width="200px;">'
183
+
184
+ def card_with_conf(i, conf, url):
185
+ conf = "%.4f"%(conf)
186
+ return f'<img id="img{i}" src="{url}" width="200px;" style="margin:50px 50px"><b>Relevance: {conf}</b>'
187
+
188
+ def get_top_k(xq, top_k=9):
189
+ """ wrapper function for query
190
+
191
+ Args:
192
+ xq (numpy.ndarray or list of floats): Query vector
193
+ top_k (int, optional): Number of returned vectors. Defaults to 9.
194
+
195
+ Returns:
196
+ matches: See `query()`
197
+ """
198
+ matches = query(
199
+ xq, top_k=top_k
200
+ )
201
+ return matches
202
+
203
+ def tune(X, y, iters=2):
204
+ """ Train the Zero-shot Classifier
205
+
206
+ Args:
207
+ X (numpy.ndarray): Input vectors (retreived vectors)
208
+ y (list of floats or numpy.ndarray): Scores given by user
209
+ iters (int, optional): iterations of updates to be run
210
+ """
211
+ # train the classifier
212
+ st.session_state.clf.fit(X, y, iters=iters)
213
+ # extract new vector
214
+ st.session_state.xq = st.session_state.clf.get_weights()
215
+
216
+
217
+ def refresh_index():
218
+ """ Clean the session
219
+ """
220
+ del st.session_state["meta"]
221
+ st.session_state.meta = {}
222
+ st.session_state.query_num = 0
223
+ logging.info(f"Refresh for '{st.session_state.meta}'")
224
+ init_db.clear()
225
+ # refresh session states
226
+ st.session_state.meta, st.session_state.index = init_db()
227
+ del st.session_state.clf, st.session_state.xq
228
+
229
+ def calc_dist():
230
+ xq = np.array(st.session_state.xq)
231
+ orig_xq = np.array(st.session_state.orig_xq)
232
+ return np.linalg.norm(xq - orig_xq)
233
+
234
+ def submit():
235
+ """ Tune the model w.r.t given score from user.
236
+ """
237
+ st.session_state.query_num += 1
238
+ matches = st.session_state.matches
239
+ velocity = 1 #st.session_state.velocity
240
+ scores = {}
241
+ states = [
242
+ st.session_state[f"input{i}"] for i in range(len(matches))
243
+ ]
244
+ for i, match in enumerate(matches):
245
+ scores[match['id']] = float(states[i])
246
+ # reset states to 1.0
247
+ for i in range(len(matches)):
248
+ st.session_state[f"input{i}"] = 1.0
249
+ # get training data and labels
250
+ X = list([match['vector'] for match in matches])
251
+ y = [v for v in list(scores.values())]
252
+ tune(X, y, iters=int(st.session_state.iters))
253
+ # update record metadata after training
254
+ for match in matches:
255
+ st.session_state.meta[match['id']] = 1
256
+ logging.info(f"Exclude List: {st.session_state.meta}")
257
+
258
+ def delete_element(element):
259
+ del element
260
+
261
+ st.markdown("""
262
+ <link
263
+ rel="stylesheet"
264
+ href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap"
265
+ />
266
+ """, unsafe_allow_html=True)
267
+
268
+ messages = [
269
+ f"""
270
+ Find most relevant examples from a large visual dataset by combining text query and few-shot learning.
271
+ """,
272
+ f"""
273
+ Then then you can adjust the weight on each image. Those weights should **represent how much it
274
+ can meet your preference**. You can either choose the images that match your prompt or change
275
+ your mind.
276
+
277
+ You might notice that there is a iteration slide bar on the top of all retrieved images. This will
278
+ control the speed of changes on vectors. More **iterations** will change the vector faster while
279
+ lower values on **iterations** will make the retrieval smoother.
280
+ """,
281
+ f"""
282
+ This example will manage to train a classifier to distinguish between samples you want and samples
283
+ you don't want. By initializing the weight from prompt, you can get a good enough classifier to cluster
284
+ images you want to search. If you think the result is not as perfect as you expected, you can also
285
+ supervise the classifer with **Relevance** annotation. If you cannot see any difference in Top-K
286
+ Retrieved results, try to enlarge **Number of Iteration**
287
+ """,
288
+ # TODO @ fangruil: fill the link with our tech blog
289
+ f"""
290
+ The app uses the [MyScale](http://mqdb.page.moqi.ai/mqdb-docs/) to store and query images
291
+ using vector search. All images are sourced from the
292
+ [Unsplash Lite dataset](https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip)
293
+ and encoded using [OpenAI's CLIP](https://huggingface.co/openai/clip-vit-base-patch32). We explain how
294
+ it all works [here]().
295
+ """
296
+ ]
297
+
298
+ with st.spinner("Connecting DB..."):
299
+ st.session_state.meta, st.session_state.index = init_db()
300
+
301
+ with st.spinner("Loading Models..."):
302
+ # Initialize CLIP model
303
+ if 'xq' not in st.session_state:
304
+ tokenizer, clip = init_clip()
305
+ st.session_state.query_num = 0
306
+
307
+ if 'xq' not in st.session_state:
308
+ # If it's a fresh start
309
+ if st.session_state.query_num < len(messages):
310
+ msg = messages[st.session_state.query_num]
311
+ else:
312
+ msg = messages[-1]
313
+
314
+ # Basic Layout
315
+
316
+ with st.container():
317
+ st.title("Visual Dataset Explorer")
318
+ start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty()]
319
+ start[0].info(msg)
320
+ prompt = start[1].text_input("Prompt:", value="", placeholder="Examples: a photo of white dogs, cats in the snow, a house by the lake")
321
+ start[2].markdown(
322
+ '<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>',
323
+ unsafe_allow_html=True)
324
+ with start[3]:
325
+ col = st.columns(8)
326
+ prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
327
+ random_xq = col[7].button("Random", disabled=len(prompt) != 0)
328
+ if random_xq:
329
+ # Randomly pick a vector to query
330
+ xq, orig_xq = init_random_query()
331
+ st.session_state.xq = xq
332
+ st.session_state.orig_xq = orig_xq
333
+ _ = [elem.empty() for elem in start]
334
+ elif prompt_xq:
335
+ print(f"Input prompt is {prompt}")
336
+ # Tokenize the vectors
337
+ xq = prompt2vec(prompt)
338
+ st.session_state.xq = xq
339
+ st.session_state.orig_xq = xq
340
+ _ = [elem.empty() for elem in start]
341
+
342
+ if 'xq' in st.session_state:
343
+ # If it is not a fresh start
344
+ if st.session_state.query_num+1 < len(messages):
345
+ msg = messages[st.session_state.query_num+1]
346
+ else:
347
+ msg = messages[-1]
348
+ # initialize classifier
349
+ if 'clf' not in st.session_state:
350
+ st.session_state.clf = Classifier(st.session_state.xq)
351
+
352
+ # if we want to display images we end up here
353
+ st.info(msg)
354
+ # first retrieve images from pinecone
355
+ st.session_state.matches, st.session_state.top_k = get_top_k(st.session_state.clf.get_weights(), top_k=9)
356
+ with st.container():
357
+ with st.sidebar:
358
+ with st.container():
359
+ st.header("Top K Nearest in Database")
360
+ for i, k in enumerate(st.session_state.top_k):
361
+ url = k["url"]
362
+ url += "?q=75&fm=jpg&w=200&fit=max"
363
+ if k["id"] not in BAD_IDS:
364
+ disabled = False
365
+ else:
366
+ disable = True
367
+ dist = np.matmul(st.session_state.clf.get_weights() / np.linalg.norm(st.session_state.clf.get_weights()),
368
+ np.array(k["vector"]).T)
369
+ st.markdown(card_with_conf(i, dist, url), unsafe_allow_html=True)
370
+
371
+ # once retrieved, display them alongside checkboxes in a form
372
+ with st.form("batch", clear_on_submit=False):
373
+ st.session_state.iters = st.slider("Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2)
374
+ col = st.columns([1,9])
375
+ col[0].form_submit_button("Train!", on_click=submit)
376
+ col[1].form_submit_button("Choose a new prompt", on_click=refresh_index)
377
+ # we have three columns in the form
378
+ cols = st.columns(3)
379
+ for i, match in enumerate(st.session_state.matches):
380
+ # find good url
381
+ url = match["url"]
382
+ url += "?q=75&fm=jpg&w=200&fit=max"
383
+ if match["id"] not in BAD_IDS:
384
+ disabled = False
385
+ else:
386
+ disable = True
387
+ # the card shows an image and a checkbox
388
+ cols[i%3].markdown(card(i, url), unsafe_allow_html=True)
389
+ # we access the values of the checkbox via st.session_state[f"input{i}"]
390
+ cols[i%3].slider(
391
+ "Relevance",
392
+ min_value=0.0,
393
+ max_value=1.0,
394
+ value=1.0,
395
+ step=0.05,
396
+ key=f"input{i}",
397
+ disabled=disabled
398
+ )