Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Petr Tsvetkov
		
	commited on
		
		
					Commit 
							
							Β·
						
						6676c5a
	
1
								Parent(s):
							
							2d03034
								
Generate a dataset for the labeling app
Browse files- api_wrappers/grazie_wrapper.py +1 -1
- api_wrappers/hf_data_loader.py +22 -6
- config.py +5 -1
- generate_annotated_diffs.py +2 -2
- generate_synthetic_dataset.py +1 -1
- generation_steps/examples.py +1 -1
- generation_steps/for_labeling.py +58 -0
- generation_steps/metrics_analysis.py +1 -1
- generation_steps/synthetic_end_to_start.py +1 -1
- generation_steps/synthetic_start_to_end.py +18 -14
    	
        api_wrappers/grazie_wrapper.py
    CHANGED
    
    | @@ -10,7 +10,7 @@ import config | |
| 10 | 
             
            client = GrazieApiGatewayClient(
         | 
| 11 | 
             
                grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"),
         | 
| 12 | 
             
                url=GrazieApiGatewayUrls.STAGING,
         | 
| 13 | 
            -
                auth_type=AuthType. | 
| 14 | 
             
                grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
         | 
| 15 | 
             
            )
         | 
| 16 |  | 
|  | |
| 10 | 
             
            client = GrazieApiGatewayClient(
         | 
| 11 | 
             
                grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"),
         | 
| 12 | 
             
                url=GrazieApiGatewayUrls.STAGING,
         | 
| 13 | 
            +
                auth_type=AuthType.USER,
         | 
| 14 | 
             
                grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
         | 
| 15 | 
             
            )
         | 
| 16 |  | 
    	
        api_wrappers/hf_data_loader.py
    CHANGED
    
    | @@ -3,14 +3,14 @@ from datasets import load_dataset | |
| 3 | 
             
            import config
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            -
            def  | 
| 7 | 
             
                return load_dataset(config.HF_RAW_DATASET_NAME,
         | 
| 8 | 
             
                                    split=config.HF_RAW_DATASET_SPLIT,
         | 
| 9 | 
             
                                    token=config.HF_TOKEN,
         | 
| 10 | 
             
                                    cache_dir=config.CACHE_DIR).to_pandas()
         | 
| 11 |  | 
| 12 |  | 
| 13 | 
            -
            def  | 
| 14 | 
             
                return load_dataset(path=config.HF_FULL_COMMITS_DATASET_NAME,
         | 
| 15 | 
             
                                    name=config.HF_FULL_COMMITS_DATASET_SUBNAME,
         | 
| 16 | 
             
                                    split=config.HF_FULL_COMMITS_DATASET_SPLIT,
         | 
| @@ -18,19 +18,35 @@ def load_full_commit_dataset_as_pandas(): | |
| 18 | 
             
                    columns={'message': 'reference'})
         | 
| 19 |  | 
| 20 |  | 
| 21 | 
            -
            def  | 
| 22 | 
            -
                manual_rewriting =  | 
| 23 | 
             
                    ["hash", "repo", "commit_msg_start", "commit_msg_end", "session"]]
         | 
| 24 | 
             
                manual_rewriting.set_index(["hash", "repo"], inplace=True)
         | 
| 25 |  | 
| 26 | 
            -
                mods_dataset =  | 
| 27 | 
             
                mods_dataset.set_index(["hash", "repo"], inplace=True)
         | 
| 28 |  | 
| 29 | 
             
                return manual_rewriting.join(other=mods_dataset, how='left').reset_index()
         | 
| 30 |  | 
| 31 |  | 
| 32 | 
            -
            def  | 
| 33 | 
             
                return load_dataset(config.HF_SYNTHETIC_DATASET_NAME,
         | 
| 34 | 
             
                                    split=config.HF_SYNTHETIC_DATASET_SPLIT,
         | 
| 35 | 
             
                                    token=config.HF_TOKEN,
         | 
| 36 | 
             
                                    cache_dir=config.CACHE_DIR).to_pandas()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 | 
             
            import config
         | 
| 4 |  | 
| 5 |  | 
| 6 | 
            +
            def load_raw_rewriting_as_pandas():
         | 
| 7 | 
             
                return load_dataset(config.HF_RAW_DATASET_NAME,
         | 
| 8 | 
             
                                    split=config.HF_RAW_DATASET_SPLIT,
         | 
| 9 | 
             
                                    token=config.HF_TOKEN,
         | 
| 10 | 
             
                                    cache_dir=config.CACHE_DIR).to_pandas()
         | 
