abhitopia
commited on
Commit
•
4383992
1
Parent(s):
fb8ffc5
Added AWS Sagemaker friendly code
Browse files- .python-version +1 -0
- README.md +0 -23
- code/inference.py +33 -0
- code/qa_generator_pipeline.py +139 -0
- code/requirements.txt +6 -0
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
qa_gen
|
README.md
CHANGED
@@ -25,26 +25,3 @@ For QA
|
|
25 |
`question: What is 42 context: 42 is the answer to life, the universe and everything. </s>`
|
26 |
|
27 |
For more deatils see [this](https://github.com/patil-suraj/question_generation) repo.
|
28 |
-
|
29 |
-
|
30 |
-
### Model in action 🚀
|
31 |
-
|
32 |
-
You'll need to clone the [repo](https://github.com/patil-suraj/question_generation).
|
33 |
-
|
34 |
-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/question_generation/blob/master/question_generation.ipynb)
|
35 |
-
|
36 |
-
```python3
|
37 |
-
from pipelines import pipeline
|
38 |
-
nlp = pipeline("multitask-qa-qg", model="valhalla/t5-base-qa-qg-hl")
|
39 |
-
|
40 |
-
# to generate questions simply pass the text
|
41 |
-
nlp("42 is the answer to life, the universe and everything.")
|
42 |
-
=> [{'answer': '42', 'question': 'What is the answer to life, the universe and everything?'}]
|
43 |
-
|
44 |
-
# for qa pass a dict with "question" and "context"
|
45 |
-
nlp({
|
46 |
-
"question": "What is 42 ?",
|
47 |
-
"context": "42 is the answer to life, the universe and everything."
|
48 |
-
})
|
49 |
-
=> 'the answer to life, the universe and everything'
|
50 |
-
```
|
|
|
25 |
`question: What is 42 context: 42 is the answer to life, the universe and everything. </s>`
|
26 |
|
27 |
For more deatils see [this](https://github.com/patil-suraj/question_generation) repo.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/inference.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from qa_generator_pipeline import QAGeneratorPipeline
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
JSON_CONTENT_TYPE = 'application/json'
|
8 |
+
|
9 |
+
|
10 |
+
def model_fn(model_dir):
|
11 |
+
logger.info(f"model_dir: {model_dir}")
|
12 |
+
model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True)
|
13 |
+
return model
|
14 |
+
|
15 |
+
|
16 |
+
def predict_fn(input_data, model):
|
17 |
+
logger.info("input text: {}".format(input_data))
|
18 |
+
prediction = model(input_data)
|
19 |
+
logger.info("prediction: {}".format(input_data))
|
20 |
+
|
21 |
+
|
22 |
+
def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
|
23 |
+
if content_type == JSON_CONTENT_TYPE:
|
24 |
+
input_data = json.loads(serialized_input_data)
|
25 |
+
return input_data
|
26 |
+
else:
|
27 |
+
pass
|
28 |
+
|
29 |
+
|
30 |
+
def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
|
31 |
+
if accept == JSON_CONTENT_TYPE:
|
32 |
+
return json.dumps(prediction_output), accept
|
33 |
+
raise Exception('Unsupported Content Type')
|
code/qa_generator_pipeline.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
+
import subprocess
|
5 |
+
# subprocess.call(["pip", "install", "nltk"])
|
6 |
+
# subprocess.call(["python", "-m", "nltk.downloader", "punkt"])
|
7 |
+
|
8 |
+
from nltk import sent_tokenize
|
9 |
+
import nltk
|
10 |
+
nltk.download('punkt')
|
11 |
+
|
12 |
+
|
13 |
+
class QAGeneratorPipeline:
|
14 |
+
"""Poor man's QG pipeline"""
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
model_dir: str,
|
18 |
+
use_cuda: bool = True
|
19 |
+
):
|
20 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
21 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
|
22 |
+
self.device = "cuda" if torch.cuda.is_available () and use_cuda else "cpu"
|
23 |
+
self.model.to(self.device)
|
24 |
+
assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]
|
25 |
+
self.model_type = "t5"
|
26 |
+
|
27 |
+
def __call__(self, inputs: str):
|
28 |
+
inputs = " ".join(inputs.split ())
|
29 |
+
sents, answers = self._extract_answers(inputs)
|
30 |
+
flat_answers = list(itertools.chain(*answers))
|
31 |
+
|
32 |
+
if len(flat_answers) == 0:
|
33 |
+
return []
|
34 |
+
|
35 |
+
qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
|
36 |
+
qg_inputs = [example['source_text'] for example in qg_examples]
|
37 |
+
questions = self._generate_questions(qg_inputs)
|
38 |
+
output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
|
39 |
+
return output
|
40 |
+
|
41 |
+
def _generate_questions(self, inputs):
|
42 |
+
inputs = self._tokenize(inputs, padding=True, truncation=True)
|
43 |
+
|
44 |
+
outs = self.model.generate(
|
45 |
+
input_ids=inputs['input_ids'].to(self.device),
|
46 |
+
attention_mask=inputs['attention_mask'].to(self.device),
|
47 |
+
max_length=32,
|
48 |
+
num_beams=4,
|
49 |
+
)
|
50 |
+
|
51 |
+
questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
|
52 |
+
return questions
|
53 |
+
|
54 |
+
def _extract_answers(self, context):
|
55 |
+
sents, inputs = self._prepare_inputs_for_ans_extraction(context)
|
56 |
+
inputs = self._tokenize(inputs, padding=True, truncation=True)
|
57 |
+
|
58 |
+
outs = self.model.generate(
|
59 |
+
input_ids=inputs['input_ids'].to(self.device),
|
60 |
+
attention_mask=inputs['attention_mask'].to(self.device),
|
61 |
+
max_length=32,
|
62 |
+
)
|
63 |
+
|
64 |
+
dec = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
|
65 |
+
answers = [item.split('<sep>') for item in dec]
|
66 |
+
answers = [i[:-1] for i in answers]
|
67 |
+
|
68 |
+
return sents, answers
|
69 |
+
|
70 |
+
def _tokenize(self,
|
71 |
+
inputs,
|
72 |
+
padding=True,
|
73 |
+
truncation=True,
|
74 |
+
add_special_tokens=True,
|
75 |
+
max_length=512
|
76 |
+
):
|
77 |
+
inputs = self.tokenizer.batch_encode_plus(
|
78 |
+
inputs,
|
79 |
+
max_length=max_length,
|
80 |
+
add_special_tokens=add_special_tokens,
|
81 |
+
truncation=truncation,
|
82 |
+
padding="max_length" if padding else False,
|
83 |
+
pad_to_max_length=padding,
|
84 |
+
return_tensors="pt"
|
85 |
+
)
|
86 |
+
return inputs
|
87 |
+
|
88 |
+
def _prepare_inputs_for_ans_extraction(self, text):
|
89 |
+
sents = sent_tokenize(text)
|
90 |
+
|
91 |
+
inputs = []
|
92 |
+
for i in range(len(sents)):
|
93 |
+
source_text = "extract answers:"
|
94 |
+
for j, sent in enumerate(sents):
|
95 |
+
if i == j:
|
96 |
+
sent = "<hl> %s <hl>" % sent
|
97 |
+
source_text = "%s %s" % (source_text, sent)
|
98 |
+
source_text = source_text.strip ()
|
99 |
+
|
100 |
+
if self.model_type == "t5":
|
101 |
+
source_text = source_text + " </s>"
|
102 |
+
inputs.append(source_text)
|
103 |
+
|
104 |
+
return sents, inputs
|
105 |
+
|
106 |
+
def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
|
107 |
+
inputs = []
|
108 |
+
for i, answer in enumerate(answers):
|
109 |
+
if len(answer) == 0: continue
|
110 |
+
for answer_text in answer:
|
111 |
+
sent = sents[i]
|
112 |
+
sents_copy = sents[:]
|
113 |
+
|
114 |
+
answer_text = answer_text.strip ()
|
115 |
+
|
116 |
+
ans_start_idx = sent.index(answer_text)
|
117 |
+
|
118 |
+
sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
|
119 |
+
sents_copy[i] = sent
|
120 |
+
|
121 |
+
source_text = " ".join(sents_copy)
|
122 |
+
source_text = f"generate question: {source_text}"
|
123 |
+
if self.model_type == "t5":
|
124 |
+
source_text = source_text + " </s>"
|
125 |
+
|
126 |
+
inputs.append({"answer": answer_text, "source_text": source_text})
|
127 |
+
|
128 |
+
return inputs
|
129 |
+
|
130 |
+
def _prepare_inputs_for_qg_from_answers_prepend(self, context, answers):
|
131 |
+
flat_answers = list(itertools.chain(*answers))
|
132 |
+
examples = []
|
133 |
+
for answer in flat_answers:
|
134 |
+
source_text = f"answer: {answer} context: {context}"
|
135 |
+
if self.model_type == "t5":
|
136 |
+
source_text = source_text + " </s>"
|
137 |
+
|
138 |
+
examples.append({"answer": answer, "source_text": source_text})
|
139 |
+
return examples
|
code/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wheel
|
2 |
+
torch==1.12.1
|
3 |
+
transformers==4.21.2
|
4 |
+
nltk==3.7
|
5 |
+
sentencepiece==0.1.97
|
6 |
+
protobuf==3.20
|