manga-semantic-search / inference.py
bwconrad
Add application file
d553e7f
import pickle
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer, util, CrossEncoder
def retrieve(
query: str,
corpus_embeddings: torch.Tensor,
top_k: int = 5,
model_name: str = "all-mpnet-base-v2",
):
"""Retrieve the most similar series in a corpus given a query"""
# Embed query
model = SentenceTransformer(model_name)
prompt_embedding = model.encode(query, convert_to_tensor=True)
# Find most similar
results = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=top_k)[0]
results = pd.DataFrame(results, columns=["corpus_id", "score"])
return results
def rerank(
query: str,
retrieved: pd.DataFrame,
top_k: int = 5,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
):
"""Re-rank the retrieved series"""
# Create pairs of query and descriptions
inp = [[query, desc] for desc in retrieved["desc"]]
# Get scores for each pair
cross_encoder = CrossEncoder(model_name)
cross_scores = cross_encoder.predict(inp)
retrieved["cross-score"] = cross_scores
# Keep top-k after re-ranking
results = retrieved.sort_values("cross-score", ascending=False).iloc[:top_k]
return results
if __name__ == "__main__":
with open("embeddings/desc-embeddings.all-mpnet-base-v2.pkl", "rb") as f:
data, corpus_embeddings = pickle.load(f).values()
q = "a series about people battling each other in cooking competitions"
results = retrieve(q, corpus_embeddings, top_k=50)
idxs = results["corpus_id"].tolist()
descs = data.iloc[idxs].input.tolist()
results["desc"] = descs
print(results)
reranked = rerank(q, results, top_k=5)
print(reranked)