MesonWarrior commited on
Commit
2c8764e
1 Parent(s): 757fd82

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -23
app.py CHANGED
@@ -5,55 +5,51 @@ from huggingface_hub import login
5
  login(token="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")
6
 
7
  def generate_text(
8
- model,
9
  text,
10
  min_length,
11
  max_length
12
- # do_not_truncate
13
  ):
 
 
 
 
 
 
 
 
14
  pipe = pipeline(
15
  'text-generation',
16
- model='MesonWarrior/gpt2-vk-bugro',
17
- tokenizer='MesonWarrior/gpt2-vk-bugro',
18
  min_length=min_length,
19
  max_length=max_length,
20
- # do_not_truncate=do_not_truncate,
21
  use_auth_token=True
22
  )
23
 
24
- print('generating...')
25
-
26
- output = pipe(text)
27
-
28
- print(output)
29
-
30
- return output[0]['generated_text']
31
 
32
  def interface():
33
  with gr.Row():
34
  with gr.Column():
35
  with gr.Row():
36
  model = gr.Dropdown(
37
- ["Бугро", "Юморески", "Калик"], label="Model", value="Бугро",
38
  )
39
- text = gr.Textbox(lines=7, label="Input text")
40
- output = gr.Textbox(lines=12, label="Output text")
41
  with gr.Row():
42
  with gr.Column():
43
  min_length = gr.Slider(
44
  minimum=0, maximum=128, value=32, step=1,
45
  label="Min Length",
 
46
  )
47
-
48
  max_length = gr.Slider(
49
- minimum=0, maximum=512, value=96, step=1,
50
  label="Max Length",
 
51
  )
52
-
53
- # do_not_truncate = gr.Checkbox(
54
- # True,
55
- # label="Do not truncate"
56
- # )
57
  with gr.Column():
58
  with gr.Row():
59
  generate_btn = gr.Button(
@@ -67,7 +63,6 @@ def interface():
67
  text,
68
  min_length,
69
  max_length,
70
- # do_not_truncate
71
  ],
72
  outputs=output,
73
  )
@@ -77,6 +72,10 @@ with gr.Blocks(
77
  gr.Markdown("""
78
  ## GPT2 VK
79
  Файнтюны модели [ai-forever/rugpt3medium_based_on_gpt2](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) по вашим любимым пабликам ВКонтакте.
 
 
 
 
80
  """)
81
  interface()
82
 
 
5
  login(token="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")
6
 
7
  def generate_text(
8
+ model_name,
9
  text,
10
  min_length,
11
  max_length
 
12
  ):
13
+ models_map = {
14
+ "Юморески": "gpt2-vk-aneki",
15
+ "Калик": "gpt2-vk-kalik",
16
+ "Бугро": "gpt2-vk-bugro"
17
+ }
18
+
19
+ model = "MesonWarrior/" + models_map[model_name]
20
+
21
  pipe = pipeline(
22
  'text-generation',
23
+ model=model,
24
+ tokenizer=model,
25
  min_length=min_length,
26
  max_length=max_length,
 
27
  use_auth_token=True
28
  )
29
 
30
+ return pipe(text)[0]['generated_text']
 
 
 
 
 
 
31
 
32
  def interface():
33
  with gr.Row():
34
  with gr.Column():
35
  with gr.Row():
36
  model = gr.Dropdown(
37
+ ["Юморески", "Калик", "Бугро"], label="Модель (Текст какого паблика генерировать)", value="Бугро",
38
  )
39
+ text = gr.Textbox(lines=7, label="Входной текст", placeholder="Введите текст который продолжит нейросеть...")
40
+ output = gr.Textbox(lines=12, label="Выходной текст", placeholder="Здесь будет текст сгенерированный нейросетью...")
41
  with gr.Row():
42
  with gr.Column():
43
  min_length = gr.Slider(
44
  minimum=0, maximum=128, value=32, step=1,
45
  label="Min Length",
46
+ info="Минимальное количество символов в выходном тексте."
47
  )
 
48
  max_length = gr.Slider(
49
+ minimum=0, maximum=512, value=64, step=1,
50
  label="Max Length",
51
+ info="Максимальное количество символов в выходном тексте."
52
  )
 
 
 
 
 
53
  with gr.Column():
54
  with gr.Row():
55
  generate_btn = gr.Button(
 
63
  text,
64
  min_length,
65
  max_length,
 
66
  ],
67
  outputs=output,
68
  )
 
72
  gr.Markdown("""
73
  ## GPT2 VK
74
  Файнтюны модели [ai-forever/rugpt3medium_based_on_gpt2](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) по вашим любимым пабликам ВКонтакте.
75
+ Паблики представленные в моделях:
76
+ - Мои любимые юморески 🎩
77
+ - Калик) 💨
78
+ - бугро тред 🅰
79
  """)
80
  interface()
81