File size: 4,299 Bytes
7c96c8f
 
 
 
 
 
 
b649128
7c96c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba2367e
7601346
7c96c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90e227c
 
 
 
7c96c8f
 
 
 
 
 
 
90e227c
 
 
7c96c8f
 
 
 
 
 
 
92aff5c
 
 
7c96c8f
 
 
 
59db485
7c96c8f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import torch
import time
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer, squad_convert_examples_to_features)
from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
from transformers.data.metrics.squad_metrics import compute_predictions_logits
import gradio as gr

model_name_or_path = "ktrapeznikov/albert-xlarge-v2-squad-v2"

output_dir = ""

# Config
n_best_size = 1
max_answer_length = 30
do_lower_case = True
null_score_diff_threshold = 0.0

def to_list(tensor):
    return tensor.detach().cpu().tolist()

# Setup model
config_class, model_class, tokenizer_class = (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer)
config = config_class.from_pretrained(model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
model = model_class.from_pretrained(model_name_or_path, config=config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
processor = SquadV2Processor()

def run_prediction(context_text, question):
    print(question)
    """Setup function to compute predictions"""
    examples = []
    question_texts = [question]
    for i, question_text in enumerate(question_texts):
        example = SquadExample(
            qas_id=str(i),
            question_text=question_text,
            context_text=context_text,
            answer_text=None,
            start_position_character=None,
            title="Predict",
            is_impossible=False,
            answers=None,
        )

        examples.append(example)

    features, dataset = squad_convert_examples_to_features(
        examples=examples,
        tokenizer=tokenizer,
        max_seq_length=384,
        doc_stride=128,
        max_query_length=64,
        is_training=False,
        return_dataset="pt",
        threads=1,
    )

    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10)

    all_results = []

    for batch in eval_dataloader:
        model.eval()
        batch = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
            }

            example_indices = batch[3]

            outputs = model(**inputs)

            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)

                output = [to_list(output[i]) for output in outputs]

                start_logits, end_logits = output
                result = SquadResult(unique_id, start_logits, end_logits)
                all_results.append(result)

    output_prediction_file = "predictions.json"
    output_nbest_file = "nbest_predictions.json"
    output_null_log_odds_file = "null_predictions.json"

    predictions = compute_predictions_logits(
        examples,
        features,
        all_results,
        n_best_size,
        max_answer_length,
        do_lower_case,
        output_prediction_file,
        output_nbest_file,
        output_null_log_odds_file,
        False,  # verbose_logging
        True,  # version_2_with_negative
        null_score_diff_threshold,
        tokenizer,
    )
    answer = "empty"
    for key in predictions.keys():
      if predictions[key]:
          answer=predictions[key]
          break
    return answer 
    
    
context = "4/5/2022 · In connection with the closing, Helix Acquisition Corp changed its name to MoonLake Immunotherapeutics (“MoonLake” or the “Company”). Beginning April 6, 2022, MoonLake’s shares will trade on the Nasdaq Stock Market..."
question = "Helix Acquisition Corp change its name to"
title = 'Question Answering demo with Albert QA transformer and gradio'

# Run method

  
  
gr.Interface(run_prediction,inputs=[gr.inputs.Textbox(lines=7, default=context, label="Context"), gr.inputs.Textbox(lines=2, default=question, label="Question")],
             outputs=[gr.outputs.Textbox(type="auto",label="Answer")],title = title,theme = "peach").launch()