import streamlit as st from datasets import load_dataset import numpy as np import pinecone import base64 from io import BytesIO from transformers import CLIPTokenizerFast, CLIPModel import torch from datetime import datetime import logging from urllib3.exceptions import ProtocolError PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] # app.pinecone.io INDEX = "unsplash-25k-clip" MODEL_ID = "openai/clip-vit-base-patch32" DIMS = 512 @st.experimental_singleton(show_spinner=False) def init_clip(): tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID) clip = CLIPModel.from_pretrained(MODEL_ID) return tokenizer, clip @st.experimental_singleton(show_spinner=False) def init_db(): pinecone.init( api_key=PINECONE_API_KEY, environment="us-west1-gcp" ) meta_field = datetime.now().isoformat() return meta_field, pinecone.Index(INDEX) @st.experimental_singleton(show_spinner=False) def init_query_num(): print("init query_num") return 0 def query(xq, top_k=10, include_values=True, include_metadata=True, filter=None): logging.info(f"Query to Pinecone with '{st.session_state.meta}'") attempt = 0 while attempt < 3: try: xc = st.session_state.index.query( xq, top_k=top_k, include_values=include_values, include_metadata=include_metadata, filter=filter ) matches = xc['matches'] break except: # force reload pinecone.init(api_key=PINECONE_API_KEY, environment="us-west1-gcp") st.session_state.index = pinecone.Index(INDEX) attempt += 1 matches = [] if len(matches) == 0: logging.error(f"No matches found for '{st.session_state.meta}'") return matches @st.experimental_singleton(show_spinner=False) def init_random_query(): xq = np.random.rand(DIMS).tolist() return xq, xq.copy() class Classifier: def __init__(self, xq: list): # initialize model with DIMS input size and 1 output self.model = torch.nn.Linear(DIMS, 1) # convert initial query `xq` to tensor parameter to init weights init_weight = torch.Tensor(xq).reshape(1, -1) self.model.weight = torch.nn.Parameter(init_weight) # init loss and optimizer self.loss = torch.nn.BCEWithLogitsLoss() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) def fit(self, X: list, y: list, iters: int = 5): # convert X and y to tensor X = torch.Tensor(X) y = torch.Tensor(y).reshape(-1, 1) for i in range(iters): # zero gradients self.optimizer.zero_grad() # forward pass out = self.model(X) # compute loss loss = self.loss(out, y) # backward pass loss.backward() # update weights self.optimizer.step() def get_weights(self): xq = self.model.weight.detach().numpy()[0].tolist() return xq def prompt2vec(prompt: str): inputs = tokenizer(prompt, return_tensors='pt') out = clip.get_text_features(**inputs) xq = out.squeeze(0).cpu().detach().numpy().tolist() return xq def pil_to_bytes(img): with BytesIO() as buf: img.save(buf, format='jpeg') img_bin = buf.getvalue() img_bin = base64.b64encode(img_bin).decode('utf-8') return img_bin def card(i, url): return f'' def get_top_k(xq, top_k=9): matches = query( xq, top_k=top_k, include_values=True, include_metadata=True, filter={st.session_state.meta: {"$ne": 1}} ) return matches def tune(X, y, iters=5): # train the classifier print(y) st.session_state.clf.fit(X, y, iters=iters) # extract new vector st.session_state.xq = st.session_state.clf.get_weights() def refresh_index(): logging.info(f"Refresh for '{st.session_state.meta}'") st.session_state.query_num = 0 xq = st.session_state.xq if type(xq) is not list: xq = xq.tolist() while True: matches = query(xq, top_k=100, filter={st.session_state.meta: 1}) id_vals = [match['id'] for match in matches] if len(id_vals) == 0: break for i in id_vals: st.session_state.index.update(str(i), set_metadata={st.session_state.meta: 0}) # refresh session states del st.session_state.clf, st.session_state.xq def calc_dist(): xq = np.array(st.session_state.xq) orig_xq = np.array(st.session_state.orig_xq) return np.linalg.norm(xq - orig_xq) def submit(): st.session_state.query_num += 1 matches = st.session_state.matches velocity = 2 #st.session_state.velocity scores = {} states = [ st.session_state[f"input{i}"] for i in range(len(matches)) ] for i, match in enumerate(matches): scores[match['id']] = float(states[i]) states[i] = False # reset states to unchecked for i in range(len(matches)): st.session_state[f"input{i}"] = False # get training data and labels X = list([match['values'] for match in matches]) y = list(scores.values()) tune(X, y, iters=velocity) # update record metadata after training for match in matches: st.session_state.index.update(str(match['id']), set_metadata={st.session_state.meta: 1}) def delete_element(element): del element st.markdown(""" """, unsafe_allow_html=True) messages = [ f""" Welcome to the semantic query trainer app! Here we will demo how to efficiently train a classifier to *very accurately* classify images based on their semantic content. First, we need to initialize the classifier with a simple prompt. Try and write something similar to what you're looking for, or if you want a challenge, try something completely different. """, f""" With the first query we have initialized the classifier weights (they're a 512-d vector) and used those weights to perform a *vector search* to find images embeddings (also 512-d vectors) that closely match the classifier weights. These are essentially the images that the classifier would currently classify as "positive". Based on your *target class* for the classifier, decide how relevant each of the images are below, rating them from -1 (completely irrelevant) to +1 (a perfect match). """, f""" Each of the image embeddings is paired with the *score* that you just gave it. These are all fed into the classifier and used to train it. The classifier learns to *move* towards the positively scored images, and to *avoid* the negatively scored images. """, f""" As we repeat the process, the classifier rapidly learns the target space of our intended class. Typically, we don't train classifiers like this, instead we label a huge dataset and train the classifier across all images and their labels. This is massively inefficient. Here we save annotation and compute time by using vector search to identify and focus on the images that make the *biggest* difference in classifier performance. """, f""" We shouldn't need to repeat this process many times before our classifier converges on our target space. Once we begin returning only relevant images, we can stop training the classifier. *(In this demo, you can try changing your target space and 'traversing' the vector space to the new target space)* """, f""" The app uses the [Pinecone vector database](https://pinecone.io/) to store and query images using vector search. All images are sourced from the [Unsplash Lite dataset](https://huggingface.co/openai/clip-vit-base-patch32) and encoded using [OpenAI's CLIP](https://huggingface.co/openai/clip-vit-base-patch32). We explain how it all works [here](https://classifier-train-vector-search--optimistic-curran-b817a8.netlify.app/learn/classifier-train-vector-search/). """ ] with st.spinner("Initializing everything..."): st.session_state.meta, st.session_state.index = init_db() if 'xq' not in st.session_state: tokenizer, clip = init_clip() st.session_state.query_num = 0 if 'xq' not in st.session_state: if st.session_state.query_num < len(messages): msg = messages[st.session_state.query_num] else: msg = messages[-1] start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] start[0].info(msg) prompt = start[1].text_input("Prompt:", value="") prompt_xq = start[2].button("Prompt", disabled=len(prompt) == 0) random_xq = start[3].button("Random", disabled=len(prompt) != 0) start[4].markdown('Not sure what to write? Try **"dogs in the snow"**, **"close up of a dog"**, **"sony radio"**, or click **Random**.') if random_xq: print("r_xq") xq, orig_xq = init_random_query() st.session_state.xq = xq st.session_state.orig_xq = orig_xq _ = [elem.empty() for elem in start] elif prompt_xq: print("p_xq") xq = prompt2vec(prompt) st.session_state.xq = xq st.session_state.orig_xq = xq _ = [elem.empty() for elem in start] if 'xq' in st.session_state: if st.session_state.query_num+1 < len(messages): msg = messages[st.session_state.query_num+1] else: msg = messages[-1] # initialize classifier if 'clf' not in st.session_state: st.session_state.clf = Classifier(st.session_state.xq) refresh = st.button("Refresh") if refresh: # we use this to remove filters in index, refresh models etc refresh_index() else: # if we want to display images we end up here st.info(msg) # first retrieve images from pinecone st.session_state.matches = get_top_k(st.session_state.xq, top_k=9) # once retrieved, display them alongside checkboxes in a form with st.form("my_form", clear_on_submit=False): st.form_submit_button("Tune", on_click=submit) #velocity = st.slider("Velocity", 1, 8, 2, key="velocity") # we have three columns in the form cols = st.columns(3) for i, match in enumerate(st.session_state.matches): # find good url loc = match["metadata"].get("good_url") if loc: url = match["metadata"][loc] if loc == "photo_url": url += "/download?force=true&w=640" disabled = False else: # will show no image, but not sure what else to place here url = match["metadata"]["photo_url"] disabled=True # the card shows an image and a checkbox cols[i%3].markdown(card(i, url), unsafe_allow_html=True) # we access the values of the checkbox via st.session_state[f"input{i}"] cols[i%3].slider( "Relevance", min_value=-1.0, max_value=1.0, value=0.0, step=0.1, key=f"input{i}", disabled=disabled )