derat0r's picture
Update app.py
edd84f4
import gradio as gr
from transformers import pipeline
qa_pipeline = pipeline(
"question-answering",
model="Den4ikAI/rubert-large-squad",
tokenizer="Den4ikAI/rubert-large-squad"
)
gen_pipeline = pipeline(
"text-generation",
model='ai-forever/rugpt3small_based_on_gpt2',
tokenizer='ai-forever/rugpt3small_based_on_gpt2'
)
def text_generation(inp):
inp = inp.split('_')
if inp[0] == "qa":
if len(inp) == 3:
predictions = qa_pipeline({
'context': inp[1],
'question': inp[2].split('*')[0]
})['answer']
if len(inp[2].split('*')) > 1:
predictions = inp[2].split('*')[1] + ' ' + predictions
else:
return f'Ожидаемая длина запроса для функции "qa": 3. Длина вашего запроса: {len(inp)}. Проверьте, используете ли Вы в качестве разделителя символ "_" и корректно ли Вы указали параметры context и question'
elif inp[0] == 'tg':
if len(inp) == 2:
predictions = gen_pipeline([inp[1]])[0][0]['generated_text'][len(inp[1])+1:]
else:
return 'Указанная Вами функция не поддерживается. Выберете между "qa" и "tg"'
return predictions
demo = gr.Interface(fn=text_generation,
inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
outputs=gr.outputs.Textbox(label="Generated Text"),
)
demo.launch()