import gradio as gr from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel import torch import io from PIL import Image import os from cryptography.fernet import Fernet from google.cloud import storage import pinecone import json # decrypt Storage Cloud credentials fernet = Fernet(os.environ['DECRYPTION_KEY']) with open('cloud-storage.encrypted', 'rb') as fp: encrypted = fp.read() creds = json.loads(fernet.decrypt(encrypted).decode()) # then save creds to file with open('cloud-storage.json', 'w', encoding='utf-8') as fp: fp.write(json.dumps(creds, indent=4)) # connect to Cloud Storage os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json' storage_client = storage.Client() bucket = storage_client.get_bucket('diffusion-search') # get api key for pinecone auth PINECONE_KEY = os.environ['PINECONE_KEY'] index_id = "diffusion-search" # init connection to pinecone pinecone.init( api_key=PINECONE_KEY, environment="us-west1-gcp" ) if index_id not in pinecone.list_indexes(): raise ValueError(f"Index '{index_id}' not found") index = pinecone.Index(index_id) device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using '{device}' device...") # init all of the models and move them to a given GPU # if you have CUDA or MPS, set it to the active device like this device = "cuda" if torch.cuda.is_available() else "cpu" model_id = "openai/clip-vit-base-patch32" # we initialize a tokenizer, image processor, and the model itself tokenizer = CLIPTokenizerFast.from_pretrained(model_id) model = CLIPModel.from_pretrained(model_id).to(device) missing_im = Image.open('missing.png') threshold = 0.85 def encode_text(text: str): # create transformer-readable tokens inputs = tokenizer(text, return_tensors="pt").to(device) text_emb = model.get_text_features(**inputs).cpu().detach().tolist() return text_emb def prompt_query(text: str): print(f"Running prompt_query('{text}')") embeds = encode_text(text) try: xc = index.query(embeds, top_k=30, include_metadata=True) except Exception as e: print(f"Error during query: {e}") # reinitialize connection pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') index2 = pinecone.Index(index_id) try: xc = index2.query(embeds, top_k=30, include_metadata=True) print("Reinitialized query successful") except Exception as e: raise ValueError(e) scores = [round(match['score'], 2) for match in xc['matches']] ids = [match['id'] for match in xc['matches']] return ids def get_image(url: str): blob = bucket.blob(url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) return im def test_image(_id, image): try: image.save('tmp.png') return True except OSError: # delete corrupted file from pinecone and cloud index.delete(ids=[_id]) bucket.blob(f"images/{_id}.png").delete() return False def prompt_image(text: str): embeds = encode_text(text) try: xc = index.query( embeds, top_k=9, include_metadata=True, filter={"image_nsfw": {"$lt": 0.5}} ) except Exception as e: print(f"Error during query: {e}") # reinitialize connection pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') index2 = pinecone.Index(index_id) try: xc = index2.query( embeds, top_k=9, include_metadata=True, filter={"image_nsfw": {"$lt": 0.5}} ) print("Reinitialized query successful") except Exception as e: raise ValueError(e) scores = [match['score'] for match in xc['matches']] ids = [match['id'] for match in xc['matches']] images = [] for _id in ids: try: image_url = f"images/{_id}.png" blob = bucket.blob(image_url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) if test_image(_id, im): images.append(im) else: images.append(missing_im) except ValueError: print(f"ValueError: '{image_url}'") return images, scores # __APP FUNCTIONS__ def set_suggestion(text: str): return gr.TextArea.update(value=text[0]) def set_images(text: str): images, scores = prompt_image(text) return gr.Gallery.update(value=images) # __CREATE APP__ demo = gr.Blocks() with demo: gr.HTML( """
""" ) with gr.Row(): with gr.Column(): prompt = gr.TextArea( value="space dogs", placeholder="Something cool to search for...", interactive=True ) search = gr.Button(value="Search!") gr.Markdown( """ #### Search through 10K images generated by AI This app demonstrates the idea of text-to-image search. The search process uses an AI model that understands the *meaning* of text and images to identify images that best align to a search prompt. 🪄 [*Built with the OP Stack*](https://gkogan.notion.site/gkogan/The-OP-Stack-aafcab0005e3445a8ad8491aac80446c) """ ) # results column with gr.Column(): pics = gr.Gallery() pics.style(grid=3) # search event listening try: search.click(set_images, prompt, pics) except OSError: print("OSError") demo.launch()