vinid commited on
Commit
7987133
1 Parent(s): dc3cb2a

image2image search

Browse files
Files changed (3) hide show
  1. app.py +5 -19
  2. home.py +1 -1
  3. image2image.py +109 -0
app.py CHANGED
@@ -1,31 +1,17 @@
1
- import streamlit as st
2
- import pandas as pd
3
  import home
4
- import numpy as np
5
- from PIL import Image
6
- import requests
7
- import transformers
8
  import text2image
9
- import zeroshot
10
- import tokenizers
11
- from io import BytesIO
12
  import streamlit as st
13
- from transformers import CLIPModel
14
- import clip
15
- import torch
16
- from transformers import (
17
- VisionTextDualEncoderModel,
18
- AutoFeatureExtractor,
19
- AutoTokenizer
20
- )
21
- from transformers import AutoProcessor
22
 
23
  st.sidebar.title("Explore our PLIP Demo")
24
 
25
  PAGES = {
26
  "Introduction": home,
27
  "Text to Image": text2image,
28
- "Image Prediction": zeroshot,
29
  }
30
 
31
  page = st.sidebar.radio("", list(PAGES.keys()))
 
 
 
1
  import home
 
 
 
 
2
  import text2image
3
+ import image2image
 
 
4
  import streamlit as st
5
+
6
+
7
+
 
 
 
 
 
 
8
 
9
  st.sidebar.title("Explore our PLIP Demo")
10
 
11
  PAGES = {
12
  "Introduction": home,
13
  "Text to Image": text2image,
14
+ "Image to Image": image2image,
15
  }
16
 
17
  page = st.sidebar.radio("", list(PAGES.keys()))
home.py CHANGED
@@ -11,7 +11,7 @@ def app():
11
  intro_markdown = read_markdown_file("introduction.md")
12
  st.markdown(intro_markdown, unsafe_allow_html=True)
13
 
14
- st.text('An example of twitter:')
15
  components.html('''
16
  <blockquote class="twitter-tweet">
17
  <a href="https://twitter.com/xxx/status/1580753362059788288"></a>
 
11
  intro_markdown = read_markdown_file("introduction.md")
12
  st.markdown(intro_markdown, unsafe_allow_html=True)
13
 
14
+ st.text('An example of tweet:')
15
  components.html('''
16
  <blockquote class="twitter-tweet">
17
  <a href="https://twitter.com/xxx/status/1580753362059788288"></a>
image2image.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from plip_support import embed_text
4
+ import numpy as np
5
+ from PIL import Image
6
+ import requests
7
+ import tokenizers
8
+ from io import BytesIO
9
+ import torch
10
+ from transformers import (
11
+ VisionTextDualEncoderModel,
12
+ AutoFeatureExtractor,
13
+ AutoTokenizer,
14
+ CLIPModel,
15
+ AutoProcessor
16
+ )
17
+ import streamlit.components.v1 as components
18
+
19
+
20
+ def embed_images(model, images, processor):
21
+ inputs = processor(images=images)
22
+ pixel_values = torch.tensor(np.array(inputs["pixel_values"]))
23
+
24
+ with torch.no_grad():
25
+ embeddings = model.get_image_features(pixel_values=pixel_values)
26
+ return embeddings
27
+
28
+ @st.cache
29
+ def load_embeddings(embeddings_path):
30
+ print("loading embeddings")
31
+ return np.load(embeddings_path)
32
+
33
+ @st.cache(
34
+ hash_funcs={
35
+ torch.nn.parameter.Parameter: lambda _: None,
36
+ tokenizers.Tokenizer: lambda _: None,
37
+ tokenizers.AddedToken: lambda _: None
38
+ }
39
+ )
40
+ def load_path_clip():
41
+ model = CLIPModel.from_pretrained("vinid/plip")
42
+ processor = AutoProcessor.from_pretrained("vinid/plip")
43
+ return model, processor
44
+
45
+
46
+ def app():
47
+ st.title('PLIP Image Search')
48
+
49
+ plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
50
+ plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t")
51
+
52
+ model, processor = load_path_clip()
53
+
54
+ image_embedding = load_embeddings("tweet_eval_embeddings.npy")
55
+
56
+ query = st.file_uploader("Choose a file")
57
+
58
+
59
+ if query:
60
+ image = Image.open(query)
61
+ single_image = embed_images(model, [image], processor)[0].detach().cpu().numpy()
62
+
63
+ single_image = single_image/np.linalg.norm(single_image)
64
+
65
+ # Sort IDs by cosine-similarity from high to low
66
+ similarity_scores = single_image.dot(image_embedding.T)
67
+ id_sorted = np.argsort(similarity_scores)[::-1]
68
+
69
+
70
+ best_id = id_sorted[0]
71
+ score = similarity_scores[best_id]
72
+
73
+ target_weblink = plip_weblink.iloc[best_id]["weblink"]
74
+
75
+ st.caption('Most relevant image (similarity = %.4f)' % score)
76
+
77
+ components.html('''
78
+ <blockquote class="twitter-tweet">
79
+ <a href="%s"></a>
80
+ </blockquote>
81
+ <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
82
+ </script>
83
+ ''' % target_weblink,
84
+ height=600)
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+