Spaces:
Runtime error
Runtime error
Duplicate from teven-projects/how_many_data_points
Browse filesCo-authored-by: Julien Chaumond <julien-c@users.noreply.huggingface.co>
- .gitattributes +34 -0
- .gitignore +6 -0
- Dockerfile +11 -0
- README.md +12 -0
- naacl_demo/demo_utils.py +514 -0
- naacl_demo/main.py +294 -0
- naacl_demo/text.md +82 -0
- naacl_demo/text.py +169 -0
- requirements.txt +22 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*$py.class
|
4 |
+
|
5 |
+
|
6 |
+
.env/
|
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.7
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY ./requirements.txt /code/requirements.txt
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
8 |
+
|
9 |
+
COPY . .
|
10 |
+
|
11 |
+
CMD ["bokeh", "serve", "naacl_demo", "--allow-websocket-origin=*"]
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: How Many Data Points
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
app_port: 5006
|
9 |
+
duplicated_from: teven-projects/how_many_data_points
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
naacl_demo/demo_utils.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from itertools import product
|
6 |
+
import shapely
|
7 |
+
from bokeh.models import Span, Label, ColumnDataSource, Whisker
|
8 |
+
from bokeh.plotting import figure, show
|
9 |
+
from shapely.geometry import Polygon
|
10 |
+
import matplotlib as mpl
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import seaborn
|
13 |
+
|
14 |
+
task_patterns = {
|
15 |
+
"CB": [0, 3],
|
16 |
+
"RTE": [0, 3],
|
17 |
+
"BoolQ": [0, 3, 5],
|
18 |
+
"MNLI": [0, 3],
|
19 |
+
"COPA": [0, 1],
|
20 |
+
"WSC": [0, 1, 2],
|
21 |
+
"WiC": [0, 1],
|
22 |
+
"MultiRC": [0, 1, 2],
|
23 |
+
}
|
24 |
+
task_reps = {"CB": 4, "RTE": 4, "BoolQ": 4, "MNLI": 4, "COPA": 4, "WSC": 4, "WiC": 4, "MultiRC": 4}
|
25 |
+
task_best_pattern = {"CB": 0, "RTE": 0, "BoolQ": 0, "MNLI": 0, "COPA": 1, "WSC": 0, "WiC": 0, "MultiRC": 1}
|
26 |
+
task_metric_short = {
|
27 |
+
"CB": "f1-macro",
|
28 |
+
"RTE": "acc",
|
29 |
+
"BoolQ": "acc",
|
30 |
+
"MNLI": "acc",
|
31 |
+
"COPA": "acc",
|
32 |
+
"WSC": "acc",
|
33 |
+
"WiC": "acc",
|
34 |
+
"MultiRC": "f1",
|
35 |
+
}
|
36 |
+
task_metrics = {
|
37 |
+
"CB": "F1-macro",
|
38 |
+
"RTE": "accuracy",
|
39 |
+
"BoolQ": "accuracy",
|
40 |
+
"MNLI": "accuracy",
|
41 |
+
"COPA": "accuracy",
|
42 |
+
"WSC": "accuracy",
|
43 |
+
"WiC": "accuracy",
|
44 |
+
"MultiRC": "F1",
|
45 |
+
}
|
46 |
+
task_neutral = {
|
47 |
+
"CB": True,
|
48 |
+
"RTE": True,
|
49 |
+
"BoolQ": True,
|
50 |
+
"MNLI": True,
|
51 |
+
"COPA": False,
|
52 |
+
"WSC": False,
|
53 |
+
"multirc": True,
|
54 |
+
"WiC": True,
|
55 |
+
"MultiRC": True,
|
56 |
+
}
|
57 |
+
neutral_tasks = [
|
58 |
+
"BoolQ",
|
59 |
+
"CB",
|
60 |
+
"MNLI",
|
61 |
+
"MultiRC",
|
62 |
+
"RTE",
|
63 |
+
"WiC",
|
64 |
+
]
|
65 |
+
tasks = sorted(task_patterns.keys())
|
66 |
+
|
67 |
+
pvp_colors = ["goldenrod", "blanchedalmond", "floralwhite"]
|
68 |
+
ctl_colors = ["crimson", "salmon", "mistyrose"]
|
69 |
+
clf_colors = ["indigo", "plum", "thistle"]
|
70 |
+
|
71 |
+
|
72 |
+
def prompt_boolq(passage, question, pattern):
|
73 |
+
if pattern == 0:
|
74 |
+
return f"""<span style="color: #0c593d">{passage}</span> <span style="color: #910713"><b>Based on the previous passage,</b></span> <span style="color: #031154">{question}</span> <span style="color: #ba9004"><b>[YES/NO]</b></span>"""
|
75 |
+
if pattern == 1:
|
76 |
+
return f"""<span style="color: #0c593d">{passage}</span><span style="color: #910713"><b> Question:</b></span> <span style="color: #031154">{question}</span><span style="color: #910713"><b> Answer: </b></span><span style="color: #ba9004"><b>[YES/NO]</b></span>"""
|
77 |
+
if pattern == 2:
|
78 |
+
return f"""<span style="color: #910713"><b>Based on the following passage,</b></span> <span style="color: #031154">{question}</span><span style="color: #ba9004"><b> [YES/NO]</b></span> <span style="color: #0c593d">{passage}</span>"""
|
79 |
+
|
80 |
+
|
81 |
+
def advantage_text(advantage):
|
82 |
+
model_type = (
|
83 |
+
"""<span style="color: #4B0082">Head</span>"""
|
84 |
+
if advantage < 0
|
85 |
+
else """<span style="color: #daa520">Prompting</span>"""
|
86 |
+
)
|
87 |
+
return f"""<b>{model_type}</b> advantage: <b>{abs(advantage):.2f}</b> data points"""
|
88 |
+
|
89 |
+
|
90 |
+
def average_advantage_text(advantage):
|
91 |
+
model_type = (
|
92 |
+
"""<span style="color: #4B0082">head</span>"""
|
93 |
+
if advantage < 0
|
94 |
+
else """<span style="color: #daa520">prompting</span>"""
|
95 |
+
)
|
96 |
+
return f"""<b>Average {model_type}</b> advantage: <b>{abs(advantage):.2f}</b> data points"""
|
97 |
+
|
98 |
+
|
99 |
+
def naming_convention(task, seed, pvp_index=None, neutral=False):
|
100 |
+
method = f"PVP {pvp_index}" if pvp_index is not None else "CLF"
|
101 |
+
model = "roberta"
|
102 |
+
if neutral:
|
103 |
+
verbalizer = "neutral"
|
104 |
+
else:
|
105 |
+
verbalizer = None
|
106 |
+
return (
|
107 |
+
f"{method} {model}"
|
108 |
+
+ (f" {verbalizer} verbalizer" if verbalizer is not None else "")
|
109 |
+
+ f" seed {seed} - test-{task_metric_short[task]}-all-p"
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def get_data(task):
|
114 |
+
url = f"https://raw.githubusercontent.com/TevenLeScao/pet/master/exported_results/{task.lower()}/wandb_export.csv"
|
115 |
+
df = pd.read_csv(url)
|
116 |
+
training_points = df["training_points"]
|
117 |
+
|
118 |
+
head_performances = np.transpose(np.array([df[naming_convention(task, i)] for i in range(task_reps[task])]))
|
119 |
+
pattern_performances = {}
|
120 |
+
for pattern in task_patterns[task]:
|
121 |
+
pattern_performances[pattern] = {
|
122 |
+
"normal": np.transpose(np.array([df[naming_convention(task, i, pattern)] for i in range(task_reps[task])]))
|
123 |
+
}
|
124 |
+
if task_neutral[task]:
|
125 |
+
pattern_performances[pattern]["neutral"] = np.transpose(
|
126 |
+
np.array([df[naming_convention(task, i, pattern, True)] for i in range(task_reps[task])])
|
127 |
+
)
|
128 |
+
|
129 |
+
return training_points, head_performances, pattern_performances
|
130 |
+
|
131 |
+
|
132 |
+
def reduct(performances, reduction="accmax", final_pattern=0, verbalizer="normal", exclude=None):
|
133 |
+
# Combining the different runs for each experimental set-up
|
134 |
+
reducted = None
|
135 |
+
|
136 |
+
if isinstance(performances, dict):
|
137 |
+
performances = performances[final_pattern][verbalizer]
|
138 |
+
if exclude is not None:
|
139 |
+
performances = np.delete(performances, exclude, axis=1)
|
140 |
+
|
141 |
+
if reduction == "avg":
|
142 |
+
# Average
|
143 |
+
reducted = np.nanmean(performances, axis=1)
|
144 |
+
|
145 |
+
if reduction == "std":
|
146 |
+
# Standard deviation
|
147 |
+
reducted = np.nanstd(performances, axis=1)
|
148 |
+
|
149 |
+
if reduction == "max":
|
150 |
+
# Maximum
|
151 |
+
reducted = np.nanmax(performances, axis=1)
|
152 |
+
|
153 |
+
if reduction == "accmax":
|
154 |
+
# This makes the maximum curve monotonic
|
155 |
+
max_performance = np.nanmax(performances, axis=1)
|
156 |
+
reducted = np.maximum.accumulate(max_performance)
|
157 |
+
|
158 |
+
assert reducted is not None, "unrecognized reduction method"
|
159 |
+
return reducted
|
160 |
+
|
161 |
+
|
162 |
+
def find_surrounding_points(perf, clf_results, pvp_results):
|
163 |
+
for i, clf_result in enumerate(clf_results):
|
164 |
+
if i - 1 > 0 and clf_result == clf_results[i - 1]:
|
165 |
+
continue
|
166 |
+
if clf_result > perf:
|
167 |
+
if i == 0:
|
168 |
+
raise ValueError(f"value {perf} too small")
|
169 |
+
else:
|
170 |
+
break
|
171 |
+
for j, pvp_result in enumerate(pvp_results):
|
172 |
+
if j - 1 > 0 and pvp_result == pvp_results[j - 1]:
|
173 |
+
continue
|
174 |
+
if pvp_result > perf:
|
175 |
+
if j == 0:
|
176 |
+
raise ValueError(f"value {perf} too small")
|
177 |
+
else:
|
178 |
+
break
|
179 |
+
return i - 1, j - 1
|
180 |
+
|
181 |
+
|
182 |
+
def interpolate(perf, x1, x2, y1, y2):
|
183 |
+
return x1 + (perf - y1) * (x2 - x1) / (y2 - y1)
|
184 |
+
|
185 |
+
|
186 |
+
def interpolate_from_idx(perf, idx, results, training_points):
|
187 |
+
return interpolate(perf, training_points[idx], training_points[idx + 1], results[idx], results[idx + 1])
|
188 |
+
|
189 |
+
|
190 |
+
def interpolate_from_perf(perf, overlapping_range, training_points, clf_results, pvp_results):
|
191 |
+
if not overlapping_range[0] <= perf <= overlapping_range[1]:
|
192 |
+
raise ValueError(f"perf {perf} not in acceptable bounds {overlapping_range}")
|
193 |
+
clf_idx, pvp_idx = find_surrounding_points(perf, clf_results, pvp_results)
|
194 |
+
return interpolate_from_idx(perf, clf_idx, clf_results, training_points), interpolate_from_idx(
|
195 |
+
perf, pvp_idx, pvp_results, training_points
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def data_difference(perf, overlapping_range, training_points, clf_results, pvp_results):
|
200 |
+
x1, x2 = interpolate_from_perf(perf, overlapping_range, training_points, clf_results, pvp_results)
|
201 |
+
return x1 - x2
|
202 |
+
|
203 |
+
|
204 |
+
def calculate_overlap(clf_results, pvp_results, full_range=False):
|
205 |
+
if full_range:
|
206 |
+
return (min(min(clf_results), min(pvp_results)), max(max(clf_results), max(pvp_results)))
|
207 |
+
else:
|
208 |
+
return (max(min(clf_results), min(pvp_results)), min(max(clf_results), max(pvp_results)))
|
209 |
+
|
210 |
+
|
211 |
+
def calculate_range(overlapping_range, number_of_points):
|
212 |
+
integral_range = (
|
213 |
+
overlapping_range[0] + i / (number_of_points + 1) * (overlapping_range[1] - overlapping_range[0])
|
214 |
+
for i in range(1, number_of_points + 1)
|
215 |
+
)
|
216 |
+
return integral_range
|
217 |
+
|
218 |
+
|
219 |
+
def calculate_differences(integral_range, overlapping_range, training_points, clf_results, pvp_results):
|
220 |
+
differences = [
|
221 |
+
data_difference(y, overlapping_range, training_points, clf_results, pvp_results) for y in integral_range
|
222 |
+
]
|
223 |
+
return differences
|
224 |
+
|
225 |
+
|
226 |
+
def calculate_offset(training_points, clf_results, pvp_results, number_of_points=1000):
|
227 |
+
overlapping_range = calculate_overlap(clf_results, pvp_results)
|
228 |
+
integral_range = calculate_range(overlapping_range, number_of_points)
|
229 |
+
differences = calculate_differences(integral_range, overlapping_range, training_points, clf_results, pvp_results)
|
230 |
+
offset = sum(differences) / number_of_points
|
231 |
+
return offset
|
232 |
+
|
233 |
+
|
234 |
+
def intersection_with_range(training_points, results, band):
|
235 |
+
result_polygon = Polygon(
|
236 |
+
[(training_points[i], results[i]) for i in range(len(training_points))]
|
237 |
+
+ [(training_points[-1], 0), (training_points[0], 0)]
|
238 |
+
)
|
239 |
+
return result_polygon.intersection(band)
|
240 |
+
|
241 |
+
|
242 |
+
def fill_polygon(fig, polygon, color, label=None, alpha=1.0):
|
243 |
+
if polygon.is_empty or isinstance(polygon, shapely.geometry.LineString):
|
244 |
+
return
|
245 |
+
if isinstance(polygon, Polygon):
|
246 |
+
xs, ys = polygon.exterior.xy
|
247 |
+
fig.patch(xs, ys, color=color, alpha=alpha)
|
248 |
+
else:
|
249 |
+
for geom in polygon.geoms:
|
250 |
+
if isinstance(geom, shapely.geometry.LineString):
|
251 |
+
continue
|
252 |
+
xs, ys = geom.exterior.xy
|
253 |
+
fig.patch(xs, ys, color=color, alpha=alpha)
|
254 |
+
label = None
|
255 |
+
|
256 |
+
|
257 |
+
label_order = {
|
258 |
+
"head run": 0,
|
259 |
+
"head advantage": 1,
|
260 |
+
"control run": 2,
|
261 |
+
"optimization advantage": 3,
|
262 |
+
"prompting run": 4,
|
263 |
+
"semantics advantage": 5,
|
264 |
+
"region of comparison": 6,
|
265 |
+
}
|
266 |
+
|
267 |
+
|
268 |
+
def metric_tap(
|
269 |
+
event, overlapping_range, training_points, clf_results, pvp_results, advantage_box, advantage_plot
|
270 |
+
):
|
271 |
+
_, metric_value = event.x, event.y
|
272 |
+
try:
|
273 |
+
advantage_value = data_difference(metric_value, overlapping_range, training_points, clf_results, pvp_results)
|
274 |
+
advantage_box.text = advantage_text(advantage_value)
|
275 |
+
if not isinstance(advantage_plot.renderers[-1], Span):
|
276 |
+
metric_line = Span(
|
277 |
+
location=metric_value,
|
278 |
+
line_alpha=0.7,
|
279 |
+
dimension="width",
|
280 |
+
line_color=clf_colors[0] if advantage_value < 0 else pvp_colors[0],
|
281 |
+
line_dash="dashed",
|
282 |
+
line_width=1,
|
283 |
+
)
|
284 |
+
advantage_plot.renderers.extend([metric_line])
|
285 |
+
else:
|
286 |
+
advantage_plot.renderers[-1].location = metric_value
|
287 |
+
advantage_plot.renderers[-1].line_color = clf_colors[0] if advantage_value < 0 else pvp_colors[0]
|
288 |
+
# clicking outside the region
|
289 |
+
except ValueError:
|
290 |
+
pass
|
291 |
+
|
292 |
+
|
293 |
+
def plot_polygons_bokeh(task, training_points, clf_results, pvp_results, clf_colors, pvp_colors, x_log_scale=False):
|
294 |
+
overlapping_range = calculate_overlap(clf_results, pvp_results, False)
|
295 |
+
full_range = calculate_overlap(clf_results, pvp_results, True)
|
296 |
+
middle_y = (full_range[0] + full_range[1]) / 2
|
297 |
+
|
298 |
+
fig = figure(plot_height=400, plot_width=800, max_height=400, max_width=800,
|
299 |
+
x_axis_type="log" if x_log_scale else "linear", title="Performance over training subset sizes of head and prompting methods")
|
300 |
+
|
301 |
+
fig.circle(training_points, clf_results, color=clf_colors[0], legend="head run")
|
302 |
+
fig.circle(training_points, pvp_results, color=pvp_colors[0], legend="prompting run")
|
303 |
+
fig.line(training_points, clf_results, color=clf_colors[0], alpha=1)
|
304 |
+
fig.line(training_points, pvp_results, color=pvp_colors[0], alpha=1)
|
305 |
+
fig.xaxis.axis_label = "training subset size"
|
306 |
+
fig.yaxis.axis_label = task_metrics[task]
|
307 |
+
fig.patch(
|
308 |
+
[training_points[0], training_points[0], training_points[-1], training_points[-1]],
|
309 |
+
[overlapping_range[0], overlapping_range[1], overlapping_range[1], overlapping_range[0]],
|
310 |
+
color="black",
|
311 |
+
fill_alpha=0,
|
312 |
+
line_width=0,
|
313 |
+
legend="comparison region",
|
314 |
+
hatch_alpha=0.14,
|
315 |
+
hatch_scale=40,
|
316 |
+
hatch_pattern="/",
|
317 |
+
)
|
318 |
+
|
319 |
+
band = Polygon(
|
320 |
+
[
|
321 |
+
(training_points[0], overlapping_range[0]),
|
322 |
+
(training_points[0], overlapping_range[1]),
|
323 |
+
(training_points[-1], overlapping_range[1]),
|
324 |
+
(training_points[-1], overlapping_range[0]),
|
325 |
+
]
|
326 |
+
)
|
327 |
+
full_band = Polygon(
|
328 |
+
[
|
329 |
+
(training_points[0], full_range[0]),
|
330 |
+
(training_points[0], full_range[1]),
|
331 |
+
(training_points[-1], full_range[1]),
|
332 |
+
(training_points[-1], full_range[0]),
|
333 |
+
]
|
334 |
+
)
|
335 |
+
clf_polygon = intersection_with_range(training_points, clf_results, band)
|
336 |
+
pvp_polygon = intersection_with_range(training_points, pvp_results, band)
|
337 |
+
full_clf_polygon = intersection_with_range(training_points, clf_results, full_band)
|
338 |
+
full_pvp_polygon = intersection_with_range(training_points, pvp_results, full_band)
|
339 |
+
|
340 |
+
clf_inside_area = clf_polygon.difference(pvp_polygon)
|
341 |
+
pvp_inside_area = pvp_polygon.difference(clf_polygon)
|
342 |
+
clf_outside_area = (full_clf_polygon.difference(full_pvp_polygon)).difference(clf_inside_area)
|
343 |
+
pvp_outside_area = (full_pvp_polygon.difference(full_clf_polygon)).difference(pvp_inside_area)
|
344 |
+
|
345 |
+
fill_polygon(fig, clf_outside_area, clf_colors[1], alpha=0.13)
|
346 |
+
fill_polygon(fig, pvp_outside_area, pvp_colors[1], alpha=0.18)
|
347 |
+
fill_polygon(
|
348 |
+
fig, clf_inside_area, clf_colors[1], alpha=0.4, label="head advantage" if task == "WiC" else None
|
349 |
+
)
|
350 |
+
fill_polygon(fig, pvp_inside_area, pvp_colors[1], alpha=0.4, label="prompting advantage")
|
351 |
+
|
352 |
+
fig.line([training_points[0], training_points[-1]], [overlapping_range[0], overlapping_range[0]], color="dimgrey")
|
353 |
+
fig.line([training_points[0], training_points[-1]], [overlapping_range[1], overlapping_range[1]], color="dimgrey")
|
354 |
+
|
355 |
+
vline = Span(
|
356 |
+
location=training_points[-1], dimension="height", line_color="black", line_width=2.5, line_dash="dashed"
|
357 |
+
)
|
358 |
+
end_label = Label(
|
359 |
+
x=training_points[-1], y=middle_y, text="End of dataset", angle=90, angle_units="deg", text_align="center"
|
360 |
+
)
|
361 |
+
fig.renderers.extend([vline, end_label])
|
362 |
+
|
363 |
+
fig.legend.location = "bottom_right"
|
364 |
+
|
365 |
+
return fig
|
366 |
+
|
367 |
+
|
368 |
+
def plot_three_polygons_bokeh(
|
369 |
+
task, training_points, clf_results, pvp_results, ctl_results, clf_colors, pvp_colors, ctl_colors,
|
370 |
+
x_log_scale=False
|
371 |
+
):
|
372 |
+
overlapping_range = calculate_overlap(clf_results, pvp_results, False)
|
373 |
+
full_range = calculate_overlap(clf_results, pvp_results, True)
|
374 |
+
middle_y = (full_range[0] + full_range[1]) / 2
|
375 |
+
|
376 |
+
fig = figure(plot_height=400, plot_width=800, max_height=400, max_width=800,
|
377 |
+
x_axis_type="log" if x_log_scale else "linear", title="Performance over training subset sizes of head, prompting and prompting with a null verbalizer")
|
378 |
+
fig.xaxis.axis_label = "training subset size"
|
379 |
+
fig.yaxis.axis_label = task_metrics[task]
|
380 |
+
fig.circle(training_points, clf_results, color=clf_colors[0], legend="head run")
|
381 |
+
fig.circle(training_points, pvp_results, color=pvp_colors[0], legend="prompting run")
|
382 |
+
fig.circle(training_points, ctl_results, color=ctl_colors[0], legend="null verbalizer run")
|
383 |
+
fig.line(training_points, clf_results, color=clf_colors[0], alpha=1)
|
384 |
+
fig.line(training_points, pvp_results, color=pvp_colors[0], alpha=1)
|
385 |
+
fig.line(training_points, ctl_results, color=ctl_colors[0], alpha=1)
|
386 |
+
|
387 |
+
fig.patch(
|
388 |
+
[training_points[0], training_points[0], training_points[-1], training_points[-1]],
|
389 |
+
[overlapping_range[0], overlapping_range[1], overlapping_range[1], overlapping_range[0]],
|
390 |
+
color="black",
|
391 |
+
fill_alpha=0,
|
392 |
+
line_width=0,
|
393 |
+
legend="comparison region",
|
394 |
+
hatch_alpha=0.14,
|
395 |
+
hatch_scale=40,
|
396 |
+
hatch_pattern="/",
|
397 |
+
)
|
398 |
+
|
399 |
+
band = Polygon(
|
400 |
+
[
|
401 |
+
(training_points[0], overlapping_range[0]),
|
402 |
+
(training_points[0], overlapping_range[1]),
|
403 |
+
(training_points[-1], overlapping_range[1]),
|
404 |
+
(training_points[-1], overlapping_range[0]),
|
405 |
+
]
|
406 |
+
)
|
407 |
+
full_band = Polygon(
|
408 |
+
[
|
409 |
+
(training_points[0], full_range[0]),
|
410 |
+
(training_points[0], full_range[1]),
|
411 |
+
(training_points[-1], full_range[1]),
|
412 |
+
(training_points[-1], full_range[0]),
|
413 |
+
]
|
414 |
+
)
|
415 |
+
|
416 |
+
clf_polygon = intersection_with_range(training_points, clf_results, band)
|
417 |
+
pvp_polygon = intersection_with_range(training_points, pvp_results, band)
|
418 |
+
ctl_polygon = intersection_with_range(training_points, ctl_results, band)
|
419 |
+
|
420 |
+
full_clf_polygon = intersection_with_range(training_points, clf_results, full_band)
|
421 |
+
full_pvp_polygon = intersection_with_range(training_points, pvp_results, full_band)
|
422 |
+
full_ctl_polygon = intersection_with_range(training_points, ctl_results, full_band)
|
423 |
+
|
424 |
+
clf_inside_area = clf_polygon.difference(ctl_polygon)
|
425 |
+
pvp_inside_area = pvp_polygon.difference(clf_polygon).difference(ctl_polygon)
|
426 |
+
ctl_inside_area = ctl_polygon.difference(clf_polygon)
|
427 |
+
|
428 |
+
clf_outside_area = (full_clf_polygon.difference(full_ctl_polygon)).difference(clf_inside_area)
|
429 |
+
pvp_outside_area = (full_pvp_polygon.difference(full_clf_polygon).difference(ctl_polygon)).difference(
|
430 |
+
pvp_inside_area
|
431 |
+
)
|
432 |
+
ctl_outside_area = (full_ctl_polygon.difference(full_clf_polygon)).difference(pvp_inside_area)
|
433 |
+
|
434 |
+
fill_polygon(
|
435 |
+
fig, clf_inside_area, clf_colors[1], alpha=0.4, label="head advantage" if task == "WiC" else None
|
436 |
+
)
|
437 |
+
fill_polygon(fig, pvp_inside_area, pvp_colors[1], alpha=0.4, label="prompting advantage")
|
438 |
+
fill_polygon(fig, ctl_inside_area, ctl_colors[1], alpha=0.4, label="null verbalizer advantage")
|
439 |
+
fill_polygon(fig, clf_outside_area, clf_colors[1], alpha=0.13)
|
440 |
+
fill_polygon(fig, pvp_outside_area, pvp_colors[1], alpha=0.18)
|
441 |
+
fill_polygon(fig, ctl_outside_area, ctl_colors[1], alpha=0.13)
|
442 |
+
|
443 |
+
fig.line([training_points[0], training_points[-1]], [overlapping_range[0], overlapping_range[0]], color="dimgrey")
|
444 |
+
fig.line([training_points[0], training_points[-1]], [overlapping_range[1], overlapping_range[1]], color="dimgrey")
|
445 |
+
|
446 |
+
vline = Span(
|
447 |
+
location=training_points[-1], dimension="height", line_color="black", line_width=2.5, line_dash="dashed"
|
448 |
+
)
|
449 |
+
end_label = Label(
|
450 |
+
x=training_points[-1], y=middle_y, text="End of dataset", angle=90, angle_units="deg", text_align="center"
|
451 |
+
)
|
452 |
+
fig.renderers.extend([vline, end_label])
|
453 |
+
|
454 |
+
fig.legend.location = "bottom_right"
|
455 |
+
|
456 |
+
return fig
|
457 |
+
|
458 |
+
|
459 |
+
def pattern_graph(task):
|
460 |
+
fig = figure(plot_height=400, plot_width=800, max_height=400, max_width=800, x_axis_type="log", title="Performance over training subset sizes of different prompt patterns")
|
461 |
+
fig.xaxis.axis_label = "training subset size"
|
462 |
+
fig.yaxis.axis_label = task_metrics[task]
|
463 |
+
url = f"https://raw.githubusercontent.com/TevenLeScao/pet/master/exported_results/{task.lower()}/wandb_export.csv"
|
464 |
+
df = pd.read_csv(url)
|
465 |
+
expanded_training_points = np.array(list(df["training_points"]) * task_reps[task] * len(task_patterns[task]))
|
466 |
+
data = np.array(df[[naming_convention(task, seed, pattern) for pattern in task_patterns[task] for seed in
|
467 |
+
range(task_reps[task])]])
|
468 |
+
data = data.reshape(-1, task_reps[task])
|
469 |
+
col_med = np.nanmean(data, axis=1)
|
470 |
+
# Find indices that you need to replace
|
471 |
+
inds = np.where(np.isnan(data))
|
472 |
+
# Place column means in the indices. Align the arrays using take
|
473 |
+
data[inds] = np.take(col_med, inds[0])
|
474 |
+
data = data.reshape(len(df["training_points"]), -1)
|
475 |
+
data = data.transpose().reshape(-1)
|
476 |
+
data = data + np.random.normal(0, 0.01, len(data))
|
477 |
+
pattern = np.array([i // (len(data) // len(task_patterns[task])) for i in range(len(data))])
|
478 |
+
seed = np.array([0, 1, 2, 3] * (len(data) // task_reps[task]))
|
479 |
+
long_df = pd.DataFrame(np.stack((expanded_training_points, pattern, seed, data), axis=1),
|
480 |
+
columns=["training_points", "pattern", "seed", task_metrics[task]])
|
481 |
+
long_df['pattern'] = long_df['pattern'].astype(int).astype(str)
|
482 |
+
gby_pattern = long_df.groupby('pattern')
|
483 |
+
pattern_colors = ["royalblue", "darkturquoise", "darkviolet"]
|
484 |
+
|
485 |
+
for i, (pattern, pattern_df) in enumerate(gby_pattern):
|
486 |
+
gby_training_points = pattern_df.groupby('training_points')
|
487 |
+
x = [training_point for training_point, training_point_df in gby_training_points]
|
488 |
+
y_max = list([np.max(training_point_df[task_metrics[task]]) for training_point, training_point_df in gby_training_points])
|
489 |
+
y_min = list([np.min(training_point_df[task_metrics[task]]) for training_point, training_point_df in gby_training_points])
|
490 |
+
y = list([np.median(training_point_df[task_metrics[task]]) for training_point, training_point_df in gby_training_points])
|
491 |
+
fig.circle(x, y, color=pattern_colors[i], alpha=1, legend=f"Pattern {i}")
|
492 |
+
fig.line(x, y, color=pattern_colors[i], alpha=1)
|
493 |
+
fig.varea(x=x, y1=y_max, y2=y_min, color=pattern_colors[i], alpha=0.11)
|
494 |
+
# source = ColumnDataSource(data=dict(base=x, lower=y_min, upper=y_max))
|
495 |
+
# w = Whisker(source=source, base="base", upper="upper", lower="lower", line_color=pattern_colors[i], line_alpha=0.3)
|
496 |
+
# w.upper_head.line_color = pattern_colors[i]
|
497 |
+
# w.lower_head.line_color = pattern_colors[i]
|
498 |
+
# fig.add_layout(w)
|
499 |
+
|
500 |
+
return fig
|
501 |
+
|
502 |
+
|
503 |
+
|
504 |
+
def cubic_easing(t):
|
505 |
+
if t < 0.5:
|
506 |
+
return 4 * t * t * t
|
507 |
+
p = 2 * t - 2
|
508 |
+
return 0.5 * p * p * p + 1
|
509 |
+
|
510 |
+
|
511 |
+
def circ_easing(t):
|
512 |
+
if t < 0.5:
|
513 |
+
return 0.5 * (1 - math.sqrt(1 - 4 * (t * t)))
|
514 |
+
return 0.5 * (math.sqrt(-((2 * t) - 3) * ((2 * t) - 1)) + 1)
|
naacl_demo/main.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bokeh.events import Tap
|
2 |
+
from bokeh.io import curdoc
|
3 |
+
from bokeh.layouts import column
|
4 |
+
from bokeh.models import Div, TextInput, RadioButtonGroup, TextAreaInput, Span, Button, Panel, Tabs
|
5 |
+
from bokeh.models.tools import CrosshairTool
|
6 |
+
|
7 |
+
from demo_utils import (
|
8 |
+
get_data,
|
9 |
+
prompt_boolq,
|
10 |
+
pvp_colors,
|
11 |
+
ctl_colors,
|
12 |
+
clf_colors,
|
13 |
+
reduct,
|
14 |
+
task_best_pattern,
|
15 |
+
plot_polygons_bokeh,
|
16 |
+
advantage_text,
|
17 |
+
data_difference,
|
18 |
+
calculate_overlap,
|
19 |
+
circ_easing,
|
20 |
+
average_advantage_text,
|
21 |
+
plot_three_polygons_bokeh,
|
22 |
+
tasks,
|
23 |
+
metric_tap,
|
24 |
+
neutral_tasks, pattern_graph,
|
25 |
+
)
|
26 |
+
from text import text1, text2, text3, text4, initial_passage, initial_question, text5
|
27 |
+
|
28 |
+
########################################################################################################################
|
29 |
+
# Basic dimensions
|
30 |
+
########################################################################################################################
|
31 |
+
|
32 |
+
plot_width = 1200
|
33 |
+
plot_height = 400
|
34 |
+
sidebar_width = 400
|
35 |
+
in_text_plot_height = 300
|
36 |
+
text_width = 800
|
37 |
+
widget_size = 400
|
38 |
+
|
39 |
+
########################################################################################################################
|
40 |
+
# Patternification widget
|
41 |
+
########################################################################################################################
|
42 |
+
|
43 |
+
passage = TextAreaInput(title="Passage", rows=3, value=initial_passage, max_width=text_width)
|
44 |
+
passage.align = "center"
|
45 |
+
question = TextInput(title="Question", value=initial_question, max_width=text_width)
|
46 |
+
question.align = "center"
|
47 |
+
radio_button_group = RadioButtonGroup(labels=["Pattern 1", "Pattern 2", "Pattern 3"], active=0, max_width=text_width)
|
48 |
+
radio_button_group.align = "center"
|
49 |
+
|
50 |
+
box_style = {
|
51 |
+
"display": "block",
|
52 |
+
"margin": "0 auto",
|
53 |
+
"width": f"{text_width}px",
|
54 |
+
"text-align": "center",
|
55 |
+
"white-space": "pre-wrap",
|
56 |
+
"background": "#f4f4f4",
|
57 |
+
"border": "1px solid #ddd",
|
58 |
+
# "border-left": "3px solid #4d4945",
|
59 |
+
"color": "#666",
|
60 |
+
"page-break-inside": "avoid",
|
61 |
+
# "font-family": "monospace",
|
62 |
+
"font-size": "15px",
|
63 |
+
"line-height": "1.6",
|
64 |
+
"max-width": "100%",
|
65 |
+
"overflow": "hidden",
|
66 |
+
"min-height": "30px",
|
67 |
+
"word-wrap": "break-word",
|
68 |
+
}
|
69 |
+
|
70 |
+
prompt_box = Div(
|
71 |
+
text=prompt_boolq(passage.value, question.value, radio_button_group.active),
|
72 |
+
width=text_width,
|
73 |
+
style=box_style,
|
74 |
+
sizing_mode="scale_width",
|
75 |
+
)
|
76 |
+
prompt_box.align = "center"
|
77 |
+
|
78 |
+
|
79 |
+
def update_prompt(attrname, old, new):
|
80 |
+
prompt_box.text = prompt_boolq(passage.value, question.value, radio_button_group.active)
|
81 |
+
|
82 |
+
|
83 |
+
passage.on_change("value", update_prompt)
|
84 |
+
question.on_change("value", update_prompt)
|
85 |
+
radio_button_group.on_change("active", update_prompt)
|
86 |
+
|
87 |
+
patternification = column(passage, question, radio_button_group, prompt_box, sizing_mode="scale_width")
|
88 |
+
patternification.align = "center"
|
89 |
+
|
90 |
+
########################################################################################################################
|
91 |
+
# Advantage diagram
|
92 |
+
########################################################################################################################
|
93 |
+
|
94 |
+
advantage_plots_per_task = []
|
95 |
+
overlapping_range_per_task = []
|
96 |
+
training_points_per_task = []
|
97 |
+
clf_results_per_task = []
|
98 |
+
pvp_results_per_task = []
|
99 |
+
advantage_tabs = []
|
100 |
+
advantage_all_figures = Tabs(tabs=advantage_tabs)
|
101 |
+
|
102 |
+
advantage_box = Div(
|
103 |
+
text="Click within the comparison region to compute the data advantage for a performance level",
|
104 |
+
width=text_width,
|
105 |
+
style=box_style,
|
106 |
+
sizing_mode="scale_width",
|
107 |
+
)
|
108 |
+
advantage_box.align = "center"
|
109 |
+
|
110 |
+
for task in tasks:
|
111 |
+
training_points, classifier_performances, pattern_performances = get_data(task)
|
112 |
+
training_points_per_task.append(list(training_points))
|
113 |
+
clf_results_per_task.append(reduct(classifier_performances, "accmax"))
|
114 |
+
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal"))
|
115 |
+
advantage_plots_per_task.append(plot_polygons_bokeh(
|
116 |
+
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors,
|
117 |
+
pvp_colors
|
118 |
+
))
|
119 |
+
advantage_plots_per_task[-1].align = "center"
|
120 |
+
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
121 |
+
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1]))
|
122 |
+
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title=task))
|
123 |
+
|
124 |
+
advantage_plots_per_task[-1].on_event(
|
125 |
+
Tap,
|
126 |
+
lambda event: metric_tap(
|
127 |
+
event,
|
128 |
+
overlapping_range_per_task[advantage_all_figures.active],
|
129 |
+
training_points_per_task[advantage_all_figures.active],
|
130 |
+
clf_results_per_task[advantage_all_figures.active],
|
131 |
+
pvp_results_per_task[advantage_all_figures.active],
|
132 |
+
advantage_box,
|
133 |
+
advantage_plots_per_task[advantage_all_figures.active],
|
134 |
+
),
|
135 |
+
)
|
136 |
+
|
137 |
+
if task == "MNLI":
|
138 |
+
training_points_per_task.append(list(training_points))
|
139 |
+
clf_results_per_task.append(reduct(classifier_performances, "accmax"))
|
140 |
+
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal"))
|
141 |
+
advantage_plots_per_task.append(plot_polygons_bokeh(
|
142 |
+
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors,
|
143 |
+
pvp_colors, x_log_scale=True
|
144 |
+
))
|
145 |
+
advantage_plots_per_task[-1].align = "center"
|
146 |
+
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
147 |
+
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1]))
|
148 |
+
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title="MNLI (log scale)"))
|
149 |
+
|
150 |
+
advantage_plots_per_task[-1].on_event(
|
151 |
+
Tap,
|
152 |
+
lambda event: metric_tap(
|
153 |
+
event,
|
154 |
+
overlapping_range_per_task[advantage_all_figures.active],
|
155 |
+
training_points_per_task[advantage_all_figures.active],
|
156 |
+
clf_results_per_task[advantage_all_figures.active],
|
157 |
+
pvp_results_per_task[advantage_all_figures.active],
|
158 |
+
advantage_box,
|
159 |
+
advantage_plots_per_task[advantage_all_figures.active],
|
160 |
+
),
|
161 |
+
)
|
162 |
+
|
163 |
+
advantage_all_figures = Tabs(tabs=advantage_tabs)
|
164 |
+
advantage_all_figures.align = "center"
|
165 |
+
|
166 |
+
|
167 |
+
def on_integrate_click():
|
168 |
+
frames = 200
|
169 |
+
initial_placement = overlapping_range_per_task[advantage_all_figures.active][0]
|
170 |
+
|
171 |
+
if not isinstance(advantage_plots_per_task[advantage_all_figures.active].renderers[-1], Span):
|
172 |
+
metric_line = Span(
|
173 |
+
location=initial_placement,
|
174 |
+
line_alpha=0.7,
|
175 |
+
dimension="width",
|
176 |
+
line_color=clf_colors[0] if initial_placement < 0 else pvp_colors[0],
|
177 |
+
line_dash="dashed",
|
178 |
+
line_width=1,
|
179 |
+
)
|
180 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers.extend([metric_line])
|
181 |
+
else:
|
182 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = initial_placement
|
183 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[
|
184 |
+
0] if initial_placement < 0 else pvp_colors[0]
|
185 |
+
|
186 |
+
average_advantage = 0
|
187 |
+
for i in range(1, frames):
|
188 |
+
metric_value = overlapping_range_per_task[advantage_all_figures.active][0] + (
|
189 |
+
overlapping_range_per_task[advantage_all_figures.active][1] -
|
190 |
+
overlapping_range_per_task[advantage_all_figures.active][0]) * (i / frames)
|
191 |
+
advantage_value = data_difference(metric_value, overlapping_range_per_task[advantage_all_figures.active],
|
192 |
+
training_points_per_task[advantage_all_figures.active],
|
193 |
+
clf_results_per_task[advantage_all_figures.active],
|
194 |
+
pvp_results_per_task[advantage_all_figures.active])
|
195 |
+
average_advantage = ((i - 1) * average_advantage + advantage_value) / i
|
196 |
+
|
197 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = metric_value
|
198 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[
|
199 |
+
0] if advantage_value < 0 else pvp_colors[0]
|
200 |
+
advantage_box.text = average_advantage_text(average_advantage)
|
201 |
+
|
202 |
+
|
203 |
+
integrate = Button(width=175, max_width=175, label="Integrate over the whole region!")
|
204 |
+
integrate.align = "center"
|
205 |
+
integrate.on_click(on_integrate_click)
|
206 |
+
|
207 |
+
|
208 |
+
def on_tab_change(attr, old, new):
|
209 |
+
advantage_box.text = "Click within the comparison region to compute the data advantage for a performance level"
|
210 |
+
|
211 |
+
|
212 |
+
advantage_all_figures.on_change('active', on_tab_change)
|
213 |
+
|
214 |
+
advantage_column = column(advantage_all_figures, advantage_box, integrate, sizing_mode="scale_width")
|
215 |
+
|
216 |
+
########################################################################################################################
|
217 |
+
# Null verbalizer diagram
|
218 |
+
########################################################################################################################
|
219 |
+
|
220 |
+
null_tabs = []
|
221 |
+
null_all_figures = Tabs(tabs=null_tabs)
|
222 |
+
|
223 |
+
for task in neutral_tasks:
|
224 |
+
training_points, classifier_performances, pattern_performances = get_data(task)
|
225 |
+
training_points = list(training_points)
|
226 |
+
clf_results = reduct(classifier_performances, "accmax")
|
227 |
+
pvp_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")
|
228 |
+
ctl_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "neutral")
|
229 |
+
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors,
|
230 |
+
pvp_colors, ctl_colors)
|
231 |
+
null_plot.align = "center"
|
232 |
+
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
233 |
+
null_tabs.append(Panel(child=null_plot, title=task))
|
234 |
+
|
235 |
+
if task == "MNLI":
|
236 |
+
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors,
|
237 |
+
pvp_colors, ctl_colors, x_log_scale=True)
|
238 |
+
null_plot.align = "center"
|
239 |
+
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
240 |
+
null_tabs.append(Panel(child=null_plot, title="MNLI (log scale)"))
|
241 |
+
|
242 |
+
null_all_figures = Tabs(tabs=null_tabs)
|
243 |
+
null_all_figures.align = "center"
|
244 |
+
|
245 |
+
########################################################################################################################
|
246 |
+
# Patterns diagram
|
247 |
+
########################################################################################################################
|
248 |
+
|
249 |
+
pattern_tabs = []
|
250 |
+
pattern_all_figures = Tabs(tabs=pattern_tabs)
|
251 |
+
|
252 |
+
for task in tasks:
|
253 |
+
pattern_plot = pattern_graph(task)
|
254 |
+
pattern_plot.align = "center"
|
255 |
+
pattern_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
256 |
+
pattern_tabs.append(Panel(child=pattern_plot, title=task))
|
257 |
+
|
258 |
+
pattern_all_figures = Tabs(tabs=pattern_tabs)
|
259 |
+
pattern_all_figures.align = "center"
|
260 |
+
|
261 |
+
########################################################################################################################
|
262 |
+
# Add write-up text
|
263 |
+
########################################################################################################################
|
264 |
+
|
265 |
+
main_text_style = {
|
266 |
+
"min-height": "100px",
|
267 |
+
"overflow": "hidden",
|
268 |
+
"display": "block",
|
269 |
+
"margin": "auto",
|
270 |
+
"width": f"{text_width}px",
|
271 |
+
"font-size": "18px",
|
272 |
+
}
|
273 |
+
|
274 |
+
textbox1 = Div(text=text1, style=main_text_style)
|
275 |
+
textbox2 = Div(text=text2, style=main_text_style)
|
276 |
+
textbox3 = Div(text=text3, style=main_text_style)
|
277 |
+
textbox4 = Div(text=text4, style=main_text_style)
|
278 |
+
textbox5 = Div(text=text5, style=main_text_style)
|
279 |
+
textbox1.align = "center"
|
280 |
+
textbox2.align = "center"
|
281 |
+
textbox3.align = "center"
|
282 |
+
textbox4.align = "center"
|
283 |
+
textbox5.align = "center"
|
284 |
+
|
285 |
+
########################################################################################################################
|
286 |
+
# Set up layouts and add to document
|
287 |
+
########################################################################################################################
|
288 |
+
|
289 |
+
main_body = column(textbox1, patternification, textbox2, advantage_column, textbox3, null_all_figures, textbox4, pattern_all_figures,
|
290 |
+
textbox5, sizing_mode="scale_width")
|
291 |
+
main_body.align = "center"
|
292 |
+
|
293 |
+
curdoc().add_root(main_body)
|
294 |
+
curdoc().title = "How many data points is a prompt worth ?"
|
naacl_demo/text.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pre-trained language models, fine-tuned with task-specific heads, are the backbone of applied NLP, and bigger and bigger language models are coming. With this in mind, alternative methods are emerging to compete with the classifier heads used in BERT, UniLM and GPT. In particular, GPT-3 has popularized prompts, natural language inputs designed to steer the pre-trained language model itself into solving the task, rather than a classifier built on top of it.
|
2 |
+
|
3 |
+
Prompts are interesting because they allow a practitioner to give information to the model, although in a very different fashion from standard ML supervision. In our NAACL 2021 paper, we investigate prompt-based fine-tuning, a promising alternative fine-tuning approach, and find that prompts often yield an edge over the standard approach. As we interpret a prompt as additional human-crafted information for the model, we measure that edge in terms of data points and quantify: **how many data points is a prompt worth?**
|
4 |
+
|
5 |
+
## Prompting
|
6 |
+
|
7 |
+
In order to adapt pre-trained language models to a task, the main method is to replace the final token prediction layer of the original model with a randomly initialized linear classifier head. Supervised task data is then used to train the modified model via backpropagation, learning weights for this new head but also modifying weights deeper in the model. In this work, we call this a _head_ model.
|
8 |
+
|
9 |
+
A competing approach is _prompting_: a broad class of methods that attempt to use the initial language model to answer the task by predicting words correlated with the classes instead of a class label. This allows them to perform classification while preserving the language model functionality. For this, _prompts_ are used: input sequences designed to produce the desired answer as textual output.
|
10 |
+
|
11 |
+
Although this may sound abstract, this is a very natural way to reason about text for humans in practice: school exercises, for example, tend to be presented as a text input (for example, an article about Mars) and a question ("Is there life on Mars?") with an expected answer in natural text ("No"<sup>1</sup>) that maps to one of the classes of the task (presumably here, "No" to `False` and "Yes" to `True`). In this paradigm, task-specific data is presented to the model much like a grammar exercise where a student would need to fill in blanks in a fixed way over a list of sequences. Prompting attempts to use the pre-training information contained in the language model explicitly, rather than implicitly through hidden representations that get fed into the linear classifier head.
|
12 |
+
|
13 |
+
|
14 |
+
Here's an example for SuperGLUE task BoolQ, which provides a text <span style="color: #0c593d">passage</span> and a <span style="color: #031154">question</span> and expects a boolean yes-or-no answer. This data is combined with a <span style="color: #910713">**pattern**</span> into a sequence with a single <span style="color: #ba9004">**masked token**</span> that the model must predict. This prediction is turned into a classification prediction with a pre-set *verbalizer*, a mapping between tokens and classes: the model probabilities on this token for *yes* and *no* are compared, with the final prediction being `True` if *yes* dominates and `False` if *no* does.
|
15 |
+
|
16 |
+
![image](mockups/boolqpatterns.png)
|
17 |
+
|
18 |
+
## Fine-tuning
|
19 |
+
|
20 |
+
With this, we have turned our general language model into a task-specific classifier. These language model classifiers based on prompts have been used in very diverse ways:
|
21 |
+
|
22 |
+
- The preserved language modeling functionality from the pre-trained model allows them to perform without additional data, as opposed to linear classifier _heads_ that are initialized from scratch and always start at random performance. A variety of papers have used this for zero-shot classification.
|
23 |
+
- In order to incorporate supervised task data, they can use backpropagation with the usual language modeling cross-entropy loss objective: the verbalizer token associated with the correct class then serves as the correct token prediction. This is a component of PET, and is the objective used by T5 - although T5 uses prefixes to indicate the task rather than describing it with a natural-language prompt.
|
24 |
+
- They can also use _priming_, where the sequence that needs to be filled in is prefixed with a list of correctly-filled examples. No backpropagation is used, and the weights of the language model are never modified: instead, it can attend to correct examples at inference time. This is the method used by GPT3.
|
25 |
+
- Finally, PET uses prompt models to pseudo-label unlabeled data that is then fed to a linear head model.
|
26 |
+
|
27 |
+
In this paper, our goal is to present the fairest comparison possible with head models, so we fine-tune with backpropagation.
|
28 |
+
|
29 |
+
## How many data points is a prompt worth?
|
30 |
+
|
31 |
+
As we have seen, both heads and prompting can be used in a task specific supervised setting. The core difference is that the prompted model is given a specific sentence that roughly describes the task in addition to supervised examples. In some sense, this sentence is supervision as it tells the model about the task, but it is qualitatively a very different form of supervision than is standard in ML. How should we think about this supervision? How do we quantify how “zero-shot” this setup really is?
|
32 |
+
|
33 |
+
We do this by comparing the _head_ and _prompt_ setups on the SuperGLUE tasks and MNLI. For each task, we extract subsets of the dataset of growing size, and repeat fine-tuning on `RoBERTa-large` with both methods on every subset, keeping everything else the same. For fairness, we tune the hyperparameters on the head baseline until they've attained the level of performance of the BERT++ baseline from the SuperGLUE leaderboard, and keep them the same for the _prompt_ model.
|
34 |
+
|
35 |
+
The curves of final performance (on each task's metric) vs dataset size are plotted below for each task <sup>2</sup>. They allow us to contrast the amount of data required to attain a certain level of performance with both setups on a given task. We call this difference the _data advantage_ of a training setup over the other at that level of performance. We call the range of performance that has been attained by both models the _comparison window_. By integrating over it we get the _average data advantage_ of a method over the other on the task. Graphically, that is simply the area between the curves, divided by the height of the comparison window. <sup>3</sup>
|
36 |
+
|
37 |
+
![image](mockups/advantage.png)
|
38 |
+
|
39 |
+
Here's a recapitulative table of the average data advantage of the prompt model over the head model per task, with error bounds obtained by a bootstrapping approach where we hold out one of the 4 head runs and 4 prompt runs (16 combinations total for every data size), and compute the standard deviation of those outcomes. Results are very different from task to task; they even vary for the same task on different dataset, for example for MNLI and RTE, both entailment tasks. However, on every task but WiC <sup>4</sup>, the prompt method has a significant edge. **The additional information provided by the prompt is consistently equivalent to hundreds of data points**.
|
40 |
+
|
41 |
+
| | MNLI | BoolQ | CB | COPA | MultiRC<sup>5</sup> | RTE | WiC | WSC |
|
42 |
+
|----------------|----------|--------|------|---------|----------|--------|---------|---------|
|
43 |
+
| Prompt vs Head | 3506±536 | 752±46 | 90±2 | 288±242 | 384±378 | 282±34 | -424±74 | 281±137 |
|
44 |
+
|
45 |
+
|
46 |
+
## Patterns and verbalizers
|
47 |
+
|
48 |
+
#### Control verbalizers
|
49 |
+
|
50 |
+
Prompting has for now mostly been used as a tool for zero-shot classification, which is a natural use case. However, zero-shot is usually tricky and requires perfectly aligning the prompt and verbalizer. We have already shown that prompting could be applied more generally, including in the full-data regime. In order to contrast the zero-shot and adaptive natures of prompts, we consider a _null verbalizer_, a control with a verbalizer that is completely decorrelated from the task. For tasks that only require filling in one token (thus excluding the more free-form COPA and WSC), we replace the verbalizers, for example, "yes", "no", "maybe", "right" or "wrong", with random first names. This makes the model unusable without training data, much like a head model. We plot the corresponding curves and perform the same advantage analysis below:
|
51 |
+
|
52 |
+
![image](mockups/nullverbalizer.png)
|
53 |
+
|
54 |
+
| | MNLI | BoolQ | CB | MultiRC<sup>4</sup> | RTE | WiC |
|
55 |
+
|----------------|----------|--------|------|----------|--------|---------|
|
56 |
+
| Prompt vs Head | 3506±536 | 752±46 | 90±2 | 384±378 | 282±34 | -424±74 |
|
57 |
+
| Prompt vs Null | 150±252 | 299±81 | 78±2 | 74±56 | 404±68 | -354±166 |
|
58 |
+
| Null vs Head | 3355±612 | 453±90 | 12±1 | 309±320 | -122±62 | -70±160 |
|
59 |
+
|
60 |
+
Results are noisier than for the straight prompt vs head comparison; however, we find that even with a null verbalizer, the language model is able to adapt to the task, generally catching up with the proper prompted model even with a few data points, and generally doing either on par with or better than the head model, showing the inductive bias of the prompt patterns is beneficial even without an informative verbalizer.
|
61 |
+
|
62 |
+
#### Influence of the pattern choice
|
63 |
+
|
64 |
+
Another choice that can make or break zero-shot classification is that of the pattern, and we investigate whether that still holds in our setting. In all of our experiments, we have re-used the pattern choices from PET - two or three quite different formulations per task - and repeated all of our prompt experiments with every pattern available on the task. We plot results below; they show that the choice of prompt does not have a significant influence, being always within random seed variance.
|
65 |
+
|
66 |
+
![image](mockups/prompts.png)
|
67 |
+
|
68 |
+
## Mot de la fin
|
69 |
+
|
70 |
+
In this work, we investigate alternate methods of fine-tuning based on natural language prompts, that aim to use the language modeling ability of pre-trained models explicitly through word predictions, instead of implicitly through linear classifiers based on the model's internal representations. We isolate the problem of fine-tuning prompt-based classifier language models with backpropagation, and find that they generally outperform standard fine-tuned linear classifiers. We estimate this advantage in terms of data point to measure the additional information provided by the human via the prompt, and find that **writing a prompt is consistently worth hundreds of data points**. Furthermore, this advantage holds even with non-informative target tokens and is fairly robust to the choice of prompt.
|
71 |
+
|
72 |
+
For practitioners, we believe that prompt-based fine-tuning should become a standard tool: especially for small- and middle-size task-specific datasets, designing a prompt yourself is a small effort for a sizable data advantage. For researchers, we believe that a lot of questions remain unexplored in this space: Why is the same prompt worth 3500 MNLI data points but only 282 RTE data points? How are prompts related to standard ML supervision? Do they react differently to adversarial or out-of domain examples, since they have some zero-shot behaviour?
|
73 |
+
|
74 |
+
<sup>1</sup>: Or at least not that we know of.
|
75 |
+
|
76 |
+
<sup>2</sup>: A sharp-eyed reader will have noticed that all those curves are monotonous. We've performed 4 runs for every experiment (i.e. every data size of every task for head and prompt models). For clarity, and because fine-tuning can sometimes fail for both methods, resulting in negative outliers, we report for every data size the maximum performance that has been attained at this data size or smaller, which we call the _accumulated maximum_ aggregate. This does not have a big impact on the reported data advantage besides reducing variance, and the graphical interpretation would still hold even with non-monotonous curves.
|
77 |
+
|
78 |
+
<sup>3</sup>: We treat each metric linearly to calculate advantage; alternatively, we could re-parameterize the y axis for each task. This choice does not have a consistent effect for or against prompting. For example, emphasizing gains close to convergence increases prompting advantage on CB and MNLI but decreases it on COPA or BoolQ.
|
79 |
+
|
80 |
+
<sup>4</sup>: where, interestingly, PET had already found prompting to be ineffective
|
81 |
+
|
82 |
+
<sup>5</sup>: The comparison window of MultiRC is too small as the head baseline fails to learn beyond majority class; we use the full region for a lower-bound result.
|
naacl_demo/text.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text1 = """<h1 id="how-big-should-my-language-model-be">How many data points is a prompt worth?</h1>
|
2 |
+
<img class='center' style='height: 5em; float: right;' src='https://raw.githubusercontent.com/TevenLeScao/transformer-xl/master/pytorch/assets/avatar_logo_joint.png' alt='avatar'>
|
3 |
+
<h4>Published on April 6, 2021.</h4>
|
4 |
+
<h4>Teven Le Scao, researcher at Hugging Face • <a href="https://twitter.com/Fluke_Ellington">@Fluke_Ellington</a> </h4>
|
5 |
+
<p>Pre-trained language models, fine-tuned with task-specific heads, are the backbone of applied NLP, and bigger and bigger language models are coming. With this in mind, alternative methods are emerging to compete with the classifier heads used in <a href="https://arxiv.org/abs/1810.04805">BERT</a>, <a href="https://arxiv.org/abs/1905.03197">UniLM</a> and <a href="https://openai.com/blog/language-unsupervised/">GPT</a>. In particular, GPT-3 has popularized prompts, natural language inputs designed to steer the pre-trained language model itself into solving the task, rather than a classifier built on top of it. </p>
|
6 |
+
<p>Prompts are interesting because they allow a practitioner to give information to the model, although in a very different fashion from standard ML supervision. In our NAACL 2021 <a href="https://arxiv.org/abs/2103.08493">paper</a> with <a href="http://rush-nlp.com/">Sasha Rush</a>, we investigate prompt-based fine-tuning, a promising alternative fine-tuning approach, and find that prompts often yield an edge over the standard approach. As we interpret a prompt as additional human-crafted information for the model, we measure that edge in terms of data points and quantify: <strong>how many data points is a prompt worth?</strong> </p>
|
7 |
+
<h2 id="prompting">Prompting</h2>
|
8 |
+
<p>In order to adapt pre-trained language models to a task, the main method is to replace the final token prediction layer of the original model with a randomly initialized linear classifier head. Supervised task data is then used to train the modified model via backpropagation, learning weights for this new head but also modifying weights deeper in the model. In this work, we call this a <em>head</em> model. </p>
|
9 |
+
<p>A competing approach is <em>prompting</em>: a broad class of methods that attempt to use the initial language model to answer the task by predicting words correlated with the classes instead of a class label. This allows them to perform classification while preserving the language model functionality. For this, <em>prompts</em> are used: input sequences designed to produce the desired answer as textual output. </p>
|
10 |
+
<p id="footnote1back">Although this may sound abstract, this is a very natural way to reason about text for humans in practice: school exercises, for example, tend to be presented as a text input (for example, an article about Mars) and a question ("Is there life on Mars?") with an expected answer in natural text ("No"<a href="#footnote1"><sup>1</sup></a>) that maps to one of the classes of the task (presumably here, "No" to <code>False</code> and "Yes" to <code>True</code>). In this paradigm, task-specific data is presented to the model much like a grammar exercise where a student would need to fill in blanks in a fixed way over a list of sequences. Prompting attempts to use the pre-training information contained in the language model explicitly, rather than implicitly through hidden representations that get fed into the linear classifier head. </p>
|
11 |
+
<p>Here's an example for <a href="https://arxiv.org/abs/1905.00537">SuperGLUE</a> task <a href="https://arxiv.org/abs/1905.10044">BoolQ</a>, which provides a text <span style="color: #0c593d">passage</span> and a <span style="color: #031154">question</span> and expects a boolean yes-or-no answer. This data is combined with a <span style="color: #910713"><strong>pattern</strong></span> into a sequence with a single <span style="color: #ba9004"><strong>masked token</strong></span> that the model must predict. This prediction is turned into a classification prediction with a pre-set <em>verbalizer</em>, a mapping between tokens and classes: the model probabilities on this token for <em>yes</em> and <em>no</em> are compared, with the final prediction being <code>True</code> if <em>yes</em> dominates and <code>False</code> if <em>no</em> does.</p>
|
12 |
+
"""
|
13 |
+
|
14 |
+
text2 = """<h2 id="fine-tuning">Fine-tuning</h2>
|
15 |
+
<p>With this, we have turned our general language model into a task-specific classifier. These language model classifiers based on prompts have been used in very diverse ways: </p>
|
16 |
+
<ul>
|
17 |
+
<li>The preserved language modeling functionality from the pre-trained model allows them to perform without additional data, as opposed to linear classifier <em>heads</em> that are initialized from scratch and always start at random performance. A variety of papers have used this for <a href="https://arxiv.org/abs/1912.10165">zero-shot classification.</a> </li>
|
18 |
+
<li>In order to incorporate supervised task data, they can use backpropagation with the usual language modeling cross-entropy loss objective: the verbalizer token associated with the correct class then serves as the correct token prediction. This is a component of <a href="https://arxiv.org/abs/2001.07676">PET</a>, and is the objective used by <a href="https://arxiv.org/abs/1910.10683">T5</a> - although T5 uses prefixes to indicate the task rather than describing it with a natural-language prompt. </li>
|
19 |
+
<li>They can also use <em>priming</em>, where the sequence that needs to be filled in is prefixed with a list of correctly-filled examples. No backpropagation is used, and the weights of the language model are never modified: instead, it can attend to correct examples at inference time. This is the method used by <a href="https://arxiv.org/abs/2005.14165">GPT-3</a>. </li>
|
20 |
+
<li>Finally, PET uses prompt models to pseudo-label unlabeled data that is then fed to a linear head model. </li>
|
21 |
+
</ul>
|
22 |
+
<p>In this paper, our goal is to present the fairest comparison possible with head models, so we fine-tune with backpropagation.</p>
|
23 |
+
<h2 id="how-many-data-points-is-a-prompt-worth-">How many data points is a prompt worth?</h2>
|
24 |
+
<p>As we have seen, both heads and prompting can be used in a task specific supervised setting. The core difference is that the prompted model is given a specific sentence that roughly describes the task in addition to supervised examples. In some sense, this sentence is supervision as it tells the model about the task, but it is qualitatively a very different form of supervision than is standard in ML. How should we think about this supervision? How do we quantify how “zero-shot” this setup really is? </p>
|
25 |
+
<p>We do this by comparing the <em>head</em> and <em>prompt</em> setups on the SuperGLUE tasks and MNLI. For each task, we extract subsets of the dataset of growing size, and repeat fine-tuning on <a href="https://arxiv.org/abs/1907.11692"><code>RoBERTa-large</code></a> with both methods on every subset, keeping everything else the same. For fairness, we tune the hyperparameters on the head baseline until they've attained the level of performance of the BERT++ baseline from the SuperGLUE leaderboard, and keep them the same for the <em>prompt</em> model. </p>
|
26 |
+
<p id="footnote2back">The curves of final performance (on each task's metric) vs dataset size are plotted below for each task <a href="#footnote2"><sup>2</sup></a>. They allow us to contrast the amount of data required to attain a certain level of performance with both setups on a given task. We call this difference the <em>data advantage</em> of a training setup over the other at that level of performance. We call the range of performance that has been attained by both models the <em>comparison window</em>. By integrating over it we get the <em>average data advantage</em> of a method over the other on the task. Graphically, that is simply the area between the curves, divided by the height of the comparison window. <a href="#footnote3"><sup>3</sup></a></p>
|
27 |
+
"""
|
28 |
+
|
29 |
+
text3 = """<html>
|
30 |
+
<head>
|
31 |
+
<style>
|
32 |
+
table, th, td {
|
33 |
+
border: 1px solid black;
|
34 |
+
border-collapse: collapse;
|
35 |
+
}
|
36 |
+
.styled-table {
|
37 |
+
margin-left: auto;
|
38 |
+
margin-right: auto;
|
39 |
+
}
|
40 |
+
.styled-table {
|
41 |
+
border-collapse: collapse;
|
42 |
+
font-size: 1em;
|
43 |
+
font-family: sans-serif;
|
44 |
+
min-width: 400px;
|
45 |
+
box-shadow: 0 0 20px rgba(0, 0, 0, 0.15);
|
46 |
+
}
|
47 |
+
.styled-table thead tr {
|
48 |
+
background-color: #ffebcd;
|
49 |
+
color: #000000;
|
50 |
+
text-align: left;
|
51 |
+
}
|
52 |
+
.styled-table th,
|
53 |
+
.styled-table td {
|
54 |
+
padding: 6px 8px;
|
55 |
+
font-size: 13px;
|
56 |
+
}
|
57 |
+
.styled-table tbody tr {
|
58 |
+
border-bottom: 1px solid #dddddd;
|
59 |
+
}
|
60 |
+
|
61 |
+
.styled-table tbody tr:nth-of-type(even) {
|
62 |
+
background-color: #f3f3f3;
|
63 |
+
}
|
64 |
+
|
65 |
+
.styled-table tbody tr:last-of-type {
|
66 |
+
border-bottom: 2px solid #29004a;
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
</style>
|
72 |
+
</head>
|
73 |
+
<body>
|
74 |
+
<p id="footnote4back">Here's a recapitulative table of the average data advantage of the prompt model over the head model per task, with error bounds obtained by a bootstrapping approach where we hold out one of the 4 head runs and 4 prompt runs (16 combinations total for every data size), and compute the standard deviation of those outcomes. Results are very different from task to task; they even vary for the same task on different dataset, for example for MNLI and RTE, both entailment tasks. However, on every task but WiC <a href="#footnote4"><sup>4</sup></a>, the prompt method has a significant edge. <strong>The additional information provided by the prompt is consistently equivalent to hundreds of data points</strong>. </p>
|
75 |
+
<table id="footnote5back" class="styled-table">
|
76 |
+
<thead>
|
77 |
+
<tr>
|
78 |
+
<th></th>
|
79 |
+
<th><a href="https://arxiv.org/abs/1704.05426">MNLI</a></th>
|
80 |
+
<th><a href="https://arxiv.org/abs/1905.10044">BoolQ</a></th>
|
81 |
+
<th><a href="https://ojs.ub.uni-konstanz.de/sub/index.php/sub/article/view/601">CB</a></th>
|
82 |
+
<th><a href="https://people.ict.usc.edu/~gordon/publications/AAAI-SPRING11A.PDF">COPA</a></th>
|
83 |
+
<th><a href="https://www.aclweb.org/anthology/N18-1023/">MultiRC</a><sup><a href="#footnote5">5</a></sup></th>
|
84 |
+
<th><a href="https://link.springer.com/chapter/10.1007/978-94-024-0881-2_42">RTE</a></th>
|
85 |
+
<th><a href="https://arxiv.org/abs/1808.09121">WiC</a></th>
|
86 |
+
<th><a href="https://arxiv.org/abs/1808.09121">WSC</a></th>
|
87 |
+
</tr>
|
88 |
+
</thead>
|
89 |
+
<tbody>
|
90 |
+
<tr>
|
91 |
+
<td>Prompt vs Head</td>
|
92 |
+
<td>3506±536</td>
|
93 |
+
<td>752±46</td>
|
94 |
+
<td>90±2</td>
|
95 |
+
<td>288±242</td>
|
96 |
+
<td>384±378</td>
|
97 |
+
<td>282±34</td>
|
98 |
+
<td>-424±74</td>
|
99 |
+
<td>281±137</td>
|
100 |
+
</tr>
|
101 |
+
</tbody>
|
102 |
+
</table>
|
103 |
+
<h2 id="patterns-and-verbalizers">Patterns and verbalizers</h2>
|
104 |
+
<h4 id="control-verbalizers">Control verbalizers</h4>
|
105 |
+
<p>Prompting has for now mostly been used as a tool for zero-shot classification, which is a natural use case. However, zero-shot is usually tricky and requires perfectly aligning the prompt and verbalizer. We have already shown that prompting could be applied more generally, including in the full-data regime. In order to contrast the zero-shot and adaptive natures of prompts, we consider a <em>null verbalizer</em>, a control with a verbalizer that is completely decorrelated from the task. For tasks that only require filling in one token (thus excluding the more free-form COPA and WSC), we replace the verbalizers, for example, "yes", "no", "maybe", "right" or "wrong", with random first names. This makes the model unusable without training data, much like a head model. We plot the corresponding curves and perform the same advantage analysis below:</p>
|
106 |
+
</body>
|
107 |
+
</html>
|
108 |
+
"""
|
109 |
+
|
110 |
+
text4 = """<table id="footnote6back" class="styled-table">
|
111 |
+
<thead>
|
112 |
+
<tr>
|
113 |
+
<th></th>
|
114 |
+
<th>MNLI</th>
|
115 |
+
<th>BoolQ</th>
|
116 |
+
<th>CB</th>
|
117 |
+
<th>MultiRC<a href="#footnote5"><sup>6</sup></a></th>
|
118 |
+
<th>RTE</th>
|
119 |
+
<th>WiC</th>
|
120 |
+
</tr>
|
121 |
+
</thead>
|
122 |
+
<tbody>
|
123 |
+
<tr>
|
124 |
+
<td>Prompt vs Head</td>
|
125 |
+
<td>3506±536</td>
|
126 |
+
<td>752±46</td>
|
127 |
+
<td>90±2</td>
|
128 |
+
<td>384±378</td>
|
129 |
+
<td>282±34</td>
|
130 |
+
<td>-424±74</td>
|
131 |
+
</tr>
|
132 |
+
<tr>
|
133 |
+
<td>Prompt vs Null</td>
|
134 |
+
<td>150±252</td>
|
135 |
+
<td>299±81</td>
|
136 |
+
<td>78±2</td>
|
137 |
+
<td>74±56</td>
|
138 |
+
<td>404±68</td>
|
139 |
+
<td>-354±166</td>
|
140 |
+
</tr>
|
141 |
+
<tr>
|
142 |
+
<td>Null vs Head</td>
|
143 |
+
<td>3355±612</td>
|
144 |
+
<td>453±90</td>
|
145 |
+
<td>12±1</td>
|
146 |
+
<td>309±320</td>
|
147 |
+
<td>-122±62</td>
|
148 |
+
<td>-70±160</td>
|
149 |
+
</tr>
|
150 |
+
</tbody>
|
151 |
+
</table>
|
152 |
+
<p>Results are noisier than for the straight prompt vs head comparison; however, we find that even with a null verbalizer, the language model is able to adapt to the task, generally catching up with the proper prompted model even with a few data points, and generally doing either on par with or better than the head model, showing the inductive bias of the prompt patterns is beneficial even without an informative verbalizer. </p>
|
153 |
+
<h4 id="influence-of-the-pattern-choice">Influence of the pattern choice</h4>
|
154 |
+
<p>Another choice that can make or break zero-shot classification is that of the pattern, and we investigate whether that still holds in our setting. In all of our experiments, we have re-used the pattern choices from PET - two or three quite different formulations per task - and repeated all of our prompt experiments with every pattern available on the task. We plot the median, maximum and minimum performance over the 4 runs for each pattern below; they show that the choice of prompt does not generally have a significant influence, with only the few-shot settings of BoolQ and WiC seeing a pattern consistently above the others. </p>
|
155 |
+
"""
|
156 |
+
|
157 |
+
text5 = """<h2 id="mot-de-la-fin">Mot de la fin</h2>
|
158 |
+
<p>In this work, we investigate alternate methods of fine-tuning based on natural language prompts, that aim to use the language modeling ability of pre-trained models explicitly through word predictions, instead of implicitly through linear classifiers based on the model's internal representations. We isolate the problem of fine-tuning prompt-based classifier language models with backpropagation, and find that they generally outperform standard fine-tuned linear classifiers. We estimate this advantage in terms of data point to measure the additional information provided by the human via the prompt, and find that <strong>writing a prompt is consistently worth hundreds of data points</strong>. Furthermore, this advantage holds even with non-informative target tokens and is fairly robust to the choice of prompt. </p>
|
159 |
+
<p>For practitioners, we believe that prompt-based fine-tuning should become a standard tool: especially for small- and middle-size task-specific datasets, designing a prompt yourself is a small effort for a sizable data advantage. For researchers, we believe that a lot of questions remain unexplored in this space: Why is the same prompt worth 3500 MNLI data points but only 282 RTE data points? How are prompts related to standard ML supervision? Do they react differently to adversarial or out-of domain examples, since they have some zero-shot behaviour?</p>
|
160 |
+
<p id="footnote1"><sup><a href="#footnote1back">1</a></sup>: Or at least not that we know of.</p>
|
161 |
+
<p id="footnote2"><sup><a href="#footnote2back">2</a></sup>: A sharp-eyed reader will have noticed that all those curves are monotonous. We've performed 4 runs for every experiment (i.e. every data size of every task for head and prompt models). For clarity, and because fine-tuning can sometimes fail for both methods, resulting in negative outliers, we report for every data size the maximum performance that has been attained at this data size or smaller, which we call the <em>accumulated maximum</em> aggregate. This does not have a big impact on the reported data advantage besides reducing variance, and the graphical interpretation would still hold even with non-monotonous curves. </p>
|
162 |
+
<p id="footnote3"><sup><a href="#footnote2back">3</a></sup>: We treat each metric linearly to calculate advantage; alternatively, we could re-parameterize the y axis for each task. This choice does not have a consistent effect for or against prompting. For example, emphasizing gains close to convergence increases prompting advantage on CB and MNLI but decreases it on COPA or BoolQ. </p>
|
163 |
+
<p id="footnote4"><sup><a href="#footnote4back">4</a></sup>: where, interestingly, PET had already found prompting to be ineffective</p>
|
164 |
+
<p id="footnote5"><sup><a href="#footnote5back">5</a> <a href="#footnote6back">6</a></sup>: The comparison window of MultiRC is too small as the head baseline fails to learn beyond majority class; we use the full region for a lower-bound result.</p>
|
165 |
+
"""
|
166 |
+
|
167 |
+
initial_passage = "In informal games, it is customary to announce 'check' when making a move that puts the opponent's king in check. In formal competitions, however, check is rarely announced."
|
168 |
+
|
169 |
+
initial_question = "do you always have to say check in chess?"
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bokeh==2.3.0
|
2 |
+
cycler==0.10.0
|
3 |
+
Jinja2==2.11.2
|
4 |
+
kiwisolver==1.3.1
|
5 |
+
MarkupSafe==1.1.1
|
6 |
+
matplotlib==3.4.1
|
7 |
+
numpy==1.18.4
|
8 |
+
packaging==20.4
|
9 |
+
pandas==1.0.3
|
10 |
+
Pillow==7.1.2
|
11 |
+
pyparsing==2.4.7
|
12 |
+
python-dateutil==2.8.1
|
13 |
+
pytz==2020.1
|
14 |
+
PyYAML==5.3.1
|
15 |
+
randomcolor==0.4.4.5
|
16 |
+
scipy==1.4.1
|
17 |
+
seaborn==0.11.1
|
18 |
+
Shapely==1.7.1
|
19 |
+
six==1.15.0
|
20 |
+
tornado==6.0.4
|
21 |
+
typing-extensions==3.7.4.2
|
22 |
+
virtualenv-clone==0.5.4
|