""" Module which builds embeddings for issues and pull requests The module is designed to be run from the command line and takes the following arguments: --input_filename: The name of the file containing the issues and pull requests --model_id: The name of the sentence transformer model to use --issue_type: The type of issue to embed (either "issue" or "pull") --n_issues: The number of issues to embed --update: Whether to update the existing embeddings The module saves the embeddings to a file called _embeddings.npy and the index to a file called embedding_index_to_.json The index provides a mapping from the index of the embedding to the issue or pull request number. """ import argparse import json import logging import os import numpy as np from sentence_transformers import SentenceTransformer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def load_model(model_id: str): return SentenceTransformer(model_id) class EmbeddingWriter: def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None: self.output_embedding_filename = output_embedding_filename self.output_index_filename = output_index_filename self.embeddings = [] self.embedding_to_issue_index = embedding_to_issue_index self.update = update def __enter__(self): return self.embeddings def __exit__(self, exc_type, exc_val, exc_tb): if len(self.embeddings) == 0: return embeddings = np.array(self.embeddings) if self.update and os.path.exists(self.output_embedding_filename): embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings]) logger.info(f"Saving embeddings to {self.output_embedding_filename}") np.save(self.output_embedding_filename, embeddings) logger.info(f"Saving embedding index to {self.output_index_filename}") with open(self.output_index_filename, "w") as f: json.dump(self.embedding_to_issue_index, f, indent=4) def embed_issues( input_filename: str, model_id: str, issue_type: str, n_issues: int = -1, update: bool = False ): model = load_model(model_id) output_embedding_filename = f"{issue_type}_embeddings.npy" output_index_filename = f"embedding_index_to_{issue_type}.json" with open(input_filename, "r") as f: issues = json.load(f) if update and os.path.exists(output_index_filename): with open(output_index_filename, "r") as f: embedding_to_issue_index = json.load(f) embedding_index = len(embedding_to_issue_index) else: embedding_to_issue_index = {} embedding_index = 0 max_issues = n_issues if n_issues > 0 else len(issues) n_issues = 0 with EmbeddingWriter( output_embedding_filename=output_embedding_filename, output_index_filename=output_index_filename, update=update, embedding_to_issue_index=embedding_to_issue_index ) as embeddings: #, embedding_to_issue_index: for issue_id, issue in issues.items(): if n_issues >= max_issues: break if issue_id in embedding_to_issue_index.values() and update: logger.info(f"Skipping issue {issue_id} as it is already embedded") continue if "body" not in issue: logger.info(f"Skipping issue {issue_id} as it has no body") continue if issue_type == "pull" and "pull_request" not in issue: logger.info(f"Skipping issue {issue_id} as it is not a pull request") continue elif issue_type == "issue" and "pull_request" in issue: logger.info(f"Skipping issue {issue_id} as it is a pull request") continue title = issue["title"] if issue["title"] is not None else "" body = issue["body"] if issue["body"] is not None else "" logger.info(f"Embedding issue {issue_id}") embedding = model.encode(title + "\n" + body) embedding_to_issue_index[embedding_index] = issue_id embeddings.append(embedding) embedding_index += 1 n_issues += 1 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue') parser.add_argument("--input_filename", type=str, default="issues_dict.json") parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2") parser.add_argument("--n_issues", type=int, default=-1) parser.add_argument("--update", action="store_true") args = parser.parse_args() embed_issues(**vars(args))