Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from googletrans import Translator | |
from transformers import T5Tokenizer | |
from transformers import T5ForConditionalGeneration | |
from transformers import BartForConditionalGeneration | |
from transformers import BartTokenizer | |
from transformers import PreTrainedModel | |
from transformers import PreTrainedTokenizer | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForSeq2SeqLM | |
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') | |
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') | |
# Question launcher | |
class E2EQGPipeline: | |
def __init__( | |
self, | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer | |
): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = model | |
self.tokenizer = tokenizer | |
self.model_type = "t5" | |
self.kwargs = { | |
"max_length": 256, | |
"num_beams": 4, | |
"length_penalty": 1.5, | |
"no_repeat_ngram_size": 3, | |
"early_stopping": True, | |
} | |
def generate_questions(self, context: str): | |
inputs = self._prepare_inputs_for_e2e_qg(context) | |
outs = self.model.generate( | |
input_ids=inputs['input_ids'].to(self.device), | |
attention_mask=inputs['attention_mask'].to(self.device), | |
**self.kwargs | |
) | |
prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True) | |
questions = prediction.split("<sep>") | |
questions = [question.strip() for question in questions[:-1]] | |
return questions | |
def _prepare_inputs_for_e2e_qg(self, context): | |
source_text = f"generate questions: {context}" | |
inputs = self._tokenize([source_text], padding=False) | |
return inputs | |
def _tokenize( | |
self, | |
inputs, | |
padding=True, | |
truncation=True, | |
add_special_tokens=True, | |
max_length=512 | |
): | |
inputs = self.tokenizer.batch_encode_plus( | |
inputs, | |
max_length=max_length, | |
add_special_tokens=add_special_tokens, | |
truncation=truncation, | |
padding="max_length" if padding else False, | |
pad_to_max_length=padding, | |
return_tensors="pt" | |
) | |
return inputs | |
def generate_questions(text): | |
qg_model = T5ForConditionalGeneration.from_pretrained('valhalla/t5-base-e2e-qg') | |
qg_tokenizer = T5Tokenizer.from_pretrained('valhalla/t5-base-e2e-qg') | |
qg_final_model = E2EQGPipeline(qg_model, qg_tokenizer) | |
questions = qg_final_model.generate_questions(text) | |
translator = Translator() | |
translated_questions = [translator.translate(question, dest='es').text for question in questions] | |
return translated_questions | |
def generate_summary(text): | |
inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True) | |
summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
# QA | |
# Cargar el modelo para preguntas y respuestas | |
qa_model_name = "MaRiOrOsSi/t5-base-finetuned-question-answering" | |
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) | |
qa_model = AutoModelForSeq2SeqLM.from_pretrained(qa_model_name) | |
def generate_question_response(question, context): | |
# Formar el input para el modelo de preguntas y respuestas | |
input_text = f"question: {question} context: {context}" | |
encoded_input = qa_tokenizer(input_text, return_tensors='pt', max_length=1024, truncation=True) | |
output = qa_model.generate(input_ids=encoded_input['input_ids'], attention_mask=encoded_input['attention_mask']) | |
response_en = qa_tokenizer.decode(output[0], skip_special_tokens=True) | |
translator = Translator() | |
translated_response = translator.translate(response_en, dest='es').text | |
return f'Respuesta: {translated_response}' | |
class SummarizerAndQA: | |
def __init__(self): | |
self.input_text = '' | |
self.question = '' | |
self.summary = '' | |
self.study_generated_questions = '' | |
self.question_response = '' | |
def process(self, text, question): | |
if text != self.input_text: | |
self.input_text = text | |
self.summary = generate_summary(text) | |
self.study_generated_questions = generate_questions(text) | |
if question != self.question and text != '': | |
self.question = question | |
self.question_response = generate_question_response(question, text) | |
return self.summary, self.study_generated_questions, self.question_response | |
summarizer_and_qa = SummarizerAndQA() | |
textbox_input = gr.Textbox(label="Pega el text aca:", placeholder="Texto...", lines=15) | |
question_input = gr.Textbox(label="Pregunta sobre el texto aca:", placeholder="Mensaje...", lines=15) | |
summary_output = gr.Textbox(label="Resumen", lines=15) | |
questions_output = gr.Textbox(label="Preguntas de guia generadas", lines=5) | |
questions_response = gr.Textbox(label="Respuestas", lines=5) | |
demo = gr.Interface(fn=summarizer_and_qa.process, inputs=[textbox_input, question_input], outputs=[summary_output, questions_output, questions_response], allow_flagging="never") | |
demo.launch() | |