commit-message-editing-visualization / generate_synthetic_dataset.py
Petr Tsvetkov
Synthetic dataset generation for the first 5 samples; visualization fixed
30e165f
raw
history blame
2.32 kB
from grazie.api.client.chat.prompt import ChatPrompt
from grazie.api.client.endpoints import GrazieApiGatewayUrls
from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType
from grazie.api.client.profiles import LLMProfile
from tqdm import tqdm
import config
import hf_data_loader
client = GrazieApiGatewayClient(
grazie_agent=GrazieAgent(name="commit-rewriting-summary-generation", version="dev"),
url=GrazieApiGatewayUrls.STAGING,
auth_type=AuthType.SERVICE,
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
)
def build_prompt(reference, diff):
return f"""A software developer uses a LLM to generate commit messages.
They generated a commit message for the following source code changes:
START OF THE SOURCE CODE CHANGES
{diff}
END OF THE SOURCE CODE CHANGES
After generating the commit message the developer understands that it is not perfect. After making dome changes,
they come up with an edited version of the message. Here is this edited message:
START OF THE COMMIT MESSAGE
{reference}
END OF THE COMMIT MESSAGE
Your task is to print the initial, LLM-generated commit message. Print only the initial commit message's text after the
token "OUTPUT".
OUTPUT"""
def generate_prompt_for_row(row):
reference = row['reference']
diff = row['mods']
return build_prompt(reference, diff)
def generate_initial_msg(prompt):
commit_msg = client.chat(
chat=ChatPrompt()
.add_system("You are a helpful assistant.")
.add_user(prompt),
profile=LLMProfile("gpt-4-1106-preview")
).content
return commit_msg
def generate_synthetic_dataset():
df = hf_data_loader.load_full_commit_dataset_as_pandas()
df['initial_msg_prompt'] = df.apply(generate_prompt_for_row, axis=1)
initial_messages_pred = []
for i, prompt in enumerate(tqdm(df['initial_msg_prompt'])):
output = None
if i < 5:
while output is None:
try:
output = generate_initial_msg(prompt)
except:
pass
initial_messages_pred.append(output if output is not None else "TBA")
df['initial_msg_pred'] = initial_messages_pred
df.to_csv(config.SYNTHETIC_DATASET_ARTIFACT)
if __name__ == '__main__':
generate_synthetic_dataset()