File size: 3,288 Bytes
13e3243
 
 
f1b08a8
13e3243
 
 
 
 
f5faae7
13e3243
f5faae7
13e3243
 
6676c5a
13e3243
 
 
 
 
 
 
6676c5a
13e3243
 
 
 
 
 
9d943c1
13e3243
 
 
 
 
 
 
 
 
 
 
 
6676c5a
 
13e3243
 
 
6676c5a
13e3243
6676c5a
 
 
13e3243
6676c5a
13e3243
6676c5a
13e3243
 
 
 
 
 
 
 
6676c5a
f1b08a8
 
 
 
 
6676c5a
 
 
 
 
13e3243
 
 
 
 
 
 
 
 
 
 
6676c5a
 
13e3243
 
 
 
 
 
 
 
f1b08a8
e027012
f1b08a8
 
 
 
 
 
e027012
 
f1b08a8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import pandas as pd
from tqdm import tqdm

import config
import generate_annotated_diffs
import statistics
from api_wrappers import grazie_wrapper
from generation_steps import examples

GENERATION_MULTIPLIER = 3
REL_DELETIONS_THRESHOLD = 0.75
GENERATION_ATTEMPTS = 3


def build_prompt(prediction, diff):
    return f"""A LLM generated a commit message for the following source code changes:
START OF THE SOURCE CODE CHANGES
{diff}
END OF THE SOURCE CODE CHANGES

Here is the message the LLM generated:
START OF THE COMMIT MESSAGE 
{prediction}
END OF THE COMMIT MESSAGE

This generated message is not perfect. Your task is to rewrite and improve it.
You have to simulate a human software developer who manually rewrites the LLM-generated commit message, 
so the message you print must share some fragments with the generated message.   
Your message should be concise. 
Follow the Conventional Commits guidelines.
Here are some examples of what you should output:
START OF THE EXAMPLES LIST
{examples.EXAMPLES_START_TO_END}
END OF THE EXAMPLES LIST


Print only the improved commit message's text after the 
token "OUTPUT".

OUTPUT"""


def generate_end_msg(start_msg, diff):
    prompt = build_prompt(prediction=start_msg, diff=diff)
    results = []

    for i in range(GENERATION_ATTEMPTS):
        end_msg_pred = grazie_wrapper.generate_for_prompt(prompt)

        stats = statistics.get_statistics(start_msg=start_msg, end_msg=end_msg_pred,
                                          annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg,
                                                                                                    end_msg_pred))
        if stats["deletions"] < REL_DELETIONS_THRESHOLD:
            return end_msg_pred
        else:
            results.append((stats["deletions"], end_msg_pred))

    results.sort()
    return results[0][1]


COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"]


def print_config():
    print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
    print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
    print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
    print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")


def transform(df):
    print(f"Start -> send synthesis:")
    print_config()

    df['start_to_end'] = False

    generated_data = {
        "commit_msg_end": []
    }

    for col in COLS_TO_KEEP:
        generated_data[col] = []

    for _, row in tqdm(df.iterrows(), total=len(df)):
        for i in range(GENERATION_MULTIPLIER):
            commit_msg_end_pred = generate_end_msg(start_msg=row["commit_msg_start"],
                                                   diff=row["mods"])

            generated_data["commit_msg_end"].append(commit_msg_end_pred)
            for col in COLS_TO_KEEP:
                generated_data[col].append(row[col])

    generated_df = pd.DataFrame.from_dict(generated_data)
    generated_df['start_to_end'] = True

    result = pd.concat([df, generated_df], ignore_index=True)
    result.to_csv(config.START_TO_END_ARTIFACT)

    print("Done")
    return result


def main():
    df = pd.read_csv(config.END_TO_START_ARTIFACT, index_col=[0])
    transform(df)


if __name__ == '__main__':
    main()