| 11 |  | 
| 12 |  | 
| 13 | 
            +
            def load_full_commit_as_pandas():
         | 
| 14 | 
             
                return load_dataset(path=config.HF_FULL_COMMITS_DATASET_NAME,
         | 
| 15 | 
             
                                    name=config.HF_FULL_COMMITS_DATASET_SUBNAME,
         | 
| 16 | 
             
                                    split=config.HF_FULL_COMMITS_DATASET_SPLIT,
         | 
|  | |
| 18 | 
             
                    columns={'message': 'reference'})
         | 
| 19 |  | 
| 20 |  | 
| 21 | 
            +
            def load_processed_rewriting_as_pandas():
         | 
| 22 | 
            +
                manual_rewriting = load_raw_rewriting_as_pandas()[
         | 
| 23 | 
             
                    ["hash", "repo", "commit_msg_start", "commit_msg_end", "session"]]
         | 
| 24 | 
             
                manual_rewriting.set_index(["hash", "repo"], inplace=True)
         | 
| 25 |  | 
| 26 | 
            +
                mods_dataset = load_full_commit_as_pandas()[["hash", "repo", "mods"]]
         | 
| 27 | 
             
                mods_dataset.set_index(["hash", "repo"], inplace=True)
         | 
| 28 |  | 
| 29 | 
             
                return manual_rewriting.join(other=mods_dataset, how='left').reset_index()
         | 
| 30 |  | 
| 31 |  | 
| 32 | 
            +
            def load_synthetic_as_pandas():
         | 
| 33 | 
             
                return load_dataset(config.HF_SYNTHETIC_DATASET_NAME,
         | 
| 34 | 
             
                                    split=config.HF_SYNTHETIC_DATASET_SPLIT,
         | 
| 35 | 
             
                                    token=config.HF_TOKEN,
         | 
| 36 | 
             
                                    cache_dir=config.CACHE_DIR).to_pandas()
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def load_full_commit_with_predictions_as_pandas():
         | 
| 40 | 
            +
                full_dataset = load_full_commit_as_pandas()
         | 
| 41 | 
            +
                predictions_dataset = load_dataset(config.HF_PREDICTIONS_DATASET_NAME,
         | 
| 42 | 
            +
                                                   config.HF_PREDICTIONS_DATASET_SUBNAME,
         | 
| 43 | 
            +
                                                   split=config.HF_PREDICTIONS_DATASET_SPLIT,
         | 
| 44 | 
            +
                                                   cache_dir=config.CACHE_DIR
         | 
| 45 | 
            +
                                                   ).to_pandas().sample(frac=1, random_state=config.RANDOM_STATE
         | 
| 46 | 
            +
                                                                        ).set_index(['hash', 'repo'])[["prediction"]]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                predictions_dataset = predictions_dataset[~predictions_dataset.index.duplicated(keep='first')]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                dataset = full_dataset.join(other=predictions_dataset, on=('hash', 'repo'))
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                return dataset.reset_index()
         | 
    	
        config.py
    CHANGED
    
    | @@ -15,6 +15,10 @@ HF_FULL_COMMITS_DATASET_NAME = "JetBrains-Research/lca-commit-message-generation | |
| 15 | 
             
            HF_FULL_COMMITS_DATASET_SUBNAME = "commitchronicle-py-long"
         | 
| 16 | 
             
            HF_FULL_COMMITS_DATASET_SPLIT = "test"
         | 
| 17 |  | 
|  | |
|  | |
|  | |
|  | |
| 18 | 
             
            HF_SYNTHETIC_DATASET_NAME = "petrtsv-jb/synthetic-commit-msg-rewriting"
         | 
| 19 | 
             
            HF_SYNTHETIC_DATASET_SPLIT = 'train'
         | 
| 20 |  | 
| @@ -24,8 +28,8 @@ CACHE_DIR.mkdir(exist_ok=True) | |
| 24 | 
             
            OUTPUT_DIR = Path("output")
         | 
| 25 | 
             
            OUTPUT_DIR.mkdir(exist_ok=True)
         | 
| 26 |  | 
| 27 | 
            -
             | 
| 28 | 
             
            END_TO_START_ARTIFACT = OUTPUT_DIR / "end_to_start.csv"
         | 
