import os import torch import time from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers import (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer, squad_convert_examples_to_features) from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample from transformers.data.metrics.squad_metrics import compute_predictions_logits import gradio as gr model_name_or_path = "ktrapeznikov/albert-xlarge-v2-squad-v2" output_dir = "" # Config n_best_size = 1 max_answer_length = 30 do_lower_case = True null_score_diff_threshold = 0.0 def to_list(tensor): return tensor.detach().cpu().tolist() # Setup model config_class, model_class, tokenizer_class = (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer) config = config_class.from_pretrained(model_name_or_path) tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True) model = model_class.from_pretrained(model_name_or_path, config=config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) processor = SquadV2Processor() def run_prediction(context_text, question): print(question) """Setup function to compute predictions""" examples = [] question_texts = [question] for i, question_text in enumerate(question_texts): example = SquadExample( qas_id=str(i), question_text=question_text, context_text=context_text, answer_text=None, start_position_character=None, title="Predict", is_impossible=False, answers=None, ) examples.append(example) features, dataset = squad_convert_examples_to_features( examples=examples, tokenizer=tokenizer, max_seq_length=384, doc_stride=128, max_query_length=64, is_training=False, return_dataset="pt", threads=1, ) eval_sampler = SequentialSampler(dataset) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10) all_results = [] for batch in eval_dataloader: model.eval() batch = tuple(t.to(device) for t in batch) with torch.no_grad(): inputs = { "input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2], } example_indices = batch[3] outputs = model(**inputs) for i, example_index in enumerate(example_indices): eval_feature = features[example_index.item()] unique_id = int(eval_feature.unique_id) output = [to_list(output[i]) for output in outputs] start_logits, end_logits = output result = SquadResult(unique_id, start_logits, end_logits) all_results.append(result) output_prediction_file = "predictions.json" output_nbest_file = "nbest_predictions.json" output_null_log_odds_file = "null_predictions.json" predictions = compute_predictions_logits( examples, features, all_results, n_best_size, max_answer_length, do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file, False, # verbose_logging True, # version_2_with_negative null_score_diff_threshold, tokenizer, ) answer = "empty" for key in predictions.keys(): if predictions[key]: answer=predictions[key] break return answer context = "4/5/2022 · In connection with the closing, Helix Acquisition Corp changed its name to MoonLake Immunotherapeutics (“MoonLake” or the “Company”). Beginning April 6, 2022, MoonLake’s shares will trade on the Nasdaq Stock Market..." question = "Helix Acquisition Corp change its name to" title = 'Question Answering demo with Albert QA transformer and gradio' # Run method gr.Interface(run_prediction,inputs=[gr.inputs.Textbox(lines=7, default=context, label="Context"), gr.inputs.Textbox(lines=2, default=question, label="Question")], outputs=[gr.outputs.Textbox(type="auto",label="Answer")],title = title,theme = "peach").launch()