transformers-github-bot / find_similar_issues.py
Amy Roberts
Updates
c1fc690
raw
history blame
3.33 kB
import pprint
import json
import argparse
import requests
from defaults import OWNER, REPO, TOKEN
from sentence_transformers import SentenceTransformer
import numpy as np
model_id = "all-mpnet-base-v2"
model = SentenceTransformer(model_id)
def load_embeddings():
"""
Function to load embeddings from file
"""
embeddings = np.load("issue_embeddings.npy")
return embeddings
def load_issue_information():
"""
Function to load issue information from file
"""
with open("embedding_index_to_issue.json", "r") as f:
embedding_index_to_issue = json.load(f)
with open("issues_dict.json", "r") as f:
issues = json.load(f)
return embedding_index_to_issue, issues
def cosine_similarity(a, b):
if a.ndim == 1:
a = a.reshape(1, -1)
if b.ndim == 1:
b = b.reshape(1, -1)
return np.dot(a, b.T) / (np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1))
def get_issue(issue_no, token=TOKEN, owner=OWNER, repo=REPO):
"""
Function to get issue from GitHub
"""
url = f"https://api.github.com/repos/{owner}/{repo}/issues"
headers = {
"Accept": "application/vnd.github+json",
f"Authorization": "{token}",
"X-GitHub-Api-Version": "2022-11-28",
"User-Agent": "amyeroberts",
}
request = requests.get(
f"https://api.github.com/repos/{OWNER}/{REPO}/issues/{issue_no}",
headers=headers,
)
if request.status_code != 200:
raise ValueError(f"Request failed with status code {request.status_code}")
return request.json()
def get_similar_issues(issue_no, query, top_k=5, token=TOKEN, owner=OWNER, repo=REPO):
"""
Function to find similar issues
"""
if issue_no is not None and query is not None:
raise ValueError("Only one of issue_no or query can be provided")
if issue_no is not None and query is not None:
raise ValueError("Only one of issue_no or query can be provided")
if issue_no is not None:
issue = get_issue(issue_no, token=token, owner=owner, repo=repo)
query = issue["title"] + "\n" +issue["body"]
query_embedding = model.encode(query)
query_embedding = query_embedding.reshape(1, -1)
embeddings = load_embeddings()
# Calculate the cosine similarity between the query and all the issues
cosine_similarities = cosine_similarity(query_embedding, embeddings)
# Get the index of the most similar issue
most_similar_indices = np.argsort(cosine_similarities)
most_similar_indices = most_similar_indices[0][::-1]
embedding_index_to_issue, issues = load_issue_information()
similar_issues = []
for i in most_similar_indices[:top_k]:
issue_no = embedding_index_to_issue[str(i)]
similar_issues.append(issues[issue_no])
return similar_issues
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--issue_no", type=int, default=None)
parser.add_argument("--query", type=str, default=None)
parser.add_argument("--top_k", type=int, default=5)
parser.add_argument("--token", type=str, default=TOKEN)
parser.add_argument("--owner", type=str, default=OWNER)
parser.add_argument("--repo", type=str, default=REPO)
args = parser.parse_args()
get_similar_issues(**vars(args))