Spaces:
Runtime error
Runtime error
File size: 3,879 Bytes
d553e7f ff372e6 d553e7f ff372e6 d553e7f ff372e6 d553e7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import pickle
from io import BytesIO
import pandas as pd
import requests
import streamlit as st
from inference import retrieve, rerank
def get_data(results: pd.DataFrame, data: pd.DataFrame, reranked=False):
"""Given the corpus indices of the top-k series get the required data for the UI"""
if reranked:
scores_list = results["cross-score"].tolist()
else:
scores_list = results.score.tolist()
titles, scores, covers, urls = [], [], [], []
for idx, score in zip(results.corpus_id.tolist(), scores_list):
titles.append(data.iloc[idx].romaji)
scores.append(score)
covers.append(data.iloc[idx].cover)
urls.append(data.iloc[idx].url)
return titles, scores, covers, urls
def add_descriptions_to_results(results: pd.DataFrame):
"""Add the corresponding description to the retrieval results"""
idxs = results["corpus_id"].tolist()
descs = data.iloc[idxs].input.tolist()
results["desc"] = descs
return results
# Input UI
st.title("Manga Semantic Search")
st.markdown(
"""
An application to search for manga series using text descriptions of it's content. Find the name of a series based on a vague recollection about its story.
Performs a semantic retrieve and re-rank search using Sentence Transform models.
Source code for this application can be found in the repo [here](https://github.com/bwconrad/manga-semantic-search).
__Note__: The current database only includes the top-1000 manga series on AniList.
"""
)
query = st.text_input(
"Enter a description of the manga you are searching for:",
value="",
)
embeddings_path = st.selectbox("Embeddings Corpus", os.listdir("embeddings"))
top_k = st.number_input(
"Number of results", value=5, min_value=1, max_value=100, step=1
)
do_rerank = st.checkbox("Re-Rank", value=True)
k_retrieve = None
if do_rerank:
k_retrieve = st.number_input(
"Number of initialy retrieved series",
value=50,
min_value=1,
max_value=500,
step=1,
)
# Convert UI values into the correct function argument values
model_name = str(embeddings_path).split(".")[-2]
embeddings_path = os.path.join("embeddings", str(embeddings_path))
# Output UI
if st.button("Search"):
if not k_retrieve:
k_retrieve = top_k
# Check that query is not empty
if not query:
st.write("Please enter a query")
# Check that top_k is not > retrieve_k
elif top_k > k_retrieve:
st.write(
"'Number of results' should be less than or equal to 'Number of number of initialy retrieved series'"
)
else:
# Load embeddings and corresponding data table
with open(embeddings_path, "rb") as f:
data, corpus_embeddings = pickle.load(f).values()
# Retrieve most similar series
results = retrieve(
query,
corpus_embeddings=corpus_embeddings,
model_name=model_name,
top_k=int(k_retrieve),
)
# Re-rank the retrieved series
if do_rerank:
results = add_descriptions_to_results(results)
results = rerank(query, results, top_k=int(top_k))
# Display results
titles, scores, covers, urls = get_data(results, data, do_rerank)
for title, score, cover, url in zip(titles, scores, covers, urls):
with st.container():
col1, col2 = st.columns(2)
with col1:
st.markdown(
f"""
## [{title}]({url})
Score: {score:.2f}
"""
)
with col2:
response = requests.get(cover)
img = BytesIO(response.content)
st.image(img, width=200)
|