abhitopia commited on
Commit
4383992
1 Parent(s): fb8ffc5

Added AWS Sagemaker friendly code

Browse files
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ qa_gen
README.md CHANGED
@@ -25,26 +25,3 @@ For QA
25
  `question: What is 42 context: 42 is the answer to life, the universe and everything. </s>`
26
 
27
  For more deatils see [this](https://github.com/patil-suraj/question_generation) repo.
28
-
29
-
30
- ### Model in action 🚀
31
-
32
- You'll need to clone the [repo](https://github.com/patil-suraj/question_generation).
33
-
34
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/question_generation/blob/master/question_generation.ipynb)
35
-
36
- ```python3
37
- from pipelines import pipeline
38
- nlp = pipeline("multitask-qa-qg", model="valhalla/t5-base-qa-qg-hl")
39
-
40
- # to generate questions simply pass the text
41
- nlp("42 is the answer to life, the universe and everything.")
42
- => [{'answer': '42', 'question': 'What is the answer to life, the universe and everything?'}]
43
-
44
- # for qa pass a dict with "question" and "context"
45
- nlp({
46
- "question": "What is 42 ?",
47
- "context": "42 is the answer to life, the universe and everything."
48
- })
49
- => 'the answer to life, the universe and everything'
50
- ```
 
25
  `question: What is 42 context: 42 is the answer to life, the universe and everything. </s>`
26
 
27
  For more deatils see [this](https://github.com/patil-suraj/question_generation) repo.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+
4
+ from qa_generator_pipeline import QAGeneratorPipeline
5
+ logger = logging.getLogger(__name__)
6
+
7
+ JSON_CONTENT_TYPE = 'application/json'
8
+
9
+
10
+ def model_fn(model_dir):
11
+ logger.info(f"model_dir: {model_dir}")
12
+ model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True)
13
+ return model
14
+
15
+
16
+ def predict_fn(input_data, model):
17
+ logger.info("input text: {}".format(input_data))
18
+ prediction = model(input_data)
19
+ logger.info("prediction: {}".format(input_data))
20
+
21
+
22
+ def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
23
+ if content_type == JSON_CONTENT_TYPE:
24
+ input_data = json.loads(serialized_input_data)
25
+ return input_data
26
+ else:
27
+ pass
28
+
29
+
30
+ def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
31
+ if accept == JSON_CONTENT_TYPE:
32
+ return json.dumps(prediction_output), accept
33
+ raise Exception('Unsupported Content Type')
code/qa_generator_pipeline.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import subprocess
5
+ # subprocess.call(["pip", "install", "nltk"])
6
+ # subprocess.call(["python", "-m", "nltk.downloader", "punkt"])
7
+
8
+ from nltk import sent_tokenize
9
+ import nltk
10
+ nltk.download('punkt')
11
+
12
+
13
+ class QAGeneratorPipeline:
14
+ """Poor man's QG pipeline"""
15
+ def __init__(
16
+ self,
17
+ model_dir: str,
18
+ use_cuda: bool = True
19
+ ):
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
21
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
22
+ self.device = "cuda" if torch.cuda.is_available () and use_cuda else "cpu"
23
+ self.model.to(self.device)
24
+ assert self.model.__class__.__name__ in ["T5ForConditionalGeneration", "BartForConditionalGeneration"]
25
+ self.model_type = "t5"
26
+
27
+ def __call__(self, inputs: str):
28
+ inputs = " ".join(inputs.split ())
29
+ sents, answers = self._extract_answers(inputs)
30
+ flat_answers = list(itertools.chain(*answers))
31
+
32
+ if len(flat_answers) == 0:
33
+ return []
34
+
35
+ qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
36
+ qg_inputs = [example['source_text'] for example in qg_examples]
37
+ questions = self._generate_questions(qg_inputs)
38
+ output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
39
+ return output
40
+
41
+ def _generate_questions(self, inputs):
42
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
43
+
44
+ outs = self.model.generate(
45
+ input_ids=inputs['input_ids'].to(self.device),
46
+ attention_mask=inputs['attention_mask'].to(self.device),
47
+ max_length=32,
48
+ num_beams=4,
49
+ )
50
+
51
+ questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
52
+ return questions
53
+
54
+ def _extract_answers(self, context):
55
+ sents, inputs = self._prepare_inputs_for_ans_extraction(context)
56
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
57
+
58
+ outs = self.model.generate(
59
+ input_ids=inputs['input_ids'].to(self.device),
60
+ attention_mask=inputs['attention_mask'].to(self.device),
61
+ max_length=32,
62
+ )
63
+
64
+ dec = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
65
+ answers = [item.split('<sep>') for item in dec]
66
+ answers = [i[:-1] for i in answers]
67
+
68
+ return sents, answers
69
+
70
+ def _tokenize(self,
71
+ inputs,
72
+ padding=True,
73
+ truncation=True,
74
+ add_special_tokens=True,
75
+ max_length=512
76
+ ):
77
+ inputs = self.tokenizer.batch_encode_plus(
78
+ inputs,
79
+ max_length=max_length,
80
+ add_special_tokens=add_special_tokens,
81
+ truncation=truncation,
82
+ padding="max_length" if padding else False,
83
+ pad_to_max_length=padding,
84
+ return_tensors="pt"
85
+ )
86
+ return inputs
87
+
88
+ def _prepare_inputs_for_ans_extraction(self, text):
89
+ sents = sent_tokenize(text)
90
+
91
+ inputs = []
92
+ for i in range(len(sents)):
93
+ source_text = "extract answers:"
94
+ for j, sent in enumerate(sents):
95
+ if i == j:
96
+ sent = "<hl> %s <hl>" % sent
97
+ source_text = "%s %s" % (source_text, sent)
98
+ source_text = source_text.strip ()
99
+
100
+ if self.model_type == "t5":
101
+ source_text = source_text + " </s>"
102
+ inputs.append(source_text)
103
+
104
+ return sents, inputs
105
+
106
+ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
107
+ inputs = []
108
+ for i, answer in enumerate(answers):
109
+ if len(answer) == 0: continue
110
+ for answer_text in answer:
111
+ sent = sents[i]
112
+ sents_copy = sents[:]
113
+
114
+ answer_text = answer_text.strip ()
115
+
116
+ ans_start_idx = sent.index(answer_text)
117
+
118
+ sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
119
+ sents_copy[i] = sent
120
+
121
+ source_text = " ".join(sents_copy)
122
+ source_text = f"generate question: {source_text}"
123
+ if self.model_type == "t5":
124
+ source_text = source_text + " </s>"
125
+
126
+ inputs.append({"answer": answer_text, "source_text": source_text})
127
+
128
+ return inputs
129
+
130
+ def _prepare_inputs_for_qg_from_answers_prepend(self, context, answers):
131
+ flat_answers = list(itertools.chain(*answers))
132
+ examples = []
133
+ for answer in flat_answers:
134
+ source_text = f"answer: {answer} context: {context}"
135
+ if self.model_type == "t5":
136
+ source_text = source_text + " </s>"
137
+
138
+ examples.append({"answer": answer, "source_text": source_text})
139
+ return examples
code/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ wheel
2
+ torch==1.12.1
3
+ transformers==4.21.2
4
+ nltk==3.7
5
+ sentencepiece==0.1.97
6
+ protobuf==3.20