Spaces:
Runtime error
Runtime error
import json | |
from tqdm import tqdm | |
import config | |
from api_wrappers import hf_data_loader | |
from generation_steps import synthetic_start_to_end | |
def transform(df): | |
print(f"Generating data for labeling:") | |
synthetic_start_to_end.print_config() | |
tqdm.pandas() | |
manual_df = hf_data_loader.load_raw_rewriting_as_pandas() | |
manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE | |
).set_index(['hash', 'repo'])[["commit_msg_start", "commit_msg_end"]] | |
manual_df = manual_df[~manual_df.index.duplicated(keep='first')] | |
def get_is_manually_rewritten(row): | |
commit_id = (row['hash'], row['repo']) | |
return commit_id in manual_df.index | |
result = df | |
result['manual_sample'] = result.progress_apply(get_is_manually_rewritten, axis=1) | |
def get_prediction_message(row): | |
commit_id = (row['hash'], row['repo']) | |
if row['manual_sample']: | |
return manual_df.loc[commit_id]['commit_msg_start'] | |
return row['prediction'] | |
def get_enhanced_message(row): | |
commit_id = (row['hash'], row['repo']) | |
if row['manual_sample']: | |
return manual_df.loc[commit_id]['commit_msg_end'] | |
return synthetic_start_to_end.generate_end_msg(start_msg=row["prediction"], | |
diff=row["mods"]) | |
result['enhanced'] = result.progress_apply(get_enhanced_message, axis=1) | |
result['prediction'] = result.progress_apply(get_prediction_message, axis=1) | |
result['mods'] = result['mods'].progress_apply(json.dumps) | |
result.to_csv(config.DATA_FOR_LABELING_ARTIFACT) | |
print("Done") | |
return result | |
def main(): | |
synthetic_start_to_end.GENERATION_ATTEMPTS = 3 | |
df = hf_data_loader.load_full_commit_with_predictions_as_pandas() | |
transform(df) | |
if __name__ == '__main__': | |
main() | |