File size: 3,798 Bytes
9b744c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import json
import logging
import os

import numpy as np
from sentence_transformers import SentenceTransformer

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)


def load_model(model_id: str):
    return SentenceTransformer(model_id)


class EmbeddingWriter:
    def __init__(
        self,
        output_embedding_filename,
        output_index_filename,
        update,
        embedding_to_issue_index,
        embeddings=None
    ) -> None:
        self.output_embedding_filename = output_embedding_filename
        self.output_index_filename = output_index_filename
        self.embeddings = [] if embeddings is None else list(embeddings)
        self.embedding_to_issue_index = embedding_to_issue_index
        self.update = update

    def __enter__(self):
        return self.embeddings

    def __exit__(self, exc_type, exc_val, exc_tb):
        embeddings = np.array(self.embeddings)

        if self.update and os.path.exists(self.output_embedding_filename):
            embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings])

        logger.info(f"Saving embeddings to {self.output_embedding_filename}")
        np.save(self.output_embedding_filename, embeddings)

        logger.info(f"Saving embedding index to {self.output_index_filename}")
        with open(self.output_index_filename, "w") as f:
            json.dump(self.embedding_to_issue_index, f, indent=4)


def embed_issues(
    input_filename: str,
    model_id: str,
    issue_type: str,
):
    output_embedding_filename = f"{issue_type}_embeddings.npy"
    output_index_filename = f"embedding_index_to_{issue_type}.json"
    model = load_model(model_id)

    with open(input_filename, "r") as f:
        updated_issues = json.load(f)

    with open(output_index_filename, "r") as f:
        embedding_to_issue_index = json.load(f)

    embeddings = np.load(output_embedding_filename)

    issue_to_embedding_index = {v: k for k, v in embedding_to_issue_index.items()}

    with EmbeddingWriter(
        output_embedding_filename=output_embedding_filename,
        output_index_filename=output_index_filename,
        update=False,
        embedding_to_issue_index=embedding_to_issue_index,
        embeddings=embeddings
    ) as embeddings:
        for issue_id, issue in updated_issues.items():
            if "body" not in issue:
                logger.info(f"Skipping issue {issue_id} as it has no body")
                continue

            if issue_type == "pull_request" and "pull_request" not in issue:
                logger.info(f"Skipping issue {issue_id} as it is not a pull request")
                continue

            elif issue_type == "issue" and "pull_request" in issue:
                logger.info(f"Skipping issue {issue_id} as it is a pull request")
                continue

            logger.info(f"Embedding issue {issue_id}")
            embedding = model.encode(issue["body"])

            if issue_id in issue_to_embedding_index:
                index = issue_to_embedding_index[issue_id]
                embeddings[index] = embedding
            else:
                index = len(embeddings)
                # embeddings = np.concatenate([embeddings, embedding.reshape(1, -1)])
                embeddings.append(embedding)
                issue_to_embedding_index[issue_id] = index
                embedding_to_issue_index[index] = issue_id


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue')
    parser.add_argument("--input_filename", type=str, default="updated_issues.json")
    parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
    args = parser.parse_args()
    embed_issues(**vars(args))