Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import time | |
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler | |
from transformers import (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer, squad_convert_examples_to_features) | |
from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample | |
from transformers.data.metrics.squad_metrics import compute_predictions_logits | |
import gradio as gr | |
model_name_or_path = "ktrapeznikov/albert-xlarge-v2-squad-v2" | |
output_dir = "" | |
# Config | |
n_best_size = 1 | |
max_answer_length = 30 | |
do_lower_case = True | |
null_score_diff_threshold = 0.0 | |
def to_list(tensor): | |
return tensor.detach().cpu().tolist() | |
# Setup model | |
config_class, model_class, tokenizer_class = (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer) | |
config = config_class.from_pretrained(model_name_or_path) | |
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True) | |
model = model_class.from_pretrained(model_name_or_path, config=config) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
processor = SquadV2Processor() | |
def run_prediction(context_text, question): | |
print(question) | |
"""Setup function to compute predictions""" | |
examples = [] | |
question_texts = [question] | |
for i, question_text in enumerate(question_texts): | |
example = SquadExample( | |
qas_id=str(i), | |
question_text=question_text, | |
context_text=context_text, | |
answer_text=None, | |
start_position_character=None, | |
title="Predict", | |
is_impossible=False, | |
answers=None, | |
) | |
examples.append(example) | |
features, dataset = squad_convert_examples_to_features( | |
examples=examples, | |
tokenizer=tokenizer, | |
max_seq_length=384, | |
doc_stride=128, | |
max_query_length=64, | |
is_training=False, | |
return_dataset="pt", | |
threads=1, | |
) | |
eval_sampler = SequentialSampler(dataset) | |
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10) | |
all_results = [] | |
for batch in eval_dataloader: | |
model.eval() | |
batch = tuple(t.to(device) for t in batch) | |
with torch.no_grad(): | |
inputs = { | |
"input_ids": batch[0], | |
"attention_mask": batch[1], | |
"token_type_ids": batch[2], | |
} | |
example_indices = batch[3] | |
outputs = model(**inputs) | |
for i, example_index in enumerate(example_indices): | |
eval_feature = features[example_index.item()] | |
unique_id = int(eval_feature.unique_id) | |
output = [to_list(output[i]) for output in outputs] | |
start_logits, end_logits = output | |
result = SquadResult(unique_id, start_logits, end_logits) | |
all_results.append(result) | |
output_prediction_file = "predictions.json" | |
output_nbest_file = "nbest_predictions.json" | |
output_null_log_odds_file = "null_predictions.json" | |
predictions = compute_predictions_logits( | |
examples, | |
features, | |
all_results, | |
n_best_size, | |
max_answer_length, | |
do_lower_case, | |
output_prediction_file, | |
output_nbest_file, | |
output_null_log_odds_file, | |
False, # verbose_logging | |
True, # version_2_with_negative | |
null_score_diff_threshold, | |
tokenizer, | |
) | |
answer = "empty" | |
for key in predictions.keys(): | |
if predictions[key]: | |
answer=predictions[key] | |
break | |
return answer | |
context = "4/5/2022 · In connection with the closing, Helix Acquisition Corp changed its name to MoonLake Immunotherapeutics (“MoonLake” or the “Company”). Beginning April 6, 2022, MoonLake’s shares will trade on the Nasdaq Stock Market..." | |
question = "Helix Acquisition Corp change its name to" | |
title = 'Question Answering demo with Albert QA transformer and gradio' | |
# Run method | |
gr.Interface(run_prediction,inputs=[gr.inputs.Textbox(lines=7, default=context, label="Context"), gr.inputs.Textbox(lines=2, default=question, label="Question")], | |
outputs=[gr.outputs.Textbox(type="auto",label="Answer")],title = title,theme = "peach").launch() |