import os import pickle from io import BytesIO import pandas as pd import requests import streamlit as st from inference import retrieve, rerank def get_data(results: pd.DataFrame, data: pd.DataFrame, reranked=False): """Given the corpus indices of the top-k series get the required data for the UI""" if reranked: scores_list = results["cross-score"].tolist() else: scores_list = results.score.tolist() titles, scores, covers, urls = [], [], [], [] for idx, score in zip(results.corpus_id.tolist(), scores_list): titles.append(data.iloc[idx].romaji) scores.append(score) covers.append(data.iloc[idx].cover) urls.append(data.iloc[idx].url) return titles, scores, covers, urls def add_descriptions_to_results(results: pd.DataFrame): """Add the corresponding description to the retrieval results""" idxs = results["corpus_id"].tolist() descs = data.iloc[idxs].input.tolist() results["desc"] = descs return results # Input UI st.title("Manga Semantic Search") st.markdown( """ An application to search for manga series using text descriptions of it's content. Find the name of a series based on a vague recollection about its story. Performs a semantic retrieve and re-rank search using Sentence Transform models. Source code for this application can be found in the repo [here](https://github.com/bwconrad/manga-semantic-search). __Note__: The current database only includes the top-1000 manga series on AniList. """ ) query = st.text_input( "Enter a description of the manga you are searching for:", value="", ) embeddings_path = st.selectbox("Embeddings Corpus", os.listdir("embeddings")) top_k = st.number_input( "Number of results", value=5, min_value=1, max_value=100, step=1 ) do_rerank = st.checkbox("Re-Rank", value=True) k_retrieve = None if do_rerank: k_retrieve = st.number_input( "Number of initialy retrieved series", value=50, min_value=1, max_value=500, step=1, ) # Convert UI values into the correct function argument values model_name = str(embeddings_path).split(".")[-2] embeddings_path = os.path.join("embeddings", str(embeddings_path)) # Output UI if st.button("Search"): if not k_retrieve: k_retrieve = top_k # Check that query is not empty if not query: st.write("Please enter a query") # Check that top_k is not > retrieve_k elif top_k > k_retrieve: st.write( "'Number of results' should be less than or equal to 'Number of number of initialy retrieved series'" ) else: # Load embeddings and corresponding data table with open(embeddings_path, "rb") as f: data, corpus_embeddings = pickle.load(f).values() # Retrieve most similar series results = retrieve( query, corpus_embeddings=corpus_embeddings, model_name=model_name, top_k=int(k_retrieve), ) # Re-rank the retrieved series if do_rerank: results = add_descriptions_to_results(results) results = rerank(query, results, top_k=int(top_k)) # Display results titles, scores, covers, urls = get_data(results, data, do_rerank) for title, score, cover, url in zip(titles, scores, covers, urls): with st.container(): col1, col2 = st.columns(2) with col1: st.markdown( f""" ## [{title}]({url}) Score: {score:.2f} """ ) with col2: response = requests.get(cover) img = BytesIO(response.content) st.image(img, width=200)