{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [ "j4fr7yW8YPav", "E_ZbiAPAYJEF", "MgzsPmnx90zW", "o_qprIWn6bY5", "quC2dkLoJqoG", "z41veiFa62k2", "ywqlE6TaCrLq", "5ZPgocB8XiV8", "BA2tieoUVQzp", "A5XAUFILVZlU", "F3ztWSyGWsQt", "BfWA10d0Vf7J", "9BiPlUf6YOzO" ] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# Imports" ], "metadata": { "id": "j4fr7yW8YPav" } }, { "cell_type": "code", "source": [ "%%capture\n", "# установка gradio для написания веб интерефейса\n", "# установка transformers для использования языковых моделей с платформы HF\n", "\n", "!pip install gradio transformers" ], "metadata": { "id": "bUw0V7CaLYou" }, "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "import random\n", "import time\n", "from typing import List, Dict, Tuple, Union\n", "from IPython import display\n", "\n", "import gradio as gr\n", "import torch\n", "import transformers" ], "metadata": { "id": "QQfY77Hx9by9" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "# версии библиотек\n", "!pip list | grep -P \"torch|transformers|gradio\"" ], "metadata": { "id": "2ev61KO7ne-A", "outputId": "3ecf029f-b1b1-448a-e252-b24a08be137e", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": 13, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "gradio 4.4.1\n", "gradio_client 0.7.0\n", "torch 2.1.0+cu118\n", "torchaudio 2.1.0+cu118\n", "torchdata 0.7.0\n", "torchsummary 1.5.1\n", "torchtext 0.16.0\n", "torchvision 0.16.0+cu118\n", "transformers 4.35.2\n" ] } ] }, { "cell_type": "markdown", "source": [ "# Tokenizer and Model" ], "metadata": { "id": "E_ZbiAPAYJEF" } }, { "cell_type": "markdown", "source": [ "**Инициализация модели** \n", "Страница модели \n", "https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2" ], "metadata": { "id": "mc1bW8QXYZ7z" } }, { "cell_type": "code", "source": [ "from transformers import pipeline\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig\n", "\n", "# # инициализация через pipeline\n", "# model_name = \"ai-forever/rugpt3medium_based_on_gpt2\"\n", "# pipe = pipeline(\"text-generation\", model=model_name)\n", "# sample = pipeline('test test', pad_token_id=generator.tokenizer.eos_token_id)\n", "\n", "model_name = \"ai-forever/rugpt3medium_based_on_gpt2\"\n", "model = AutoModelForCausalLM.from_pretrained(model_name)\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "tokenizer.pad_token_id = tokenizer.eos_token_id\n", "\n", "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model = model.to(DEVICE)" ], "metadata": { "id": "AGV9Y2fJHVI1" }, "execution_count": 14, "outputs": [] }, { "cell_type": "markdown", "source": [ "Генерация текста" ], "metadata": { "id": "D7U6M4oyZRqq" } }, { "cell_type": "code", "source": [ "%%time\n", "\n", "# токенизация текста в индексы токенов и маски внимания\n", "text_promt = 'меня засосала опасная трясина '\n", "inputs = tokenizer(text_promt, return_tensors=\"pt\").to(DEVICE)\n", "\n", "# конфиг словарь для генерации текста\n", "gen_config_dict = dict(\n", " do_sample=True, # делать ли случайное семплирование с параметрами ниже (если False то выскочит предупреждение)\n", " max_new_tokens=30, # сколько максимум новых токенов надо генерировать\n", " top_k=50, # семплировать только из top_k самых вероятных токенов\n", " top_p=0.9, # семплировать только из токенов сумма вероятностей которых не больше top_p\n", " temperature=2.0, # температура для софтмакса\n", " num_beams=3, # параметр алгоритма Beam search\n", " repetition_penalty=2.0, # штраф за то что модель повторяется\n", " pad_token_id=tokenizer.pad_token_id, # установить токен pad чтобы не было предупреждения\n", " )\n", "# конфиг для генерации текста из словаря\n", "generation_config = GenerationConfig(**gen_config_dict)\n", "\n", "# генерация текста (индексы токенов)\n", "output = model.generate(**inputs, generation_config=generation_config)\n", "\n", "# сопостовление идексам токенов слов из словаря токенайзера\n", "generated_text = tokenizer.decode(output[0], skip_special_tokens=False)\n", "\n", "# удаление исходного промта из ответа потому что он тоже возвращается\n", "generated_text = generated_text[len(text_promt):]\n", "generated_text" ], "metadata": { "id": "pSWP_0d4IE9_", "colab": { "base_uri": "https://localhost:8080/", "height": 70 }, "outputId": "2841a9d8-387d-42f2-e1b1-1fcec48b444d" }, "execution_count": 15, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "CPU times: user 11.4 s, sys: 27.2 ms, total: 11.5 s\n", "Wall time: 15.6 s\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "'\\nГрязь, гниль, вонь, вонище, хлябь, тухлятина, плесень, чавка'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 15 } ] }, { "cell_type": "markdown", "source": [ "Функция для генерации" ], "metadata": { "id": "kXsbzBUFayk7" } }, { "cell_type": "code", "source": [ "# функция принимает текстовый запрос и словарь параметров генерации\n", "def generate_text(text_promt: str, gen_config_dict: Dict[str, Union[float, int]]) -> str:\n", " inputs = tokenizer(text_promt, return_tensors=\"pt\").to(DEVICE)\n", " generation_config = GenerationConfig(**gen_config_dict)\n", " output = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, generation_config=generation_config)\n", " generated_text = tokenizer.decode(output[0], skip_special_tokens=True)\n", " generated_text = generated_text[len(text_promt):]\n", " return generated_text\n", "\n", "\n", "# конфиг словарь для генерации текста\n", "gen_config_dict = dict(\n", " do_sample=True, # делать ли случайное семплирование с параметрами ниже (если False то выскочит предупреждение)\n", " max_new_tokens=20, # сколько максимум новых токенов надо генерировать\n", " top_k=50, # семплировать только из top_k самых вероятных токенов\n", " top_p=0.9, # семплировать только из токенов сумма вероятностей которых не больше top_p\n", " temperature=2.0, # температура для софтмакса\n", " num_beams=3, # параметр алгоритма Beam search\n", " repetition_penalty=2.0, # штраф за то что модель повторяется\n", " pad_token_id=tokenizer.pad_token_id, # установить токен pad чтобы не было предупреждения\n", " )\n", "\n", "\n", "text_promt = 'в небесной канцелярии выходной'\n", "generated_text = generate_text(text_promt, gen_config_dict)\n", "generated_text" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "HMg9ckjcZ1hn", "outputId": "0640b51e-50f6-4f37-f7ee-1d5b2eeae30d" }, "execution_count": 16, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "', и только после этого можно заниматься этим вопросом. Но это уже совсем другие истории, которые могут'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 16 } ] }, { "cell_type": "markdown", "source": [ "# Gradio App" ], "metadata": { "id": "MgzsPmnx90zW" } }, { "cell_type": "markdown", "source": [ "## Новый интерфейс Чат-бота" ], "metadata": { "id": "z41veiFa62k2" } }, { "cell_type": "markdown", "source": [ "Вариант с системным промтом и разными входными аргументами и настройками" ], "metadata": { "id": "mxsjlHD0_Xxn" } }, { "cell_type": "code", "source": [ "import gradio as gr\n", "\n", "# функция будет вызыватся при нажатии на Enter в окошке вовода текста\n", "# кроме обычных аргументов - сообщения пользователя и истории - принимает еще параметры для конфига генерации\n", "def generate(message, history, *components):\n", " # print(system_promt)\n", " # обновление словаря новыми агрументами и создание конфига генерации текста\n", " gen_config.update(dict(zip(gen_config.keys(), components)))\n", " gen_config['top_k'] = int(gen_config['top_k'])\n", " gen_config['num_beams'] = int(gen_config['num_beams'])\n", " generation_config = GenerationConfig(**gen_config)\n", "\n", " # добавить системный промт в начало запроса и сгенерировать текст\n", " promt = message\n", " inputs = tokenizer(promt, return_tensors=\"pt\").to(DEVICE)\n", " output = model.generate(**inputs, generation_config=generation_config, pad_token_id=tokenizer.eos_token_id)\n", " generated_text = tokenizer.decode(output[0], skip_special_tokens=True)\n", " generated_text = generated_text[len(promt):]\n", " # имитация набора сообщения чат-ботом (посимвольня генерация через yield в цикле)\n", " for i in range(len(generated_text)):\n", " time.sleep(0.05) # задержка с которой бот вводит текст\n", " yield generated_text[:i+1]\n", "\n", "# словарь для конфига генерации текста\n", "gen_config = dict(\n", " do_sample=False,\n", " max_length=60,\n", " top_k=50,\n", " top_p=0.9,\n", " temperature=2.0,\n", " num_beams=3,\n", " repetition_penalty=2.0,\n", " )\n", "\n", "\n", "# компоненты настройки конфига генерации текста\n", "components = [\n", " gr.Checkbox(label=\"do_sample\", value=gen_config[\"do_sample\"]),\n", " gr.Slider(label=\"max_length\", value=gen_config[\"max_length\"], minimum=1, maximum=300, step=10),\n", " gr.Number(label=\"top_k\", value=gen_config[\"top_k\"], minimum=0, maximum=50, step=10),\n", " gr.Number(label=\"top_p\", value=gen_config[\"top_p\"], minimum=0, maximum=1, step=0.1),\n", " gr.Number(label=\"temperature\", value=gen_config[\"temperature\"], minimum=0, maximum=10, step=0.1),\n", " gr.Number(label=\"num_beams\", value=gen_config[\"num_beams\"], minimum=0, maximum=5, step=1),\n", " gr.Number(label=\"repetition_penalty\", value=gen_config[\"repetition_penalty\"], minimum=0, maximum=5, step=0.1),\n", " ]\n", "\n", "# при нажатии Enter в чате будет вызыватся функция generate\n", "interface = gr.ChatInterface(\n", " generate,\n", " chatbot=gr.Chatbot(height=300), # вход для функции generate: message\n", " textbox=gr.Textbox(placeholder=\"Задайте любой вопрос\", container=False, scale=2), # выходной бокс для текста\n", "\n", " # дополнительные входы для функции generate (*components)\n", " additional_inputs=components,\n", "\n", " # настройки оформления\n", " title=\"Чат-бот T10\", # название страницы\n", " description=\"Окно переписки с ботом\", # описание окошка переписки\n", " theme=\"Glass\", # темы: Glass, Monochrome, Soft\n", " # examples=[\"Hello\", \"Am I cool?\", \"Are tomatoes vegetables?\"], # примеры должны быть множественными если аргументов много\n", " # cache_examples=True, # кешировать примеры\n", "\n", " # дполнительные кнопки (если не нужна какая либо кнопка ставим None)\n", " submit_btn='Отправить',\n", " retry_btn='Повторить вопрос',\n", " undo_btn=\"Удалить предыдущий вопрос\",\n", " clear_btn=\"Очистить историю\",\n", ")\n", "\n", "# запуск приложения\n", "interface.queue().launch(debug=True)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 819 }, "id": "8c9BhO7SudbA", "outputId": "d1d67aa5-04a6-4129-c7ed-a2e60fe7e45f" }, "execution_count": 17, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n", "\n", "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n", "Running on public URL: https://82df78dd1728af38f1.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
" ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `2.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:386: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `2.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:386: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", " warnings.warn(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Keyboard interruption in main thread... closing server.\n", "Killing tunnel 127.0.0.1:7861 <> https://82df78dd1728af38f1.gradio.live\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [] }, "metadata": {}, "execution_count": 17 } ] } ] }