Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import json | |
| from tqdm import tqdm | |
| import config | |
| from api_wrappers import hf_data_loader | |
| from generation_steps import synthetic_start_to_end | |
| def transform(df): | |
| print(f"Generating data for labeling:") | |
| synthetic_start_to_end.print_config() | |
| tqdm.pandas() | |
| manual_df = hf_data_loader.load_raw_rewriting_as_pandas() | |
| manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE | |
| ).set_index(['hash', 'repo'])[["commit_msg_start", "commit_msg_end"]] | |
| manual_df = manual_df[~manual_df.index.duplicated(keep='first')] | |
| def get_is_manually_rewritten(row): | |
| commit_id = (row['hash'], row['repo']) | |
| return commit_id in manual_df.index | |
| result = df | |
| result['manual_sample'] = result.progress_apply(get_is_manually_rewritten, axis=1) | |
| def get_prediction_message(row): | |
| commit_id = (row['hash'], row['repo']) | |
| if row['manual_sample']: | |
| return manual_df.loc[commit_id]['commit_msg_start'] | |
| return row['prediction'] | |
| def get_enhanced_message(row): | |
| commit_id = (row['hash'], row['repo']) | |
| if row['manual_sample']: | |
| return manual_df.loc[commit_id]['commit_msg_end'] | |
| return synthetic_start_to_end.generate_end_msg(start_msg=row["prediction"], | |
| diff=row["mods"]) | |
| result['enhanced'] = result.progress_apply(get_enhanced_message, axis=1) | |
| result['prediction'] = result.progress_apply(get_prediction_message, axis=1) | |
| result['mods'] = result['mods'].progress_apply(json.dumps) | |
| result.to_csv(config.DATA_FOR_LABELING_ARTIFACT) | |
| print("Done") | |
| return result | |
| def main(): | |
| synthetic_start_to_end.GENERATION_ATTEMPTS = 3 | |
| df = hf_data_loader.load_full_commit_with_predictions_as_pandas() | |
| transform(df) | |
| if __name__ == '__main__': | |
| main() | |
