bwconrad's picture
Add description
ff372e6
import os
import pickle
from io import BytesIO
import pandas as pd
import requests
import streamlit as st
from inference import retrieve, rerank
def get_data(results: pd.DataFrame, data: pd.DataFrame, reranked=False):
"""Given the corpus indices of the top-k series get the required data for the UI"""
if reranked:
scores_list = results["cross-score"].tolist()
else:
scores_list = results.score.tolist()
titles, scores, covers, urls = [], [], [], []
for idx, score in zip(results.corpus_id.tolist(), scores_list):
titles.append(data.iloc[idx].romaji)
scores.append(score)
covers.append(data.iloc[idx].cover)
urls.append(data.iloc[idx].url)
return titles, scores, covers, urls
def add_descriptions_to_results(results: pd.DataFrame):
"""Add the corresponding description to the retrieval results"""
idxs = results["corpus_id"].tolist()
descs = data.iloc[idxs].input.tolist()
results["desc"] = descs
return results
# Input UI
st.title("Manga Semantic Search")
st.markdown(
"""
An application to search for manga series using text descriptions of it's content. Find the name of a series based on a vague recollection about its story.
Performs a semantic retrieve and re-rank search using Sentence Transform models.
Source code for this application can be found in the repo [here](https://github.com/bwconrad/manga-semantic-search).
__Note__: The current database only includes the top-1000 manga series on AniList.
"""
)
query = st.text_input(
"Enter a description of the manga you are searching for:",
value="",
)
embeddings_path = st.selectbox("Embeddings Corpus", os.listdir("embeddings"))
top_k = st.number_input(
"Number of results", value=5, min_value=1, max_value=100, step=1
)
do_rerank = st.checkbox("Re-Rank", value=True)
k_retrieve = None
if do_rerank:
k_retrieve = st.number_input(
"Number of initialy retrieved series",
value=50,
min_value=1,
max_value=500,
step=1,
)
# Convert UI values into the correct function argument values
model_name = str(embeddings_path).split(".")[-2]
embeddings_path = os.path.join("embeddings", str(embeddings_path))
# Output UI
if st.button("Search"):
if not k_retrieve:
k_retrieve = top_k
# Check that query is not empty
if not query:
st.write("Please enter a query")
# Check that top_k is not > retrieve_k
elif top_k > k_retrieve:
st.write(
"'Number of results' should be less than or equal to 'Number of number of initialy retrieved series'"
)
else:
# Load embeddings and corresponding data table
with open(embeddings_path, "rb") as f:
data, corpus_embeddings = pickle.load(f).values()
# Retrieve most similar series
results = retrieve(
query,
corpus_embeddings=corpus_embeddings,
model_name=model_name,
top_k=int(k_retrieve),
)
# Re-rank the retrieved series
if do_rerank:
results = add_descriptions_to_results(results)
results = rerank(query, results, top_k=int(top_k))
# Display results
titles, scores, covers, urls = get_data(results, data, do_rerank)
for title, score, cover, url in zip(titles, scores, covers, urls):
with st.container():
col1, col2 = st.columns(2)
with col1:
st.markdown(
f"""
## [{title}]({url})
Score: {score:.2f}
"""
)
with col2:
response = requests.get(cover)
img = BytesIO(response.content)
st.image(img, width=200)