File size: 1,705 Bytes
af14fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8c3351
af14fe8
 
 
 
d8c3351
af14fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import streamlit as st
import clip
import torch
from transformers import (
    VisionTextDualEncoderModel,
    AutoFeatureExtractor,
    AutoTokenizer
)
from transformers import AutoProcessor


def embed_texts(model, texts, processor):
    inputs = processor(text=texts, padding="longest")
    input_ids = torch.tensor(inputs["input_ids"])
    attention_mask = torch.tensor(inputs["attention_mask"])

    with torch.no_grad():
        embeddings = model.get_text_features(
            input_ids=input_ids, attention_mask=attention_mask
        )
    return embeddings

@st.cache_resource
def load_embeddings(embeddings_path):
    print("loading embeddings")
    return np.load(embeddings_path)

@st.cache_resource
def load_path_clip():
    model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
    processor = AutoProcessor.from_pretrained("clip-italian/clip-italian")
    return model, processor

st.title('PLIP Image Search')

plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")

model, processor = load_path_clip()

image_embedding = load_embeddings("tweet_eval_embeddings.npy")

query = st.text_input('Search Query', '')


if query:

    text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()

    text_embedding = text_embedding/np.linalg.norm(text_embedding)

    best_id = np.argmax(text_embedding.dot(image_embedding.T))
    url = (plip_dataset.iloc[best_id]["imageURL"])

    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    st.image(img)