Petr Tsvetkov
commited on
Commit
β’
303303b
1
Parent(s):
ff76f88
Add distribution charts; add more detailed statistics; compute multi-reference TER as mean of TERs for each reference to improve the performance
Browse files- change_visualizer.py +76 -9
- generation_steps/metrics_analysis.py +6 -4
- requirements.txt +4 -0
- statistics.py +31 -14
change_visualizer.py
CHANGED
@@ -12,8 +12,35 @@ n_diffs_manual = len(df_manual)
|
|
12 |
df_synthetic = generate_annotated_diffs.synthetic_data_with_annotated_diffs()
|
13 |
n_diffs_synthetic = len(df_synthetic)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
STATISTICS = {"manual": statistics.get_statistics_for_df(df_manual),
|
16 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def update_dataset_view(diff_idx, df):
|
@@ -92,20 +119,60 @@ if __name__ == '__main__':
|
|
92 |
def layout_for_statistics(statistics_group_name):
|
93 |
gr.Markdown(f"### {statistics_group_name}")
|
94 |
stats = STATISTICS[statistics_group_name]
|
95 |
-
gr.Number(label="
|
96 |
-
value=stats['
|
97 |
-
gr.Number(label="
|
98 |
-
value=stats['
|
99 |
-
gr.Number(label="
|
100 |
-
value=stats['
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
|
103 |
with gr.Row():
|
104 |
-
with gr.Column(scale=1):
|
105 |
layout_for_statistics("manual")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
|
|
107 |
with gr.Column(scale=1):
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
gr.Markdown(f"### Reference-only correlations")
|
111 |
gr.Markdown(value=analysis_util.get_correlations_for_groups(df_synthetic, right_side="ind").to_markdown())
|
|
|
12 |
df_synthetic = generate_annotated_diffs.synthetic_data_with_annotated_diffs()
|
13 |
n_diffs_synthetic = len(df_synthetic)
|
14 |
|
15 |
+
|
16 |
+
def golden():
|
17 |
+
return df_manual
|
18 |
+
|
19 |
+
|
20 |
+
def e2s():
|
21 |
+
return df_synthetic[(df_synthetic['end_to_start'] == True) & (df_synthetic['start_to_end'] == False)]
|
22 |
+
|
23 |
+
|
24 |
+
def s2e():
|
25 |
+
return df_synthetic[(df_synthetic['end_to_start'] == False) & (df_synthetic['start_to_end'] == True)]
|
26 |
+
|
27 |
+
|
28 |
+
def e2s_s2e():
|
29 |
+
return df_synthetic[(df_synthetic['end_to_start'] == True) & (df_synthetic['start_to_end'] == True)]
|
30 |
+
|
31 |
+
|
32 |
+
def synthetic():
|
33 |
+
return df_synthetic[(df_synthetic['end_to_start'] == True) | (df_synthetic['start_to_end'] == True)]
|
34 |
+
|
35 |
+
|
36 |
STATISTICS = {"manual": statistics.get_statistics_for_df(df_manual),
|
37 |
+
"e2s": statistics.get_statistics_for_df(e2s()),
|
38 |
+
"s2e": statistics.get_statistics_for_df(s2e()),
|
39 |
+
"e2s_s2e": statistics.get_statistics_for_df(e2s_s2e()),
|
40 |
+
"synthetic": statistics.get_statistics_for_df(synthetic()),
|
41 |
+
"all": statistics.get_statistics_for_df(df_synthetic)}
|
42 |
+
|
43 |
+
STAT_NAMES = list(STATISTICS['manual'].keys())
|
44 |
|
45 |
|
46 |
def update_dataset_view(diff_idx, df):
|
|
|
119 |
def layout_for_statistics(statistics_group_name):
|
120 |
gr.Markdown(f"### {statistics_group_name}")
|
121 |
stats = STATISTICS[statistics_group_name]
|
122 |
+
gr.Number(label="Count", interactive=False,
|
123 |
+
value=len(stats['deletions_norm']), min_width=00)
|
124 |
+
gr.Number(label="Avg deletions number (rel to the initial msg length)", interactive=False,
|
125 |
+
value=stats['deletions_norm'].mean().item(), precision=3, min_width=00)
|
126 |
+
gr.Number(label="Avg insertions number (rel to the result length)", interactive=False,
|
127 |
+
value=stats['insertions_norm'].mean().item(), precision=3, min_width=00)
|
128 |
+
gr.Number(label="Avg changes number (rel to the initial msg length)", interactive=False,
|
129 |
+
value=stats['changes_norm'].mean().item(), precision=3, min_width=00)
|
130 |
+
gr.Number(label="Avg deletions number", interactive=False,
|
131 |
+
value=stats['deletions'].mean().item(), precision=3, min_width=00)
|
132 |
+
gr.Number(label="Avg insertions number", interactive=False,
|
133 |
+
value=stats['insertions'].mean().item(), precision=3, min_width=00)
|
134 |
+
gr.Number(label="Avg changes number", interactive=False,
|
135 |
+
value=stats['changes'].mean().item(), precision=3, min_width=00)
|
136 |
|
137 |
|
138 |
with gr.Row():
|
139 |
+
with gr.Column(scale=1, min_width=100):
|
140 |
layout_for_statistics("manual")
|
141 |
+
with gr.Column(scale=1, min_width=100):
|
142 |
+
layout_for_statistics("e2s")
|
143 |
+
with gr.Column(scale=1, min_width=100):
|
144 |
+
layout_for_statistics("s2e")
|
145 |
+
with gr.Column(scale=1, min_width=100):
|
146 |
+
layout_for_statistics("e2s_s2e")
|
147 |
+
with gr.Column(scale=1, min_width=100):
|
148 |
+
layout_for_statistics("synthetic")
|
149 |
+
with gr.Column(scale=1, min_width=100):
|
150 |
+
layout_for_statistics("all")
|
151 |
|
152 |
+
with gr.Row():
|
153 |
with gr.Column(scale=1):
|
154 |
+
for stat_name in filter(lambda s: "_norm" not in s, STAT_NAMES):
|
155 |
+
chart = statistics.build_plotly_chart(
|
156 |
+
stat_golden=STATISTICS['manual'][stat_name],
|
157 |
+
stat_e2s=STATISTICS['e2s'][stat_name],
|
158 |
+
stat_s2e=STATISTICS['s2e'][stat_name],
|
159 |
+
stat_e2s_s2e=STATISTICS['e2s_s2e'][stat_name],
|
160 |
+
stat_name=stat_name
|
161 |
+
)
|
162 |
+
|
163 |
+
gr.Plot(value=chart)
|
164 |
+
with gr.Column(scale=1):
|
165 |
+
with gr.Column(scale=1):
|
166 |
+
for stat_name in filter(lambda s: "_norm" in s, STAT_NAMES):
|
167 |
+
chart = statistics.build_plotly_chart(
|
168 |
+
stat_golden=STATISTICS['manual'][stat_name],
|
169 |
+
stat_e2s=STATISTICS['e2s'][stat_name],
|
170 |
+
stat_s2e=STATISTICS['s2e'][stat_name],
|
171 |
+
stat_e2s_s2e=STATISTICS['e2s_s2e'][stat_name],
|
172 |
+
stat_name=stat_name
|
173 |
+
)
|
174 |
+
|
175 |
+
gr.Plot(value=chart)
|
176 |
|
177 |
gr.Markdown(f"### Reference-only correlations")
|
178 |
gr.Markdown(value=analysis_util.get_correlations_for_groups(df_synthetic, right_side="ind").to_markdown())
|
generation_steps/metrics_analysis.py
CHANGED
@@ -72,7 +72,8 @@ TER = evaluate.load("ter")
|
|
72 |
|
73 |
def ter_fn(pred, ref, **kwargs):
|
74 |
if "refs" in kwargs:
|
75 |
-
|
|
|
76 |
return TER.compute(predictions=[pred], references=[[ref]])["score"]
|
77 |
|
78 |
|
@@ -130,10 +131,10 @@ def gptscore_noref_5_fn(pred, ref, **kwargs):
|
|
130 |
IND_METRICS = {
|
131 |
"editdist": edit_distance_fn,
|
132 |
"editdist-norm": edit_distance_norm_fn,
|
133 |
-
"gptscore-ref-1-req": gptscore_ref_1_fn,
|
134 |
# "gptscore-ref-3-req": gptscore_ref_3_fn,
|
135 |
# "gptscore-ref-5-req": gptscore_ref_5_fn,
|
136 |
-
"gptscore-noref-1-req": gptscore_noref_1_fn,
|
137 |
# "gptscore-noref-3-req": gptscore_noref_3_fn,
|
138 |
# "gptscore-noref-5-req": gptscore_noref_5_fn,
|
139 |
"bleu": bleu_fn,
|
@@ -174,7 +175,8 @@ def compute_metrics(df):
|
|
174 |
values = []
|
175 |
for i, row in tqdm(df.iterrows(), total=len(df)):
|
176 |
others = df[(df["hash"] == row["hash"]) & (df["repo"] == row["repo"]) & (
|
177 |
-
df["commit_msg_start"] != row["commit_msg_start"])
|
|
|
178 |
others.append(row["reference"])
|
179 |
others = list(set(others))
|
180 |
metric_fn = AGGR_METRICS[metric]
|
|
|
72 |
|
73 |
def ter_fn(pred, ref, **kwargs):
|
74 |
if "refs" in kwargs:
|
75 |
+
scores = [TER.compute(predictions=[pred], references=[[ref]])["score"] for ref in kwargs["refs"]]
|
76 |
+
return sum(scores) / len(scores)
|
77 |
return TER.compute(predictions=[pred], references=[[ref]])["score"]
|
78 |
|
79 |
|
|
|
131 |
IND_METRICS = {
|
132 |
"editdist": edit_distance_fn,
|
133 |
"editdist-norm": edit_distance_norm_fn,
|
134 |
+
# "gptscore-ref-1-req": gptscore_ref_1_fn,
|
135 |
# "gptscore-ref-3-req": gptscore_ref_3_fn,
|
136 |
# "gptscore-ref-5-req": gptscore_ref_5_fn,
|
137 |
+
# "gptscore-noref-1-req": gptscore_noref_1_fn,
|
138 |
# "gptscore-noref-3-req": gptscore_noref_3_fn,
|
139 |
# "gptscore-noref-5-req": gptscore_noref_5_fn,
|
140 |
"bleu": bleu_fn,
|
|
|
175 |
values = []
|
176 |
for i, row in tqdm(df.iterrows(), total=len(df)):
|
177 |
others = df[(df["hash"] == row["hash"]) & (df["repo"] == row["repo"]) & (
|
178 |
+
df["commit_msg_start"] != row["commit_msg_start"]) & (
|
179 |
+
df["commit_msg_end"] != row["commit_msg_end"])]['commit_msg_end'].to_list()
|
180 |
others.append(row["reference"])
|
181 |
others = list(set(others))
|
182 |
metric_fn = AGGR_METRICS[metric]
|
requirements.txt
CHANGED
@@ -160,3 +160,7 @@ widgetsnbextension==4.0.10
|
|
160 |
xxhash==3.4.1
|
161 |
yarl==1.9.4
|
162 |
zipp==3.18.1
|
|
|
|
|
|
|
|
|
|
160 |
xxhash==3.4.1
|
161 |
yarl==1.9.4
|
162 |
zipp==3.18.1
|
163 |
+
|
164 |
+
plotly==5.22.0
|
165 |
+
tenacity==8.2.3
|
166 |
+
Levenshtein==0.25.1
|
statistics.py
CHANGED
@@ -1,24 +1,27 @@
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
|
|
3 |
|
4 |
|
5 |
def get_statistics(start_msg, end_msg, annotated_msg):
|
6 |
-
|
7 |
-
|
8 |
-
for
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
sum_changes = sum_deletions + sum_insertions
|
15 |
-
end_length = len(end_msg)
|
16 |
-
start_length = len(start_msg)
|
17 |
|
18 |
return {
|
19 |
-
"deletions":
|
20 |
-
"insertions":
|
21 |
-
"changes":
|
|
|
|
|
|
|
|
|
22 |
}
|
23 |
|
24 |
|
@@ -29,3 +32,17 @@ def get_statistics_for_df(df: pd.DataFrame):
|
|
29 |
assert len(stats) > 0
|
30 |
|
31 |
return {stat_name: np.asarray([e[stat_name] for e in stats]) for stat_name in stats[0]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import Levenshtein
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
+
import plotly.figure_factory as ff
|
5 |
|
6 |
|
7 |
def get_statistics(start_msg, end_msg, annotated_msg):
|
8 |
+
edit_ops = Levenshtein.editops(start_msg, end_msg)
|
9 |
+
n_deletes = sum([1 if op == 'delete' else 0 for op, _, _ in edit_ops])
|
10 |
+
n_inserts = sum([1 if op == 'insert' else 0 for op, _, _ in edit_ops])
|
11 |
+
n_replaces = sum([1 if op == 'replace' else 0 for op, _, _ in edit_ops])
|
12 |
+
|
13 |
+
n_changes = n_deletes + n_inserts + n_replaces
|
14 |
+
n_deletes += n_replaces
|
15 |
+
n_inserts += n_replaces
|
|
|
|
|
|
|
16 |
|
17 |
return {
|
18 |
+
"deletions": n_deletes,
|
19 |
+
"insertions": n_inserts,
|
20 |
+
"changes": n_changes,
|
21 |
+
|
22 |
+
"deletions_norm": n_deletes / len(start_msg),
|
23 |
+
"insertions_norm": n_inserts / len(end_msg),
|
24 |
+
"changes_norm": n_changes / len(end_msg),
|
25 |
}
|
26 |
|
27 |
|
|
|
32 |
assert len(stats) > 0
|
33 |
|
34 |
return {stat_name: np.asarray([e[stat_name] for e in stats]) for stat_name in stats[0]}
|
35 |
+
|
36 |
+
|
37 |
+
def build_plotly_chart(stat_golden, stat_e2s, stat_s2e, stat_e2s_s2e, stat_name):
|
38 |
+
hist_data = [stat_golden, stat_e2s, stat_s2e, stat_e2s_s2e,
|
39 |
+
np.concatenate((stat_e2s, stat_s2e, stat_e2s_s2e), axis=0)]
|
40 |
+
|
41 |
+
group_labels = ['Golden', 'e2s', 's2e', 'e2s+s 2e', 'Synthetic']
|
42 |
+
|
43 |
+
fig = ff.create_distplot(hist_data, group_labels,
|
44 |
+
bin_size=.1, show_rug=False, show_hist=True)
|
45 |
+
|
46 |
+
fig.update_layout(title_text=stat_name)
|
47 |
+
|
48 |
+
return fig
|