jx-yang commited on
Commit
9d21d47
1 Parent(s): d235d9c

<ADD> +app

Browse files
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+
6
+ import torch
7
+ from torch.nn import functional as F
8
+ from torch.utils.data import DataLoader
9
+
10
+ from common import setup_cpu
11
+ from models import build_tokenizer, build_model
12
+ from models.meta_optimizer import AttnOptimWrapper
13
+ from tasks import load_task
14
+ from tasks.loader import TokenizedForMCRightPad
15
+
16
+ DISPLAY_MAPPING = {
17
+ "sst2": {"positive": "Pos", "negative": "Neg"},
18
+ "trec": {},
19
+ }
20
+
21
+
22
+ @torch.no_grad()
23
+ def do_infer_probs(model, exemplar_attn_kv, exemplar_attn_mask, batched_choices_input):
24
+ batched_choices_logprobs = []
25
+ for batched_one_choice_input in batched_choices_input:
26
+ batch_input_ids, batch_attention_mask, batch_choice_start, batch_choice_end = batched_one_choice_input
27
+ bs = len(batch_input_ids)
28
+
29
+ merged_attn_mask = torch.cat((exemplar_attn_mask.expand(bs, -1), batch_attention_mask), dim=1)
30
+ # [B, #Heads, Length, Hidden]
31
+ expand_exemplar_attn_kv = [[layer_k.expand((bs, -1, -1, -1)), layer_v.expand((bs, -1, -1, -1))] for layer_k, layer_v in exemplar_attn_kv]
32
+
33
+ batched_logits = model(
34
+ input_ids=batch_input_ids, # [B, L']
35
+ attention_mask=merged_attn_mask, # [B, L + L']
36
+ past_key_values=expand_exemplar_attn_kv, # num_layers * 2 * [B, num_heads, L, H]
37
+ ).logits
38
+ batched_output = F.log_softmax(batched_logits, dim=-1) # [B, L', Vocab]
39
+
40
+ batched_one_choice_logprobs = []
41
+ for input_ids, choice_start, choice_end, lm_logprobs in zip(batch_input_ids, batch_choice_start, batch_choice_end, batched_output):
42
+ choice_tokens = input_ids[choice_start:choice_end].unsqueeze(1) # [L, 1]
43
+ choice_logprobs = lm_logprobs[choice_start - 1 : choice_end - 1] # [L, Vocab]
44
+
45
+ extracted = torch.gather(choice_logprobs, -1, choice_tokens).squeeze(-1)
46
+
47
+ choice_length = choice_end - choice_start
48
+ lm_log_p = torch.sum(extracted).item()
49
+ norm_lm_log_p = (lm_log_p / choice_length).item()
50
+
51
+ choice_lm_info = {"lm_log_p": lm_log_p, "norm_lm_log_p": norm_lm_log_p}
52
+ batched_one_choice_logprobs.append(choice_lm_info)
53
+ batched_choices_logprobs.append(batched_one_choice_logprobs)
54
+ return batched_choices_logprobs
55
+
56
+
57
+ @torch.no_grad()
58
+ def process_once(dataset_name, exemplar_str, forward_steps, raw_data):
59
+ model_name, model_size = "opt", "125m"
60
+ step_size, momentum = 0.01, 0.9
61
+
62
+ setup_cpu(seed=seed)
63
+ TaskHandler = load_task(dataset_name)
64
+ task_agent = TaskHandler(prompt_version)
65
+
66
+ tokenizer = build_tokenizer(model_name, model_size, padding_side="right")
67
+ model = build_model(model_name, model_size, False)
68
+ torch.autograd.set_grad_enabled(False)
69
+
70
+ processed_data = task_agent.dataset_preprocess(raw_data)
71
+ dataset = TokenizedForMCRightPad(processed_data, tokenizer, task_agent.multiple_choice_promptify)
72
+
73
+ exemplar_input_ids, exemplar_attn_mask = dataset.tokenize_demonstration(exemplar_str)
74
+ loader = DataLoader(dataset, shuffle=False, drop_last=False, batch_size=1)
75
+ meta_optim = AttnOptimWrapper(model, model_name, step_size=step_size, momentum=momentum)
76
+ meta_optim.init()
77
+
78
+ for _ in range(forward_steps):
79
+ exemplar_kv = meta_optim.step(exemplar_input_ids)
80
+
81
+ generated_info = [] # question * [choice0_prob, choice1_prob]
82
+ for batch_input in loader:
83
+ batch_output = do_infer_probs(model, exemplar_kv, exemplar_attn_mask.unsqueeze(0), batch_input) # [batch_of_choice0, batch_of_choice1, ...]
84
+ zipped_logprobs = list(zip(*batch_output)) # batch * (choice0, choice1, ...)
85
+ generated_info.extend(zipped_logprobs)
86
+
87
+ all_predicted = []
88
+ for idx, (data, choice_info) in enumerate(zip(processed_data, generated_info)):
89
+ merged_choice_info = task_agent.merge_choice_info(choice_info)
90
+ merged_predictions_idx = task_agent.choice_info_to_predictions(merged_choice_info)["lm_log_p"]
91
+ predicted = task_agent.CHOICES[merged_predictions_idx]
92
+ ground_truth = task_agent.CHOICES[data["answer_idx"]]
93
+ res = f"{DISPLAY_MAPPING[dataset_name][predicted]}{'✅' if predicted == ground_truth else '❌'}"
94
+ all_predicted.append(res)
95
+ return all_predicted
96
+
97
+
98
+ def transpose(l):
99
+ return list(map(list, zip(*l)))
100
+
101
+
102
+ def button_pressed(prev_state):
103
+ dataset_name = prev_state["dataset_name"]
104
+ exemplar_str = prev_state["exemplar_str"]
105
+ forward_steps = prev_state["step"] + 2
106
+ raw_data = prev_state["raw_data"]
107
+ prev_table_data = prev_state["table_data"]
108
+
109
+ current_output = process_once(dataset_name, exemplar_str, forward_steps, raw_data)
110
+
111
+ t_prev = transpose(prev_table_data)
112
+ t_prev.append([f"T={forward_steps}"] + current_output)
113
+ updated_table_data = transpose(t_prev)
114
+
115
+ ret = [
116
+ {
117
+ "dataset_name": dataset_name,
118
+ "exemplar_str": exemplar_str,
119
+ "raw_data": raw_data,
120
+ "step": forward_steps,
121
+ "table_data": updated_table_data,
122
+ },
123
+ f"Step + 2, Now: {forward_steps}",
124
+ updated_table_data,
125
+ ]
126
+ return ret
127
+
128
+
129
+ if __name__ == "__main__":
130
+ dataset_name = "sst2"
131
+ seed = 0
132
+ prompt_version = "default"
133
+ kv_iter = 10
134
+
135
+ print(f"Dataset: {dataset_name}")
136
+ task_root = Path("example_sets").joinpath(dataset_name)
137
+
138
+ with task_root.joinpath("demos.txt").open("r") as f:
139
+ demos = f.read()
140
+ with task_root.joinpath("sample.pkl").open("r") as f:
141
+ data = json.load(f)
142
+ raw_data = [data[str(i)] for i in range(len(data))]
143
+
144
+ css = """ #the-table > div > div > div > table > thead {display: none}"""
145
+
146
+ title = "🤔 Iterative Forward Tuning Boosts In-context Learning in Language Models"
147
+ demo = gr.Blocks(css=css, title="🤔Deep-Thinking")
148
+ with demo:
149
+ gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>")
150
+ with gr.Tab("SST-2"):
151
+ mapping = ["negative", "positive"]
152
+
153
+ init_columns = [[e["sentence"], f"*{DISPLAY_MAPPING['sst2'][mapping[e['label']]]}*"] for e in raw_data]
154
+ state = gr.State(
155
+ {
156
+ "dataset_name": "sst2",
157
+ "exemplar_str": demos,
158
+ "raw_data": raw_data,
159
+ "step": 0,
160
+ "table_data": [["**Test Input**", "**Golden**"], *init_columns],
161
+ }
162
+ )
163
+
164
+ prompt = gr.Textbox(label="Demonstrations (Prompt template formatted)", value=demos)
165
+ big_table = gr.DataFrame(
166
+ value=[["**Test Input**", "**Golden**"], *init_columns],
167
+ elem_id="the-table",
168
+ datatype=["markdown"] * 50,
169
+ headers=None,
170
+ )
171
+ step_button = gr.Button("Step + 2, Now: 0")
172
+ step_button.click(button_pressed, inputs=[state], outputs=[state, step_button, big_table])
173
+
174
+ demo.launch(server_name="0.0.0.0")
common.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from tasks import task_mapper
8
+
9
+
10
+ def setup_plain_seed(SEED):
11
+ os.environ["PYTHONHASHSEED"] = str(SEED)
12
+ random.seed(SEED)
13
+ np.random.seed(SEED)
14
+
15
+
16
+ def setup_cpu(seed):
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+ setup_plain_seed(seed)
19
+ torch.manual_seed(seed)
20
+ torch.random.manual_seed(seed)
example_sets/sst2/demos.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Review: and mainly unfunny
2
+ Sentiment: negative
3
+
4
+ Review: filmmakers david weissman and bill weber benefit enormously from the cockettes ' camera craziness -- not only did they film performances , but
5
+ Sentiment: positive
6
+
example_sets/sst2/rawdiff.pkl ADDED
The diff for this file is too large to render. See raw diff
 
