T5 / app.py
hf-dongpyo's picture
Update app.py
d312a64
# from transformers import AutoModelWithLMHead, AutoTokenizer
# Translate
from transformers import T5ForConditionalGeneration, T5Tokenizer
import gradio as grad
# make a question
# text2text_tkn = AutoTokenizer.from_pretrained('mrm8488/t5-base-finetuned-question-generation-ap')
# mdl = AutoModelWithLMHead.from_pretrained('mrm8488/t5-base-finetuned-question-generation-ap')
# summarize
# text2text_tkn = AutoTokenizer.from_pretrained('deep-learning-analytics/wikihow-t5-small')
# mdl = AutoModelWithLMHead.from_pretrained('deep-learning-analytics/wikihow-t5-small')
# translate, sentiment
text2text_tkn = T5Tokenizer.from_pretrained('t5-small')
mdl = T5ForConditionalGeneration.from_pretrained('t5-small')
def text2text(context, answer):
input_text = "answer: %s context: %s </s>" % (answer, context)
features = text2text_tkn([input_text], return_tensors = 'pt')
output = mdl.generate(
input_ids = features['input_ids'],
attention_mask = features['attention_mask'],
max_length = 64
)
response = text2text_tkn.decode(output[0])
return response
def text2text_summary(para):
initial_txt = para.strip().replace("\n", "")
tkn_text = text2text_tkn.encode(initial_txt, return_tensors = 'pt')
tkn_ids = mdl.generate(
tkn_text,
max_length = 250,
num_beams = 5,
repetition_penalty = 2.5,
early_stopping = True
)
response = text2text_tkn.decode(tkn_ids[0], skip_special_tokens = True)
return response
def text2text_translate(text):
inp = "translate English to German:: " + text
enc = text2text_tkn(inp, return_tensors = 'pt')
tokens = mdl.generate(**enc)
response = text2text_tkn.batch_decode(tokens)
return response
def text2text_sentiment(text):
inp = "sst2 sentence: " + text
enc = text2text_tkn(inp, return_tensors = 'pt')
tokens = mdl.generate(**enc)
response = text2text_tkn.batch_decode(tokens)
return response
def text2text_acceptable_sentence(text):
inp = "cola sentence: " + text
enc = text2text_tkn(inp, return_tensors = 'pt')
tokens = mdl.generate(**enc)
response = text2text_tkn.batch_decode(tokens)
return response
def text2text_paraphrase(sentence1, sentence2):
inp1 = "mrpc sentence1: " + sentence1
inp2 = "sentence2: " + sentence2
combined_inp = inp1 + " " + inp2
enc = text2text_tkn(combined_inp, return_tensors = 'pt')
tokens = mdl.generate(**enc)
response = text2text_tkn.batch_decode(tokens)
return response
def text2text_deductible(sentence1, sentence2):
inp1 = "rte sentence1: " + sentence1
inp2 = "sentence2: " + sentence2
combined_inpu = inp1 + " " + inp2
enc = text2text_tkn(combined_inp, return_tensors = 'pt')
tokens = mdl.generate(**enc)
response = text2text_tkn.batch_decode(tokens)
return response
# question
# context = grad.Textbox(lines = 10, label = 'English', placeholder = 'Context')
# ans = grad.Textbox(lines = 1, label = 'Answer')
# out = grad.Textbox(lines = 1, label = 'Generated Question')
# summary
# para = grad.Textbox(lines = 10, label = 'Paragraph', placeholder = 'Copy paragraph')
# out = grad.Textbox(lines = 1, label = 'Summary')
# tranlate
# para = grad.Textbox(lines = 1, label = 'English Text', placeholder = 'Text in English')
# out = grad.Textbox(lines = 1, label = 'German Translation')
# sentiment
# para = grad.Textbox(lines = 1, label = 'English Text', placeholder = 'Text in English')
# out = grad.Textbox(lines = 1, label = 'Sentiment')
# # grammatical acceptance
# para = grad.Textbox(lines = 1, label = 'English Text', placeholder = 'Text in English')
# out = grad.Textbox(lines = 1, label = 'Whether the sentence is acceptable or not')
# # paraphrase
# sent1 = grad.Textbox(lines = 1, label = 'Sentence1', placeholder = 'Text in English')
# sent2 = grad.Textbox(lines = 1, label = 'Sentence2', placeholder = 'Text in English')
# out = grad.Textbox(lines = 1, label = 'Paraphrase')
# deduction
sent1 = grad.Textbox(lines = 1, label = 'Sentence1', placeholder = 'Text in English')
sent2 = grad.Textbox(lines = 1, label = 'Sentence2', placeholder = 'Text in English')
out = grad.Textbox(lines = 1, label = 'Deduction')
grad.Interface(
# text2text,
# inputs = [context, ans],
# text2text_summary,
# text2text_translate,
# text2text_sentiment,
# text2text_acceptable_sentence,
# text2text_paraphrase,
text2text_deductible,
# inputs = para,
inputs = [sent1, sent2],
outputs = out
).launch()