Spaces:
Runtime error
Runtime error
File size: 5,467 Bytes
ae84b44 21a5dba ae84b44 9283ef4 ae84b44 7f3bae1 ae84b44 7f3bae1 ae84b44 7f3bae1 ae84b44 7f3bae1 ae84b44 7f3bae1 ae84b44 7f3bae1 ae84b44 3e462aa ae84b44 7f3bae1 ae84b44 9283ef4 ae84b44 09f0075 ae84b44 741f4b4 ae84b44 ddc2c87 21a5dba ae84b44 6966380 ae84b44 21a5dba e4ef938 ae84b44 9283ef4 7f3bae1 ae84b44 7f3bae1 ae84b44 ac61535 21a5dba e4ef938 9283ef4 3e462aa 21a5dba ac61535 ae84b44 21a5dba ae84b44 21a5dba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from util_funcs import get_length_param
def chat_function(Message, Length_of_the_answer, Who_is_next, Base_to_On_subject_temperature, history): # model, tokenizer
input_user = Message
if Length_of_the_answer == 'short':
next_len = '1'
elif Length_of_the_answer == 'medium':
next_len = '2'
elif Length_of_the_answer == 'long':
next_len = '3'
else:
next_len = '-'
if Who_is_next == 'Kirill':
next_who = 'G'
elif Who_is_next == 'Me':
next_who = 'H'
history = history or []
chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)
# encode the new user input, add parameters and return a tensor in Pytorch
if len(input_user) != 0:
new_user_input_ids = tokenizer.encode(f"|0|{get_length_param(input_user, tokenizer)}|" \
+ input_user + tokenizer.eos_token, return_tensors="pt")
# append the new user input tokens to the chat history
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
else:
input_user = '-'
if next_who == "G":
# encode the new user input, add parameters and return a tensor in Pytorch
new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
# append the new user input tokens to the chat history
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
# print(tokenizer.decode(chat_history_ids[-1])) # uncomment to see full gpt input
# save previous len
input_len = chat_history_ids.shape[-1]
# generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate
chat_history_ids = model.generate(
chat_history_ids,
num_return_sequences=1, # use for more variants, but have to print [i]
max_length=512,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.9,
temperature = float(Base_to_On_subject_temperature), # 0 for greedy
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)
else:
response = '-'
history.append((input_user, response, chat_history_ids.tolist()))
# depricated -- gr.set_state(history)
html = "<div class='chatbot'>"
for user_msg, resp_msg, _ in history:
if user_msg != '-':
html += f"<div class='user_msg'>{user_msg}</div>"
if resp_msg != '-':
html += f"<div class='resp_msg'>{resp_msg}</div>"
html += "</div>"
return html, history
# Download checkpoint:
checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram-6ep"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model = model.eval()
# Gradio
checkbox_group = gr.inputs.CheckboxGroup(['Kirill', 'Me'], default=['Kirill'], type="value", label=None)
title = "Chat with Kirill (in Russian)"
description = "Тут можно поболтать со мной. Но вместо меня бот. Оставь сообщение пустым, чтобы Кирилл продолжил говорить - он очень любит писать подряд несколько сообщений в чате. Используй слайдер, чтобы ответы были более общими или более конкретными (ближе к теме). Подробнее о технике по ссылке внизу."
article = "<p style='text-align: center'><a href='https://github.com/Kirili4ik/ruDialoGpt3-finetune-colab'>Github with fine-tuning GPT-3 on your chat</a></p>"
examples = [
["В чем смысл жизни?", 'medium', 'Kirill', 0.95],
["Когда у тебя ближайший собес?", 'medium', 'Kirill', 0.85],
["Сколько тебе лет, Кирилл?", 'medium', 'Kirill', 0.85]
]
iface = gr.Interface(chat_function,
[
"text",
gr.inputs.Radio(["short", "medium", "long"], default='medium'),
gr.inputs.Radio(["Kirill", "Me"], default='Kirill'),
gr.inputs.Slider(0, 1.5, default=0.5),
"state"
],
["html", "state"],
title=title, description=description, article=article, examples=examples,
css= """
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""",
allow_screenshot=True,
allow_flagging=False
)
if __name__ == "__main__":
iface.launch()
|