import streamlit as st
from datasets import load_dataset
import numpy as np
import pinecone
import base64
from io import BytesIO
from transformers import CLIPTokenizerFast, CLIPModel
import torch
from datetime import datetime
import logging
from urllib3.exceptions import ProtocolError
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io
INDEX = "unsplash-25k-clip"
MODEL_ID = "openai/clip-vit-base-patch32"
DIMS = 512
@st.experimental_singleton(show_spinner=False)
def init_clip():
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
clip = CLIPModel.from_pretrained(MODEL_ID)
return tokenizer, clip
@st.experimental_singleton(show_spinner=False)
def init_db():
pinecone.init(
api_key=PINECONE_API_KEY,
environment="us-west1-gcp"
)
meta_field = datetime.now().isoformat()
return meta_field, pinecone.Index(INDEX)
@st.experimental_singleton(show_spinner=False)
def init_query_num():
print("init query_num")
return 0
def query(xq, top_k=10, include_values=True, include_metadata=True, filter=None):
logging.info(f"Query to Pinecone with '{st.session_state.meta}'")
attempt = 0
while attempt < 3:
try:
xc = st.session_state.index.query(
xq,
top_k=top_k,
include_values=include_values,
include_metadata=include_metadata,
filter=filter
)
matches = xc['matches']
break
except:
# force reload
pinecone.init(api_key=PINECONE_API_KEY, environment="us-west1-gcp")
st.session_state.index = pinecone.Index(INDEX)
attempt += 1
matches = []
if len(matches) == 0:
logging.error(f"No matches found for '{st.session_state.meta}'")
return matches
@st.experimental_singleton(show_spinner=False)
def init_random_query():
xq = np.random.rand(DIMS).tolist()
return xq, xq.copy()
class Classifier:
def __init__(self, xq: list):
# initialize model with DIMS input size and 1 output
self.model = torch.nn.Linear(DIMS, 1)
# convert initial query `xq` to tensor parameter to init weights
init_weight = torch.Tensor(xq).reshape(1, -1)
self.model.weight = torch.nn.Parameter(init_weight)
# init loss and optimizer
self.loss = torch.nn.BCEWithLogitsLoss()
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
def fit(self, X: list, y: list, iters: int = 5):
# convert X and y to tensor
X = torch.Tensor(X)
y = torch.Tensor(y).reshape(-1, 1)
for i in range(iters):
# zero gradients
self.optimizer.zero_grad()
# forward pass
out = self.model(X)
# compute loss
loss = self.loss(out, y)
# backward pass
loss.backward()
# update weights
self.optimizer.step()
def get_weights(self):
xq = self.model.weight.detach().numpy()[0].tolist()
return xq
def prompt2vec(prompt: str):
inputs = tokenizer(prompt, return_tensors='pt')
out = clip.get_text_features(**inputs)
xq = out.squeeze(0).cpu().detach().numpy().tolist()
return xq
def pil_to_bytes(img):
with BytesIO() as buf:
img.save(buf, format='jpeg')
img_bin = buf.getvalue()
img_bin = base64.b64encode(img_bin).decode('utf-8')
return img_bin
def card(i, url):
return f''
def get_top_k(xq, top_k=9):
matches = query(
xq, top_k=top_k, include_values=True, include_metadata=True,
filter={st.session_state.meta: {"$ne": 1}}
)
return matches
def tune(X, y, iters=5):
# train the classifier
print(y)
st.session_state.clf.fit(X, y, iters=iters)
# extract new vector
st.session_state.xq = st.session_state.clf.get_weights()
def refresh_index():
logging.info(f"Refresh for '{st.session_state.meta}'")
st.session_state.query_num = 0
xq = st.session_state.xq
if type(xq) is not list:
xq = xq.tolist()
while True:
matches = query(xq, top_k=100, filter={st.session_state.meta: 1})
id_vals = [match['id'] for match in matches]
if len(id_vals) == 0: break
for i in id_vals:
st.session_state.index.update(str(i), set_metadata={st.session_state.meta: 0})
# refresh session states
del st.session_state.clf, st.session_state.xq
def calc_dist():
xq = np.array(st.session_state.xq)
orig_xq = np.array(st.session_state.orig_xq)
return np.linalg.norm(xq - orig_xq)
def submit():
st.session_state.query_num += 1
matches = st.session_state.matches
velocity = 2 #st.session_state.velocity
scores = {}
states = [
st.session_state[f"input{i}"] for i in range(len(matches))
]
for i, match in enumerate(matches):
scores[match['id']] = float(states[i])
states[i] = False
# reset states to unchecked
for i in range(len(matches)):
st.session_state[f"input{i}"] = False
# get training data and labels
X = list([match['values'] for match in matches])
y = list(scores.values())
tune(X, y, iters=velocity)
# update record metadata after training
for match in matches:
st.session_state.index.update(str(match['id']), set_metadata={st.session_state.meta: 1})
def delete_element(element):
del element
st.markdown("""
""", unsafe_allow_html=True)
messages = [
f"""
Welcome to the semantic query trainer app! Here we will demo how to efficiently train
a classifier to *very accurately* classify images based on their semantic content.
First, we need to initialize the classifier with a simple prompt. Try and write something
similar to what you're looking for, or if you want a challenge, try something completely
different.
""",
f"""
With the first query we have initialized the classifier weights (they're a 512-d vector)
and used those weights to perform a *vector search* to find images embeddings (also 512-d
vectors) that closely match the classifier weights.
These are essentially the images that the classifier would currently classify as "positive".
Based on your *target class* for the classifier, decide how relevant each of the images
are below, rating them from -1 (completely irrelevant) to +1 (a perfect match).
""",
f"""
Each of the image embeddings is paired with the *score* that you just gave it. These are
all fed into the classifier and used to train it. The classifier learns to *move* towards
the positively scored images, and to *avoid* the negatively scored images.
""",
f"""
As we repeat the process, the classifier rapidly learns the target space of our intended
class.
Typically, we don't train classifiers like this, instead we label a huge dataset and train
the classifier across all images and their labels. This is massively inefficient. Here we
save annotation and compute time by using vector search to identify and focus on the images
that make the *biggest* difference in classifier performance.
""",
f"""
We shouldn't need to repeat this process many times before our classifier converges on our
target space. Once we begin returning only relevant images, we can stop training the classifier.
*(In this demo, you can try changing your target space and 'traversing' the vector space
to the new target space)*
""",
f"""
The app uses the [Pinecone vector database](https://pinecone.io/) to store and query images
using vector search. All images are sourced from the [Unsplash Lite dataset](https://huggingface.co/openai/clip-vit-base-patch32) and encoded
using [OpenAI's CLIP](https://huggingface.co/openai/clip-vit-base-patch32). We explain how
it all works [here](https://classifier-train-vector-search--optimistic-curran-b817a8.netlify.app/learn/classifier-train-vector-search/).
"""
]
with st.spinner("Initializing everything..."):
st.session_state.meta, st.session_state.index = init_db()
if 'xq' not in st.session_state:
tokenizer, clip = init_clip()
st.session_state.query_num = 0
if 'xq' not in st.session_state:
if st.session_state.query_num < len(messages):
msg = messages[st.session_state.query_num]
else:
msg = messages[-1]
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty()]
start[0].info(msg)
prompt = start[1].text_input("Prompt:", value="")
prompt_xq = start[2].button("Prompt", disabled=len(prompt) == 0)
random_xq = start[3].button("Random", disabled=len(prompt) != 0)
start[4].markdown('Not sure what to write? Try **"dogs in the snow"**, **"close up of a dog"**, **"sony radio"**, or click **Random**.')
if random_xq:
print("r_xq")
xq, orig_xq = init_random_query()
st.session_state.xq = xq
st.session_state.orig_xq = orig_xq
_ = [elem.empty() for elem in start]
elif prompt_xq:
print("p_xq")
xq = prompt2vec(prompt)
st.session_state.xq = xq
st.session_state.orig_xq = xq
_ = [elem.empty() for elem in start]
if 'xq' in st.session_state:
if st.session_state.query_num+1 < len(messages):
msg = messages[st.session_state.query_num+1]
else:
msg = messages[-1]
# initialize classifier
if 'clf' not in st.session_state:
st.session_state.clf = Classifier(st.session_state.xq)
refresh = st.button("Refresh")
if refresh:
# we use this to remove filters in index, refresh models etc
refresh_index()
else:
# if we want to display images we end up here
st.info(msg)
# first retrieve images from pinecone
st.session_state.matches = get_top_k(st.session_state.xq, top_k=9)
# once retrieved, display them alongside checkboxes in a form
with st.form("my_form", clear_on_submit=False):
st.form_submit_button("Tune", on_click=submit)
#velocity = st.slider("Velocity", 1, 8, 2, key="velocity")
# we have three columns in the form
cols = st.columns(3)
for i, match in enumerate(st.session_state.matches):
# find good url
loc = match["metadata"].get("good_url")
if loc:
url = match["metadata"][loc]
if loc == "photo_url":
url += "/download?force=true&w=640"
disabled = False
else:
# will show no image, but not sure what else to place here
url = match["metadata"]["photo_url"]
disabled=True
# the card shows an image and a checkbox
cols[i%3].markdown(card(i, url), unsafe_allow_html=True)
# we access the values of the checkbox via st.session_state[f"input{i}"]
cols[i%3].slider(
"Relevance",
min_value=-1.0,
max_value=1.0,
value=0.0,
step=0.1,
key=f"input{i}",
disabled=disabled
)