Lautaro Cardarelli commited on
Commit
f34392a
·
1 Parent(s): 471e321

add question generatiton

Browse files
Files changed (2) hide show
  1. app.py +88 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,11 +1,97 @@
1
  import gradio as gr
 
 
 
 
 
2
  from transformers import BartForConditionalGeneration
3
  from transformers import BartTokenizer
 
4
 
5
  tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
6
  model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def generate_summary(text):
10
  inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
11
  summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
@@ -14,9 +100,9 @@ def generate_summary(text):
14
 
15
 
16
  def process(text):
17
- return generate_summary(text)
18
 
19
 
20
  textbox = gr.Textbox(label="Pega el text aca:", placeholder="Texto...", lines=15)
21
- demo = gr.Interface(fn=process, inputs=textbox, outputs="text")
22
  demo.launch()
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ from googletrans import Translator
5
+ from transformers import T5Tokenizer
6
+ from transformers import T5ForConditionalGeneration
7
  from transformers import BartForConditionalGeneration
8
  from transformers import BartTokenizer
9
+ from transformers import pipeline
10
 
11
  tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
12
  model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
13
 
14
 
15
+
16
+ from transformers import PreTrainedModel
17
+ from transformers import PreTrainedTokenizer
18
+
19
+ # Question launcher
20
+ class E2EQGPipeline:
21
+ def __init__(
22
+ self,
23
+ model: PreTrainedModel,
24
+ tokenizer: PreTrainedTokenizer
25
+ ):
26
+
27
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+ self.model = model
30
+ self.tokenizer = tokenizer
31
+
32
+ self.model_type = "t5"
33
+
34
+ self.kwargs = {
35
+ "max_length": 256,
36
+ "num_beams": 4,
37
+ "length_penalty": 1.5,
38
+ "no_repeat_ngram_size": 3,
39
+ "early_stopping": True,
40
+ }
41
+
42
+ def generate_questions(self, context: str):
43
+ inputs = self._prepare_inputs_for_e2e_qg(context)
44
+
45
+ outs = self.model.generate(
46
+ input_ids=inputs['input_ids'].to(self.device),
47
+ attention_mask=inputs['attention_mask'].to(self.device),
48
+ **self.kwargs
49
+ )
50
+
51
+ prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True)
52
+
53
+ questions = prediction.split("<sep>")
54
+ questions = [question.strip() for question in questions[:-1]]
55
+ return questions
56
+
57
+ def _prepare_inputs_for_e2e_qg(self, context):
58
+ source_text = f"generate questions: {context}"
59
+
60
+ inputs = self._tokenize([source_text], padding=False)
61
+
62
+ return inputs
63
+
64
+ def _tokenize(
65
+ self,
66
+ inputs,
67
+ padding=True,
68
+ truncation=True,
69
+ add_special_tokens=True,
70
+ max_length=512
71
+ ):
72
+ inputs = self.tokenizer.batch_encode_plus(
73
+ inputs,
74
+ max_length=max_length,
75
+ add_special_tokens=add_special_tokens,
76
+ truncation=truncation,
77
+ padding="max_length" if padding else False,
78
+ pad_to_max_length=padding,
79
+ return_tensors="pt"
80
+ )
81
+
82
+ return inputs
83
+
84
+
85
+ def generate_questions(text):
86
+ qg_model = T5ForConditionalGeneration.from_pretrained('valhalla/t5-base-e2e-qg')
87
+ qg_tokenizer = T5Tokenizer.from_pretrained('valhalla/t5-base-e2e-qg')
88
+ qg_final_model = E2EQGPipeline(qg_model, qg_tokenizer)
89
+ questions = qg_final_model.generate_questions(text)
90
+ translator = Translator()
91
+ translated_questions = [translator.translate(question, dest='es').text for question in questions]
92
+ return translated_questions
93
+
94
+
95
  def generate_summary(text):
96
  inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
97
  summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
 
100
 
101
 
102
  def process(text):
103
+ return generate_summary(text), generate_questions(text)
104
 
105
 
106
  textbox = gr.Textbox(label="Pega el text aca:", placeholder="Texto...", lines=15)
107
+ demo = gr.Interface(fn=process, inputs=textbox, outputs=["text", "text"])
108
  demo.launch()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  transformers
2
  torch
3
- accelerate
 
 
1
  transformers
2
  torch
3
+ accelerate
4
+ googletrans