File size: 3,268 Bytes
f5faae7 a8a595d f1b08a8 7ab7be2 347f566 e027012 13e3243 a8a595d f5faae7 9d943c1 f5faae7 a8a595d 13e3243 a8a595d 347f566 a8a595d 02ebb6e f5faae7 02ebb6e a8a595d f1b08a8 a8a595d 02ebb6e a8a595d f5faae7 02ebb6e a8a595d 02ebb6e a8a595d 02ebb6e a8a595d f5faae7 a8a595d f1b08a8 e027012 f1b08a8 6676c5a e027012 f1b08a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from itertools import chain
import pandas as pd
from tqdm import tqdm
import config
import dataset_statistics
import generate_annotated_diffs
from api_wrappers import grazie_wrapper, hf_data_loader
from generation_steps import examples
GENERATION_MULTIPLIER = 3
REL_INSERTIONS_THRESHOLD = 0.5
GENERATION_ATTEMPTS = 3
def build_prompt(reference, diff):
return f"""A software developer uses a LLM to generate commit messages.
They generated a commit message for the following source code changes:
START OF THE SOURCE CODE CHANGES
{diff}
END OF THE SOURCE CODE CHANGES
After generating the commit message the developer understands that it is not perfect. After making dome changes,
they come up with an edited version of the message. Here is this edited message:
START OF THE COMMIT MESSAGE
{reference}
END OF THE COMMIT MESSAGE
Your task is to print the initial, LLM-generated commit message.
The message you print must share some fragments with the edited message.
Here are some examples of what you should output:
START OF THE EXAMPLES LIST
{examples.EXAMPLES_END_TO_START}
END OF THE EXAMPLES LIST
Print only the initial commit message's text after the
token "OUTPUT".
OUTPUT"""
def generate_start_msg(end_msg, diff):
prompt = build_prompt(reference=end_msg, diff=diff)
results = []
for i in range(GENERATION_ATTEMPTS):
start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
stats = dataset_statistics.get_statistics_for_sample(start_msg=start_msg_pred, end_msg=end_msg,)
if stats["insertions"] < REL_INSERTIONS_THRESHOLD:
return start_msg_pred
else:
results.append((stats["insertions"], start_msg_pred))
results.sort()
return results[0][1]
COLS_TO_KEEP = ["hash", "repo", "commit_msg_end", "mods", "session"]
COLS_TO_DEFAULT = {"edit_time": None}
def transform(df):
print(f"End -> start synthesis:")
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
print(f"REL_INSERTIONS_THRESHOLD = {REL_INSERTIONS_THRESHOLD}")
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
df['end_to_start'] = False
generated_data = {
"commit_msg_start": []
}
for col in chain(COLS_TO_KEEP, COLS_TO_DEFAULT):
generated_data[col] = []
for _, row in tqdm(df.iterrows(), total=len(df)):
for i in range(GENERATION_MULTIPLIER):
commit_msg_start_pred = generate_start_msg(end_msg=row["commit_msg_end"],
diff=row["mods"])
generated_data["commit_msg_start"].append(commit_msg_start_pred)
for col in COLS_TO_KEEP:
generated_data[col].append(row[col])
for col in COLS_TO_DEFAULT:
generated_data[col].append(COLS_TO_DEFAULT[col])
generated_df = pd.DataFrame.from_dict(generated_data)
generated_df['end_to_start'] = True
result = pd.concat([df, generated_df], ignore_index=True)
result.to_csv(config.END_TO_START_ARTIFACT)
print("Done")
return result
def main():
df = hf_data_loader.load_processed_rewriting_as_pandas()
transform(df)
if __name__ == '__main__':
main()
|