File size: 4,119 Bytes
9b744c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import json
import logging
import os

import numpy as np
from sentence_transformers import SentenceTransformer

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)


def load_model(model_id: str):
    return SentenceTransformer(model_id)


class EmbeddingWriter:
    def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None:
        self.output_embedding_filename = output_embedding_filename
        self.output_index_filename = output_index_filename
        self.embeddings = []
        self.embedding_to_issue_index = embedding_to_issue_index
        self.update = update

    def __enter__(self):
        return self.embeddings

    def __exit__(self, exc_type, exc_val, exc_tb):
        if len(self.embeddings) == 0:
            return

        embeddings = np.array(self.embeddings)

        if self.update and os.path.exists(self.output_embedding_filename):
            embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings])

        logger.info(f"Saving embeddings to {self.output_embedding_filename}")
        np.save(self.output_embedding_filename, embeddings)

        logger.info(f"Saving embedding index to {self.output_index_filename}")
        with open(self.output_index_filename, "w") as f:
            json.dump(self.embedding_to_issue_index, f, indent=4)


def embed_issues(
    input_filename: str,
    model_id: str,
    issue_type: str,
    n_issues: int = -1,
    update: bool = False
):
    model = load_model(model_id)

    output_embedding_filename = f"{issue_type}_embeddings.npy"
    output_index_filename = f"embedding_index_to_{issue_type}.json"

    with open(input_filename, "r") as f:
        issues = json.load(f)

    if update and os.path.exists(output_index_filename):
        with open(output_index_filename, "r") as f:
            embedding_to_issue_index = json.load(f)
        embedding_index = len(embedding_to_issue_index)
    else:
        embedding_to_issue_index = {}
        embedding_index = 0

    max_issues = n_issues if n_issues > 0 else len(issues)
    n_issues = 0

    with EmbeddingWriter(
        output_embedding_filename=output_embedding_filename,
        output_index_filename=output_index_filename,
        update=update,
        embedding_to_issue_index=embedding_to_issue_index
    ) as embeddings: #, embedding_to_issue_index:
        for issue_id, issue in issues.items():
            if n_issues >= max_issues:
                break

            if issue_id in embedding_to_issue_index.values() and update:
                logger.info(f"Skipping issue {issue_id} as it is already embedded")
                continue

            if "body" not in issue:
                logger.info(f"Skipping issue {issue_id} as it has no body")
                continue

            if issue_type == "pull_request" and "pull_request" not in issue:
                logger.info(f"Skipping issue {issue_id} as it is not a pull request")
                continue

            elif issue_type == "issue" and "pull_request" in issue:
                logger.info(f"Skipping issue {issue_id} as it is a pull request")
                continue

            title = issue["title"] if issue["title"] is not None else ""
            body = issue["body"] if issue["body"] is not None else ""

            logger.info(f"Embedding issue {issue_id}")
            embedding = model.encode(title + "\n" + body)
            embedding_to_issue_index[embedding_index] = issue_id
            embeddings.append(embedding)
            embedding_index += 1
            n_issues += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue')
    parser.add_argument("--input_filename", type=str, default="issues_dict.json")
    parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
    parser.add_argument("--n_issues", type=int, default=-1)
    parser.add_argument("--update", action="store_true")
    args = parser.parse_args()
    embed_issues(**vars(args))