Spaces:
Sleeping
Sleeping
File size: 4,119 Bytes
9b744c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import argparse
import json
import logging
import os
import numpy as np
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_model(model_id: str):
return SentenceTransformer(model_id)
class EmbeddingWriter:
def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None:
self.output_embedding_filename = output_embedding_filename
self.output_index_filename = output_index_filename
self.embeddings = []
self.embedding_to_issue_index = embedding_to_issue_index
self.update = update
def __enter__(self):
return self.embeddings
def __exit__(self, exc_type, exc_val, exc_tb):
if len(self.embeddings) == 0:
return
embeddings = np.array(self.embeddings)
if self.update and os.path.exists(self.output_embedding_filename):
embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings])
logger.info(f"Saving embeddings to {self.output_embedding_filename}")
np.save(self.output_embedding_filename, embeddings)
logger.info(f"Saving embedding index to {self.output_index_filename}")
with open(self.output_index_filename, "w") as f:
json.dump(self.embedding_to_issue_index, f, indent=4)
def embed_issues(
input_filename: str,
model_id: str,
issue_type: str,
n_issues: int = -1,
update: bool = False
):
model = load_model(model_id)
output_embedding_filename = f"{issue_type}_embeddings.npy"
output_index_filename = f"embedding_index_to_{issue_type}.json"
with open(input_filename, "r") as f:
issues = json.load(f)
if update and os.path.exists(output_index_filename):
with open(output_index_filename, "r") as f:
embedding_to_issue_index = json.load(f)
embedding_index = len(embedding_to_issue_index)
else:
embedding_to_issue_index = {}
embedding_index = 0
max_issues = n_issues if n_issues > 0 else len(issues)
n_issues = 0
with EmbeddingWriter(
output_embedding_filename=output_embedding_filename,
output_index_filename=output_index_filename,
update=update,
embedding_to_issue_index=embedding_to_issue_index
) as embeddings: #, embedding_to_issue_index:
for issue_id, issue in issues.items():
if n_issues >= max_issues:
break
if issue_id in embedding_to_issue_index.values() and update:
logger.info(f"Skipping issue {issue_id} as it is already embedded")
continue
if "body" not in issue:
logger.info(f"Skipping issue {issue_id} as it has no body")
continue
if issue_type == "pull_request" and "pull_request" not in issue:
logger.info(f"Skipping issue {issue_id} as it is not a pull request")
continue
elif issue_type == "issue" and "pull_request" in issue:
logger.info(f"Skipping issue {issue_id} as it is a pull request")
continue
title = issue["title"] if issue["title"] is not None else ""
body = issue["body"] if issue["body"] is not None else ""
logger.info(f"Embedding issue {issue_id}")
embedding = model.encode(title + "\n" + body)
embedding_to_issue_index[embedding_index] = issue_id
embeddings.append(embedding)
embedding_index += 1
n_issues += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue')
parser.add_argument("--input_filename", type=str, default="issues_dict.json")
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
parser.add_argument("--n_issues", type=int, default=-1)
parser.add_argument("--update", action="store_true")
args = parser.parse_args()
embed_issues(**vars(args))
|