transformers-github-bot / update_embeddings.py
Amy Roberts
Updates
c1fc690
raw
history blame
3.71 kB
import argparse
import json
import logging
import os
import numpy as np
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_model(model_id: str):
return SentenceTransformer(model_id)
class EmbeddingWriter:
def __init__(
self,
output_embedding_filename,
output_index_filename,
update,
embedding_to_issue_index,
embeddings=None
) -> None:
self.output_embedding_filename = output_embedding_filename
self.output_index_filename = output_index_filename
self.embeddings = [] if embeddings is None else list(embeddings)
self.embedding_to_issue_index = embedding_to_issue_index
self.update = update
def __enter__(self):
return self.embeddings
def __exit__(self, exc_type, exc_val, exc_tb):
embeddings = np.array(self.embeddings)
if self.update and os.path.exists(self.output_embedding_filename):
embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings])
logger.info(f"Saving embeddings to {self.output_embedding_filename}")
np.save(self.output_embedding_filename, embeddings)
logger.info(f"Saving embedding index to {self.output_index_filename}")
with open(self.output_index_filename, "w") as f:
json.dump(self.embedding_to_issue_index, f, indent=4)
def embed_issues(
input_filename: str,
model_id: str,
issue_type: str,
):
output_embedding_filename = f"{issue_type}_embeddings.npy"
output_index_filename = f"embedding_index_to_{issue_type}.json"
model = load_model(model_id)
with open(input_filename, "r") as f:
updated_issues = json.load(f)
with open(output_index_filename, "r") as f:
embedding_to_issue_index = json.load(f)
embeddings = np.load(output_embedding_filename)
issue_to_embedding_index = {v: k for k, v in embedding_to_issue_index.items()}
with EmbeddingWriter(
output_embedding_filename=output_embedding_filename,
output_index_filename=output_index_filename,
update=False,
embedding_to_issue_index=embedding_to_issue_index,
embeddings=embeddings
) as embeddings:
for issue_id, issue in updated_issues.items():
if "body" not in issue:
logger.info(f"Skipping issue {issue_id} as it has no body")
continue
if issue_type == "pull_request" and "pull_request" not in issue:
logger.info(f"Skipping issue {issue_id} as it is not a pull request")
continue
elif issue_type == "issue" and "pull_request" in issue:
logger.info(f"Skipping issue {issue_id} as it is a pull request")
continue
logger.info(f"Embedding issue {issue_id}")
embedding = model.encode(issue["body"])
if issue_id in issue_to_embedding_index:
index = issue_to_embedding_index[issue_id]
embeddings[index] = embedding
else:
index = len(embeddings)
embeddings.append(embedding)
issue_to_embedding_index[issue_id] = index
embedding_to_issue_index[index] = issue_id
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue')
parser.add_argument("--input_filename", type=str, default="updated_issues.json")
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
args = parser.parse_args()
embed_issues(**vars(args))