vinid commited on
Commit
af14fe8
1 Parent(s): 3c34cd0

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +63 -0
  2. tweet_eval_embeddings.npy +3 -0
  3. tweet_eval_retrieval.tsv +0 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from plip_support import embed_text
4
+ import numpy as np
5
+ from PIL import Image
6
+ import requests
7
+ from io import BytesIO
8
+ import streamlit as st
9
+ import clip
10
+ import torch
11
+ from transformers import (
12
+ VisionTextDualEncoderModel,
13
+ AutoFeatureExtractor,
14
+ AutoTokenizer
15
+ )
16
+ from transformers import AutoProcessor
17
+
18
+
19
+ def embed_texts(model, texts, processor):
20
+ inputs = processor(text=texts, padding="longest")
21
+ input_ids = torch.tensor(inputs["input_ids"])
22
+ attention_mask = torch.tensor(inputs["attention_mask"])
23
+
24
+ with torch.no_grad():
25
+ embeddings = model.get_text_features(
26
+ input_ids=input_ids, attention_mask=attention_mask
27
+ )
28
+ return embeddings
29
+
30
+ @st.cache_resource
31
+ def load_embeddings(embeddings_path):
32
+ print("loading embeddings")
33
+ return np.load(embeddings_path)
34
+
35
+ @st.cache_resource
36
+ def load_path_clip():
37
+ model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
38
+ processor = AutoProcessor.from_pretrained("clip-italian/clip-italian")
39
+ return model, processor
40
+
41
+ st.title('PLIP Image Search')
42
+
43
+ plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
44
+
45
+ model, processor = load_path_clip()
46
+
47
+ image_embedding = load_embeddings("tweet_eval_embeddings.npy")
48
+
49
+ query = st.text_input('Search Query', '')
50
+
51
+
52
+ if query:
53
+
54
+ text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
55
+
56
+ text_embedding = text_embedding/np.linalg.norm(text_embedding)
57
+
58
+ best_id = np.argmax(text_embedding.dot(image_embedding.T))
59
+ url = (plip_dataset.iloc[best_id]["imageURL"])
60
+
61
+ response = requests.get(url)
62
+ img = Image.open(BytesIO(response.content))
63
+ st.image(img)
tweet_eval_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36e445b069b1d937a0a780ddeab9239df5fd13264e8cd1f6cf033be3210352e1
3
+ size 2401408
tweet_eval_retrieval.tsv ADDED
The diff for this file is too large to render. See raw diff