File size: 3,268 Bytes
f5faae7
 
a8a595d
 
 
f1b08a8
7ab7be2
347f566
e027012
13e3243
a8a595d
f5faae7
9d943c1
f5faae7
a8a595d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e3243
a8a595d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347f566
 
a8a595d
 
 
 
 
 
 
 
 
02ebb6e
 
f5faae7
 
02ebb6e
a8a595d
f1b08a8
 
 
 
 
 
a8a595d
 
 
02ebb6e
a8a595d
 
f5faae7
02ebb6e
 
a8a595d
 
 
 
02ebb6e
a8a595d
02ebb6e
 
a8a595d
f5faae7
 
 
a8a595d
 
 
f1b08a8
e027012
f1b08a8
 
 
 
 
 
6676c5a
e027012
f1b08a8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from itertools import chain

import pandas as pd
from tqdm import tqdm

import config
import dataset_statistics
import generate_annotated_diffs
from api_wrappers import grazie_wrapper, hf_data_loader
from generation_steps import examples

GENERATION_MULTIPLIER = 3
REL_INSERTIONS_THRESHOLD = 0.5
GENERATION_ATTEMPTS = 3


def build_prompt(reference, diff):
    return f"""A software developer uses a LLM to generate commit messages.

They generated a commit message for the following source code changes:
START OF THE SOURCE CODE CHANGES
{diff}
END OF THE SOURCE CODE CHANGES

After generating the commit message the developer understands that it is not perfect. After making dome changes,
they come up with an edited version of the message. Here is this edited message:
START OF THE COMMIT MESSAGE 
{reference}
END OF THE COMMIT MESSAGE

Your task is to print the initial, LLM-generated commit message. 
The message you print must share some fragments with the edited message. 
Here are some examples of what you should output:
START OF THE EXAMPLES LIST
{examples.EXAMPLES_END_TO_START}
END OF THE EXAMPLES LIST


Print only the initial commit message's text after the 
token "OUTPUT".

OUTPUT"""


def generate_start_msg(end_msg, diff):
    prompt = build_prompt(reference=end_msg, diff=diff)
    results = []

    for i in range(GENERATION_ATTEMPTS):
        start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)

        stats = dataset_statistics.get_statistics_for_sample(start_msg=start_msg_pred, end_msg=end_msg,)

        if stats["insertions"] < REL_INSERTIONS_THRESHOLD:
            return start_msg_pred
        else:
            results.append((stats["insertions"], start_msg_pred))

    results.sort()
    return results[0][1]


COLS_TO_KEEP = ["hash", "repo", "commit_msg_end", "mods", "session"]

COLS_TO_DEFAULT = {"edit_time": None}


def transform(df):
    print(f"End -> start synthesis:")
    print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
    print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
    print(f"REL_INSERTIONS_THRESHOLD = {REL_INSERTIONS_THRESHOLD}")
    print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")

    df['end_to_start'] = False

    generated_data = {
        "commit_msg_start": []
    }

    for col in chain(COLS_TO_KEEP, COLS_TO_DEFAULT):
        generated_data[col] = []

    for _, row in tqdm(df.iterrows(), total=len(df)):
        for i in range(GENERATION_MULTIPLIER):
            commit_msg_start_pred = generate_start_msg(end_msg=row["commit_msg_end"],
                                                       diff=row["mods"])

            generated_data["commit_msg_start"].append(commit_msg_start_pred)
            for col in COLS_TO_KEEP:
                generated_data[col].append(row[col])

            for col in COLS_TO_DEFAULT:
                generated_data[col].append(COLS_TO_DEFAULT[col])

    generated_df = pd.DataFrame.from_dict(generated_data)
    generated_df['end_to_start'] = True

    result = pd.concat([df, generated_df], ignore_index=True)
    result.to_csv(config.END_TO_START_ARTIFACT)

    print("Done")
    return result


def main():
    df = hf_data_loader.load_processed_rewriting_as_pandas()
    transform(df)


if __name__ == '__main__':
    main()