Kirili4ik's picture
Update app.py
1673dcd
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from util_funcs import get_length_param
def chat_function(Message, Length_of_the_answer, Who_is_next, Base_to_On_subject_temperature, history): # model, tokenizer
input_user = Message
if Length_of_the_answer == 'short':
next_len = '1'
elif Length_of_the_answer == 'medium':
next_len = '2'
elif Length_of_the_answer == 'long':
next_len = '3'
else:
next_len = '-'
if Who_is_next == 'Kirill':
next_who = 'G'
elif Who_is_next == 'Me':
next_who = 'H'
history = history or []
chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)
# encode the new user input, add parameters and return a tensor in Pytorch
if len(input_user) != 0:
new_user_input_ids = tokenizer.encode(f"|0|{get_length_param(input_user, tokenizer)}|" \
+ input_user + tokenizer.eos_token, return_tensors="pt")
# append the new user input tokens to the chat history
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
else:
input_user = '-'
if next_who == "G":
# encode the new user input, add parameters and return a tensor in Pytorch
new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
# append the new user input tokens to the chat history
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
# print(tokenizer.decode(chat_history_ids[-1])) # uncomment to see full gpt input
# save previous len
input_len = chat_history_ids.shape[-1]
# generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate
chat_history_ids = model.generate(
chat_history_ids,
num_return_sequences=1, # use for more variants, but have to print [i]
max_length=512,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.9,
temperature = float(Base_to_On_subject_temperature), # 0 for greedy
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)
else:
response = '-'
history.append((input_user, response, chat_history_ids.tolist()))
# depricated -- gr.set_state(history)
html = "<div class='chatbot'>"
for user_msg, resp_msg, _ in history:
if user_msg != '-':
html += f"<div class='user_msg'>{user_msg}</div>"
if resp_msg != '-':
html += f"<div class='resp_msg'>{resp_msg}</div>"
html += "</div>"
return html, history
# Download checkpoint:
checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram-6ep"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model = model.eval()
# Gradio
checkbox_group = gr.inputs.CheckboxGroup(['Kirill', 'Me'], default=['Kirill'], type="value", label=None)
title = "Chat with Kirill (in Russian)"
description = "Тут можно поболтать со мной. Но вместо меня бот. Оставь сообщение пустым, чтобы Кирилл продолжил говорить - он очень любит писать подряд несколько сообщений в чате. Используй слайдер, чтобы ответы были более общими или более конкретными (ближе к теме). Подробнее о технике по ссылке внизу."
article = "<p style='text-align: center'><a href='https://github.com/Kirili4ik/ruDialoGpt3-finetune-colab'>Github with fine-tuning GPT-3 on your chat</a></p>"
examples = [
["В чем смысл жизни?", 'medium', 'Kirill', 0.95],
["Когда у тебя ближайший собес?", 'medium', 'Kirill', 0.85],
["Сколько тебе лет, Кирилл?", 'medium', 'Kirill', 0.85]
]
iface = gr.Interface(chat_function,
[
"text",
gr.inputs.Radio(["short", "medium", "long"], default='medium'),
gr.inputs.Radio(["Kirill", "Me"], default='Kirill'),
gr.inputs.Slider(0, 1.5, default=0.5),
"state"
],
["html", "state"],
title=title, description=description, article=article, examples=examples,
css= """
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""",
allow_screenshot=True,
allow_flagging=False,
api_mode=True
)
if __name__ == "__main__":
iface.launch()