File size: 7,327 Bytes
56317d8
 
 
 
 
 
574c554
56317d8
 
 
 
 
 
8ecc4c6
56317d8
 
 
 
 
 
 
 
a6aa0d1
 
 
 
7e978a1
a6aa0d1
56317d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf96618
 
 
 
 
 
 
a6aa0d1
56317d8
 
a6aa0d1
 
 
 
56317d8
 
a6aa0d1
 
 
56317d8
a6aa0d1
 
 
 
addfd62
88070e7
 
 
 
4ca2a76
88070e7
56317d8
addfd62
56317d8
 
 
34c3b4c
4f9fd57
4ca2a76
56317d8
7e1cabd
bf96618
 
 
 
 
 
 
 
 
 
 
 
 
7e1cabd
7b617af
 
 
 
bf96618
7b617af
bf96618
 
7b617af
34c3b4c
 
 
b01002a
56317d8
 
4f9fd57
bf96618
81827f6
56317d8
3ef73e4
12bca30
3cfb917
 
10264fb
56317d8
3ef73e4
 
8ace82b
56317d8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import os 
from PIL import Image
import base64
import requests
from langchain.embeddings import HuggingFaceEmbeddings
#from langchain.llms import OpenAI
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
import pickle

HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] 
model_name = "sentence-transformers/all-mpnet-base-v2"
hf = HuggingFaceEmbeddings(model_name=model_name)

#Loading FAISS search index from disk
#This is a vector space of embeddings from one-tenth of PlaygrondAI image-prompts
#PlaygrondAI open-sourced dataset is a collection of around 1.3 mil generated images and caption pairs
with open("search_index0.pickle", "rb") as f:
    search_index = pickle.load(f)
with open("search_index1.pickle", "rb") as f:
    search_index1 = pickle.load(f)
with open("search_index2.pickle", "rb") as f:
    search_index2 = pickle.load(f)
with open("search_index3.pickle", "rb") as f: 
    search_index3 = pickle.load(f)

#Defining methods for inference
def encode(img):
    #Encode source image file to base64 string
    with open(img, "rb") as image_file:
      encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
    #Returning image as encoded string
    return encoded_string

def get_caption(image_in):
    #Sending requests to BLIP2 Gradio-space API
    BLIP2_GRADIO_API_URL = "https://nielsr-comparing-captioning-models.hf.space/run/predict"
    response = requests.post(BLIP2_GRADIO_API_URL, json={
              "data": ["data:image/jpg;base64," + encode(image_in) ]
              }).json()
    data = response["data"][-1]
    return data

def Image_similarity_search(image_in, search_query):
    if search_query == '':
        #Get image caption from Bip2 Gradio space
        img_caption = get_caption(image_in)
    else:
        img_caption = search_query
    print(f"Image caption from Blip2 Gradio Space or the search_query is - {img_caption}")
    
    #Searching the vector space
    search_result = search_index.similarity_search(img_caption)[0]
    search_result1 = search_index1.similarity_search(img_caption)[0]
    search_result2 = search_index2.similarity_search(img_caption)[0]
    search_result3 = search_index3.similarity_search(img_caption)[0]

    #Formatting the search results 
    pai_prompt = list(search_result)[0][1]
    pai_prompt1 = list(search_result1)[0][1]
    pai_prompt2 = list(search_result2)[0][1]
    pai_prompt3 = list(search_result3)[0][1]
    pai_img_link = list(search_result)[-2][-1]['source']
    pai_img_link1 = list(search_result1)[-2][-1]['source']
    pai_img_link2 = list(search_result2)[-2][-1]['source']
    pai_img_link3 = list(search_result3)[-2][-1]['source']
    
    html_tag = f"""<div style="display: flex; flex-direction: row; overflow-x: auto;">
                <img src='{pai_img_link}'  alt='{img_caption}' style='display: block; margin: auto;'>
                <img src='{pai_img_link1}' alt='{img_caption}' style='display: block; margin: auto;'>
                <img src='{pai_img_link2}' alt='{img_caption}' style='display: block; margin: auto;'>
                <img src='{pai_img_link3}' alt='{img_caption}' style='display: block; margin: auto;'>
                </div>""" #class='gallery' >
    return html_tag #pai_prompt 

                        