example_sets/sst2/sample.pkl ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": { "sentence": "the cold turkey would 've been a far better title . ", "label": 0, "idx": 57 },
3
+ "1": { "sentence": "it 's a cookie-cutter movie , a cut-and-paste job . ", "label": 0, "idx": 28 },
4
+ "2": { "sentence": "a solid film ... but more conscientious than it is truly stirring . ", "label": 1, "idx": 143 },
5
+ "3": { "sentence": "it 's slow -- very , very slow . ", "label": 0, "idx": 4 },
6
+ "4": { "sentence": "filmmakers who can deftly change moods are treasures and even marvels . ", "label": 1, "idx": 679 },
7
+ "5": { "sentence": "it all adds up to good fun . ", "label": 1, "idx": 393 },
8
+ "6": { "sentence": "i am sorry that i was unable to get the full brunt of the comedy . ", "label": 0, "idx": 423 },
9
+ "7": { "sentence": "hilariously inept and ridiculous . ", "label": 1, "idx": 112 }
10
+ }
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .huggingface import build_model_signature, build_tokenizer, build_model
models/huggingface.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
2
+
3
+
4
+ def build_model_signature(model_type, model_size):
5
+ if model_type == "opt":
6
+ # ["125m", "350m", "1.3b", "2.7b", "6.7b", "13b", "30b", "66b"]
7
+ return f"facebook/opt-{model_size}"
8
+ if model_type == "gpt2":
9
+ # ["sm", "medium", "large", "xl"]
10
+ if model_size == "sm":
11
+ return "gpt2"
12
+ return f"gpt2-{model_size}"
13
+ if model_type == "e-gpt":
14
+ # ["neo-125M", "neo-1.3B", "neo-2.7B", "j-6B", "neox-20b"]
15
+ return f"EleutherAI/gpt-{model_size}"
16
+ if model_type == "bloom":
17
+ # ["560m", "1b1", "1b7", "3b", "7b1"]
18
+ return f"bigscience/bloom-{model_size}"
19
+
20
+
21
+ def build_tokenizer(model_type, model_size, padding_side="left", use_fast=False):
22
+ sign = build_model_signature(model_type, model_size)
23
+ if not use_fast:
24
+ tok = AutoTokenizer.from_pretrained(sign, padding_side=padding_side)
25
+ else:
26
+ tok = PreTrainedTokenizerFast.from_pretrained(sign, padding_side=padding_side)
27
+ if model_type in ["gpt2", "e-gpt"]:
28
+ tok.pad_token_id = tok.eos_token_id
29
+ tok.pad_token = tok.eos_token
30
+ return tok
31
+
32
+
33
+ def build_model(model_type, model_size, in_8bit):
34
+ sign = build_model_signature(model_type, model_size)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ sign,
37
+ device_map="auto",
38
+ load_in_8bit=in_8bit,
39
+ )
40
+ model.eval()
41
+ return model
models/meta_optimizer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class MomentumOptim:
5
+ def __init__(self, step_size=0.01, momentum=0.9):
6
+ self.step_size = step_size
7
+ self.momentum = momentum
8
+ self.m = None # velocity
9
+
10
+ def init(self):
11
+ self.m = None
12
+
13
+ def upd_m(self, old_m, g):
14
+ return g + self.momentum * old_m
15
+
16
+ def upd(self, old_x, m):
17
+ return old_x + self.step_size * m
18
+
19
+ def __call__(self, old_xs, new_xs):
20
+ pesudo_gs = [new_x - old_x for old_x, new_x in zip(old_xs, new_xs)]
21
+
22
+ if not self.m:
23
+ self.m = pesudo_gs
24
+ else:
25
+ self.m = [self.upd_m(old_m, g) for old_m, g in zip(self.m, pesudo_gs)]
26
+
27
+ updated_kv = [self.upd(old_x, m) for old_x, m in zip(old_xs, self.m)]
28
+ return updated_kv
29
+
30
+
31
+ class AttnOptimWrapper:
32
+ def __init__(self, llm, model_type, optimizer="momentum", **optimizer_args):
33
+ self.model = llm
34
+ self.kv = None
35
+ self.model_type = model_type
36
+
37
+ if optimizer == "momentum":
38
+ self.optim_k = MomentumOptim(**optimizer_args)
39
+ self.optim_v = MomentumOptim(**optimizer_args)
40
+ else:
41
+ raise ValueError()
42
+
43
+ def init(self):
44
+ self.optim_k.init()
45
+ self.optim_v.init()
46
+
47
+ @torch.no_grad()
48
+ def step(self, ctx_ids):
49
+ L = len(ctx_ids)
50
+
51
+ ctx_ids = ctx_ids.unsqueeze(0) # [1, L]
52
+ mask = torch.ones_like(ctx_ids)
53
+ if self.kv is not None:
54
+ mask = mask.repeat(1, 2) # [1, 2*L]
55
+
56
+ next_kv = self.model(
57
+ input_ids=ctx_ids,
58
+ attention_mask=mask,
59
+ past_key_values=self.kv,
60
+ use_cache=True,
61
+ ).past_key_values # kv @ (old_ctx + new_ctx)
62
+
63
+ cur_kv = []
64
+ for layer_k, layer_v in next_kv:
65
+ # [B, num_head, 2*L, head_hidden]
66
+ cur_kv.append([layer_k[:, :, -L:, :], layer_v[:, :, -L:, :]]) # kv @ (new_ctx)
67
+
68
+ if not self.kv:
69
+ self.kv = cur_kv
70
+ else:
71
+ old_ks, old_vs = zip(*self.kv)
72
+ cur_ks, cur_vs = zip(*cur_kv)
73
+
74
+ upd_ks = self.optim_k(old_ks, cur_ks)
75
+ upd_vs = self.optim_v(old_vs, cur_vs)
76
+ self.kv = list(zip(upd_ks, upd_vs))
77
+
78
+ return self.kv
tasks/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sst2 import SST2ProbInferenceForMC
2
+
3
+
4
+ task_mapper = {"sst2": SST2ProbInferenceForMC}
5
+
6
+
7
+ def load_task(name):
8
+ if name not in task_mapper.keys():
9
+ raise ValueError(f"Unrecognized dataset `{name}`")
10
+
11
+ return task_mapper[name]
tasks/base.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class BaseProbInference:
5
+ def __init__(self, prompt_version):
6
+ if prompt_version == "default":
7
+ self.prompt_version = self.default_prompt_version()
8
+ else:
9
+ self.prompt_version = prompt_version
10
+
11
+ self.raw_data_result = None
12
+ self.raw_data_sample = None
13
+ self.raw_data_dev = None
14
+
15
+ self.can_be_stratified = False
16
+ self.CHOICES = None
17
+ self.num_base_shot = 1
18
+
19
+ def default_prompt_version(self):
20
+ raise NotImplementedError
21
+
22
+ def dataset_signature(self):
23
+ # {
24
+ # "result": (dataset_name, subset, split), # which produce the final result
25
+ # "sample": (dataset_name, subset, split), # which we sample ICL few-shot examples
26
+ # }
27
+ raise NotImplementedError
28
+
29
+ def dataset_part(self, part):
30
+ return self.dataset_signature()[part]
31
+
32
+ def dataset_preprocess(self, raw_data):
33
+ raise NotImplementedError
34
+
35
+ def handcrafted_exemplars(self):
36
+ raise NotImplementedError
37
+
38
+ def exemplar_seperator(self):
39
+ raise NotImplementedError
40
+
41
+ def multiple_choice_promptify(self, query, choice):
42
+ raise NotImplementedError
43
+
44
+ @staticmethod
45
+ def merge_choice_info(choice_info):
46
+ merged = {}
47
+ for k in ["lm_log_p", "norm_lm_log_p"]:
48
+ one_metric_merged = []
49
+ for info in choice_info:
50
+ one_metric_merged.append(info[k])
51
+ merged[k] = one_metric_merged
52
+ return merged
53
+
54
+ @staticmethod
55
+ def choice_info_to_predictions(info):
56
+ lm_log_p_idx = int(np.argmax(info["lm_log_p"]))
57
+ norm_lm_log_p_idx = int(np.argmax(info["norm_lm_log_p"]))
58
+ return {"lm_log_p": lm_log_p_idx, "norm_lm_log_p": norm_lm_log_p_idx}
tasks/loader.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from transformers import PreTrainedTokenizer
4
+
5
+
6
+ class TokenizedForMCRightPad(Dataset):
7
+ def __init__(self, data, tok: PreTrainedTokenizer, prompt_fn):
8
+ # data: [query: str, choices: list(str)]
9
+ self.tok = tok
10
+ self.prompt_fn = prompt_fn
11
+ self.max_length = self._find_max_length(data)
12
+ self.data = self._build_mc_data(data)
13
+
14
+ def _find_max_length(self, data):
15
+ max_len = 0
16
+
17
+ def tok_len(t):
18
+ return len(self.tok.encode(t))
19
+
20
+ for ex in data:
21
+ query = ex["query"]
22
+ len_choices = [tok_len(self.prompt_fn(query, c)[1]) for c in ex["choices"]]
23
+ max_len = max(max_len, *len_choices)
24
+
25
+ return max_len
26
+
27
+ def _build_mc_data(self, data):
28
+ processed = []
29
+ num_choices = set(len(e["choices"]) for e in data)
30
+ if not len(num_choices) == 1:
31
+ raise ValueError(f"Queries have different number of choices, which is not supported! #choices: {num_choices}")
32
+ for ex in data:
33
+ query, choices = ex["query"], ex["choices"]
34
+ processed_input = [self.prompt_fn(query, choice) for choice in choices]
35
+ processed_input = [self.tokenize(t_query, t_full) for t_query, t_full in processed_input]
36
+ processed.append(processed_input)
37
+
38
+ return processed
39
+
40
+ def tokenize_demonstration(self, demonstration):
41
+ e = self.tok(demonstration)
42
+ return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]) # no padding
43
+
44
+ def tokenize(self, only_query, full_text):
45
+ tok_only_query = self.tok(only_query, add_special_tokens=False)
46
+ tok_full_no_padding = self.tok(full_text, add_special_tokens=False)
47
+ tok_full = self.tok(
48
+ full_text,
49
+ padding="max_length",
50
+ max_length=self.max_length,
51
+ add_special_tokens=False,
52
+ ) # <pad> is not a special token
53
+ # tok_only_query = self.tok(only_query)
54
+ # tok_full_no_padding = self.tok(full_text)
55
+ # tok_full = self.tok(
56
+ # full_text,
57
+ # padding="max_length",
58
+ # max_length=self.max_length,
59
+ # ) # <pad> is not a special token
60
+
61
+ # print(f"tok_only_query: {self.tok.convert_ids_to_tokens(tok_only_query.input_ids)}")
62
+ # print(f"tok_full_no_padding: {self.tok.convert_ids_to_tokens(tok_full_no_padding.input_ids)}")
63
+ # print(f"tok_full: {self.tok.convert_ids_to_tokens(tok_full.input_ids)}")
64
+ # exit(0)
65
+
66
+ len_full = len(tok_full_no_padding.input_ids)
67
+ len_query = len(tok_only_query.input_ids)
68
+ e = {
69
+ "input_ids": tok_full.input_ids,
70
+ "attention_mask": tok_full.attention_mask,
71
+ "choice_start": len_query,
72
+ "choice_end": len_full,
73
+ }
74
+ # print("Attn:")
75
+ # print(tok_full.attention_mask)
76
+ # print("input_ids:")
77
+ # print(tok_full.input_ids)
78
+
79
+ dcd_sp = self.tok.convert_ids_to_tokens(tok_full.input_ids, skip_special_tokens=False)
80
+
81
+ # print(f'{e["choice_start"]}: {e["choice_end"]} = [{self.tok.convert_tokens_to_string(dcd_sp[e["choice_start"] : e["choice_end"]])}]')
82
+
83
+ return e
84
+
85
+ def __len__(self):
86
+ return len(self.data)
87
+
88
+ def __getitem__(self, idx):
89
+ def _get_one_item(e):
90
+ return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]), e["choice_start"], e["choice_end"]
91
+
92
+ es = self.data[idx]
93
+ # num_choices * (input_ids, attn, start_idx, end_idx)
94
+ # input_ids, attn: [B, L]
95
+ # start_idx, end_idx: [B, ]
96
+ return [_get_one_item(e) for e in es]
tasks/sst2.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tasks.base import BaseProbInference
2
+
3
+
4
+ class SST2ProbInferenceForMC(BaseProbInference):
5
+ def __init__(self, prompt_version):
6
+ super().__init__(prompt_version)
7
+
8
+ self.CHOICES = ["negative", "positive"]
9
+ self.can_be_stratified = True
10
+ self.num_base_shot = len(self.CHOICES)
11
+
12
+ def default_prompt_version(self):
13
+ return "sp"
14
+
15
+ def dataset_signature(self):
16
+ return {
17
+ "result": ("glue", "sst2", "validation"),
18
+ "sample": ("glue", "sst2", "train"),
19
+ }
20
+
21
+ def dataset_preprocess(self, raw_data):
22
+ data = []
23
+ for e in raw_data:
24
+ # print(e, flush=True)
25
+ data.append({"query": e["sentence"].strip(), "choices": self.CHOICES, "answer_idx": e["label"]})
26
+ return data
27
+
28
+ def handcrafted_exemplars(self):
29
+ raise NotImplementedError
30
+
31
+ def exemplar_seperator(self):
32
+ if self.prompt_version.startswith("sp"):
33
+ return "\n\n"
34
+ else:
35
+ raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
36
+
37
+ def multiple_choice_promptify(self, query, choice):
38
+ if self.prompt_version.startswith("sp"):
39
+ with_query = f"Review: {query}\nSentiment:"
40
+ with_query_and_choice = f"{with_query} {choice}"
41
+ else:
42
+ raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
43
+ return with_query, with_query_and_choice
utils/__init__.py ADDED
File without changes