import pprint import json import argparse import requests from defaults import OWNER, REPO, TOKEN from sentence_transformers import SentenceTransformer import numpy as np model_id = "all-mpnet-base-v2" model = SentenceTransformer(model_id) def load_embeddings(): """ Function to load embeddings from file """ embeddings = np.load("issue_embeddings.npy") return embeddings def load_issue_information(): """ Function to load issue information from file """ with open("embedding_index_to_issue.json", "r") as f: embedding_index_to_issue = json.load(f) with open("issues_dict.json", "r") as f: issues = json.load(f) return embedding_index_to_issue, issues def cosine_similarity(a, b): if a.ndim == 1: a = a.reshape(1, -1) if b.ndim == 1: b = b.reshape(1, -1) return np.dot(a, b.T) / (np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)) def get_similar_issues(issue_no, top_k=5, token=TOKEN, owner=OWNER, repo=REPO): """ Function to find similar issues """ url = f"https://api.github.com/repos/{owner}/{repo}/issues" headers = { "Accept": "application/vnd.github+json", f"Authorization": "{token}", "X-GitHub-Api-Version": "2022-11-28", "User-Agent": "amyeroberts", } request = requests.get( f"https://api.github.com/repos/{OWNER}/{REPO}/issues/{issue_no}", headers=headers, ) if request.status_code != 200: raise ValueError(f"Request failed with status code {request.status_code}") query_embedding = model.encode(request.json()["body"]) query_embedding = query_embedding.reshape(1, -1) embeddings = load_embeddings() # Calculate the cosine similarity between the query and all the issues cosine_similarities = cosine_similarity(query_embedding, embeddings) # Get the index of the most similar issue most_similar_indices = np.argsort(cosine_similarities) most_similar_indices = most_similar_indices[0][::-1] embedding_index_to_issue, issues = load_issue_information() similar_issues = [] for i in most_similar_indices[:top_k]: issue_no = embedding_index_to_issue[str(i)] similar_issues.append(issues[issue_no]) return similar_issues if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("issue_no", type=int) parser.add_argument("--top_k", type=int, default=5) parser.add_argument("--token", type=str, default=TOKEN) parser.add_argument("--owner", type=str, default=OWNER) parser.add_argument("--repo", type=str, default=REPO) args = parser.parse_args() get_similar_issues(args.issue_no, args.top_k, args.token, args.owner, args.repo)