ysharma HF staff commited on
Commit
56317d8
1 Parent(s): 5850b39

create app file

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from PIL import Image
4
+ import base64
5
+ import requests
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.llms import OpenAI
8
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
+ from langchain.docstore.document import Document
10
+ from langchain.embeddings.openai import OpenAIEmbeddings
11
+ from langchain.vectorstores.faiss import FAISS
12
+ import pickle
13
+
14
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
15
+
16
+ model_name = "sentence-transformers/all-mpnet-base-v2"
17
+ hf = HuggingFaceEmbeddings(model_name=model_name)
18
+
19
+ #Loading FAISS search index from disk
20
+ #This is a vector space of embeddings from one-tenth of PlaygrondAI image-prompts
21
+ #PlaygrondAI open-sourced dataset is a collection of around 1.3 mil generated images and caption pairs
22
+ with open("search_index0.pickle", "rb") as f:
23
+ search_index = pickle.load(f)
24
+
25
+ #Defining methods for inference
26
+ def encode(img):
27
+ #Encode source image file to base64 string
28
+ with open(img, "rb") as image_file:
29
+ encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
30
+ #Returning image as encoded string
31
+ return encoded_string
32
+
33
+ def get_caption(image_in):
34
+ #Sending requests to BLIP2 Gradio-space API
35
+ BLIP2_GRADIO_API_URL = "https://nielsr-comparing-captioning-models.hf.space/run/predict"
36
+ response = requests.post(BLIP2_GRADIO_API_URL, json={
37
+ "data": ["data:image/jpg;base64," + encode(image_in) ]
38
+ }).json()
39
+ data = response["data"][-1]
40
+ return data
41
+
42
+ def Image_similarity_search(image_in):
43
+ #Get image caption from Bip2 Gradio space
44
+ img_caption = get_caption(image_in)
45
+ print(f"Image caption from Blip2 Gradio Space is - {img_caption}")
46
+ #Searching the vector space
47
+ search_result = search_index.similarity_search(img_caption)[0]
48
+ #Formatting the search results
49
+ pai_prompt = list(search_result)[0][1]
50
+ pai_img_link = list(search_result)[-2][-1]['source']
51
+ #formatting html output for displaying image
52
+ html_tag = "<img src='"+pai_img_link+"' alt='"+img_caption+"' height='512' style='display: block; margin: auto;'>"
53
+ return pai_prompt, html_tag
54
+
55
+ #Defining Gradio Blocks
56
+ with gr.Blocks(css = """#label_mid {padding-top: 2px; padding-bottom: 2px;}
57
+ #label_results {padding-top: 5px; padding-bottom: 1px;}
58
+ """) as demo:
59
+ with gr.Column(scale=2):
60
+ pass
61
+ with gr.Column(scale=1):
62
+ label_top = gr.HTML(value= "<center>🖼️Upload an Image for your search📷</center>", elem_id="label_top")
63
+ image_in = gr.Image(label="Upoload an Image for search", type='filepath', elem_id="image_in")
64
+ label_mid = gr.HTML(value= "<p style='text-align: center; color: red;'>Or</center></p>", elem_id='label_mid')
65
+ label_bottom = gr.HTML(value= "<center>🔍Type in your serch query and press Enter</center>", elem_id="label_bottom")
66
+ search_query = gr.Textbox(placeholder="Example: A small cat sitting", label="", elem_id="search_query")
67
+ #b1 = gr.Button("Search").style(full_width=False)
68
+ label_results = gr.HTML(value= "<p style='text-align: center; color: blue; font-weight: bold;'>Search results from PlaygroundAI</center></p>", elem_id="label_results")
69
+ img_search = gr.HTML(label = 'Image search results from PlaygroundAI dataset', elem_id="img_search")
70
+ pai_prompt = gr.Textbox(label="Image prompt from PlaygroundAI dataset", elem_id="pai_prompt")
71
+ with gr.Column(scale=2):
72
+ pass
73
+
74
+ image_in.change(Image_similarity_search, image_in, [pai_prompt, img_search] )
75
+ #b1.click(Image_similarity_search, image_in, [pai_prompt, img_search] )
76
+
77
+ demo.launch(debug=True)