File size: 3,879 Bytes
d553e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff372e6
 
 
 
 
 
 
 
 
 
d553e7f
 
 
 
ff372e6
 
d553e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff372e6
d553e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import pickle
from io import BytesIO

import pandas as pd
import requests
import streamlit as st

from inference import retrieve, rerank


def get_data(results: pd.DataFrame, data: pd.DataFrame, reranked=False):
    """Given the corpus indices of the top-k series get the required data for the UI"""
    if reranked:
        scores_list = results["cross-score"].tolist()
    else:
        scores_list = results.score.tolist()

    titles, scores, covers, urls = [], [], [], []
    for idx, score in zip(results.corpus_id.tolist(), scores_list):
        titles.append(data.iloc[idx].romaji)
        scores.append(score)
        covers.append(data.iloc[idx].cover)
        urls.append(data.iloc[idx].url)

    return titles, scores, covers, urls


def add_descriptions_to_results(results: pd.DataFrame):
    """Add the corresponding description to the retrieval results"""
    idxs = results["corpus_id"].tolist()
    descs = data.iloc[idxs].input.tolist()
    results["desc"] = descs
    return results


# Input UI
st.title("Manga Semantic Search")
st.markdown(
    """
    An application to search for manga series using text descriptions of it's content. Find the name of a series based on a vague recollection about its story.
    Performs a semantic retrieve and re-rank search using Sentence Transform models.
    Source code for this application can be found in the repo [here](https://github.com/bwconrad/manga-semantic-search).

    __Note__: The current database only includes the top-1000 manga series on AniList.
    """
)

query = st.text_input(
    "Enter a description of the manga you are searching for:",
    value="",
)


embeddings_path = st.selectbox("Embeddings Corpus", os.listdir("embeddings"))
top_k = st.number_input(
    "Number of results", value=5, min_value=1, max_value=100, step=1
)
do_rerank = st.checkbox("Re-Rank", value=True)
k_retrieve = None
if do_rerank:
    k_retrieve = st.number_input(
        "Number of initialy retrieved series",
        value=50,
        min_value=1,
        max_value=500,
        step=1,
    )


# Convert UI values into the correct function argument values
model_name = str(embeddings_path).split(".")[-2]
embeddings_path = os.path.join("embeddings", str(embeddings_path))


# Output UI
if st.button("Search"):
    if not k_retrieve:
        k_retrieve = top_k

    # Check that query is not empty
    if not query:
        st.write("Please enter a query")
    # Check that top_k is not > retrieve_k
    elif top_k > k_retrieve:
        st.write(
            "'Number of results' should be less than or equal to 'Number of number of initialy retrieved series'"
        )
    else:
        # Load embeddings and corresponding data table
        with open(embeddings_path, "rb") as f:
            data, corpus_embeddings = pickle.load(f).values()

        # Retrieve most similar series
        results = retrieve(
            query,
            corpus_embeddings=corpus_embeddings,
            model_name=model_name,
            top_k=int(k_retrieve),
        )
        # Re-rank the retrieved series
        if do_rerank:
            results = add_descriptions_to_results(results)
            results = rerank(query, results, top_k=int(top_k))

        # Display results
        titles, scores, covers, urls = get_data(results, data, do_rerank)
        for title, score, cover, url in zip(titles, scores, covers, urls):
            with st.container():
                col1, col2 = st.columns(2)
                with col1:
                    st.markdown(
                        f"""
                            ## [{title}]({url})
                            Score: {score:.2f}
                        """
                    )
                with col2:
                    response = requests.get(cover)
                    img = BytesIO(response.content)
                    st.image(img, width=200)