jamescalam commited on
Commit
41b0565
1 Parent(s): be029bd

added app and requirements

Browse files
Files changed (2) hide show
  1. app.py +213 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from datasets import load_dataset
3
+ import numpy as np
4
+ import pinecone
5
+ import base64
6
+ from io import BytesIO
7
+ from transformers import CLIPTokenizerFast, CLIPModel
8
+ import torch
9
+ from typing import List
10
+
11
+ PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
12
+ INDEX = "imagenet-query-trainer-clip"
13
+ MODEL_ID = "openai/clip-vit-base-patch32"
14
+ DIMS = 512
15
+
16
+ @st.experimental_singleton(show_spinner=False)
17
+ def init_dataset():
18
+ return load_dataset(
19
+ 'frgfm/imagenette',
20
+ 'full_size',
21
+ split='train',
22
+ ignore_verifications=False # set to True if seeing splits Error
23
+ )
24
+
25
+ @st.experimental_singleton(show_spinner=False)
26
+ def init_clip():
27
+ tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
28
+ clip = CLIPModel.from_pretrained(MODEL_ID)
29
+ return tokenizer, clip
30
+
31
+ @st.experimental_singleton(show_spinner=False)
32
+ def init_db():
33
+ pinecone.init(
34
+ api_key="2b2e0e96-7a31-4e87-b924-5e182ed26b03",
35
+ environment="us-west1-gcp"
36
+ )
37
+ return pinecone.Index(INDEX)
38
+
39
+ @st.experimental_singleton(show_spinner=False)
40
+ def init_random_query():
41
+ xq = np.random.rand(DIMS)
42
+ return xq, xq.copy()
43
+
44
+ class Classifier:
45
+ def __init__(self, xq: list):
46
+ # initialize model with DIMS input size and 1 output
47
+ self.model = torch.nn.Linear(DIMS, 1)
48
+ # convert initial query `xq` to tensor parameter to init weights
49
+ init_weight = torch.Tensor(xq).reshape(1, -1)
50
+ self.model.weight = torch.nn.Parameter(init_weight)
51
+ # init loss and optimizer
52
+ self.loss = torch.nn.BCEWithLogitsLoss()
53
+ self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.2)
54
+
55
+ def fit(self, X: list, y: list, iters: int = 20):
56
+ # convert X and y to tensor
57
+ X = torch.Tensor(X)
58
+ y = torch.Tensor(y).reshape(-1, 1)
59
+ for i in range(iters):
60
+ # zero gradients
61
+ self.optimizer.zero_grad()
62
+ # forward pass
63
+ out = self.model(X)
64
+ # compute loss
65
+ loss = self.loss(out, y)
66
+ # backward pass
67
+ loss.backward()
68
+ # update weights
69
+ self.optimizer.step()
70
+
71
+ def get_weights(self):
72
+ xq = self.model.weight.detach().numpy()[0].tolist()
73
+ return xq
74
+
75
+ def prompt2vec(prompt: str):
76
+ inputs = tokenizer(prompt, return_tensors='pt')
77
+ out = clip.get_text_features(**inputs)
78
+ xq = out.squeeze(0).cpu().detach().numpy().tolist()
79
+ return xq
80
+
81
+ def pil_to_bytes(img):
82
+ with BytesIO() as buf:
83
+ img.save(buf, format='jpeg')
84
+ img_bin = buf.getvalue()
85
+ img_bin = base64.b64encode(img_bin).decode('utf-8')
86
+ return img_bin
87
+
88
+ def card(i):
89
+ img = imagenet[int(i)]['image']
90
+ img_bin = pil_to_bytes(img)
91
+ return f'<img id="img{i}" src="data:image/jpeg;base64,{img_bin}" width="200px;">'
92
+
93
+ def get_top_k(xq, top_k=10):
94
+ xc = index.query(
95
+ xq,
96
+ top_k=top_k,
97
+ include_values=True,
98
+ filter={"seen": 0}
99
+ )
100
+ matches = {match['id']: match['values'] for match in xc['matches']}
101
+ return matches
102
+
103
+ def tune(matches, inputs):
104
+ positive_idx = [idx for idx, val in inputs.items() if val == 1]
105
+ negatives = [match for match in matches.items() if match[0] not in positive_idx]
106
+ negative_idx = [match[0] for match in negatives]
107
+ negative_vectors = [match[1] for match in negatives]
108
+ positive_vectors = [match[1] for match in matches.items() if match[0] in positive_idx]
109
+ # prep training data
110
+ y = [1] * len(positive_idx) + [0] * len(negative_idx)
111
+ X = positive_vectors + negative_vectors
112
+ # train the classifier
113
+ st.session_state.clf.fit(X, y)
114
+ # extract new vector
115
+ st.session_state.xq = st.session_state.clf.get_weights()
116
+ # update one record at a time
117
+ for i in positive_idx + negative_idx:
118
+ index.update(str(i), set_metadata={"seen": 1})
119
+ # return
120
+ #return clf, xq
121
+
122
+ def refresh_index():
123
+ xq = st.session_state.xq
124
+ if type(xq) is not list:
125
+ xq = xq.tolist()
126
+ while True:
127
+ xc = index.query(xq, top_k=100, filter={"seen": 1})
128
+ idx = [match['id'] for match in xc['matches']]
129
+ if len(idx) == 0: break
130
+ for i in idx:
131
+ index.update(str(i), set_metadata={"seen": 0})
132
+ # refresh session states
133
+ del st.session_state.clf, st.session_state.xq, st.session_state.show_images
134
+
135
+ def calc_dist():
136
+ xq = np.array(st.session_state.xq)
137
+ orig_xq = np.array(st.session_state.orig_xq)
138
+ return np.linalg.norm(xq - orig_xq)
139
+
140
+ def submit():
141
+ matches = st.session_state.matches
142
+ inputs = {}
143
+ states = [
144
+ st.session_state[f"input{i}"] for i in range(len(matches))
145
+ ]
146
+ for i, idx in enumerate(matches.keys()):
147
+ inputs[idx] = int(states[i])
148
+ states[i] = False
149
+ tune(matches, inputs)
150
+ #st.session_state.show_images = False
151
+
152
+ def set_tuner_true():
153
+ st.session_state.tuner = True
154
+
155
+ st.markdown("""
156
+ <link
157
+ rel="stylesheet"
158
+ href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap"
159
+ />
160
+ """, unsafe_allow_html=True)
161
+
162
+ with st.spinner("Initializing everything..."):
163
+ imagenet = init_dataset()
164
+ index = init_db()
165
+ if 'xq' not in st.session_state:
166
+ tokenizer, clip = init_clip()
167
+ if 'show_images' not in st.session_state:
168
+ st.session_state.show_images = False
169
+ st.session_state.tuner = False
170
+
171
+ if 'xq' not in st.session_state:
172
+ prompt = st.text_input("Prompt:", value="")
173
+ random_xq = st.button("Random")
174
+ prompt_xq = st.button("Prompt", disabled=len(prompt) == 0)
175
+ if random_xq:
176
+ xq, orig_xq = init_random_query()
177
+ st.session_state.xq = xq
178
+ st.session_state.orig_xq = orig_xq
179
+ st.session_state.show_images = True
180
+ elif prompt_xq:
181
+ xq = prompt2vec(prompt)
182
+ st.session_state.xq = xq
183
+ st.session_state.orig_xq = xq
184
+ st.session_state.show_images = True
185
+
186
+ else:
187
+ # initialize classifier
188
+ if 'clf' not in st.session_state:
189
+ st.session_state.clf = Classifier(st.session_state.xq)
190
+
191
+ new_results = st.button("Search", disabled=st.session_state.show_images)
192
+ if new_results:
193
+ st.session_state.show_images = True
194
+
195
+ refresh = st.button("Refresh")
196
+ if refresh:
197
+ # we use this to remove filters in index, refresh models etc
198
+ refresh_index()
199
+ elif st.session_state.show_images:
200
+ # if we want to display images we end up here
201
+ st.markdown(f"Distance travelled: *{round(calc_dist(), 4)}*")
202
+ # first retrieve images from pinecone
203
+ st.session_state.matches = get_top_k(st.session_state.xq, top_k=10)
204
+ # once retrieved, display them alongside checkboxes in a form
205
+ with st.form("my_form", clear_on_submit=True):
206
+ # we have three columns in the form
207
+ cols = st.columns(3)
208
+ for i, idx in enumerate(st.session_state.matches.keys()):
209
+ # the card shows an image and a checkbox
210
+ cols[i%3].markdown(card(idx), unsafe_allow_html=True)
211
+ # we access the values of the checkbox via st.session_state[f"input{i}"]
212
+ cols[i%3].checkbox("Relevant", key=f"input{i}")
213
+ st.form_submit_button("Tune", on_click=submit)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets
2
+ pinecone-client
3
+ numpy
4
+ transformers