transformers-github-bot / utils /build_embeddings.py
Amy Roberts
Move to utils
7d5704e
raw
history blame
No virus
4.86 kB
"""
Module which builds embeddings for issues and pull requests
The module is designed to be run from the command line and takes the following arguments:
--input_filename: The name of the file containing the issues and pull requests
--model_id: The name of the sentence transformer model to use
--issue_type: The type of issue to embed (either "issue" or "pull")
--n_issues: The number of issues to embed
--update: Whether to update the existing embeddings
The module saves the embeddings to a file called <issue_type>_embeddings.npy and the index to a file called
embedding_index_to_<issue_type>.json
The index provides a mapping from the index of the embedding to the issue or pull request number.
"""
import argparse
import json
import logging
import os
import numpy as np
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_model(model_id: str):
return SentenceTransformer(model_id)
class EmbeddingWriter:
def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None:
self.output_embedding_filename = output_embedding_filename
self.output_index_filename = output_index_filename
self.embeddings = []
self.embedding_to_issue_index = embedding_to_issue_index
self.update = update
def __enter__(self):
return self.embeddings
def __exit__(self, exc_type, exc_val, exc_tb):
if len(self.embeddings) == 0:
return
embeddings = np.array(self.embeddings)
if self.update and os.path.exists(self.output_embedding_filename):
embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings])
logger.info(f"Saving embeddings to {self.output_embedding_filename}")
np.save(self.output_embedding_filename, embeddings)
logger.info(f"Saving embedding index to {self.output_index_filename}")
with open(self.output_index_filename, "w") as f:
json.dump(self.embedding_to_issue_index, f, indent=4)
def embed_issues(
input_filename: str,
model_id: str,
issue_type: str,
n_issues: int = -1,
update: bool = False
):
model = load_model(model_id)
output_embedding_filename = f"{issue_type}_embeddings.npy"
output_index_filename = f"embedding_index_to_{issue_type}.json"
with open(input_filename, "r") as f:
issues = json.load(f)
if update and os.path.exists(output_index_filename):
with open(output_index_filename, "r") as f:
embedding_to_issue_index = json.load(f)
embedding_index = len(embedding_to_issue_index)
else:
embedding_to_issue_index = {}
embedding_index = 0
max_issues = n_issues if n_issues > 0 else len(issues)
n_issues = 0
with EmbeddingWriter(
output_embedding_filename=output_embedding_filename,
output_index_filename=output_index_filename,
update=update,
embedding_to_issue_index=embedding_to_issue_index
) as embeddings: #, embedding_to_issue_index:
for issue_id, issue in issues.items():
if n_issues >= max_issues:
break
if issue_id in embedding_to_issue_index.values() and update:
logger.info(f"Skipping issue {issue_id} as it is already embedded")
continue
if "body" not in issue:
logger.info(f"Skipping issue {issue_id} as it has no body")
continue
if issue_type == "pull" and "pull_request" not in issue:
logger.info(f"Skipping issue {issue_id} as it is not a pull request")
continue
elif issue_type == "issue" and "pull_request" in issue:
logger.info(f"Skipping issue {issue_id} as it is a pull request")
continue
title = issue["title"] if issue["title"] is not None else ""
body = issue["body"] if issue["body"] is not None else ""
logger.info(f"Embedding issue {issue_id}")
embedding = model.encode(title + "\n" + body)
embedding_to_issue_index[embedding_index] = issue_id
embeddings.append(embedding)
embedding_index += 1
n_issues += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue')
parser.add_argument("--input_filename", type=str, default="issues_dict.json")
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
parser.add_argument("--n_issues", type=int, default=-1)
parser.add_argument("--update", action="store_true")
args = parser.parse_args()
embed_issues(**vars(args))