| 29 | 
             
            START_TO_END_ARTIFACT = OUTPUT_DIR / "start_to_end.csv"
         | 
| 30 | 
             
            SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
         | 
| 31 | 
             
            METRICS_CORRELATIONS_ARTIFACT = OUTPUT_DIR / "metrics_correlations.csv"
         | 
|  | 
|  | |
| 15 | 
             
            HF_FULL_COMMITS_DATASET_SUBNAME = "commitchronicle-py-long"
         | 
| 16 | 
             
            HF_FULL_COMMITS_DATASET_SPLIT = "test"
         | 
| 17 |  | 
| 18 | 
            +
            HF_PREDICTIONS_DATASET_NAME = "JetBrains-Research/lca-results"
         | 
| 19 | 
            +
            HF_PREDICTIONS_DATASET_SUBNAME = "cmg_gpt_4_0613"
         | 
| 20 | 
            +
            HF_PREDICTIONS_DATASET_SPLIT = "test"
         | 
| 21 | 
            +
             | 
| 22 | 
             
            HF_SYNTHETIC_DATASET_NAME = "petrtsv-jb/synthetic-commit-msg-rewriting"
         | 
| 23 | 
             
            HF_SYNTHETIC_DATASET_SPLIT = 'train'
         | 
| 24 |  | 
|  | |
| 28 | 
             
            OUTPUT_DIR = Path("output")
         | 
| 29 | 
             
            OUTPUT_DIR.mkdir(exist_ok=True)
         | 
| 30 |  | 
|  | |
| 31 | 
             
            END_TO_START_ARTIFACT = OUTPUT_DIR / "end_to_start.csv"
         | 
| 32 | 
             
            START_TO_END_ARTIFACT = OUTPUT_DIR / "start_to_end.csv"
         | 
| 33 | 
             
            SYNTHETIC_DATASET_ARTIFACT = OUTPUT_DIR / "synthetic.csv"
         | 
| 34 | 
             
            METRICS_CORRELATIONS_ARTIFACT = OUTPUT_DIR / "metrics_correlations.csv"
         | 
| 35 | 
            +
            DATA_FOR_LABELING_ARTIFACT = OUTPUT_DIR / "data_for_labeling.csv"
         | 
    	
        generate_annotated_diffs.py
    CHANGED
    
    | @@ -26,14 +26,14 @@ def annotated_diff_for_row(row): | |
| 26 |  | 
| 27 |  | 
| 28 | 
             
            def manual_data_with_annotated_diffs():
         | 
| 29 | 
            -
                df = hf_data_loader. | 
| 30 | 
             
                annotated = df.apply(annotated_diff_for_row, axis=1)
         | 
| 31 | 
             
                df['annotated_diff'] = annotated
         | 
| 32 | 
             
                return df
         | 
| 33 |  | 
| 34 |  | 
| 35 | 
             
            def synthetic_data_with_annotated_diffs():
         | 
| 36 | 
            -
                df = hf_data_loader. | 
| 37 | 
             
                annotated = df.apply(annotated_diff_for_row, axis=1)
         | 
| 38 | 
             
                df['annotated_diff'] = annotated
         | 
| 39 | 
             
                return df
         | 
|  | |
| 26 |  | 
| 27 |  | 
| 28 | 
             
            def manual_data_with_annotated_diffs():
         | 
| 29 | 
            +
                df = hf_data_loader.load_raw_rewriting_as_pandas()
         | 
| 30 | 
             
                annotated = df.apply(annotated_diff_for_row, axis=1)
         | 
| 31 | 
             
                df['annotated_diff'] = annotated
         | 
| 32 | 
             
                return df
         | 
| 33 |  | 
| 34 |  | 
| 35 | 
             
            def synthetic_data_with_annotated_diffs():
         | 
| 36 | 
            +
                df = hf_data_loader.load_synthetic_as_pandas()
         | 
| 37 | 
             
                annotated = df.apply(annotated_diff_for_row, axis=1)
         | 
| 38 | 
             
                df['annotated_diff'] = annotated
         | 
| 39 | 
             
                return df
         | 
    	
        generate_synthetic_dataset.py
    CHANGED
    
    | @@ -4,7 +4,7 @@ from generation_steps import synthetic_end_to_start, synthetic_start_to_end, met | |
| 4 |  | 
| 5 |  | 
| 6 | 
             
            def run():
         | 
