|
|
|
"""Gradio_GPT_bot.ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/18CH6wtcr46hWqBqpzieH_oBOmJHecOVl |
|
|
|
# Imports |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
import time |
|
from typing import List, Dict, Tuple, Union |
|
from IPython import display |
|
|
|
import gradio as gr |
|
import torch |
|
import transformers |
|
|
|
|
|
!pip list | grep -P "torch|transformers|gradio" |
|
|
|
"""# Tokenizer and Model |
|
|
|
**Инициализация модели** |
|
Страница модели |
|
https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2 |
|
""" |
|
|
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "ai-forever/rugpt3medium_based_on_gpt2" |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(DEVICE) |
|
|
|
"""Генерация текста""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Функция для генерации""" |
|
|
|
|
|
def generate_text(text_promt: str, gen_config_dict: Dict[str, Union[float, int]]) -> str: |
|
inputs = tokenizer(text_promt, return_tensors="pt").to(DEVICE) |
|
generation_config = GenerationConfig(**gen_config_dict) |
|
output = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, generation_config=generation_config) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
generated_text = generated_text[len(text_promt):] |
|
return generated_text |
|
|
|
|
|
|
|
gen_config_dict = dict( |
|
do_sample=True, |
|
max_new_tokens=20, |
|
top_k=50, |
|
top_p=0.9, |
|
temperature=2.0, |
|
num_beams=3, |
|
repetition_penalty=2.0, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
|
|
|
|
text_promt = 'в небесной канцелярии выходной' |
|
generated_text = generate_text(text_promt, gen_config_dict) |
|
generated_text |
|
|
|
"""# Gradio App |
|
|
|
## Новый интерфейс Чат-бота |
|
|
|
Вариант с системным промтом и разными входными аргументами и настройками |
|
""" |
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
def generate(message, history, *components): |
|
|
|
|
|
gen_config.update(dict(zip(gen_config.keys(), components))) |
|
gen_config['top_k'] = int(gen_config['top_k']) |
|
gen_config['num_beams'] = int(gen_config['num_beams']) |
|
generation_config = GenerationConfig(**gen_config) |
|
|
|
|
|
promt = message |
|
inputs = tokenizer(promt, return_tensors="pt").to(DEVICE) |
|
output = model.generate(**inputs, generation_config=generation_config, pad_token_id=tokenizer.eos_token_id) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
generated_text = generated_text[len(promt):] |
|
|
|
for i in range(len(generated_text)): |
|
time.sleep(0.05) |
|
yield generated_text[:i+1] |
|
|
|
|
|
gen_config = dict( |
|
do_sample=False, |
|
max_length=60, |
|
top_k=50, |
|
top_p=0.9, |
|
temperature=2.0, |
|
num_beams=3, |
|
repetition_penalty=2.0, |
|
) |
|
|
|
|
|
|
|
components = [ |
|
gr.Checkbox(label="do_sample", value=gen_config["do_sample"]), |
|
gr.Slider(label="max_length", value=gen_config["max_length"], minimum=1, maximum=300, step=10), |
|
gr.Number(label="top_k", value=gen_config["top_k"], minimum=0, maximum=50, step=10), |
|
gr.Number(label="top_p", value=gen_config["top_p"], minimum=0, maximum=1, step=0.1), |
|
gr.Number(label="temperature", value=gen_config["temperature"], minimum=0, maximum=10, step=0.1), |
|
gr.Number(label="num_beams", value=gen_config["num_beams"], minimum=0, maximum=5, step=1), |
|
gr.Number(label="repetition_penalty", value=gen_config["repetition_penalty"], minimum=0, maximum=5, step=0.1), |
|
] |
|
|
|
|
|
interface = gr.ChatInterface( |
|
generate, |
|
chatbot=gr.Chatbot(height=300), |
|
textbox=gr.Textbox(placeholder="Задайте любой вопрос", container=False, scale=2), |
|
|
|
|
|
additional_inputs=components, |
|
|
|
|
|
title="Чат-бот T10", |
|
description="Окно переписки с ботом", |
|
theme="Glass", |
|
|
|
|
|
|
|
|
|
submit_btn='Отправить', |
|
retry_btn='Повторить вопрос', |
|
undo_btn="Удалить предыдущий вопрос", |
|
clear_btn="Очистить историю", |
|
) |
|
|
|
|
|
interface.queue().launch(debug=True) |