Christian Koch commited on
Commit
0df07e9
1 Parent(s): fe38db6

question generator

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ .idea/
2
+ model/*.ckpt
3
+ venv/
app.py CHANGED
@@ -1,12 +1,18 @@
1
  import streamlit as st
2
  from transformers import pipeline, PegasusForConditionalGeneration, PegasusTokenizer
 
 
3
  from fill_in_summary import FillInSummary
4
  from paraphrase import PegasusParaphraser
 
 
5
 
6
- def paraphrase(text):
7
- return text
 
8
 
9
 
 
10
  st.title('Question Generator by Eddevs')
11
 
12
  select = st.selectbox('Type', ['Question Generator', 'Paraphrasing', 'Summarization', 'Fill in the blank'])
@@ -18,17 +24,61 @@ if select == "Question Generator":
18
  # left_column.selectbox('Type', ['Question Generator', 'Paraphrasing'])
19
  #st.selectbox('Model', ['T5', 'GPT Neo-X'])
20
 
21
- text_input = st.text_area("Input Text")
 
22
 
23
- submitted = st.form_submit_button("Generate")
 
 
24
 
25
- if submitted:
26
- with st.spinner('Wait for it...'):
27
- result = FillInSummary().summarize(text_input)
28
- st.write(text_input)
29
 
 
30
 
31
- if select == "Summarization":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  with st.form("summarization"):
33
  # left_column, right_column = st.columns(2)
34
  # left_column.selectbox('Type', ['Question Generator', 'Paraphrasing'])
@@ -44,7 +94,7 @@ if select == "Summarization":
44
  st.write(text_input)
45
 
46
 
47
- if select == "Fill in the blank":
48
  with st.form("fill_in_the_blank"):
49
  text_input = st.text_area("Input Text")
50
 
@@ -58,7 +108,7 @@ if select == "Fill in the blank":
58
  st.write(result)
59
 
60
 
61
- if select == "Paraphrasing":
62
  with st.form("paraphrasing"):
63
  # st.selectbox('Model', ['T5', 'GPT Neo-X'])
64
  left_column, right_column = st.columns(2)
1
  import streamlit as st
2
  from transformers import pipeline, PegasusForConditionalGeneration, PegasusTokenizer
3
+ import nltk
4
+
5
  from fill_in_summary import FillInSummary
6
  from paraphrase import PegasusParaphraser
7
+ import question_generator as q
8
+
9
 
10
+ # Question Generator Variables
11
+ ids = {'mt5-small': st.secrets['small'],
12
+ 'mt5-base': st.secrets['base']}
13
 
14
 
15
+ st.set_page_config(layout="centered")
16
  st.title('Question Generator by Eddevs')
17
 
18
  select = st.selectbox('Type', ['Question Generator', 'Paraphrasing', 'Summarization', 'Fill in the blank'])
24
  # left_column.selectbox('Type', ['Question Generator', 'Paraphrasing'])
25
  #st.selectbox('Model', ['T5', 'GPT Neo-X'])
26
 
27
+ # Download all models from drive
28
+ q.download_models(ids)
29
 
30
+ # Model selection
31
+ model_path = st.selectbox('', options=[k for k in ids], index=1, help='Model to use. ')
32
+ model = q.load_model(model_path=f"model/{model_path}.ckpt")
33
 
34
+ text_input = st.text_area("Input Text")
 
 
 
35
 
36
+ submitted = st.form_submit_button("Generate")
37
 
38
+ split = st.checkbox('Split into sentences', value=True)
39
+
40
+ if split:
41
+ # Split into sentences
42
+ sent_tokenized = nltk.sent_tokenize(inputs)
43
+ res = {}
44
+
45
+ with st.spinner('Please wait while the inputs are being processed...'):
46
+ # Iterate over sentences
47
+ for sentence in sent_tokenized:
48
+ predictions = model.multitask([sentence], max_length=512)
49
+ questions, answers, answers_bis = predictions['questions'], predictions['answers'], predictions[
50
+ 'answers_bis']
51
+
52
+ # Build answer dict
53
+ content = {}
54
+ for question, answer, answer_bis in zip(questions[0], answers[0], answers_bis[0]):
55
+ content[question] = {'answer (extracted)': answer, 'answer (generated)': answer_bis}
56
+ res[sentence] = content
57
+
58
+ # Answer area
59
+ st.write(res)
60
+
61
+ else:
62
+ with st.spinner('Please wait while the inputs are being processed...'):
63
+ # Prediction
64
+ predictions = model.multitask([inputs], max_length=512)
65
+ questions, answers, answers_bis = predictions['questions'], predictions['answers'], predictions[
66
+ 'answers_bis']
67
+
68
+ # Answer area
69
+ zip = zip(questions[0], answers[0], answers_bis[0])
70
+ content = {}
71
+ for question, answer, answer_bis in zip:
72
+ content[question] = {'answer (extracted)': answer, 'answer (generated)': answer_bis}
73
+
74
+ st.write(content)
75
+ if submitted:
76
+ with st.spinner('Wait for it...'):
77
+ result = FillInSummary().summarize(text_input)
78
+ st.write(text_input)
79
+
80
+
81
+ elif select == "Summarization":
82
  with st.form("summarization"):
83
  # left_column, right_column = st.columns(2)
84
  # left_column.selectbox('Type', ['Question Generator', 'Paraphrasing'])
94
  st.write(text_input)
95
 
96
 
97
+ elif select == "Fill in the blank":
98
  with st.form("fill_in_the_blank"):
99
  text_input = st.text_area("Input Text")
100
 
108
  st.write(result)
109
 
110
 
111
+ elif select == "Paraphrasing":
112
  with st.form("paraphrasing"):
113
  # st.selectbox('Model', ['T5', 'GPT Neo-X'])
114
  left_column, right_column = st.columns(2)
models/.gitkeep ADDED
File without changes
mt5.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding:utf-8
2
+ """
3
+ Filename: mt5.py
4
+ Author: @DvdNss
5
+ Created on 12/30/2021
6
+ """
7
+
8
+ from typing import List
9
+
10
+ from pytorch_lightning import LightningModule
11
+ from transformers import MT5ForConditionalGeneration, AutoTokenizer
12
+
13
+
14
+ class MT5(LightningModule):
15
+ """
16
+ Google MT5 transformer class.
17
+ """
18
+
19
+ def __init__(self, model_name_or_path: str = None):
20
+ """
21
+ Initialize module.
22
+ :param model_name_or_path: model name
23
+ """
24
+
25
+ super().__init__()
26
+
27
+ # Load model and tokenizer
28
+ self.save_hyperparameters()
29
+ self.model = MT5ForConditionalGeneration.from_pretrained(
30
+ model_name_or_path) if model_name_or_path is not None else None
31
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
32
+ use_fast=True) if model_name_or_path is not None else None
33
+
34
+ def forward(self, **inputs):
35
+ """
36
+ Forward inputs.
37
+ :param inputs: dictionary of inputs (input_ids, attention_mask, labels)
38
+ """
39
+
40
+ return self.model(**inputs)
41
+
42
+ def qa(self, batch: List[dict], max_length: int = 512, **kwargs):
43
+ """
44
+ Question answering prediction.
45
+ :param batch: batch of dict {question: q, context: c}
46
+ :param max_length: max length of output
47
+ """
48
+
49
+ # Transform inputs
50
+ inputs = [f"question: {context['question']} context: {context['context']}" for context in batch]
51
+
52
+ # Predict
53
+ outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
54
+
55
+ return outputs
56
+
57
+ def qg(self, batch: List[str] = None, max_length: int = 512, **kwargs):
58
+ """
59
+ Question generation prediction.
60
+ :param batch: batch of context with highlighted elements
61
+ :param max_length: max length of output
62
+ """
63
+
64
+ # Transform inputs
65
+ inputs = [f"generate: {context}" for context in batch]
66
+
67
+ # Predict
68
+ outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
69
+
70
+ return outputs
71
+
72
+ def ae(self, batch: List[str], max_length: int = 512, **kwargs):
73
+ """
74
+ Answer extraction prediction.
75
+ :param batch: list of context
76
+ :param max_length: max length of output
77
+ """
78
+
79
+ # Transform inputs
80
+ inputs = [f"extract: {context}" for context in batch]
81
+
82
+ # Predict
83
+ outputs = self.predict(inputs=inputs, max_length=max_length, **kwargs)
84
+
85
+ return outputs
86
+
87
+ def multitask(self, batch: List[str], max_length: int = 512, **kwargs):
88
+ """
89
+ Answer extraction + question generation + question answering.
90
+ :param batch: list of context
91
+ :param max_length: max length of outputs
92
+ """
93
+
94
+ # Build output dict
95
+ dict_batch = {'context': [context for context in batch], 'answers': [], 'questions': [], 'answers_bis': []}
96
+
97
+ # Iterate over context
98
+ for context in batch:
99
+ answers = self.ae(batch=[context], max_length=max_length, **kwargs)[0]
100
+ answers = answers.split('<sep>')
101
+ answers = [ans.strip() for ans in answers if ans != ' ']
102
+ dict_batch['answers'].append(answers)
103
+ for_qg = [f"{context.replace(ans, f'<hl> {ans} <hl> ')}" for ans in answers]
104
+ questions = self.qg(batch=for_qg, max_length=max_length, **kwargs)
105
+ dict_batch['questions'].append(questions)
106
+ new_answers = self.qa([{'context': context, 'question': question} for question in questions],
107
+ max_length=max_length, **kwargs)
108
+ dict_batch['answers_bis'].append(new_answers)
109
+ return dict_batch
110
+
111
+ def predict(self, inputs, max_length, **kwargs):
112
+ """
113
+ Inference processing.
114
+ :param inputs: list of inputs
115
+ :param max_length: max_length of outputs
116
+ """
117
+
118
+ # Tokenize inputs
119
+ inputs = self.tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True,
120
+ return_tensors="pt")
121
+
122
+ # Retrieve input_ids and attention_mask
123
+ input_ids = inputs.input_ids.to(self.model.device)
124
+ attention_mask = inputs.attention_mask.to(self.model.device)
125
+
126
+ # Predict
127
+ outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_length,
128
+ **kwargs)
129
+
130
+ # Decode outputs
131
+ predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
132
+
133
+ return predictions
question_generator.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gdown as gdown
4
+ import nltk
5
+ import streamlit as st
6
+ import torch
7
+ from transformers import AutoTokenizer
8
+
9
+ from mt5 import MT5
10
+
11
+
12
+ def download_models(ids):
13
+ """
14
+ Download all models.
15
+ :param ids: name and links of models
16
+ :return:
17
+ """
18
+
19
+ # Download sentence tokenizer
20
+ nltk.download('punkt')
21
+
22
+ # Download model from drive if not stored locally
23
+ for key in ids:
24
+ if not os.path.isfile(f"model/{key}.ckpt"):
25
+ url = f"https://drive.google.com/u/0/uc?id={ids[key]}"
26
+ gdown.download(url=url, output=f"model/{key}.ckpt")
27
+
28
+
29
+ @st.cache(allow_output_mutation=True)
30
+ def load_model(model_path):
31
+ """
32
+ Load model and cache it.
33
+ :param model_path: path to model
34
+ :return:
35
+ """
36
+
37
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
+
39
+ # Loading model and tokenizer
40
+ model = MT5.load_from_checkpoint(model_path).eval().to(device)
41
+ model.tokenizer = AutoTokenizer.from_pretrained('tokenizer')
42
+
43
+ return model
44
+
45
+ # elif task == 'Question Answering':
46
+ #
47
+ # # Input area
48
+ # inputs = st.text_area('Context:', value="A few years after the First Crusade, in 1107, the Normans under "
49
+ # "the command of Bohemond, Robert\'s son, landed in Valona and "
50
+ # "besieged Dyrrachium using the most sophisticated military "
51
+ # "equipment of the time, but to no avail. Meanwhile, they occupied "
52
+ # "Petrela, the citadel of Mili at the banks of the river Deabolis, "
53
+ # "Gllavenica (Ballsh), Kanina and Jericho. This time, "
54
+ # "the Albanians sided with the Normans, dissatisfied by the heavy "
55
+ # "taxes the Byzantines had imposed upon them. With their help, "
56
+ # "the Normans secured the Arbanon passes and opened their way to "
57
+ # "Dibra. The lack of supplies, disease and Byzantine resistance "
58
+ # "forced Bohemond to retreat from his campaign and sign a peace "
59
+ # "treaty with the Byzantines in the city of Deabolis. ", max_chars=2048,
60
+ # height=250)
61
+ # question = st.text_input('Question:', value="What forced Bohemond to retreat from his campaign? ")
62
+ #
63
+ # # Prediction
64
+ # with st.spinner('Please wait while the inputs are being processed...'):
65
+ # predictions = model.qa([{'question': question, 'context': inputs}], max_length=512)
66
+ # answer = {question: predictions[0]}
67
+ #
68
+ # # Answer area
69
+ # st.write(answer)
70
+ #
71
+ # elif task == 'Question Generation':
72
+ #
73
+ # # Input area
74
+ # inputs = st.text_area('Context (highlight answers with <hl> tokens): ',
75
+ # value="A few years after the First Crusade, in <hl> 1107 <hl>, the <hl> Normans <hl> under "
76
+ # "the command of <hl> Bohemond <hl>, Robert\'s son, landed in Valona and "
77
+ # "besieged Dyrrachium using the most sophisticated military "
78
+ # "equipment of the time, but to no avail. Meanwhile, they occupied "
79
+ # "Petrela, <hl> the citadel of Mili <hl> at the banks of the river Deabolis, "
80
+ # "Gllavenica (Ballsh), Kanina and Jericho. This time, "
81
+ # "the Albanians sided with the Normans, dissatisfied by the heavy "
82
+ # "taxes the Byzantines had imposed upon them. With their help, "
83
+ # "the Normans secured the Arbanon passes and opened their way to "
84
+ # "Dibra. The <hl> lack of supplies, disease and Byzantine resistance <hl> "
85
+ # "forced Bohemond to retreat from his campaign and sign a peace "
86
+ # "treaty with the Byzantines in the city of Deabolis. ", max_chars=2048,
87
+ # height=250)
88
+ #
89
+ # # Split by highlights
90
+ # hl_index = [i for i in range(len(inputs)) if inputs.startswith('<hl>', i)]
91
+ # contexts = []
92
+ # answers = []
93
+ #
94
+ # # Build a context for each highlight pair
95
+ # for i in range(0, len(hl_index), 2):
96
+ # contexts.append(inputs[:hl_index[i]].replace('<hl>', '') +
97
+ # inputs[hl_index[i]: hl_index[i + 1] + 4] +
98
+ # inputs[hl_index[i + 1] + 4:].replace('<hl>', ''))
99
+ # answers.append(inputs[hl_index[i]: hl_index[i + 1] + 4].replace('<hl>', '').strip())
100
+ #
101
+ # # Prediction
102
+ # with st.spinner('Please wait while the inputs are being processed...'):
103
+ # predictions = model.qg(contexts, max_length=512)
104
+ #
105
+ # # Answer area
106
+ # content = {}
107
+ # for pred, ans in zip(predictions, answers):
108
+ # content[pred] = ans
109
+ # st.write(content)
requirements.txt CHANGED
@@ -3,3 +3,7 @@ torch
3
  tensorflow
4
  streamlit~=1.8.1
5
  sentencepiece==0.1.96
 
 
 
 
3
  tensorflow
4
  streamlit~=1.8.1
5
  sentencepiece==0.1.96
6
+ gdown~=4.3.1
7
+ nltk~=3.7
8
+ pytorch-lightning~=1.5.10
9
+ protobuf~=3.19.4
tokenizer/added_tokens.json ADDED
@@ -0,0 +1 @@
 
1
+ {"<hl>": 250100, "<sep>": 250101}
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
tokenizer/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef78f86560d809067d12bac6c09f19a462cb3af3f54d2b8acbba26e1433125d6
3
+ size 4309802
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "C:\\Users\\dvdna/.cache\\huggingface\\transformers\\685ac0ca8568ec593a48b61b0a3c272beee9bc194a3c7241d15dcadb5f875e53.f76030f3ec1b96a8199b2593390c610e76ca8028ef3d24680000619ffb646276", "name_or_path": "google/mt5-small", "sp_model_kwargs": {}, "tokenizer_class": "T5Tokenizer"}