Spaces:
Sleeping
Sleeping
File size: 4,299 Bytes
7c96c8f b649128 7c96c8f ba2367e 7601346 7c96c8f 90e227c 7c96c8f 90e227c 7c96c8f 92aff5c 7c96c8f 59db485 7c96c8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
import torch
import time
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer, squad_convert_examples_to_features)
from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
from transformers.data.metrics.squad_metrics import compute_predictions_logits
import gradio as gr
model_name_or_path = "ktrapeznikov/albert-xlarge-v2-squad-v2"
output_dir = ""
# Config
n_best_size = 1
max_answer_length = 30
do_lower_case = True
null_score_diff_threshold = 0.0
def to_list(tensor):
return tensor.detach().cpu().tolist()
# Setup model
config_class, model_class, tokenizer_class = (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer)
config = config_class.from_pretrained(model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
model = model_class.from_pretrained(model_name_or_path, config=config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
processor = SquadV2Processor()
def run_prediction(context_text, question):
print(question)
"""Setup function to compute predictions"""
examples = []
question_texts = [question]
for i, question_text in enumerate(question_texts):
example = SquadExample(
qas_id=str(i),
question_text=question_text,
context_text=context_text,
answer_text=None,
start_position_character=None,
title="Predict",
is_impossible=False,
answers=None,
)
examples.append(example)
features, dataset = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=384,
doc_stride=128,
max_query_length=64,
is_training=False,
return_dataset="pt",
threads=1,
)
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10)
all_results = []
for batch in eval_dataloader:
model.eval()
batch = tuple(t.to(device) for t in batch)
with torch.no_grad():
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2],
}
example_indices = batch[3]
outputs = model(**inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
output = [to_list(output[i]) for output in outputs]
start_logits, end_logits = output
result = SquadResult(unique_id, start_logits, end_logits)
all_results.append(result)
output_prediction_file = "predictions.json"
output_nbest_file = "nbest_predictions.json"
output_null_log_odds_file = "null_predictions.json"
predictions = compute_predictions_logits(
examples,
features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
False, # verbose_logging
True, # version_2_with_negative
null_score_diff_threshold,
tokenizer,
)
answer = "empty"
for key in predictions.keys():
if predictions[key]:
answer=predictions[key]
break
return answer
context = "4/5/2022 · In connection with the closing, Helix Acquisition Corp changed its name to MoonLake Immunotherapeutics (“MoonLake” or the “Company”). Beginning April 6, 2022, MoonLake’s shares will trade on the Nasdaq Stock Market..."
question = "Helix Acquisition Corp change its name to"
title = 'Question Answering demo with Albert QA transformer and gradio'
# Run method
gr.Interface(run_prediction,inputs=[gr.inputs.Textbox(lines=7, default=context, label="Context"), gr.inputs.Textbox(lines=2, default=question, label="Question")],
outputs=[gr.outputs.Textbox(type="auto",label="Answer")],title = title,theme = "peach").launch() |