MesonWarrior commited on
Commit
71b4b2d
1 Parent(s): 9eb24f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ def generate_text(
5
+ model,
6
+ text,
7
+ min_length,
8
+ max_length,
9
+ do_not_truncate,
10
+ ):
11
+ pipe = pipeline(
12
+ 'text-generation',
13
+ model='MesonWarrior/gpt2-bugro',
14
+ tokenizer='MesonWarrior/gpt2-bugro',
15
+ min_length=min_length,
16
+ max_length=max_length,
17
+ use_auth_token="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"
18
+ )
19
+
20
+ return pipe(text)[0]['generated_text']
21
+
22
+ def interface():
23
+ with gr.Row():
24
+ with gr.Column():
25
+ with gr.Row():
26
+ model = gr.Dropdown(
27
+ ["Бугро", "Юморески", "Калик"], label="Model", value="Бугро",
28
+ )
29
+ text = gr.Textbox(lines=7, label="Input text")
30
+ output = gr.Textbox(lines=12, label="Output text")
31
+ with gr.Row():
32
+ with gr.Column():
33
+ min_length = gr.Slider(
34
+ minimum=0, maximum=128, value=32, step=1,
35
+ label="Min Length",
36
+ )
37
+
38
+ max_length = gr.Slider(
39
+ minimum=0, maximum=512, value=96, step=1,
40
+ label="Max Length",
41
+ )
42
+
43
+ do_not_truncate = gr.Checkbox(
44
+ True,
45
+ label="Do not truncate"
46
+ )
47
+ with gr.Column():
48
+ with gr.Row():
49
+ generate_btn = gr.Button(
50
+ "Generate", variant="primary", label="Generate",
51
+ )
52
+
53
+ generate_btn.click(
54
+ fn=generate_text,
55
+ inputs=[
56
+ model,
57
+ text,
58
+ min_length,
59
+ max_length,
60
+ do_not_truncate
61
+ ],
62
+ outputs=output,
63
+ )
64
+
65
+ with gr.Blocks(
66
+ title="GPT2 VK") as demo:
67
+ gr.Markdown("""
68
+ ## GPT2 VK
69
+ Файнтюны модели [ai-forever/rugpt3medium_based_on_gpt2](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) по вашим любимым пабликам ВКонтакте.
70
+ """)
71
+ interface()
72
+
73
+ demo.queue().launch()