gpt_bot / gradio_gpt_bot.py
Jorj2064's picture
Upload 2 files
418d4b3
# -*- coding: utf-8 -*-
"""Gradio_GPT_bot.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/18CH6wtcr46hWqBqpzieH_oBOmJHecOVl
# Imports
"""
# Commented out IPython magic to ensure Python compatibility.
# %%capture
# # установка gradio для написания веб интерефейса
# # установка transformers для использования языковых моделей с платформы HF
#
# !pip install gradio transformers
import random
import time
from typing import List, Dict, Tuple, Union
from IPython import display
import gradio as gr
import torch
import transformers
# версии библиотек
!pip list | grep -P "torch|transformers|gradio"
"""# Tokenizer and Model
**Инициализация модели**
Страница модели
https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2
"""
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
# # инициализация через pipeline
# model_name = "ai-forever/rugpt3medium_based_on_gpt2"
# pipe = pipeline("text-generation", model=model_name)
# sample = pipeline('test test', pad_token_id=generator.tokenizer.eos_token_id)
model_name = "ai-forever/rugpt3medium_based_on_gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(DEVICE)
"""Генерация текста"""
# Commented out IPython magic to ensure Python compatibility.
# %%time
#
# # токенизация текста в индексы токенов и маски внимания
# text_promt = 'меня засосала опасная трясина '
# inputs = tokenizer(text_promt, return_tensors="pt").to(DEVICE)
#
# # конфиг словарь для генерации текста
# gen_config_dict = dict(
# do_sample=True, # делать ли случайное семплирование с параметрами ниже (если False то выскочит предупреждение)
# max_new_tokens=30, # сколько максимум новых токенов надо генерировать
# top_k=50, # семплировать только из top_k самых вероятных токенов
# top_p=0.9, # семплировать только из токенов сумма вероятностей которых не больше top_p
# temperature=2.0, # температура для софтмакса
# num_beams=3, # параметр алгоритма Beam search
# repetition_penalty=2.0, # штраф за то что модель повторяется
# pad_token_id=tokenizer.pad_token_id, # установить токен pad чтобы не было предупреждения
# )
# # конфиг для генерации текста из словаря
# generation_config = GenerationConfig(**gen_config_dict)
#
# # генерация текста (индексы токенов)
# output = model.generate(**inputs, generation_config=generation_config)
#
# # сопостовление идексам токенов слов из словаря токенайзера
# generated_text = tokenizer.decode(output[0], skip_special_tokens=False)
#
# # удаление исходного промта из ответа потому что он тоже возвращается
# generated_text = generated_text[len(text_promt):]
# generated_text
"""Функция для генерации"""
# функция принимает текстовый запрос и словарь параметров генерации
def generate_text(text_promt: str, gen_config_dict: Dict[str, Union[float, int]]) -> str:
inputs = tokenizer(text_promt, return_tensors="pt").to(DEVICE)
generation_config = GenerationConfig(**gen_config_dict)
output = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, generation_config=generation_config)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
generated_text = generated_text[len(text_promt):]
return generated_text
# конфиг словарь для генерации текста
gen_config_dict = dict(
do_sample=True, # делать ли случайное семплирование с параметрами ниже (если False то выскочит предупреждение)
max_new_tokens=20, # сколько максимум новых токенов надо генерировать
top_k=50, # семплировать только из top_k самых вероятных токенов
top_p=0.9, # семплировать только из токенов сумма вероятностей которых не больше top_p
temperature=2.0, # температура для софтмакса
num_beams=3, # параметр алгоритма Beam search
repetition_penalty=2.0, # штраф за то что модель повторяется
pad_token_id=tokenizer.pad_token_id, # установить токен pad чтобы не было предупреждения
)
text_promt = 'в небесной канцелярии выходной'
generated_text = generate_text(text_promt, gen_config_dict)
generated_text
"""# Gradio App
## Новый интерфейс Чат-бота
Вариант с системным промтом и разными входными аргументами и настройками
"""
import gradio as gr
# функция будет вызыватся при нажатии на Enter в окошке вовода текста
# кроме обычных аргументов - сообщения пользователя и истории - принимает еще параметры для конфига генерации
def generate(message, history, *components):
# print(system_promt)
# обновление словаря новыми агрументами и создание конфига генерации текста
gen_config.update(dict(zip(gen_config.keys(), components)))
gen_config['top_k'] = int(gen_config['top_k'])
gen_config['num_beams'] = int(gen_config['num_beams'])
generation_config = GenerationConfig(**gen_config)
# добавить системный промт в начало запроса и сгенерировать текст
promt = message
inputs = tokenizer(promt, return_tensors="pt").to(DEVICE)
output = model.generate(**inputs, generation_config=generation_config, pad_token_id=tokenizer.eos_token_id)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
generated_text = generated_text[len(promt):]
# имитация набора сообщения чат-ботом (посимвольня генерация через yield в цикле)
for i in range(len(generated_text)):
time.sleep(0.05) # задержка с которой бот вводит текст
yield generated_text[:i+1]
# словарь для конфига генерации текста
gen_config = dict(
do_sample=False,
max_length=60,
top_k=50,
top_p=0.9,
temperature=2.0,
num_beams=3,
repetition_penalty=2.0,
)
# компоненты настройки конфига генерации текста
components = [
gr.Checkbox(label="do_sample", value=gen_config["do_sample"]),
gr.Slider(label="max_length", value=gen_config["max_length"], minimum=1, maximum=300, step=10),
gr.Number(label="top_k", value=gen_config["top_k"], minimum=0, maximum=50, step=10),
gr.Number(label="top_p", value=gen_config["top_p"], minimum=0, maximum=1, step=0.1),
gr.Number(label="temperature", value=gen_config["temperature"], minimum=0, maximum=10, step=0.1),
gr.Number(label="num_beams", value=gen_config["num_beams"], minimum=0, maximum=5, step=1),
gr.Number(label="repetition_penalty", value=gen_config["repetition_penalty"], minimum=0, maximum=5, step=0.1),
]
# при нажатии Enter в чате будет вызыватся функция generate
interface = gr.ChatInterface(
generate,
chatbot=gr.Chatbot(height=300), # вход для функции generate: message
textbox=gr.Textbox(placeholder="Задайте любой вопрос", container=False, scale=2), # выходной бокс для текста
# дополнительные входы для функции generate (*components)
additional_inputs=components,
# настройки оформления
title="Чат-бот T10", # название страницы
description="Окно переписки с ботом", # описание окошка переписки
theme="Glass", # темы: Glass, Monochrome, Soft
# examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"], # примеры должны быть множественными если аргументов много
# cache_examples=True, # кешировать примеры
# дполнительные кнопки (если не нужна какая либо кнопка ставим None)
submit_btn='Отправить',
retry_btn='Повторить вопрос',
undo_btn="Удалить предыдущий вопрос",
clear_btn="Очистить историю",
)
# запуск приложения
interface.queue().launch(debug=True)