Jorj2064 commited on
Commit
83d6417
1 Parent(s): 1f45a41

Upload Gradio_GPT_bot.ipynb

Browse files
Files changed (1) hide show
  1. Gradio_GPT_bot.ipynb +459 -0
Gradio_GPT_bot.ipynb ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "collapsed_sections": [
8
+ "j4fr7yW8YPav",
9
+ "E_ZbiAPAYJEF",
10
+ "MgzsPmnx90zW",
11
+ "o_qprIWn6bY5",
12
+ "quC2dkLoJqoG",
13
+ "z41veiFa62k2",
14
+ "ywqlE6TaCrLq",
15
+ "5ZPgocB8XiV8",
16
+ "BA2tieoUVQzp",
17
+ "A5XAUFILVZlU",
18
+ "F3ztWSyGWsQt",
19
+ "BfWA10d0Vf7J",
20
+ "9BiPlUf6YOzO"
21
+ ]
22
+ },
23
+ "kernelspec": {
24
+ "name": "python3",
25
+ "display_name": "Python 3"
26
+ },
27
+ "language_info": {
28
+ "name": "python"
29
+ }
30
+ },
31
+ "cells": [
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# Imports"
36
+ ],
37
+ "metadata": {
38
+ "id": "j4fr7yW8YPav"
39
+ }
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "source": [
44
+ "%%capture\n",
45
+ "# установка gradio для написания веб интерефейса\n",
46
+ "# установка transformers для использования языковых моделей с платформы HF\n",
47
+ "\n",
48
+ "!pip install gradio transformers"
49
+ ],
50
+ "metadata": {
51
+ "id": "bUw0V7CaLYou"
52
+ },
53
+ "execution_count": 11,
54
+ "outputs": []
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "source": [
59
+ "import random\n",
60
+ "import time\n",
61
+ "from typing import List, Dict, Tuple, Union\n",
62
+ "from IPython import display\n",
63
+ "\n",
64
+ "import gradio as gr\n",
65
+ "import torch\n",
66
+ "import transformers"
67
+ ],
68
+ "metadata": {
69
+ "id": "QQfY77Hx9by9"
70
+ },
71
+ "execution_count": 12,
72
+ "outputs": []
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "source": [
77
+ "# версии библиотек\n",
78
+ "!pip list | grep -P \"torch|transformers|gradio\""
79
+ ],
80
+ "metadata": {
81
+ "id": "2ev61KO7ne-A",
82
+ "outputId": "3ecf029f-b1b1-448a-e252-b24a08be137e",
83
+ "colab": {
84
+ "base_uri": "https://localhost:8080/"
85
+ }
86
+ },
87
+ "execution_count": 13,
88
+ "outputs": [
89
+ {
90
+ "output_type": "stream",
91
+ "name": "stdout",
92
+ "text": [
93
+ "gradio 4.4.1\n",
94
+ "gradio_client 0.7.0\n",
95
+ "torch 2.1.0+cu118\n",
96
+ "torchaudio 2.1.0+cu118\n",
97
+ "torchdata 0.7.0\n",
98
+ "torchsummary 1.5.1\n",
99
+ "torchtext 0.16.0\n",
100
+ "torchvision 0.16.0+cu118\n",
101
+ "transformers 4.35.2\n"
102
+ ]
103
+ }
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "source": [
109
+ "# Tokenizer and Model"
110
+ ],
111
+ "metadata": {
112
+ "id": "E_ZbiAPAYJEF"
113
+ }
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "source": [
118
+ "**Инициализация модели** \n",
119
+ "Страница модели \n",
120
+ "https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2"
121
+ ],
122
+ "metadata": {
123
+ "id": "mc1bW8QXYZ7z"
124
+ }
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "source": [
129
+ "from transformers import pipeline\n",
130
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig\n",
131
+ "\n",
132
+ "# # инициализация через pipeline\n",
133
+ "# model_name = \"ai-forever/rugpt3medium_based_on_gpt2\"\n",
134
+ "# pipe = pipeline(\"text-generation\", model=model_name)\n",
135
+ "# sample = pipeline('test test', pad_token_id=generator.tokenizer.eos_token_id)\n",
136
+ "\n",
137
+ "model_name = \"ai-forever/rugpt3medium_based_on_gpt2\"\n",
138
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
139
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
140
+ "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
141
+ "\n",
142
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
143
+ "model = model.to(DEVICE)"
144
+ ],
145
+ "metadata": {
146
+ "id": "AGV9Y2fJHVI1"
147
+ },
148
+ "execution_count": 14,
149
+ "outputs": []
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "source": [
154
+ "Генерация текста"
155
+ ],
156
+ "metadata": {
157
+ "id": "D7U6M4oyZRqq"
158
+ }
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "source": [
163
+ "%%time\n",
164
+ "\n",
165
+ "# токенизация текста в индексы токенов и маски внимания\n",
166
+ "text_promt = 'меня засосала опасная трясина '\n",
167
+ "inputs = tokenizer(text_promt, return_tensors=\"pt\").to(DEVICE)\n",
168
+ "\n",
169
+ "# конфиг словарь для генерации текста\n",
170
+ "gen_config_dict = dict(\n",
171
+ " do_sample=True, # делать ли случайное семплирование с параметрами ниже (если False то выскочит предупреждение)\n",
172
+ " max_new_tokens=30, # сколько максимум новых токенов надо генерировать\n",
173
+ " top_k=50, # семплировать только из top_k самых вероятных токенов\n",
174
+ " top_p=0.9, # семплировать только из токенов сумма вероятностей которых не больше top_p\n",
175
+ " temperature=2.0, # температура для софтмакса\n",
176
+ " num_beams=3, # параметр алгоритма Beam search\n",
177
+ " repetition_penalty=2.0, # штраф за то что модель повторяется\n",
178
+ " pad_token_id=tokenizer.pad_token_id, # установить токен pad чтобы не было предупреждения\n",
179
+ " )\n",
180
+ "# конфиг для генерации текста из словаря\n",
181
+ "generation_config = GenerationConfig(**gen_config_dict)\n",
182
+ "\n",
183
+ "# генерация текста (индексы токенов)\n",
184
+ "output = model.generate(**inputs, generation_config=generation_config)\n",
185
+ "\n",
186
+ "# сопостовление идексам токенов слов из словаря токенайзера\n",
187
+ "generated_text = tokenizer.decode(output[0], skip_special_tokens=False)\n",
188
+ "\n",
189
+ "# удаление исходного промта из ответа потому что он тоже возвращается\n",
190
+ "generated_text = generated_text[len(text_promt):]\n",
191
+ "generated_text"
192
+ ],
193
+ "metadata": {
194
+ "id": "pSWP_0d4IE9_",
195
+ "colab": {
196
+ "base_uri": "https://localhost:8080/",
197
+ "height": 70
198
+ },
199
+ "outputId": "2841a9d8-387d-42f2-e1b1-1fcec48b444d"
200
+ },
201
+ "execution_count": 15,
202
+ "outputs": [
203
+ {
204
+ "output_type": "stream",
205
+ "name": "stdout",
206
+ "text": [
207
+ "CPU times: user 11.4 s, sys: 27.2 ms, total: 11.5 s\n",
208
+ "Wall time: 15.6 s\n"
209
+ ]
210
+ },
211
+ {
212
+ "output_type": "execute_result",
213
+ "data": {
214
+ "text/plain": [
215
+ "'\\nГрязь, гниль, вонь, вонище, хлябь, тухлятина, плесень, чавка'"
216
+ ],
217
+ "application/vnd.google.colaboratory.intrinsic+json": {
218
+ "type": "string"
219
+ }
220
+ },
221
+ "metadata": {},
222
+ "execution_count": 15
223
+ }
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "markdown",
228
+ "source": [
229
+ "Функция для генерации"
230
+ ],
231
+ "metadata": {
232
+ "id": "kXsbzBUFayk7"
233
+ }
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "source": [
238
+ "# функция принимает текстовый запрос и словарь параметров генерации\n",
239
+ "def generate_text(text_promt: str, gen_config_dict: Dict[str, Union[float, int]]) -> str:\n",
240
+ " inputs = tokenizer(text_promt, return_tensors=\"pt\").to(DEVICE)\n",
241
+ " generation_config = GenerationConfig(**gen_config_dict)\n",
242
+ " output = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, generation_config=generation_config)\n",
243
+ " generated_text = tokenizer.decode(output[0], skip_special_tokens=True)\n",
244
+ " generated_text = generated_text[len(text_promt):]\n",
245
+ " return generated_text\n",
246
+ "\n",
247
+ "\n",
248
+ "# конфиг словарь для генерации текста\n",
249
+ "gen_config_dict = dict(\n",
250
+ " do_sample=True, # делать ли случайное семплирование с параметрами ниже (если False то выскочит предупреждение)\n",
251
+ " max_new_tokens=20, # сколько максимум новых токенов надо генерировать\n",
252
+ " top_k=50, # семплировать только из top_k самых вероятных токенов\n",
253
+ " top_p=0.9, # семплировать только из токенов сумма вероятностей которых не больше top_p\n",
254
+ " temperature=2.0, # температура для софтмакса\n",
255
+ " num_beams=3, # параметр алгоритма Beam search\n",
256
+ " repetition_penalty=2.0, # штраф за то что модель повторяется\n",
257
+ " pad_token_id=tokenizer.pad_token_id, # установить токен pad чтобы не было предупреждения\n",
258
+ " )\n",
259
+ "\n",
260
+ "\n",
261
+ "text_promt = 'в небесной канцелярии выходной'\n",
262
+ "generated_text = generate_text(text_promt, gen_config_dict)\n",
263
+ "generated_text"
264
+ ],
265
+ "metadata": {
266
+ "colab": {
267
+ "base_uri": "https://localhost:8080/",
268
+ "height": 36
269
+ },
270
+ "id": "HMg9ckjcZ1hn",
271
+ "outputId": "0640b51e-50f6-4f37-f7ee-1d5b2eeae30d"
272
+ },
273
+ "execution_count": 16,
274
+ "outputs": [
275
+ {
276
+ "output_type": "execute_result",
277
+ "data": {
278
+ "text/plain": [
279
+ "', и только после этого можно заниматься этим вопросом. Но это уже совсем другие истории, которые могут'"
280
+ ],
281
+ "application/vnd.google.colaboratory.intrinsic+json": {
282
+ "type": "string"
283
+ }
284
+ },
285
+ "metadata": {},
286
+ "execution_count": 16
287
+ }
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "markdown",
292
+ "source": [
293
+ "# Gradio App"
294
+ ],
295
+ "metadata": {
296
+ "id": "MgzsPmnx90zW"
297
+ }
298
+ },
299
+ {
300
+ "cell_type": "markdown",
301
+ "source": [
302
+ "## Новый интерфейс Чат-бота"
303
+ ],
304
+ "metadata": {
305
+ "id": "z41veiFa62k2"
306
+ }
307
+ },
308
+ {
309
+ "cell_type": "markdown",
310
+ "source": [
311
+ "Вариант с системным промтом и разными входными аргументами и настройками"
312
+ ],
313
+ "metadata": {
314
+ "id": "mxsjlHD0_Xxn"
315
+ }
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "source": [
320
+ "import gradio as gr\n",
321
+ "\n",
322
+ "# функция будет вызыватся при нажатии на Enter в окошке вовода текста\n",
323
+ "# кроме обычных аргументов - сообщения пользователя и истории - принимает еще параметры для конфига генерации\n",
324
+ "def generate(message, history, *components):\n",
325
+ " # print(system_promt)\n",
326
+ " # обновление словаря новыми агрументами и создание конфига генерации текста\n",
327
+ " gen_config.update(dict(zip(gen_config.keys(), components)))\n",
328
+ " gen_config['top_k'] = int(gen_config['top_k'])\n",
329
+ " gen_config['num_beams'] = int(gen_config['num_beams'])\n",
330
+ " generation_config = GenerationConfig(**gen_config)\n",
331
+ "\n",
332
+ " # добавить системный промт в начало запроса и сгенерировать текст\n",
333
+ " promt = message\n",
334
+ " inputs = tokenizer(promt, return_tensors=\"pt\").to(DEVICE)\n",
335
+ " output = model.generate(**inputs, generation_config=generation_config, pad_token_id=tokenizer.eos_token_id)\n",
336
+ " generated_text = tokenizer.decode(output[0], skip_special_tokens=True)\n",
337
+ " generated_text = generated_text[len(promt):]\n",
338
+ " # имитация набора сообщения чат-ботом (посимвольня генерация через yield в цикле)\n",
339
+ " for i in range(len(generated_text)):\n",
340
+ " time.sleep(0.05) # задержка с которой бот вводит текст\n",
341
+ " yield generated_text[:i+1]\n",
342
+ "\n",
343
+ "# словарь для конфига генерации текста\n",
344
+ "gen_config = dict(\n",
345
+ " do_sample=False,\n",
346
+ " max_length=60,\n",
347
+ " top_k=50,\n",
348
+ " top_p=0.9,\n",
349
+ " temperature=2.0,\n",
350
+ " num_beams=3,\n",
351
+ " repetition_penalty=2.0,\n",
352
+ " )\n",
353
+ "\n",
354
+ "\n",
355
+ "# компоненты настройки конфига генерации текста\n",
356
+ "components = [\n",
357
+ " gr.Checkbox(label=\"do_sample\", value=gen_config[\"do_sample\"]),\n",
358
+ " gr.Slider(label=\"max_length\", value=gen_config[\"max_length\"], minimum=1, maximum=300, step=10),\n",
359
+ " gr.Number(label=\"top_k\", value=gen_config[\"top_k\"], minimum=0, maximum=50, step=10),\n",
360
+ " gr.Number(label=\"top_p\", value=gen_config[\"top_p\"], minimum=0, maximum=1, step=0.1),\n",
361
+ " gr.Number(label=\"temperature\", value=gen_config[\"temperature\"], minimum=0, maximum=10, step=0.1),\n",
362
+ " gr.Number(label=\"num_beams\", value=gen_config[\"num_beams\"], minimum=0, maximum=5, step=1),\n",
363
+ " gr.Number(label=\"repetition_penalty\", value=gen_config[\"repetition_penalty\"], minimum=0, maximum=5, step=0.1),\n",
364
+ " ]\n",
365
+ "\n",
366
+ "# при нажатии Enter в чате будет вызыватся функция generate\n",
367
+ "interface = gr.ChatInterface(\n",
368
+ " generate,\n",
369
+ " chatbot=gr.Chatbot(height=300), # вход для функции generate: message\n",
370
+ " textbox=gr.Textbox(placeholder=\"Задайте любой вопрос\", container=False, scale=2), # выходной бокс для текста\n",
371
+ "\n",
372
+ " # дополнительные входы для функции generate (*components)\n",
373
+ " additional_inputs=components,\n",
374
+ "\n",
375
+ " # настройки оформления\n",
376
+ " title=\"Чат-бот T10\", # название страницы\n",
377
+ " description=\"Окно переписки с ботом\", # описание окошка переписки\n",
378
+ " theme=\"Glass\", # темы: Glass, Monochrome, Soft\n",
379
+ " # examples=[\"Hello\", \"Am I cool?\", \"Are tomatoes vegetables?\"], # примеры должны быть множественными если аргументов много\n",
380
+ " # cache_examples=True, # кешировать примеры\n",
381
+ "\n",
382
+ " # дполнительные кнопки (если не нужна какая либо кнопка ставим None)\n",
383
+ " submit_btn='Отправить',\n",
384
+ " retry_btn='Повторить вопрос',\n",
385
+ " undo_btn=\"Удалить предыдущий вопрос\",\n",
386
+ " clear_btn=\"Очистить историю\",\n",
387
+ ")\n",
388
+ "\n",
389
+ "# запуск приложения\n",
390
+ "interface.queue().launch(debug=True)"
391
+ ],
392
+ "metadata": {
393
+ "colab": {
394
+ "base_uri": "https://localhost:8080/",
395
+ "height": 819
396
+ },
397
+ "id": "8c9BhO7SudbA",
398
+ "outputId": "d1d67aa5-04a6-4129-c7ed-a2e60fe7e45f"
399
+ },
400
+ "execution_count": 17,
401
+ "outputs": [
402
+ {
403
+ "output_type": "stream",
404
+ "name": "stdout",
405
+ "text": [
406
+ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
407
+ "\n",
408
+ "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
409
+ "Running on public URL: https://82df78dd1728af38f1.gradio.live\n",
410
+ "\n",
411
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
412
+ ]
413
+ },
414
+ {
415
+ "output_type": "display_data",
416
+ "data": {
417
+ "text/plain": [
418
+ "<IPython.core.display.HTML object>"
419
+ ],
420
+ "text/html": [
421
+ "<div><iframe src=\"https://82df78dd1728af38f1.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
422
+ ]
423
+ },
424
+ "metadata": {}
425
+ },
426
+ {
427
+ "output_type": "stream",
428
+ "name": "stderr",
429
+ "text": [
430
+ "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `2.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
431
+ " warnings.warn(\n",
432
+ "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:386: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
433
+ " warnings.warn(\n",
434
+ "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `2.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
435
+ " warnings.warn(\n",
436
+ "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:386: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
437
+ " warnings.warn(\n"
438
+ ]
439
+ },
440
+ {
441
+ "output_type": "stream",
442
+ "name": "stdout",
443
+ "text": [
444
+ "Keyboard interruption in main thread... closing server.\n",
445
+ "Killing tunnel 127.0.0.1:7861 <> https://82df78dd1728af38f1.gradio.live\n"
446
+ ]
447
+ },
448
+ {
449
+ "output_type": "execute_result",
450
+ "data": {
451
+ "text/plain": []
452
+ },
453
+ "metadata": {},
454
+ "execution_count": 17
455
+ }
456
+ ]
457
+ }
458
+ ]
459
+ }