Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
import logging | |
import os | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def load_model(model_id: str): | |
return SentenceTransformer(model_id) | |
class EmbeddingWriter: | |
def __init__( | |
self, | |
output_embedding_filename, | |
output_index_filename, | |
update, | |
embedding_to_issue_index, | |
embeddings=None | |
) -> None: | |
self.output_embedding_filename = output_embedding_filename | |
self.output_index_filename = output_index_filename | |
self.embeddings = [] if embeddings is None else list(embeddings) | |
self.embedding_to_issue_index = embedding_to_issue_index | |
self.update = update | |
def __enter__(self): | |
return self.embeddings | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
embeddings = np.array(self.embeddings) | |
if self.update and os.path.exists(self.output_embedding_filename): | |
embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings]) | |
logger.info(f"Saving embeddings to {self.output_embedding_filename}") | |
np.save(self.output_embedding_filename, embeddings) | |
logger.info(f"Saving embedding index to {self.output_index_filename}") | |
with open(self.output_index_filename, "w") as f: | |
json.dump(self.embedding_to_issue_index, f, indent=4) | |
def embed_issues( | |
input_filename: str, | |
model_id: str, | |
issue_type: str, | |
): | |
output_embedding_filename = f"{issue_type}_embeddings.npy" | |
output_index_filename = f"embedding_index_to_{issue_type}.json" | |
model = load_model(model_id) | |
with open(input_filename, "r") as f: | |
updated_issues = json.load(f) | |
with open(output_index_filename, "r") as f: | |
embedding_to_issue_index = json.load(f) | |
embeddings = np.load(output_embedding_filename) | |
issue_to_embedding_index = {v: k for k, v in embedding_to_issue_index.items()} | |
with EmbeddingWriter( | |
output_embedding_filename=output_embedding_filename, | |
output_index_filename=output_index_filename, | |
update=False, | |
embedding_to_issue_index=embedding_to_issue_index, | |
embeddings=embeddings | |
) as embeddings: | |
for issue_id, issue in updated_issues.items(): | |
if "body" not in issue: | |
logger.info(f"Skipping issue {issue_id} as it has no body") | |
continue | |
if issue_type == "pull_request" and "pull_request" not in issue: | |
logger.info(f"Skipping issue {issue_id} as it is not a pull request") | |
continue | |
elif issue_type == "issue" and "pull_request" in issue: | |
logger.info(f"Skipping issue {issue_id} as it is a pull request") | |
continue | |
logger.info(f"Embedding issue {issue_id}") | |
embedding = model.encode(issue["body"]) | |
if issue_id in issue_to_embedding_index: | |
index = issue_to_embedding_index[issue_id] | |
embeddings[index] = embedding | |
else: | |
index = len(embeddings) | |
embeddings.append(embedding) | |
issue_to_embedding_index[issue_id] = index | |
embedding_to_issue_index[index] = issue_id | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue') | |
parser.add_argument("--input_filename", type=str, default="updated_issues.json") | |
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2") | |
args = parser.parse_args() | |
embed_issues(**vars(args)) | |