Petr Tsvetkov
Keep the session column
02ebb6e
raw
history blame
3.49 kB
import pandas as pd
from tqdm import tqdm
import config
import generate_annotated_diffs
import statistics
from api_wrappers import grazie_wrapper, hf_data_loader
N_EXAMPLES = 5
GENERATION_MULTIPLIER = 2
REL_INSERTIONS_THRESHOLD = 0.6
GENERATION_ATTEMPTS = 5
def get_example_prompt(start_msg, end_msg):
return f"""START OF THE EXAMPLE
For following the edited message:
START OF THE EDITED COMMIT MESSAGE
{end_msg}
END OF THE EDITED COMMIT MESSAGE
You would output the following initial commit message:
START OF THE INITIAL COMMIT MESSAGE
{start_msg}
END OF THE INITIAL COMMIT MESSAGE
END OF THE EXAMPLE"""
def generate_examples():
manual_df = hf_data_loader.load_raw_rewriting_dataset_as_pandas()[['commit_msg_start', 'commit_msg_end']]
manual_df = manual_df.sample(n=N_EXAMPLES, random_state=config.RANDOM_STATE)
examples = [
get_example_prompt(row['commit_msg_start'], row['commit_msg_end'])
for _, row in manual_df.iterrows()
]
return "\n".join(examples)
EXAMPLES = generate_examples()
def build_prompt(reference, diff):
return f"""A software developer uses a LLM to generate commit messages.
They generated a commit message for the following source code changes:
START OF THE SOURCE CODE CHANGES
{diff}
END OF THE SOURCE CODE CHANGES
After generating the commit message the developer understands that it is not perfect. After making dome changes,
they come up with an edited version of the message. Here is this edited message:
START OF THE COMMIT MESSAGE
{reference}
END OF THE COMMIT MESSAGE
Your task is to print the initial, LLM-generated commit message.
The message you print must share some fragments with the edited message.
Here are some examples of what you should output:
START OF THE EXAMPLES LIST
{EXAMPLES}
END OF THE EXAMPLES LIST
Print only the initial commit message's text after the
token "OUTPUT".
OUTPUT"""
def generate_start_msg(end_msg, diff):
prompt = build_prompt(reference=end_msg, diff=diff)
results = []
for i in range(GENERATION_ATTEMPTS):
start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
stats = statistics.get_statistics(start_msg=start_msg_pred, end_msg=end_msg,
annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg_pred,
end_msg))
if stats["insertions"] < REL_INSERTIONS_THRESHOLD:
return start_msg_pred
else:
results.append((stats["insertions"], start_msg_pred))
results.sort()
return results[0][1]
COLS_TO_KEEP = ["hash", "repo", "commit_msg_end", "mods", "session"]
def transform(df):
df['end_to_start'] = False
generated_data = {
"commit_msg_start": []
}
for col in COLS_TO_KEEP:
generated_data[col] = []
for _, row in tqdm(df.iterrows(), total=len(df)):
for i in range(GENERATION_MULTIPLIER):
commit_msg_start_pred = generate_start_msg(end_msg=row["commit_msg_end"],
diff=row["mods"])
generated_data["commit_msg_start"].append(commit_msg_start_pred)
for col in COLS_TO_KEEP:
generated_data[col].append(row[col])
generated_df = pd.DataFrame.from_dict(generated_data)
generated_df['end_to_start'] = True
return pd.concat([df, generated_df], ignore_index=True)