Spaces:
Runtime error
Runtime error
File size: 9,042 Bytes
ae84b44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_length_param(text: str, tokenizer) -> str:
"""Maps text to 1 of 4 buckets based on length after encoding.
Parameters
----------
text: str
The text to be given 1 of 4 length parameters.
tokenizer: HuggingFace tokenizer
Tokenizer that used to compute the length of the text after encoding.
For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
Returns
-------
len_param: str
One of four buckets:
'1' for short, '2' for medium, '3' for long texts and '-' for all others.
"""
tokens_count = len(tokenizer.encode(text))
if tokens_count <= 15:
len_param = '1'
elif tokens_count <= 50:
len_param = '2'
elif tokens_count <= 256:
len_param = '3'
else:
len_param = '-'
return len_param
def get_user_param(text: dict, machine_name_in_chat: str) -> str:
"""Maps text by 1/0 for it to be the person or the machine in the dialogue
Parameters
----------
text: Dict[..., 'from', ...]
Dict containing field 'from' with the name of the user who sent the message
machine_name_in_chat: str
Str with the name of the machine - it will be predicted
"""
if text['from'] == machine_name_in_chat:
return '1' # machine
else:
return '0' # human
def build_text_file(data_json: dict, dest_path: str,
tokenizer, machine_name_in_chat='Кирилл Гельван'):
"""Create a text file for training in special format for ruDialoGPT-3.
Parameters
----------
data_json: dict
Dict containing 'text' (message) and 'from' (user who sent the message)
dest_path: str
String containing path to write data there
tokenizer: HuggingFace tokenizer
Tokenizer that used to compute the length of the text after encoding.
For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
"""
f = open(dest_path, 'w')
new_data = ''
for i in range(len(data_json) - 1):
message, next_message = data_json[i], data_json[i+1]
if message['text'] == '' or type(message['text']) != str:
continue
if next_message['text'] == '' or type(next_message['text']) != str:
continue
user = get_user_param(message, machine_name_in_chat=machine_name_in_chat)
length = get_length_param(data_json[i+1]['text'], tokenizer)
message_text = re.sub(r"\n", ". ", message['text'])
new_data += f"|{user}|{length}|{message_text}{tokenizer.eos_token}" + "\n"
f.write(new_data)
def load_dataset(train_path, test_path, tokenizer):
"""Creates train and test PyTorch datasets and collate_fn using HuggingFace.
Parameters
----------
train_path: str
String containing path to train data
test_path: str
String containing path to test data
tokenizer: HuggingFace tokenizer
Tokenizer that used to compute the length of the text after encoding.
For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
"""
train_dataset = TextDataset(
tokenizer = tokenizer,
file_path = train_path,
block_size = 256)
test_dataset = TextDataset(
tokenizer = tokenizer,
file_path = test_path,
block_size = 256)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=False
)
return train_dataset, test_dataset, data_collator
def chat_function(message, length_of_the_answer, who_is_next, creativity): # model, tokenizer
input_user = message
if length_of_the_answer == 'short':
next_len = '1'
elif length_of_the_answer == 'medium':
next_len = '2'
elif length_of_the_answer == 'long':
next_len = '3'
else:
next_len = '-'
print(who_is_next)
if who_is_next == 'Kirill':
next_who = 'G'
elif who_is_next == 'Me':
next_who = 'H'
history = gr.get_state() or []
chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)
######### next_who = input("Who's phrase?\t") #input("H / G?") # Human or GPT
# In case Human
##### if next_who == "H":
######## input_user = input("===> Human: ")
# encode the new user input, add parameters and return a tensor in Pytorch
if len(input_user) != 0:
new_user_input_ids = tokenizer.encode(f"|0|{get_length_param(input_user, tokenizer)}|" \
+ input_user + tokenizer.eos_token, return_tensors="pt")
# append the new user input tokens to the chat history
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
else:
input_user = '-'
if next_who == "G":
######## next_len = input("Phrase len? 1/2/3/-\t") #input("Exp. len?(-/1/2/3): ")
# encode the new user input, add parameters and return a tensor in Pytorch
new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
# append the new user input tokens to the chat history
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
print(tokenizer.decode(chat_history_ids[-1])) # uncomment to see full gpt input
# save previous len
input_len = chat_history_ids.shape[-1]
# generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate
chat_history_ids = model.generate(
chat_history_ids,
num_return_sequences=1, # use for more variants, but have to print [i]
max_length=512,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.9,
temperature = float(creativity), # 0 for greedy
mask_token_id=tokenizer.mask_token_id,
eos_token_id=tokenizer.eos_token_id,
unk_token_id=tokenizer.unk_token_id,
pad_token_id=tokenizer.pad_token_id,
device='cpu'
)
response = tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)
else:
response = '-'
history.append((input_user, response, chat_history_ids.tolist()))
gr.set_state(history)
html = "<div class='chatbot'>"
for user_msg, resp_msg, _ in history:
if user_msg != '-':
html += f"<div class='user_msg'>{user_msg}</div>"
if resp_msg != '-':
html += f"<div class='resp_msg'>{resp_msg}</div>"
html += "</div>"
return html
# Download checkpoint:
checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model = model.eval()
checkbox_group = gr.inputs.CheckboxGroup(['Kirill', 'Me'], default=['Kirill'], type="value", label=None)
inputs = gr.inputs.Textbox(lines=1, label="???")
outputs = gr.outputs.Textbox(label="Kirill (GPT-2):")
title = "Chat with Kirill (in Russian)"
description = "Тут можно поболтать со мной. Но вместо меня бот. Оставь message пустым, чтобы Кирилл продолжил говорить. Подбробнее о технике по ссылке внизу."
article = "<p style='text-align: center'><a href='https://github.com/Kirili4ik/ruDialoGpt3-finetune-colab'>Github with fine-tuning GPT-2 on your chat</a></p>"
examples = [
["Привет, как дела?", 'medium', 'Kirill', 0.6],
["Сколько тебе лет?", 'medium', 'Kirill', 0.3],
]
iface = gr.Interface(chat_function,
[
"text",
gr.inputs.Radio(["short", "medium", "long"], default='medium'),
gr.inputs.Radio(["Kirill", "Me"], default='Kirill'),
gr.inputs.Slider(0, 1, default=0.6)
],
"html",
title=title, description=description, article=article, examples=examples,
css= """
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""",
allow_screenshot=True,
allow_flagging=False
)
iface.launch()
|