File size: 4,015 Bytes
71b4b2d
 
 
 
2c8764e
71b4b2d
 
0ec8cbc
 
 
 
71b4b2d
2c8764e
0ec8cbc
b20189a
 
2c8764e
 
 
 
71b4b2d
 
2c8764e
 
71b4b2d
92149d8
71b4b2d
 
1e9527f
71b4b2d
 
 
 
 
 
b20189a
4e1f078
a8cfe5b
71b4b2d
2c8764e
 
71b4b2d
 
 
e834122
4e1f078
2c8764e
71b4b2d
 
e834122
4e1f078
2c8764e
71b4b2d
0ec8cbc
 
 
 
 
 
ebaabfb
0ec8cbc
 
 
 
 
 
71b4b2d
 
 
4e1f078
71b4b2d
 
 
 
 
 
 
 
 
0ec8cbc
 
 
71b4b2d
 
 
 
 
 
 
3059068
a8cfe5b
3059068
7dea28d
 
 
71b4b2d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from transformers import pipeline

def generate_text(
    model_name,
    text,
    min_length,
    max_length,
    temperature,
    top_k,
    top_p
):
    models_map = {
        "Мои любимые юморески": "gpt2-vk-aneki",
        "бугро тред": "gpt2-vk-bugro",
        "Калик)": "gpt2-vk-kalik"
    }

    model = "MesonWarrior/" + models_map[model_name]

    pipe = pipeline(
        'text-generation',
        model=model,
        tokenizer=model,
        min_length=min_length,
        max_length=max_length
    )

    return pipe(text, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True)[0]['generated_text']

def interface():
    with gr.Row():
        with gr.Column():
            with gr.Row():
                model = gr.Dropdown(
                    ["Мои любимые юморески", "бугро тред", "Калик)"],
                    label="Модель (Текст какого паблика генерировать)",
                    value="Мои любимые юморески",
                )
            text = gr.Textbox(lines=7, label="Входной текст", placeholder="Введите текст который продолжит нейросеть...")
        output = gr.Textbox(lines=12, label="Выходной текст", placeholder="Здесь будет текст сгенерированный нейросетью...")
    with gr.Row():
        with gr.Column():
            min_length = gr.Slider(
                minimum=0, maximum=100, value=32, step=1,
                label="Минимальная длина",
                info="Минимальное количество символов в выходном тексте."
            )
            max_length = gr.Slider(
                minimum=0, maximum=200, value=64, step=1,
                label="Максимальная длина",
                info="Максимальное количество символов в выходном тексте."
            )
            temperature = gr.Slider(
                minimum=0.05, maximum=1.95, value=0.9, step=0.05,
                label="Температура",
                info="Чем выше тем рандомнее, чем ниже тем больше повторений."
            )
            top_k = gr.Slider(
                minimum=0, maximum=100, value=50, step=0.05,
                label="Top K",
            )
            top_p = gr.Slider(
                minimum=0, maximum=1, value=0.9, step=0.05,
                label="Top P",
            )
        with gr.Column():
            with gr.Row():
                generate_btn = gr.Button(
                    "Сгенерировать", variant="primary", label="Generate",
                )

        generate_btn.click(
            fn=generate_text,
            inputs=[
                model,
                text,
                min_length,
                max_length,
                temperature,
                top_k,
                top_p
            ],
            outputs=output,
        )

with gr.Blocks(
    title="GPT2 VK") as demo:
        gr.Markdown("""
        # GPT2 VK
        Файнтюны [этой](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) модели по вашим любимым пабликам ВКонтакте.
        #### Паблики представленные в моделях:
        - [Мои любимые юморески 🎩](https://huggingface.co/MesonWarrior/gpt2-vk-aneki)
        - [бугро тред 💥](https://huggingface.co/MesonWarrior/gpt2-vk-bugro)
        - [Калик) 🍏🍎💨](https://huggingface.co/MesonWarrior/gpt2-vk-kalik) <sub><sup>(Обучено на спорном датасете из постов и комментариев, надо бы переобучить на данных получше)</sup></sub>
        """)
        interface()

demo.queue().launch()