Spaces:
Runtime error
Runtime error
<ADD> +app
Browse files- app.py +174 -0
- common.py +20 -0
- example_sets/sst2/demos.txt +6 -0
- example_sets/sst2/rawdiff.pkl +0 -0
- example_sets/sst2/sample.pkl +10 -0
- models/__init__.py +1 -0
- models/huggingface.py +41 -0
- models/meta_optimizer.py +78 -0
- tasks/__init__.py +11 -0
- tasks/base.py +58 -0
- tasks/loader.py +96 -0
- tasks/sst2.py +43 -0
- utils/__init__.py +0 -0
app.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
from common import setup_cpu
|
11 |
+
from models import build_tokenizer, build_model
|
12 |
+
from models.meta_optimizer import AttnOptimWrapper
|
13 |
+
from tasks import load_task
|
14 |
+
from tasks.loader import TokenizedForMCRightPad
|
15 |
+
|
16 |
+
DISPLAY_MAPPING = {
|
17 |
+
"sst2": {"positive": "Pos", "negative": "Neg"},
|
18 |
+
"trec": {},
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def do_infer_probs(model, exemplar_attn_kv, exemplar_attn_mask, batched_choices_input):
|
24 |
+
batched_choices_logprobs = []
|
25 |
+
for batched_one_choice_input in batched_choices_input:
|
26 |
+
batch_input_ids, batch_attention_mask, batch_choice_start, batch_choice_end = batched_one_choice_input
|
27 |
+
bs = len(batch_input_ids)
|
28 |
+
|
29 |
+
merged_attn_mask = torch.cat((exemplar_attn_mask.expand(bs, -1), batch_attention_mask), dim=1)
|
30 |
+
# [B, #Heads, Length, Hidden]
|
31 |
+
expand_exemplar_attn_kv = [[layer_k.expand((bs, -1, -1, -1)), layer_v.expand((bs, -1, -1, -1))] for layer_k, layer_v in exemplar_attn_kv]
|
32 |
+
|
33 |
+
batched_logits = model(
|
34 |
+
input_ids=batch_input_ids, # [B, L']
|
35 |
+
attention_mask=merged_attn_mask, # [B, L + L']
|
36 |
+
past_key_values=expand_exemplar_attn_kv, # num_layers * 2 * [B, num_heads, L, H]
|
37 |
+
).logits
|
38 |
+
batched_output = F.log_softmax(batched_logits, dim=-1) # [B, L', Vocab]
|
39 |
+
|
40 |
+
batched_one_choice_logprobs = []
|
41 |
+
for input_ids, choice_start, choice_end, lm_logprobs in zip(batch_input_ids, batch_choice_start, batch_choice_end, batched_output):
|
42 |
+
choice_tokens = input_ids[choice_start:choice_end].unsqueeze(1) # [L, 1]
|
43 |
+
choice_logprobs = lm_logprobs[choice_start - 1 : choice_end - 1] # [L, Vocab]
|
44 |
+
|
45 |
+
extracted = torch.gather(choice_logprobs, -1, choice_tokens).squeeze(-1)
|
46 |
+
|
47 |
+
choice_length = choice_end - choice_start
|
48 |
+
lm_log_p = torch.sum(extracted).item()
|
49 |
+
norm_lm_log_p = (lm_log_p / choice_length).item()
|
50 |
+
|
51 |
+
choice_lm_info = {"lm_log_p": lm_log_p, "norm_lm_log_p": norm_lm_log_p}
|
52 |
+
batched_one_choice_logprobs.append(choice_lm_info)
|
53 |
+
batched_choices_logprobs.append(batched_one_choice_logprobs)
|
54 |
+
return batched_choices_logprobs
|
55 |
+
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def process_once(dataset_name, exemplar_str, forward_steps, raw_data):
|
59 |
+
model_name, model_size = "opt", "125m"
|
60 |
+
step_size, momentum = 0.01, 0.9
|
61 |
+
|
62 |
+
setup_cpu(seed=seed)
|
63 |
+
TaskHandler = load_task(dataset_name)
|
64 |
+
task_agent = TaskHandler(prompt_version)
|
65 |
+
|
66 |
+
tokenizer = build_tokenizer(model_name, model_size, padding_side="right")
|
67 |
+
model = build_model(model_name, model_size, False)
|
68 |
+
torch.autograd.set_grad_enabled(False)
|
69 |
+
|
70 |
+
processed_data = task_agent.dataset_preprocess(raw_data)
|
71 |
+
dataset = TokenizedForMCRightPad(processed_data, tokenizer, task_agent.multiple_choice_promptify)
|
72 |
+
|
73 |
+
exemplar_input_ids, exemplar_attn_mask = dataset.tokenize_demonstration(exemplar_str)
|
74 |
+
loader = DataLoader(dataset, shuffle=False, drop_last=False, batch_size=1)
|
75 |
+
meta_optim = AttnOptimWrapper(model, model_name, step_size=step_size, momentum=momentum)
|
76 |
+
meta_optim.init()
|
77 |
+
|
78 |
+
for _ in range(forward_steps):
|
79 |
+
exemplar_kv = meta_optim.step(exemplar_input_ids)
|
80 |
+
|
81 |
+
generated_info = [] # question * [choice0_prob, choice1_prob]
|
82 |
+
for batch_input in loader:
|
83 |
+
batch_output = do_infer_probs(model, exemplar_kv, exemplar_attn_mask.unsqueeze(0), batch_input) # [batch_of_choice0, batch_of_choice1, ...]
|
84 |
+
zipped_logprobs = list(zip(*batch_output)) # batch * (choice0, choice1, ...)
|
85 |
+
generated_info.extend(zipped_logprobs)
|
86 |
+
|
87 |
+
all_predicted = []
|
88 |
+
for idx, (data, choice_info) in enumerate(zip(processed_data, generated_info)):
|
89 |
+
merged_choice_info = task_agent.merge_choice_info(choice_info)
|
90 |
+
merged_predictions_idx = task_agent.choice_info_to_predictions(merged_choice_info)["lm_log_p"]
|
91 |
+
predicted = task_agent.CHOICES[merged_predictions_idx]
|
92 |
+
ground_truth = task_agent.CHOICES[data["answer_idx"]]
|
93 |
+
res = f"{DISPLAY_MAPPING[dataset_name][predicted]}{'✅' if predicted == ground_truth else '❌'}"
|
94 |
+
all_predicted.append(res)
|
95 |
+
return all_predicted
|
96 |
+
|
97 |
+
|
98 |
+
def transpose(l):
|
99 |
+
return list(map(list, zip(*l)))
|
100 |
+
|
101 |
+
|
102 |
+
def button_pressed(prev_state):
|
103 |
+
dataset_name = prev_state["dataset_name"]
|
104 |
+
exemplar_str = prev_state["exemplar_str"]
|
105 |
+
forward_steps = prev_state["step"] + 2
|
106 |
+
raw_data = prev_state["raw_data"]
|
107 |
+
prev_table_data = prev_state["table_data"]
|
108 |
+
|
109 |
+
current_output = process_once(dataset_name, exemplar_str, forward_steps, raw_data)
|
110 |
+
|
111 |
+
t_prev = transpose(prev_table_data)
|
112 |
+
t_prev.append([f"T={forward_steps}"] + current_output)
|
113 |
+
updated_table_data = transpose(t_prev)
|
114 |
+
|
115 |
+
ret = [
|
116 |
+
{
|
117 |
+
"dataset_name": dataset_name,
|
118 |
+
"exemplar_str": exemplar_str,
|
119 |
+
"raw_data": raw_data,
|
120 |
+
"step": forward_steps,
|
121 |
+
"table_data": updated_table_data,
|
122 |
+
},
|
123 |
+
f"Step + 2, Now: {forward_steps}",
|
124 |
+
updated_table_data,
|
125 |
+
]
|
126 |
+
return ret
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
dataset_name = "sst2"
|
131 |
+
seed = 0
|
132 |
+
prompt_version = "default"
|
133 |
+
kv_iter = 10
|
134 |
+
|
135 |
+
print(f"Dataset: {dataset_name}")
|
136 |
+
task_root = Path("example_sets").joinpath(dataset_name)
|
137 |
+
|
138 |
+
with task_root.joinpath("demos.txt").open("r") as f:
|
139 |
+
demos = f.read()
|
140 |
+
with task_root.joinpath("sample.pkl").open("r") as f:
|
141 |
+
data = json.load(f)
|
142 |
+
raw_data = [data[str(i)] for i in range(len(data))]
|
143 |
+
|
144 |
+
css = """ #the-table > div > div > div > table > thead {display: none}"""
|
145 |
+
|
146 |
+
title = "🤔 Iterative Forward Tuning Boosts In-context Learning in Language Models"
|
147 |
+
demo = gr.Blocks(css=css, title="🤔Deep-Thinking")
|
148 |
+
with demo:
|
149 |
+
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>")
|
150 |
+
with gr.Tab("SST-2"):
|
151 |
+
mapping = ["negative", "positive"]
|
152 |
+
|
153 |
+
init_columns = [[e["sentence"], f"*{DISPLAY_MAPPING['sst2'][mapping[e['label']]]}*"] for e in raw_data]
|
154 |
+
state = gr.State(
|
155 |
+
{
|
156 |
+
"dataset_name": "sst2",
|
157 |
+
"exemplar_str": demos,
|
158 |
+
"raw_data": raw_data,
|
159 |
+
"step": 0,
|
160 |
+
"table_data": [["**Test Input**", "**Golden**"], *init_columns],
|
161 |
+
}
|
162 |
+
)
|
163 |
+
|
164 |
+
prompt = gr.Textbox(label="Demonstrations (Prompt template formatted)", value=demos)
|
165 |
+
big_table = gr.DataFrame(
|
166 |
+
value=[["**Test Input**", "**Golden**"], *init_columns],
|
167 |
+
elem_id="the-table",
|
168 |
+
datatype=["markdown"] * 50,
|
169 |
+
headers=None,
|
170 |
+
)
|
171 |
+
step_button = gr.Button("Step + 2, Now: 0")
|
172 |
+
step_button.click(button_pressed, inputs=[state], outputs=[state, step_button, big_table])
|
173 |
+
|
174 |
+
demo.launch(server_name="0.0.0.0")
|
common.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from tasks import task_mapper
|
8 |
+
|
9 |
+
|
10 |
+
def setup_plain_seed(SEED):
|
11 |
+
os.environ["PYTHONHASHSEED"] = str(SEED)
|
12 |
+
random.seed(SEED)
|
13 |
+
np.random.seed(SEED)
|
14 |
+
|
15 |
+
|
16 |
+
def setup_cpu(seed):
|
17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
18 |
+
setup_plain_seed(seed)
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
torch.random.manual_seed(seed)
|
example_sets/sst2/demos.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Review: and mainly unfunny
|
2 |
+
Sentiment: negative
|
3 |
+
|
4 |
+
Review: filmmakers david weissman and bill weber benefit enormously from the cockettes ' camera craziness -- not only did they film performances , but
|
5 |
+
Sentiment: positive
|
6 |
+
|
example_sets/sst2/rawdiff.pkl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
example_sets/sst2/sample.pkl
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"0": { "sentence": "the cold turkey would 've been a far better title . ", "label": 0, "idx": 57 },
|
3 |
+
"1": { "sentence": "it 's a cookie-cutter movie , a cut-and-paste job . ", "label": 0, "idx": 28 },
|
4 |
+
"2": { "sentence": "a solid film ... but more conscientious than it is truly stirring . ", "label": 1, "idx": 143 },
|
5 |
+
"3": { "sentence": "it 's slow -- very , very slow . ", "label": 0, "idx": 4 },
|
6 |
+
"4": { "sentence": "filmmakers who can deftly change moods are treasures and even marvels . ", "label": 1, "idx": 679 },
|
7 |
+
"5": { "sentence": "it all adds up to good fun . ", "label": 1, "idx": 393 },
|
8 |
+
"6": { "sentence": "i am sorry that i was unable to get the full brunt of the comedy . ", "label": 0, "idx": 423 },
|
9 |
+
"7": { "sentence": "hilariously inept and ridiculous . ", "label": 1, "idx": 112 }
|
10 |
+
}
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .huggingface import build_model_signature, build_tokenizer, build_model
|
models/huggingface.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
2 |
+
|
3 |
+
|
4 |
+
def build_model_signature(model_type, model_size):
|
5 |
+
if model_type == "opt":
|
6 |
+
# ["125m", "350m", "1.3b", "2.7b", "6.7b", "13b", "30b", "66b"]
|
7 |
+
return f"facebook/opt-{model_size}"
|
8 |
+
if model_type == "gpt2":
|
9 |
+
# ["sm", "medium", "large", "xl"]
|
10 |
+
if model_size == "sm":
|
11 |
+
return "gpt2"
|
12 |
+
return f"gpt2-{model_size}"
|
13 |
+
if model_type == "e-gpt":
|
14 |
+
# ["neo-125M", "neo-1.3B", "neo-2.7B", "j-6B", "neox-20b"]
|
15 |
+
return f"EleutherAI/gpt-{model_size}"
|
16 |
+
if model_type == "bloom":
|
17 |
+
# ["560m", "1b1", "1b7", "3b", "7b1"]
|
18 |
+
return f"bigscience/bloom-{model_size}"
|
19 |
+
|
20 |
+
|
21 |
+
def build_tokenizer(model_type, model_size, padding_side="left", use_fast=False):
|
22 |
+
sign = build_model_signature(model_type, model_size)
|
23 |
+
if not use_fast:
|
24 |
+
tok = AutoTokenizer.from_pretrained(sign, padding_side=padding_side)
|
25 |
+
else:
|
26 |
+
tok = PreTrainedTokenizerFast.from_pretrained(sign, padding_side=padding_side)
|
27 |
+
if model_type in ["gpt2", "e-gpt"]:
|
28 |
+
tok.pad_token_id = tok.eos_token_id
|
29 |
+
tok.pad_token = tok.eos_token
|
30 |
+
return tok
|
31 |
+
|
32 |
+
|
33 |
+
def build_model(model_type, model_size, in_8bit):
|
34 |
+
sign = build_model_signature(model_type, model_size)
|
35 |
+
model = AutoModelForCausalLM.from_pretrained(
|
36 |
+
sign,
|
37 |
+
device_map="auto",
|
38 |
+
load_in_8bit=in_8bit,
|
39 |
+
)
|
40 |
+
model.eval()
|
41 |
+
return model
|
models/meta_optimizer.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class MomentumOptim:
|
5 |
+
def __init__(self, step_size=0.01, momentum=0.9):
|
6 |
+
self.step_size = step_size
|
7 |
+
self.momentum = momentum
|
8 |
+
self.m = None # velocity
|
9 |
+
|
10 |
+
def init(self):
|
11 |
+
self.m = None
|
12 |
+
|
13 |
+
def upd_m(self, old_m, g):
|
14 |
+
return g + self.momentum * old_m
|
15 |
+
|
16 |
+
def upd(self, old_x, m):
|
17 |
+
return old_x + self.step_size * m
|
18 |
+
|
19 |
+
def __call__(self, old_xs, new_xs):
|
20 |
+
pesudo_gs = [new_x - old_x for old_x, new_x in zip(old_xs, new_xs)]
|
21 |
+
|
22 |
+
if not self.m:
|
23 |
+
self.m = pesudo_gs
|
24 |
+
else:
|
25 |
+
self.m = [self.upd_m(old_m, g) for old_m, g in zip(self.m, pesudo_gs)]
|
26 |
+
|
27 |
+
updated_kv = [self.upd(old_x, m) for old_x, m in zip(old_xs, self.m)]
|
28 |
+
return updated_kv
|
29 |
+
|
30 |
+
|
31 |
+
class AttnOptimWrapper:
|
32 |
+
def __init__(self, llm, model_type, optimizer="momentum", **optimizer_args):
|
33 |
+
self.model = llm
|
34 |
+
self.kv = None
|
35 |
+
self.model_type = model_type
|
36 |
+
|
37 |
+
if optimizer == "momentum":
|
38 |
+
self.optim_k = MomentumOptim(**optimizer_args)
|
39 |
+
self.optim_v = MomentumOptim(**optimizer_args)
|
40 |
+
else:
|
41 |
+
raise ValueError()
|
42 |
+
|
43 |
+
def init(self):
|
44 |
+
self.optim_k.init()
|
45 |
+
self.optim_v.init()
|
46 |
+
|
47 |
+
@torch.no_grad()
|
48 |
+
def step(self, ctx_ids):
|
49 |
+
L = len(ctx_ids)
|
50 |
+
|
51 |
+
ctx_ids = ctx_ids.unsqueeze(0) # [1, L]
|
52 |
+
mask = torch.ones_like(ctx_ids)
|
53 |
+
if self.kv is not None:
|
54 |
+
mask = mask.repeat(1, 2) # [1, 2*L]
|
55 |
+
|
56 |
+
next_kv = self.model(
|
57 |
+
input_ids=ctx_ids,
|
58 |
+
attention_mask=mask,
|
59 |
+
past_key_values=self.kv,
|
60 |
+
use_cache=True,
|
61 |
+
).past_key_values # kv @ (old_ctx + new_ctx)
|
62 |
+
|
63 |
+
cur_kv = []
|
64 |
+
for layer_k, layer_v in next_kv:
|
65 |
+
# [B, num_head, 2*L, head_hidden]
|
66 |
+
cur_kv.append([layer_k[:, :, -L:, :], layer_v[:, :, -L:, :]]) # kv @ (new_ctx)
|
67 |
+
|
68 |
+
if not self.kv:
|
69 |
+
self.kv = cur_kv
|
70 |
+
else:
|
71 |
+
old_ks, old_vs = zip(*self.kv)
|
72 |
+
cur_ks, cur_vs = zip(*cur_kv)
|
73 |
+
|
74 |
+
upd_ks = self.optim_k(old_ks, cur_ks)
|
75 |
+
upd_vs = self.optim_v(old_vs, cur_vs)
|
76 |
+
self.kv = list(zip(upd_ks, upd_vs))
|
77 |
+
|
78 |
+
return self.kv
|
tasks/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sst2 import SST2ProbInferenceForMC
|
2 |
+
|
3 |
+
|
4 |
+
task_mapper = {"sst2": SST2ProbInferenceForMC}
|
5 |
+
|
6 |
+
|
7 |
+
def load_task(name):
|
8 |
+
if name not in task_mapper.keys():
|
9 |
+
raise ValueError(f"Unrecognized dataset `{name}`")
|
10 |
+
|
11 |
+
return task_mapper[name]
|
tasks/base.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class BaseProbInference:
|
5 |
+
def __init__(self, prompt_version):
|
6 |
+
if prompt_version == "default":
|
7 |
+
self.prompt_version = self.default_prompt_version()
|
8 |
+
else:
|
9 |
+
self.prompt_version = prompt_version
|
10 |
+
|
11 |
+
self.raw_data_result = None
|
12 |
+
self.raw_data_sample = None
|
13 |
+
self.raw_data_dev = None
|
14 |
+
|
15 |
+
self.can_be_stratified = False
|
16 |
+
self.CHOICES = None
|
17 |
+
self.num_base_shot = 1
|
18 |
+
|
19 |
+
def default_prompt_version(self):
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
def dataset_signature(self):
|
23 |
+
# {
|
24 |
+
# "result": (dataset_name, subset, split), # which produce the final result
|
25 |
+
# "sample": (dataset_name, subset, split), # which we sample ICL few-shot examples
|
26 |
+
# }
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
def dataset_part(self, part):
|
30 |
+
return self.dataset_signature()[part]
|
31 |
+
|
32 |
+
def dataset_preprocess(self, raw_data):
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
+
def handcrafted_exemplars(self):
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
+
def exemplar_seperator(self):
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
def multiple_choice_promptify(self, query, choice):
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def merge_choice_info(choice_info):
|
46 |
+
merged = {}
|
47 |
+
for k in ["lm_log_p", "norm_lm_log_p"]:
|
48 |
+
one_metric_merged = []
|
49 |
+
for info in choice_info:
|
50 |
+
one_metric_merged.append(info[k])
|
51 |
+
merged[k] = one_metric_merged
|
52 |
+
return merged
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def choice_info_to_predictions(info):
|
56 |
+
lm_log_p_idx = int(np.argmax(info["lm_log_p"]))
|
57 |
+
norm_lm_log_p_idx = int(np.argmax(info["norm_lm_log_p"]))
|
58 |
+
return {"lm_log_p": lm_log_p_idx, "norm_lm_log_p": norm_lm_log_p_idx}
|
tasks/loader.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from transformers import PreTrainedTokenizer
|
4 |
+
|
5 |
+
|
6 |
+
class TokenizedForMCRightPad(Dataset):
|
7 |
+
def __init__(self, data, tok: PreTrainedTokenizer, prompt_fn):
|
8 |
+
# data: [query: str, choices: list(str)]
|
9 |
+
self.tok = tok
|
10 |
+
self.prompt_fn = prompt_fn
|
11 |
+
self.max_length = self._find_max_length(data)
|
12 |
+
self.data = self._build_mc_data(data)
|
13 |
+
|
14 |
+
def _find_max_length(self, data):
|
15 |
+
max_len = 0
|
16 |
+
|
17 |
+
def tok_len(t):
|
18 |
+
return len(self.tok.encode(t))
|
19 |
+
|
20 |
+
for ex in data:
|
21 |
+
query = ex["query"]
|
22 |
+
len_choices = [tok_len(self.prompt_fn(query, c)[1]) for c in ex["choices"]]
|
23 |
+
max_len = max(max_len, *len_choices)
|
24 |
+
|
25 |
+
return max_len
|
26 |
+
|
27 |
+
def _build_mc_data(self, data):
|
28 |
+
processed = []
|
29 |
+
num_choices = set(len(e["choices"]) for e in data)
|
30 |
+
if not len(num_choices) == 1:
|
31 |
+
raise ValueError(f"Queries have different number of choices, which is not supported! #choices: {num_choices}")
|
32 |
+
for ex in data:
|
33 |
+
query, choices = ex["query"], ex["choices"]
|
34 |
+
processed_input = [self.prompt_fn(query, choice) for choice in choices]
|
35 |
+
processed_input = [self.tokenize(t_query, t_full) for t_query, t_full in processed_input]
|
36 |
+
processed.append(processed_input)
|
37 |
+
|
38 |
+
return processed
|
39 |
+
|
40 |
+
def tokenize_demonstration(self, demonstration):
|
41 |
+
e = self.tok(demonstration)
|
42 |
+
return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]) # no padding
|
43 |
+
|
44 |
+
def tokenize(self, only_query, full_text):
|
45 |
+
tok_only_query = self.tok(only_query, add_special_tokens=False)
|
46 |
+
tok_full_no_padding = self.tok(full_text, add_special_tokens=False)
|
47 |
+
tok_full = self.tok(
|
48 |
+
full_text,
|
49 |
+
padding="max_length",
|
50 |
+
max_length=self.max_length,
|
51 |
+
add_special_tokens=False,
|
52 |
+
) # <pad> is not a special token
|
53 |
+
# tok_only_query = self.tok(only_query)
|
54 |
+
# tok_full_no_padding = self.tok(full_text)
|
55 |
+
# tok_full = self.tok(
|
56 |
+
# full_text,
|
57 |
+
# padding="max_length",
|
58 |
+
# max_length=self.max_length,
|
59 |
+
# ) # <pad> is not a special token
|
60 |
+
|
61 |
+
# print(f"tok_only_query: {self.tok.convert_ids_to_tokens(tok_only_query.input_ids)}")
|
62 |
+
# print(f"tok_full_no_padding: {self.tok.convert_ids_to_tokens(tok_full_no_padding.input_ids)}")
|
63 |
+
# print(f"tok_full: {self.tok.convert_ids_to_tokens(tok_full.input_ids)}")
|
64 |
+
# exit(0)
|
65 |
+
|
66 |
+
len_full = len(tok_full_no_padding.input_ids)
|
67 |
+
len_query = len(tok_only_query.input_ids)
|
68 |
+
e = {
|
69 |
+
"input_ids": tok_full.input_ids,
|
70 |
+
"attention_mask": tok_full.attention_mask,
|
71 |
+
"choice_start": len_query,
|
72 |
+
"choice_end": len_full,
|
73 |
+
}
|
74 |
+
# print("Attn:")
|
75 |
+
# print(tok_full.attention_mask)
|
76 |
+
# print("input_ids:")
|
77 |
+
# print(tok_full.input_ids)
|
78 |
+
|
79 |
+
dcd_sp = self.tok.convert_ids_to_tokens(tok_full.input_ids, skip_special_tokens=False)
|
80 |
+
|
81 |
+
# print(f'{e["choice_start"]}: {e["choice_end"]} = [{self.tok.convert_tokens_to_string(dcd_sp[e["choice_start"] : e["choice_end"]])}]')
|
82 |
+
|
83 |
+
return e
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self.data)
|
87 |
+
|
88 |
+
def __getitem__(self, idx):
|
89 |
+
def _get_one_item(e):
|
90 |
+
return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]), e["choice_start"], e["choice_end"]
|
91 |
+
|
92 |
+
es = self.data[idx]
|
93 |
+
# num_choices * (input_ids, attn, start_idx, end_idx)
|
94 |
+
# input_ids, attn: [B, L]
|
95 |
+
# start_idx, end_idx: [B, ]
|
96 |
+
return [_get_one_item(e) for e in es]
|
tasks/sst2.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tasks.base import BaseProbInference
|
2 |
+
|
3 |
+
|
4 |
+
class SST2ProbInferenceForMC(BaseProbInference):
|
5 |
+
def __init__(self, prompt_version):
|
6 |
+
super().__init__(prompt_version)
|
7 |
+
|
8 |
+
self.CHOICES = ["negative", "positive"]
|
9 |
+
self.can_be_stratified = True
|
10 |
+
self.num_base_shot = len(self.CHOICES)
|
11 |
+
|
12 |
+
def default_prompt_version(self):
|
13 |
+
return "sp"
|
14 |
+
|
15 |
+
def dataset_signature(self):
|
16 |
+
return {
|
17 |
+
"result": ("glue", "sst2", "validation"),
|
18 |
+
"sample": ("glue", "sst2", "train"),
|
19 |
+
}
|
20 |
+
|
21 |
+
def dataset_preprocess(self, raw_data):
|
22 |
+
data = []
|
23 |
+
for e in raw_data:
|
24 |
+
# print(e, flush=True)
|
25 |
+
data.append({"query": e["sentence"].strip(), "choices": self.CHOICES, "answer_idx": e["label"]})
|
26 |
+
return data
|
27 |
+
|
28 |
+
def handcrafted_exemplars(self):
|
29 |
+
raise NotImplementedError
|
30 |
+
|
31 |
+
def exemplar_seperator(self):
|
32 |
+
if self.prompt_version.startswith("sp"):
|
33 |
+
return "\n\n"
|
34 |
+
else:
|
35 |
+
raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
|
36 |
+
|
37 |
+
def multiple_choice_promptify(self, query, choice):
|
38 |
+
if self.prompt_version.startswith("sp"):
|
39 |
+
with_query = f"Review: {query}\nSentiment:"
|
40 |
+
with_query_and_choice = f"{with_query} {choice}"
|
41 |
+
else:
|
42 |
+
raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
|
43 |
+
return with_query, with_query_and_choice
|
utils/__init__.py
ADDED
File without changes
|