File size: 2,383 Bytes
7987133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import tokenizers
from io import BytesIO
import torch
from transformers import (
    VisionTextDualEncoderModel,
    AutoFeatureExtractor,
    AutoTokenizer,
    CLIPModel,
    AutoProcessor
)
import streamlit.components.v1 as components


def embed_images(model, images, processor):
    inputs = processor(images=images)
    pixel_values = torch.tensor(np.array(inputs["pixel_values"]))

    with torch.no_grad():
        embeddings = model.get_image_features(pixel_values=pixel_values)
    return embeddings

@st.cache
def load_embeddings(embeddings_path):
    print("loading embeddings")
    return np.load(embeddings_path)

@st.cache(
    hash_funcs={
        torch.nn.parameter.Parameter: lambda _: None,
        tokenizers.Tokenizer: lambda _: None,
        tokenizers.AddedToken: lambda _: None
    }
)
def load_path_clip():
    model = CLIPModel.from_pretrained("vinid/plip")
    processor = AutoProcessor.from_pretrained("vinid/plip")
    return model, processor


def app():
    st.title('PLIP Image Search')

    plip_imgURL = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
    plip_weblink = pd.read_csv("tweet_eval_retrieval_twlnk.tsv", sep="\t")

    model, processor = load_path_clip()

    image_embedding = load_embeddings("tweet_eval_embeddings.npy")

    query = st.file_uploader("Choose a file")


    if query:
        image = Image.open(query)
        single_image = embed_images(model, [image], processor)[0].detach().cpu().numpy()

        single_image = single_image/np.linalg.norm(single_image)
        
        # Sort IDs by cosine-similarity from high to low
        similarity_scores = single_image.dot(image_embedding.T)
        id_sorted = np.argsort(similarity_scores)[::-1]


        best_id = id_sorted[0]
        score = similarity_scores[best_id]

        target_weblink = plip_weblink.iloc[best_id]["weblink"]

        st.caption('Most relevant image (similarity = %.4f)' % score)

        components.html('''
            <blockquote class="twitter-tweet">
                <a href="%s"></a>
            </blockquote>
            <script async src="https://platform.twitter.com/widgets.js" charset="utf-8">
            </script>
            ''' % target_weblink,
        height=600)