bwconrad commited on
Commit
d553e7f
1 Parent(s): 57a0722

Add application file

Browse files
Files changed (3) hide show
  1. app.py +109 -0
  2. inference.py +62 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from io import BytesIO
4
+
5
+ import pandas as pd
6
+ import requests
7
+ import streamlit as st
8
+
9
+ from inference import retrieve, rerank
10
+
11
+
12
+ def get_data(results: pd.DataFrame, data: pd.DataFrame, reranked=False):
13
+ """Given the corpus indices of the top-k series get the required data for the UI"""
14
+ if reranked:
15
+ scores_list = results["cross-score"].tolist()
16
+ else:
17
+ scores_list = results.score.tolist()
18
+
19
+ titles, scores, covers, urls = [], [], [], []
20
+ for idx, score in zip(results.corpus_id.tolist(), scores_list):
21
+ titles.append(data.iloc[idx].romaji)
22
+ scores.append(score)
23
+ covers.append(data.iloc[idx].cover)
24
+ urls.append(data.iloc[idx].url)
25
+
26
+ return titles, scores, covers, urls
27
+
28
+
29
+ def add_descriptions_to_results(results: pd.DataFrame):
30
+ """Add the corresponding description to the retrieval results"""
31
+ idxs = results["corpus_id"].tolist()
32
+ descs = data.iloc[idxs].input.tolist()
33
+ results["desc"] = descs
34
+ return results
35
+
36
+
37
+ # Input UI
38
+ st.title("Manga Semantic Search")
39
+ query = st.text_input(
40
+ "Enter a description of the manga you are searching for:",
41
+ value="",
42
+ )
43
+ embeddings_path = st.selectbox("Embeddings Corpus", os.listdir("embeddings"))
44
+ top_k = st.number_input(
45
+ "Number of results", value=5, min_value=1, max_value=100, step=1
46
+ )
47
+ do_rerank = st.checkbox("Re-Rank", value=True)
48
+ k_retrieve = None
49
+ if do_rerank:
50
+ k_retrieve = st.number_input(
51
+ "Number of initialy retrieved series",
52
+ value=50,
53
+ min_value=1,
54
+ max_value=500,
55
+ step=1,
56
+ )
57
+
58
+
59
+ # Convert UI values into the correct function argument values
60
+ model_name = str(embeddings_path).split(".")[-2]
61
+ embeddings_path = os.path.join("embeddings", str(embeddings_path))
62
+
63
+
64
+ # Output UI
65
+ if st.button("Search"):
66
+ if not k_retrieve:
67
+ k_retrieve = top_k
68
+
69
+ # Check that query is not empty
70
+ if not query:
71
+ st.write("Please enter a query")
72
+ # Check that top_k is not > retrieve_k
73
+ elif top_k > k_retrieve:
74
+ st.write(
75
+ "'Number of results' should be less than or equal to 'Number of number of initialy retrieved series'"
76
+ )
77
+ else:
78
+ # Load embedddings and corresponding data table
79
+ with open(embeddings_path, "rb") as f:
80
+ data, corpus_embeddings = pickle.load(f).values()
81
+
82
+ # Retrieve most similar series
83
+ results = retrieve(
84
+ query,
85
+ corpus_embeddings=corpus_embeddings,
86
+ model_name=model_name,
87
+ top_k=int(k_retrieve),
88
+ )
89
+ # Re-rank the retrieved series
90
+ if do_rerank:
91
+ results = add_descriptions_to_results(results)
92
+ results = rerank(query, results, top_k=int(top_k))
93
+
94
+ # Display results
95
+ titles, scores, covers, urls = get_data(results, data, do_rerank)
96
+ for title, score, cover, url in zip(titles, scores, covers, urls):
97
+ with st.container():
98
+ col1, col2 = st.columns(2)
99
+ with col1:
100
+ st.markdown(
101
+ f"""
102
+ ## [{title}]({url})
103
+ Score: {score:.2f}
104
+ """
105
+ )
106
+ with col2:
107
+ response = requests.get(cover)
108
+ img = BytesIO(response.content)
109
+ st.image(img, width=200)
inference.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
6
+
7
+
8
+ def retrieve(
9
+ query: str,
10
+ corpus_embeddings: torch.Tensor,
11
+ top_k: int = 5,
12
+ model_name: str = "all-mpnet-base-v2",
13
+ ):
14
+ """Retrieve the most similar series in a corpus given a query"""
15
+
16
+ # Embed query
17
+ model = SentenceTransformer(model_name)
18
+ prompt_embedding = model.encode(query, convert_to_tensor=True)
19
+
20
+ # Find most similar
21
+ results = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=top_k)[0]
22
+ results = pd.DataFrame(results, columns=["corpus_id", "score"])
23
+
24
+ return results
25
+
26
+
27
+ def rerank(
28
+ query: str,
29
+ retrieved: pd.DataFrame,
30
+ top_k: int = 5,
31
+ model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
32
+ ):
33
+ """Re-rank the retrieved series"""
34
+
35
+ # Create pairs of query and descriptions
36
+ inp = [[query, desc] for desc in retrieved["desc"]]
37
+
38
+ # Get scores for each pair
39
+ cross_encoder = CrossEncoder(model_name)
40
+ cross_scores = cross_encoder.predict(inp)
41
+ retrieved["cross-score"] = cross_scores
42
+
43
+ # Keep top-k after re-ranking
44
+ results = retrieved.sort_values("cross-score", ascending=False).iloc[:top_k]
45
+
46
+ return results
47
+
48
+
49
+ if __name__ == "__main__":
50
+ with open("embeddings/desc-embeddings.all-mpnet-base-v2.pkl", "rb") as f:
51
+ data, corpus_embeddings = pickle.load(f).values()
52
+
53
+ q = "a series about people battling each other in cooking competitions"
54
+ results = retrieve(q, corpus_embeddings, top_k=50)
55
+
56
+ idxs = results["corpus_id"].tolist()
57
+ descs = data.iloc[idxs].input.tolist()
58
+ results["desc"] = descs
59
+ print(results)
60
+
61
+ reranked = rerank(q, results, top_k=5)
62
+ print(reranked)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pandas==2.0.1
2
+ sentence_transformers==2.2.2
3
+ streamlit==1.22.0
4
+ torch==2.0.0