transformers-github-bot / build_embeddings.py
Amy Roberts
Draft
9b744c5
raw
history blame
4.12 kB
import argparse
import json
import logging
import os
import numpy as np
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_model(model_id: str):
return SentenceTransformer(model_id)
class EmbeddingWriter:
def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None:
self.output_embedding_filename = output_embedding_filename
self.output_index_filename = output_index_filename
self.embeddings = []
self.embedding_to_issue_index = embedding_to_issue_index
self.update = update
def __enter__(self):
return self.embeddings
def __exit__(self, exc_type, exc_val, exc_tb):
if len(self.embeddings) == 0:
return
embeddings = np.array(self.embeddings)
if self.update and os.path.exists(self.output_embedding_filename):
embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings])
logger.info(f"Saving embeddings to {self.output_embedding_filename}")
np.save(self.output_embedding_filename, embeddings)
logger.info(f"Saving embedding index to {self.output_index_filename}")
with open(self.output_index_filename, "w") as f:
json.dump(self.embedding_to_issue_index, f, indent=4)
def embed_issues(
input_filename: str,
model_id: str,
issue_type: str,
n_issues: int = -1,
update: bool = False
):
model = load_model(model_id)
output_embedding_filename = f"{issue_type}_embeddings.npy"
output_index_filename = f"embedding_index_to_{issue_type}.json"
with open(input_filename, "r") as f:
issues = json.load(f)
if update and os.path.exists(output_index_filename):
with open(output_index_filename, "r") as f:
embedding_to_issue_index = json.load(f)
embedding_index = len(embedding_to_issue_index)
else:
embedding_to_issue_index = {}
embedding_index = 0
max_issues = n_issues if n_issues > 0 else len(issues)
n_issues = 0
with EmbeddingWriter(
output_embedding_filename=output_embedding_filename,
output_index_filename=output_index_filename,
update=update,
embedding_to_issue_index=embedding_to_issue_index
) as embeddings: #, embedding_to_issue_index:
for issue_id, issue in issues.items():
if n_issues >= max_issues:
break
if issue_id in embedding_to_issue_index.values() and update:
logger.info(f"Skipping issue {issue_id} as it is already embedded")
continue
if "body" not in issue:
logger.info(f"Skipping issue {issue_id} as it has no body")
continue
if issue_type == "pull_request" and "pull_request" not in issue:
logger.info(f"Skipping issue {issue_id} as it is not a pull request")
continue
elif issue_type == "issue" and "pull_request" in issue:
logger.info(f"Skipping issue {issue_id} as it is a pull request")
continue
title = issue["title"] if issue["title"] is not None else ""
body = issue["body"] if issue["body"] is not None else ""
logger.info(f"Embedding issue {issue_id}")
embedding = model.encode(title + "\n" + body)
embedding_to_issue_index[embedding_index] = issue_id
embeddings.append(embedding)
embedding_index += 1
n_issues += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue')
parser.add_argument("--input_filename", type=str, default="issues_dict.json")
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
parser.add_argument("--n_issues", type=int, default=-1)
parser.add_argument("--update", action="store_true")
args = parser.parse_args()
embed_issues(**vars(args))