| 7 | 
            -
                df = hf_data_loader. | 
| 8 |  | 
| 9 | 
             
                df = synthetic_end_to_start.transform(df)
         | 
| 10 | 
             
                df = synthetic_start_to_end.transform(df)
         | 
|  | |
| 4 |  | 
| 5 |  | 
| 6 | 
             
            def run():
         | 
| 7 | 
            +
                df = hf_data_loader.load_processed_rewriting_as_pandas()
         | 
| 8 |  | 
| 9 | 
             
                df = synthetic_end_to_start.transform(df)
         | 
| 10 | 
             
                df = synthetic_start_to_end.transform(df)
         | 
    	
        generation_steps/examples.py
    CHANGED
    
    | @@ -36,7 +36,7 @@ END OF THE IMPROVED COMMIT MESSAGE | |
| 36 | 
             
            END OF THE EXAMPLE"""
         | 
| 37 |  | 
| 38 |  | 
| 39 | 
            -
            manual_df = hf_data_loader. | 
| 40 | 
             
            manual_df = manual_df.sample(n=N_EXAMPLES, random_state=config.RANDOM_STATE)
         | 
| 41 |  | 
| 42 |  | 
|  | |
| 36 | 
             
            END OF THE EXAMPLE"""
         | 
| 37 |  | 
| 38 |  | 
| 39 | 
            +
            manual_df = hf_data_loader.load_raw_rewriting_as_pandas()[['commit_msg_start', 'commit_msg_end']]
         | 
| 40 | 
             
            manual_df = manual_df.sample(n=N_EXAMPLES, random_state=config.RANDOM_STATE)
         | 
| 41 |  | 
| 42 |  | 
    	
        generation_steps/for_labeling.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from tqdm import tqdm
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import config
         | 
| 6 | 
            +
            from api_wrappers import hf_data_loader
         | 
| 7 | 
            +
            from generation_steps import synthetic_start_to_end
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def transform(df):
         | 
| 11 | 
            +
                print(f"Generating data for labeling:")
         | 
| 12 | 
            +
                synthetic_start_to_end.print_config()
         | 
| 13 | 
            +
                tqdm.pandas()
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                manual_df = hf_data_loader.load_raw_rewriting_as_pandas()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE
         | 
| 18 | 
            +
                                             ).set_index(['hash', 'repo'])[["commit_msg_start", "commit_msg_end"]]
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                manual_df = manual_df[~manual_df.index.duplicated(keep='first')]
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def get_is_manually_rewritten(row):
         | 
| 23 | 
            +
                    commit_id = (row['hash'], row['repo'])
         | 
| 24 | 
            +
                    return commit_id in manual_df.index
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                result = df
         | 
| 27 | 
            +
                result['manual_sample'] = result.progress_apply(get_is_manually_rewritten, axis=1)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def get_prediction_message(row):
         | 
| 30 | 
            +
                    commit_id = (row['hash'], row['repo'])
         | 
| 31 | 
            +
                    if row['manual_sample']:
         | 
| 32 | 
            +
                        return manual_df.loc[commit_id]['commit_msg_start']
         | 
| 33 | 
            +
                    return row['prediction']
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def get_enhanced_message(row):
         | 
| 36 | 
            +
                    commit_id = (row['hash'], row['repo'])
         | 
| 37 | 
            +
                    if row['manual_sample']:
         | 
| 38 | 
            +
                        return manual_df.loc[commit_id]['commit_msg_end']
         | 
| 39 | 
            +
                    return synthetic_start_to_end.generate_end_msg(start_msg=row["prediction"],
         | 
| 40 | 
            +
                                                                   diff=row["mods"])
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                result['enhanced'] = result.progress_apply(get_enhanced_message, axis=1)
         | 
| 43 | 
            +
                result['prediction'] = result.progress_apply(get_prediction_message, axis=1)
         | 
| 44 | 
            +
                result['mods'] = result['mods'].progress_apply(json.dumps)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                result.to_csv(config.DATA_FOR_LABELING_ARTIFACT)
         | 
| 47 | 
            +
                print("Done")
         | 
| 48 | 
            +
                return result
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def main():
         | 
| 52 | 
            +
                synthetic_start_to_end.GENERATION_ATTEMPTS = 3
         | 
| 53 | 
            +
                df = hf_data_loader.load_full_commit_with_predictions_as_pandas()
         | 
