File size: 1,585 Bytes
529d3a2
 
 
 
 
 
edd84f4
 
529d3a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
from transformers import pipeline


qa_pipeline = pipeline(
    "question-answering",
    model="Den4ikAI/rubert-large-squad",
    tokenizer="Den4ikAI/rubert-large-squad"
)
gen_pipeline = pipeline(
    "text-generation",
    model='ai-forever/rugpt3small_based_on_gpt2',
    tokenizer='ai-forever/rugpt3small_based_on_gpt2'
)


def text_generation(inp):
    inp = inp.split('_')
    if inp[0] == "qa":
        if len(inp) == 3:
            predictions = qa_pipeline({
                'context': inp[1],
                'question': inp[2].split('*')[0]
            })['answer']
            if len(inp[2].split('*')) > 1:
                predictions = inp[2].split('*')[1] + ' ' + predictions
        else:
            return f'Ожидаемая длина запроса для функции "qa": 3. Длина вашего запроса: {len(inp)}. Проверьте, используете ли Вы в качестве разделителя символ "_" и корректно ли Вы указали параметры context и question'
    elif inp[0] == 'tg':
        if len(inp) == 2:
            predictions = gen_pipeline([inp[1]])[0][0]['generated_text'][len(inp[1])+1:]
    else:
        return 'Указанная Вами функция не поддерживается. Выберете между "qa" и "tg"'
    return predictions

demo = gr.Interface(fn=text_generation,
             inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
             outputs=gr.outputs.Textbox(label="Generated Text"),
             )
demo.launch()