|
import jax |
|
import jax.numpy as jnp |
|
from bigbird_flax import FlaxBigBirdForNaturalQuestions |
|
from datasets import load_from_disk |
|
|
|
from transformers import BigBirdTokenizerFast |
|
|
|
|
|
CATEGORY_MAPPING = {0: "null", 1: "short", 2: "long", 3: "yes", 4: "no"} |
|
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"'])) |
|
|
|
|
|
def get_sub_answers(answers, begin=0, end=None): |
|
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1] |
|
|
|
|
|
def expand_to_aliases(given_answers, make_sub_answers=False): |
|
if make_sub_answers: |
|
|
|
|
|
given_answers = ( |
|
given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1) |
|
) |
|
answers = [] |
|
for answer in given_answers: |
|
alias = answer.replace("_", " ").lower() |
|
alias = "".join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias) |
|
answers.append(" ".join(alias.split()).strip()) |
|
return set(answers) |
|
|
|
|
|
def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100): |
|
best_start_scores, best_start_idx = jax.lax.top_k(start_scores, top_k) |
|
best_end_scores, best_end_idx = jax.lax.top_k(end_scores, top_k) |
|
|
|
widths = best_end_idx[:, None] - best_start_idx[None, :] |
|
mask = jnp.logical_or(widths < 0, widths > max_size) |
|
scores = (best_end_scores[:, None] + best_start_scores[None, :]) - (1e8 * mask) |
|
best_score = jnp.argmax(scores).item() |
|
|
|
return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k] |
|
|
|
|
|
def format_dataset(sample): |
|
question = sample["question"]["text"] |
|
context = sample["document"]["tokens"]["token"] |
|
is_html = sample["document"]["tokens"]["is_html"] |
|
long_answers = sample["annotations"]["long_answer"] |
|
short_answers = sample["annotations"]["short_answers"] |
|
|
|
context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]]) |
|
|
|
|
|
for answer in sample["annotations"]["yes_no_answer"]: |
|
if answer == 0 or answer == 1: |
|
return { |
|
"question": question, |
|
"context": context_string, |
|
"short": [], |
|
"long": [], |
|
"category": "no" if answer == 0 else "yes", |
|
} |
|
|
|
short_targets = [] |
|
for s in short_answers: |
|
short_targets.extend(s["text"]) |
|
short_targets = list(set(short_targets)) |
|
|
|
long_targets = [] |
|
for s in long_answers: |
|
if s["start_token"] == -1: |
|
continue |
|
answer = context[s["start_token"] : s["end_token"]] |
|
html = is_html[s["start_token"] : s["end_token"]] |
|
new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]]) |
|
if new_answer not in long_targets: |
|
long_targets.append(new_answer) |
|
|
|
category = "long_short" if len(short_targets + long_targets) > 0 else "null" |
|
|
|
return { |
|
"question": question, |
|
"context": context_string, |
|
"short": short_targets, |
|
"long": long_targets, |
|
"category": category, |
|
} |
|
|
|
|
|
def main(): |
|
dataset = load_from_disk("natural-questions-validation") |
|
dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"]) |
|
print(dataset) |
|
|
|
short_validation_dataset = dataset.filter(lambda x: (len(x["question"]) + len(x["context"])) < 4 * 4096) |
|
short_validation_dataset = short_validation_dataset.filter(lambda x: x["category"] != "null") |
|
short_validation_dataset |
|
|
|
model_id = "vasudevgupta/flax-bigbird-natural-questions" |
|
model = FlaxBigBirdForNaturalQuestions.from_pretrained(model_id) |
|
tokenizer = BigBirdTokenizerFast.from_pretrained(model_id) |
|
|
|
@jax.jit |
|
def forward(*args, **kwargs): |
|
start_logits, end_logits, pooled_logits = model(*args, **kwargs) |
|
return start_logits, end_logits, jnp.argmax(pooled_logits, axis=-1) |
|
|
|
def evaluate(example): |
|
|
|
inputs = tokenizer( |
|
example["question"], |
|
example["context"], |
|
return_tensors="np", |
|
max_length=4096, |
|
padding="max_length", |
|
truncation=True, |
|
) |
|
|
|
start_scores, end_scores, category = forward(**inputs) |
|
|
|
predicted_category = CATEGORY_MAPPING[category.item()] |
|
|
|
example["targets"] = example["long"] + example["short"] |
|
if example["category"] in ["yes", "no", "null"]: |
|
example["targets"] = [example["category"]] |
|
example["has_tgt"] = example["category"] != "null" |
|
|
|
|
|
if predicted_category in ["yes", "no", "null"]: |
|
example["output"] = [predicted_category] |
|
example["match"] = example["output"] == example["targets"] |
|
example["has_pred"] = predicted_category != "null" |
|
return example |
|
|
|
max_size = 38 if predicted_category == "short" else 1024 |
|
start_score, end_score = get_best_valid_start_end_idx( |
|
start_scores[0], end_scores[0], top_k=8, max_size=max_size |
|
) |
|
|
|
input_ids = inputs["input_ids"][0].tolist() |
|
example["output"] = [tokenizer.decode(input_ids[start_score : end_score + 1])] |
|
|
|
answers = expand_to_aliases(example["targets"], make_sub_answers=True) |
|
predictions = expand_to_aliases(example["output"]) |
|
|
|
|
|
answers = {"".join(a.split()) for a in answers} |
|
predictions = {"".join(p.split()) for p in predictions} |
|
predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]} |
|
|
|
|
|
example["match"] = len(list(answers & predictions)) > 0 |
|
example["has_pred"] = predicted_category != "null" and len(predictions) > 0 |
|
|
|
return example |
|
|
|
short_validation_dataset = short_validation_dataset.map(evaluate) |
|
|
|
total = len(short_validation_dataset) |
|
matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1)) |
|
print("EM score:", (matched / total) * 100, "%") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|