import argparse import json import pprint import numpy as np from sentence_transformers import SentenceTransformer def cosine_similarity(a, b): if a.ndim == 1: a = a.reshape(1, -1) if b.ndim == 1: b = b.reshape(1, -1) return np.dot(a, b.T) / (np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)) def retrieve_issue_rankings( query: str, model_id: str, input_embedding_filename: str, ): """ Given a query returns the list of issues sorted by similarity to the query according to their embedding index """ model = SentenceTransformer(model_id) embeddings = np.load(input_embedding_filename) query_embedding = model.encode(query) # Calculate the cosine similarity between the query and all the issues cosine_similarities = cosine_similarity(query_embedding, embeddings) # Get the index of the most similar issue most_similar_indices = np.argsort(cosine_similarities) most_similar_indices = most_similar_indices[0][::-1] return most_similar_indices def print_issue(issues, issue_id): # Get the issue id of the most similar issue issue_info = issues[issue_id] print(f"#{issue_id}", issue_info["title"]) print(issue_info["body"]) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("query", type=str) parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2") parser.add_argument("--input_embedding_filename", type=str, default="issue_embeddings.npy") parser.add_argument("--input_index_filename", type=str, default="embedding_index_to_issue.json") args = parser.parse_args() issue_rankings = retrieve_issue_rankings( query=args.query, model_id=args.model_id, input_embedding_filename=args.input_embedding_filename, ) with open("issues_dict.json", "r") as f: issues = json.load(f) with open(args.input_index_filename, "r") as f: embedding_index_to_issue = json.load(f) issue_ids = [embedding_index_to_issue[str(i)] for i in issue_rankings] for issue_id in issue_ids[:3]: print(issue_id) print_issue(issues, issue_id) print("\n\n\n")