Spaces:
Runtime error
Runtime error
import json | |
from pathlib import Path | |
import gradio as gr | |
import torch | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from common import setup_cpu | |
from models import build_tokenizer, build_model | |
from models.meta_optimizer import AttnOptimWrapper | |
from tasks import load_task | |
from tasks.loader import TokenizedForMCRightPad | |
DISPLAY_MAPPING = { | |
"sst2": {"positive": "Pos", "negative": "Neg"}, | |
} | |
def do_infer_probs(model, exemplar_attn_kv, exemplar_attn_mask, batched_choices_input): | |
batched_choices_logprobs = [] | |
for batched_one_choice_input in batched_choices_input: | |
( | |
batch_input_ids, | |
batch_attention_mask, | |
batch_choice_start, | |
batch_choice_end, | |
) = batched_one_choice_input | |
bs = len(batch_input_ids) | |
merged_attn_mask = torch.cat((exemplar_attn_mask.expand(bs, -1), batch_attention_mask), dim=1) | |
# [B, #Heads, Length, Hidden] | |
expand_exemplar_attn_kv = [[layer_k.expand((bs, -1, -1, -1)), layer_v.expand((bs, -1, -1, -1))] for layer_k, layer_v in exemplar_attn_kv] | |
batched_logits = model( | |
input_ids=batch_input_ids, # [B, L'] | |
attention_mask=merged_attn_mask, # [B, L + L'] | |
past_key_values=expand_exemplar_attn_kv, # num_layers * 2 * [B, num_heads, L, H] | |
).logits | |
batched_output = F.log_softmax(batched_logits, dim=-1) # [B, L', Vocab] | |
batched_one_choice_logprobs = [] | |
for input_ids, choice_start, choice_end, lm_logprobs in zip(batch_input_ids, batch_choice_start, batch_choice_end, batched_output): | |
choice_tokens = input_ids[choice_start:choice_end].unsqueeze(1) # [L, 1] | |
choice_logprobs = lm_logprobs[choice_start - 1 : choice_end - 1] # [L, Vocab] | |
extracted = torch.gather(choice_logprobs, -1, choice_tokens).squeeze(-1) | |
choice_length = choice_end - choice_start | |
lm_log_p = torch.sum(extracted).item() | |
norm_lm_log_p = (lm_log_p / choice_length).item() | |
choice_lm_info = {"lm_log_p": lm_log_p, "norm_lm_log_p": norm_lm_log_p} | |
batched_one_choice_logprobs.append(choice_lm_info) | |
batched_choices_logprobs.append(batched_one_choice_logprobs) | |
return batched_choices_logprobs | |
def process_once(dataset_name, exemplar_str, forward_steps, raw_data): | |
setup_cpu(seed=seed) | |
TaskHandler = load_task(dataset_name) | |
task_agent = TaskHandler(prompt_version) | |
processed_data = task_agent.dataset_preprocess(raw_data) | |
dataset = TokenizedForMCRightPad(processed_data, tokenizer, task_agent.multiple_choice_promptify) | |
exemplar_input_ids, exemplar_attn_mask = dataset.tokenize_demonstration(exemplar_str) | |
loader = DataLoader(dataset, shuffle=False, drop_last=False, batch_size=1) | |
meta_optim = AttnOptimWrapper(model, model_name, step_size=step_size, momentum=momentum) | |
meta_optim.init() | |
for _ in range(forward_steps): | |
exemplar_kv = meta_optim.step(exemplar_input_ids) | |
generated_info = [] # question * [choice0_prob, choice1_prob] | |
for batch_input in loader: | |
batch_output = do_infer_probs(model, exemplar_kv, exemplar_attn_mask.unsqueeze(0), batch_input) # [batch_of_choice0, batch_of_choice1, ...] | |
zipped_logprobs = list(zip(*batch_output)) # batch * (choice0, choice1, ...) | |
generated_info.extend(zipped_logprobs) | |
all_predicted = [] | |
num_correct = 0 | |
for idx, (data, choice_info) in enumerate(zip(processed_data, generated_info)): | |
merged_choice_info = task_agent.merge_choice_info(choice_info) | |
merged_predictions_idx = task_agent.choice_info_to_predictions(merged_choice_info)["lm_log_p"] | |
predicted = task_agent.CHOICES[merged_predictions_idx] | |
ground_truth = task_agent.CHOICES[data["answer_idx"]] | |
res = f"{DISPLAY_MAPPING[dataset_name][predicted]}" | |
if predicted == ground_truth: | |
res += " ✅" | |
num_correct += 1 | |
else: | |
res += " ❌" | |
all_predicted.append(res) | |
all_predicted.append(f"{100*num_correct / len(all_predicted):.2f}%") | |
return all_predicted | |
def transpose(l): | |
return list(map(list, zip(*l))) | |
def button_pressed(prev_state): | |
dataset_name = prev_state["dataset_name"] | |
exemplar_str = prev_state["exemplar_str"] | |
forward_steps = prev_state["step"] + 2 | |
raw_data = prev_state["raw_data"] | |
prev_table_data = prev_state["table_data"] | |
current_output = process_once(dataset_name, exemplar_str, forward_steps, raw_data) | |
t_prev = transpose(prev_table_data) | |
if forward_steps == 1: | |
t_prev.append(["**ICL**"] + current_output) | |
else: | |
t_prev.append([f"**Step={forward_steps}**"] + current_output) | |
updated_table_data = transpose(t_prev) | |
ret = [ | |
{ | |
"dataset_name": dataset_name, | |
"exemplar_str": exemplar_str, | |
"raw_data": raw_data, | |
"step": forward_steps, | |
"table_data": updated_table_data, | |
}, | |
f"Click here to train LLM ! Now Step: {forward_steps}", | |
updated_table_data, | |
] | |
return ret | |
if __name__ == "__main__": | |
dataset_name = "sst2" | |
seed = 0 | |
prompt_version = "default" | |
kv_iter = 10 | |
model_name, model_size = "opt", "125m" | |
step_size, momentum = 0.01, 0.9 | |
setup_cpu(seed=seed) | |
tokenizer = build_tokenizer(model_name, model_size, padding_side="right") | |
model = build_model(model_name, model_size, False) | |
torch.autograd.set_grad_enabled(False) | |
print(f"Dataset: {dataset_name}") | |
task_root = Path("example_sets").joinpath(dataset_name) | |
with task_root.joinpath("demos.txt").open("r") as f: | |
demos = f.read() | |
with task_root.joinpath("sample.pkl").open("r") as f: | |
raw_data = json.load(f) | |
icl_result = process_once(dataset_name, demos, 1, raw_data) | |
text = """We utilize a Large Language Model (LLM) to perform in-context learning (ICL) for sentiment classification of movie reviews. | |
Taking the following two labeled examples as demonstrations, we predict the sentiment of the subsequent test input. | |
Directly employing ICL results in lower prediction accuracy. However, in our proposed approach, **Deep-Thinking**, we repeatedly apply **Forward Tuning**, leading to improved accuracy of the model.""" | |
css = """ | |
#the-table { overflow: auto; } | |
#the-table > div:nth-child(2) { margin: auto; width: fit-content; } | |
#the-table > div > div > div > table { width: auto; margin: 0; white-space: normal; } | |
#the-table > div > div > div > table > thead {display: none} | |
#the-table > div > div > div > table > tbody > tr:last-child {background-color: beige} | |
#the-table > div > div > div > table > tbody > tr:first-child {background-color: lightgray} | |
#the-table > div > div > div > table > tbody > tr > td {padding: 0 2px;} | |
#the-table > div > div > div > table > tbody > tr > td:first-child {min-width: 300px;} | |
#the-table > div > div > div > table > tbody > tr > td:not(:first-child) {white-space: nowrap; } | |
#the-text { font-size: large; } | |
#main-button { max-width: 500px; margin: 0 auto; } | |
""" | |
title = "🤔 Iterative Forward Tuning Boosts In-context Learning in Language Models" | |
demo = gr.Blocks(css=css, title="🤔Deep-Thinking") | |
with demo: | |
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>") | |
gr.Markdown( | |
""" | |
<h2 style='text-align: center; margin-bottom: 1rem'> | |
<a href='https://arxiv.org/abs/2305.13016' target="_blank" style='text-decoration: none'>[Paper]</a> | |
<a href='https://arxiv.org/abs/2305.13016' target="_blank" style='text-decoration: none'>[Code]</a> | |
</h2>""" | |
) | |
gr.Markdown(text, elem_id="the-text") | |
with gr.Tab("SST-2"): | |
mapping = ["negative", "positive"] | |
init_columns = [[e["sentence"]] for e in raw_data] | |
init_table_result = [["**Test Input**"], *init_columns, ["**Accuracy**"]] | |
init_table_result = transpose(init_table_result) | |
init_table_result.append(["**ICL**"] + icl_result) | |
init_table_result = transpose(init_table_result) | |
state = gr.State( | |
{ | |
"dataset_name": "sst2", | |
"exemplar_str": demos, | |
"raw_data": raw_data, | |
"step": 1, | |
"table_data": init_table_result, | |
} | |
) | |
prompt = gr.Textbox(label="Demonstrations (Prompt template formatted)", value=demos) | |
gr.Markdown("<h2 style='text-align: center; margin-bottom: 1rem'>👇 Run forward tuning once !</h2>") | |
step_button = gr.Button("Click here to train LLM ! Now Step: 1", variant="primary", elem_id="main-button") | |
big_table = gr.DataFrame( | |
value=init_table_result, | |
elem_id="the-table", | |
datatype=["markdown"] * 50, | |
headers=None, | |
) | |
step_button.click(button_pressed, inputs=[state], outputs=[state, step_button, big_table]) | |
demo.launch(server_name="0.0.0.0") | |