Walid Aissa commited on
Commit
41cb046
1 Parent(s): 22eefa0

bert question answering model trained on squad added

Browse files
Files changed (1) hide show
  1. app.py +75 -4
app.py CHANGED
@@ -1,13 +1,16 @@
1
  import os
2
  import gradio as gr
 
3
  import wikipediaapi as wk
4
  from transformers import (
5
  TokenClassificationPipeline,
6
  AutoModelForTokenClassification,
7
  AutoTokenizer,
8
  )
 
9
  from transformers.pipelines import AggregationStrategy
10
- import numpy as np
 
11
 
12
  # =====[ DEFINE PIPELINE ]===== #
13
  class KeyphraseExtractionPipeline(TokenClassificationPipeline):
@@ -27,8 +30,10 @@ class KeyphraseExtractionPipeline(TokenClassificationPipeline):
27
  return np.unique([result.get("word").strip() for result in results])
28
 
29
  # =====[ LOAD PIPELINE ]===== #
30
- model_name = "ml6team/keyphrase-extraction-kbir-inspec"
31
- extractor = KeyphraseExtractionPipeline(model=model_name)
 
 
32
 
33
  #TODO: add further preprocessing
34
  def keyphrases_extraction(text: str) -> str:
@@ -53,6 +58,71 @@ def wikipedia_search(input: str) -> str:
53
  return page.summary
54
  except:
55
  return "I cannot answer this question"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # =====[ DEFINE INTERFACE ]===== #'
58
  title = "Azza Chatbot"
@@ -62,10 +132,11 @@ examples = [
62
  ]
63
 
64
 
 
65
  demo = gr.Interface(
66
  title = title,
67
 
68
- fn=wikipedia_search,
69
  inputs = "text",
70
  outputs = "text",
71
 
 
1
  import os
2
  import gradio as gr
3
+ import numpy as np
4
  import wikipediaapi as wk
5
  from transformers import (
6
  TokenClassificationPipeline,
7
  AutoModelForTokenClassification,
8
  AutoTokenizer,
9
  )
10
+ import torch
11
  from transformers.pipelines import AggregationStrategy
12
+ from transformers import BertForQuestionAnswering
13
+ from transformers import BertTokenizer
14
 
15
  # =====[ DEFINE PIPELINE ]===== #
16
  class KeyphraseExtractionPipeline(TokenClassificationPipeline):
 
30
  return np.unique([result.get("word").strip() for result in results])
31
 
32
  # =====[ LOAD PIPELINE ]===== #
33
+ keyPhraseExtractionModel = "ml6team/keyphrase-extraction-kbir-inspec"
34
+ extractor = KeyphraseExtractionPipeline(model=keyPhraseExtractionModel)
35
+ model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
36
+ tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
37
 
38
  #TODO: add further preprocessing
39
  def keyphrases_extraction(text: str) -> str:
 
58
  return page.summary
59
  except:
60
  return "I cannot answer this question"
61
+
62
+ def answer_question(question):
63
+
64
+ context = wikipedia_search(question)
65
+ if context == "I cannot answer this question":
66
+ return context
67
+
68
+ # ======== Tokenize ========
69
+ # Apply the tokenizer to the input text, treating them as a text-pair.
70
+ input_ids = tokenizer.encode(question, context)
71
+
72
+ # Report how long the input sequence is. if longer than 512 tokens, make it shorter
73
+ while(len(input_ids) > 512):
74
+ input_ids.pop()
75
+
76
+ print('Query has {:,} tokens.\n'.format(len(input_ids)))
77
+
78
+ # ======== Set Segment IDs ========
79
+ # Search the input_ids for the first instance of the `[SEP]` token.
80
+ sep_index = input_ids.index(tokenizer.sep_token_id)
81
+
82
+ # The number of segment A tokens includes the [SEP] token istelf.
83
+ num_seg_a = sep_index + 1
84
+
85
+ # The remainder are segment B.
86
+ num_seg_b = len(input_ids) - num_seg_a
87
+
88
+ # Construct the list of 0s and 1s.
89
+ segment_ids = [0]*num_seg_a + [1]*num_seg_b
90
+
91
+ # There should be a segment_id for every input token.
92
+ assert len(segment_ids) == len(input_ids)
93
+
94
+ # ======== Evaluate ========
95
+ # Run our example through the model.
96
+ outputs = model(torch.tensor([input_ids]), # The tokens representing our input text.
97
+ token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
98
+ return_dict=True)
99
+
100
+ start_scores = outputs.start_logits
101
+ end_scores = outputs.end_logits
102
+
103
+ # ======== Reconstruct Answer ========
104
+ # Find the tokens with the highest `start` and `end` scores.
105
+ answer_start = torch.argmax(start_scores)
106
+ answer_end = torch.argmax(end_scores)
107
+
108
+ # Get the string versions of the input tokens.
109
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
110
+
111
+ # Start with the first token.
112
+ answer = tokens[answer_start]
113
+
114
+ # Select the remaining answer tokens and join them with whitespace.
115
+ for i in range(answer_start + 1, answer_end + 1):
116
+
117
+ # If it's a subword token, then recombine it with the previous token.
118
+ if tokens[i][0:2] == '##':
119
+ answer += tokens[i][2:]
120
+
121
+ # Otherwise, add a space then the token.
122
+ else:
123
+ answer += ' ' + tokens[i]
124
+
125
+ return 'Answer: "' + answer + '"'
126
 
127
  # =====[ DEFINE INTERFACE ]===== #'
128
  title = "Azza Chatbot"
 
132
  ]
133
 
134
 
135
+
136
  demo = gr.Interface(
137
  title = title,
138
 
139
+ fn=answer_question,
140
  inputs = "text",
141
  outputs = "text",
142