| 54 | 
            +
                transform(df)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            if __name__ == '__main__':
         | 
| 58 | 
            +
                main()
         | 
    	
        generation_steps/metrics_analysis.py
    CHANGED
    
    | @@ -77,7 +77,7 @@ METRICS = { | |
| 77 |  | 
| 78 |  | 
| 79 | 
             
            def attach_references(df):
         | 
| 80 | 
            -
                reference_df = hf_data_loader. | 
| 81 | 
             
                df = df.set_index(["hash", "repo"])
         | 
| 82 | 
             
                return df.join(other=reference_df, how="left").reset_index()
         | 
| 83 |  | 
|  | |
| 77 |  | 
| 78 |  | 
| 79 | 
             
            def attach_references(df):
         | 
| 80 | 
            +
                reference_df = hf_data_loader.load_full_commit_as_pandas().set_index(["hash", "repo"])[["reference"]]
         | 
| 81 | 
             
                df = df.set_index(["hash", "repo"])
         | 
| 82 | 
             
                return df.join(other=reference_df, how="left").reset_index()
         | 
| 83 |  | 
    	
        generation_steps/synthetic_end_to_start.py
    CHANGED
    
    | @@ -98,7 +98,7 @@ def transform(df): | |
| 98 |  | 
| 99 |  | 
| 100 | 
             
            def main():
         | 
| 101 | 
            -
                df = hf_data_loader. | 
| 102 | 
             
                transform(df)
         | 
| 103 |  | 
| 104 |  | 
|  | |
| 98 |  | 
| 99 |  | 
| 100 | 
             
            def main():
         | 
| 101 | 
            +
                df = hf_data_loader.load_processed_rewriting_as_pandas()
         | 
| 102 | 
             
                transform(df)
         | 
| 103 |  | 
| 104 |  | 
    	
        generation_steps/synthetic_start_to_end.py
    CHANGED
    
    | @@ -12,7 +12,7 @@ REL_DELETIONS_THRESHOLD = 0.75 | |
| 12 | 
             
            GENERATION_ATTEMPTS = 5
         | 
| 13 |  | 
| 14 |  | 
| 15 | 
            -
            def build_prompt( | 
| 16 | 
             
                return f"""A LLM generated a commit message for the following source code changes:
         | 
| 17 | 
             
            START OF THE SOURCE CODE CHANGES
         | 
| 18 | 
             
            {diff}
         | 
| @@ -20,7 +20,7 @@ END OF THE SOURCE CODE CHANGES | |
| 20 |  | 
| 21 | 
             
            Here is the message the LLM generated:
         | 
| 22 | 
             
            START OF THE COMMIT MESSAGE 
         | 
| 23 | 
            -
            { | 
| 24 | 
             
            END OF THE COMMIT MESSAGE
         | 
| 25 |  | 
| 26 | 
             
            This generated message is not perfect. Your task is to rewrite and improve it.
         | 
| @@ -40,20 +40,20 @@ token "OUTPUT". | |
| 40 | 
             
            OUTPUT"""
         | 
| 41 |  | 
| 42 |  | 
| 43 | 
            -
            def  | 
| 44 | 
            -
                prompt = build_prompt( | 
| 45 | 
             
                results = []
         | 
| 46 |  | 
| 47 | 
             
                for i in range(GENERATION_ATTEMPTS):
         | 
| 48 | 
            -
                     | 
| 49 |  | 
| 50 | 
            -
                    stats = statistics.get_statistics(start_msg= | 
| 51 | 
            -
                                                      annotated_msg=generate_annotated_diffs.get_annotated_diff( | 
| 52 | 
            -
                                                                                                                 | 
| 53 | 
             
                    if stats["deletions"] < REL_DELETIONS_THRESHOLD:
         | 
| 54 | 
            -
                        return  | 
| 55 | 
             
                    else:
         | 
| 56 | 
            -
                        results.append((stats["deletions"],  | 
| 57 |  | 
| 58 | 
             
                results.sort()
         | 
| 59 | 
             
                return results[0][1]
         | 
| @@ -62,13 +62,17 @@ def generate_start_msg(end_msg, diff): | |
| 62 | 
             
            COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"]
         | 
| 63 |  | 
| 64 |  | 
| 65 | 
            -
            def  | 
| 66 | 
            -
                print(f"Start -> send synthesis:")
         | 
| 67 | 
             
                print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
         | 
| 68 | 
             
                print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
         | 
| 69 | 
             
                print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
         | 
| 70 | 
             
                print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
         | 
| 71 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 72 | 
             
                df['start_to_end'] = False
         | 
| 73 |  | 
| 74 | 
             
                generated_data = {
         | 
| @@ -80,8 +84,8 @@ def transform(df): | |
| 80 |  | 
| 81 | 
             
                for _, row in tqdm(df.iterrows(), total=len(df)):
         | 
| 82 | 
             
                    for i in range(GENERATION_MULTIPLIER):
         | 
| 83 | 
            -
                        commit_msg_end_pred =  | 
| 84 | 
            -
             | 
| 85 |  | 
| 86 | 
             
                        generated_data["commit_msg_end"].append(commit_msg_end_pred)
         | 
| 87 | 
             
                        for col in COLS_TO_KEEP:
         | 
|  | |
| 12 | 
             
            GENERATION_ATTEMPTS = 5
         | 
| 13 |  | 
| 14 |  | 
| 15 | 
            +
            def build_prompt(prediction, diff):
         | 
| 16 | 
             
                return f"""A LLM generated a commit message for the following source code changes:
         | 
| 17 | 
             
            START OF THE SOURCE CODE CHANGES
         | 
| 18 | 
             
            {diff}
         | 
|  | |
| 20 |  | 
| 21 | 
             
            Here is the message the LLM generated:
         | 
| 22 | 
             
            START OF THE COMMIT MESSAGE 
         | 
| 23 | 
            +
            {prediction}
         | 
| 24 | 
             
            END OF THE COMMIT MESSAGE
         | 
| 25 |  | 
| 26 | 
             
            This generated message is not perfect. Your task is to rewrite and improve it.
         | 
|  | |
| 40 | 
             
            OUTPUT"""
         | 
| 41 |  | 
| 42 |  | 
| 43 | 
            +
            def generate_end_msg(start_msg, diff):
         | 
| 44 | 
            +
                prompt = build_prompt(prediction=start_msg, diff=diff)
         | 
| 45 | 
             
                results = []
         | 
| 46 |  | 
| 47 | 
             
                for i in range(GENERATION_ATTEMPTS):
         | 
| 48 | 
            +
                    end_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
         | 
| 49 |  | 
| 50 | 
            +
                    stats = statistics.get_statistics(start_msg=start_msg, end_msg=end_msg_pred,
         | 
| 51 | 
            +
                                                      annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg,
         | 
| 52 | 
            +
                                                                                                                end_msg_pred))
         | 
| 53 | 
             
                    if stats["deletions"] < REL_DELETIONS_THRESHOLD:
         | 
| 54 | 
            +
                        return end_msg_pred
         | 
| 55 | 
             
                    else:
         | 
| 56 | 
            +
                        results.append((stats["deletions"], end_msg_pred))
         | 
| 57 |  | 
| 58 | 
             
                results.sort()
         | 
| 59 | 
             
                return results[0][1]
         | 
|  | |
| 62 | 
             
            COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"]
         | 
| 63 |  | 
| 64 |  | 
| 65 | 
            +
            def print_config():
         | 
|  | |
| 66 | 
             
                print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}")
         | 
| 67 | 
             
                print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}")
         | 
| 68 | 
             
                print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}")
         | 
| 69 | 
             
                print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}")
         | 
| 70 |  | 
| 71 | 
            +
             | 
| 72 | 
            +
            def transform(df):
         | 
| 73 | 
            +
                print(f"Start -> send synthesis:")
         | 
| 74 | 
            +
                print_config()
         | 
| 75 | 
            +
             | 
| 76 | 
             
                df['start_to_end'] = False
         | 
| 77 |  | 
| 78 | 
             
                generated_data = {
         | 
|  | |
| 84 |  | 
| 85 | 
             
                for _, row in tqdm(df.iterrows(), total=len(df)):
         | 
| 86 | 
             
                    for i in range(GENERATION_MULTIPLIER):
         | 
| 87 | 
            +
                        commit_msg_end_pred = generate_end_msg(start_msg=row["commit_msg_start"],
         | 
| 88 | 
            +
                                                               diff=row["mods"])
         | 
| 89 |  | 
| 90 | 
             
                        generated_data["commit_msg_end"].append(commit_msg_end_pred)
         | 
| 91 | 
             
                        for col in COLS_TO_KEEP:
         | 