#Defining Gradio Blocks
with gr.Blocks(css = """#label_mid {padding-top: 2px; padding-bottom: 2px;}
                        #label_results {padding-top: 5px; padding-bottom: 1px;}
                        #col-container {max-width: 580px; margin-left: auto; margin-right: auto;}
                        #accordion {max-width: 580px; margin-left: auto; margin-right: auto;}
                        #img_search img {margin: 10px; max-width: 300px; max-height: 300px;}
                        """) as demo:
  gr.HTML("""<div style="text-align: center; max-width: 700px; margin: 0 auto;">
        <div
        style="
            display: inline-flex;
            align-items: center;
            gap: 0.8rem;
            font-size: 1.75rem;
        "
        >
        <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
            Using Gradio Demos as API</h1><br></div>
        <div><h4 style="font-weight: 500; margin-bottom: 7px; margin-top: 5px;">
            Get BLIP2 captions from <a href="https://langchain.readthedocs.io/en/latest/" target="_blank">Niels space</a> via API call,<br> 
            Use LangChain to create vector space with PlaygroundAI prompts</h4><br>
        </div>""")
  with gr.Accordion(label="Details about the working:", open=False, elem_id='accordion'):
    gr.HTML("""
        <p style="margin-bottom: 10px; font-size: 90%"><br>
        ▶️Do you see the "view api" link located in the footer of this application? 
        By clicking on this link, a page will open which provides documentation on the REST API that developers can use to query the Interface function / Block events.<br>
        ▶️In this demo, the first step involves making an API call to the BLIP2 Gradio demo to retrieve image captions. 
        Next, Langchain is used to create an embedding and vector space for the image prompts and their respective "source" from the PlaygroundAI dataset.
        Finally, a similarity search is performed over the vector space and the top result is returned.        
        </p></div>""")
  #with gr.Column(scale=3):
  #  pass
  with gr.Column(elem_id = "col-container"):
    label_top = gr.HTML(value= "<center>🖼️ Please upload an Image here👇 that will be used as your search query</center>", elem_id="label_top")
    image_in = gr.Image(label="Upoload an Image for search", type='filepath', elem_id="image_in")
    label_mid = gr.HTML(value= "<p style='text-align: center; color: red;'>Or</center></p>", elem_id='label_mid')
    label_bottom = gr.HTML(value= "<center>🔍Type in your serch query and press Enter 👇</center>", elem_id="label_bottom")
    search_query = gr.Textbox(placeholder="Example: A small cat sitting", label="", elem_id="search_query", value='')
    label_results = gr.HTML(value= "<p style='text-align: center; color: blue; font-weight: bold;'>👇These Search results are from PlaygroundAI 'Liked_Images' dataset available on <a href='https://github.com/playgroundai/liked_images' _target='blank'>github</a></center></p>", elem_id="label_results")
    img_search = gr.HTML(label = 'Image search results from PlaygroundAI dataset', elem_id="img_search")
    #pai_prompt = gr.Textbox(label="Image prompt from PlaygroundAI dataset", elem_id="pai_prompt")
    #b1 = gr.Button("Retry").style(full_width=False)
      
  gr.HTML('''<center><a href="https://huggingface.co/spaces/ysharma/Blip_PlaygroundAI?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a></center> 
        </p></div>''')
  
  image_in.change(Image_similarity_search, [image_in, search_query], [img_search], api_name="PlaygroundAI_image_search" )
  search_query.submit(Image_similarity_search, [image_in, search_query], [img_search], api_name='PlaygroundAI_text_search' )
  #b1.click(Image_similarity_search, [image_in, search_query], [pai_prompt, img_search] ) 

demo.launch(debug=True)