Spaces:
Runtime error
Runtime error
import sys | |
import json | |
import argparse | |
import gradio as gr | |
import torch | |
import transformers | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
sys.path.insert(0, './petals/') | |
sys.path.insert(0, './personalized-chat-bot/') | |
from petals.client.remote_model import DistributedBloomForCausalLM | |
from personalized_chat_bot import PersonalizedChatBot, PersonalityManager | |
from models.personality_clustering import PersonalityClustering | |
MODEL_NAME = "bigscience/bloom-petals" | |
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) | |
model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True) | |
tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") | |
model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") | |
tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") | |
model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") | |
def predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens): | |
new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt') | |
person_description_ids = tokenizer.encode(person_description + '\n', return_tensors='pt') | |
print('Started predict_common_bloom') | |
print(f'history: {history}') | |
if history != []: | |
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) | |
else: | |
bot_input_ids = new_user_input_ids | |
print(f'bot_input_ids: {bot_input_ids}') | |
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1) | |
history = model.generate( | |
input_with_desc_ids, | |
max_new_tokens=number_of_new_tokens, | |
pad_token_id=tokenizer.eos_token_id | |
).tolist() | |
print(f'history: {history}') | |
history[0] = history[0][len(person_description_ids[0]):] | |
decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])]) | |
all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n') | |
if all_responses[0]: | |
decode_all += all_responses[0] + '\n' | |
else: | |
decode_all += all_responses[1] + '\n' | |
print(f'decode_all: {decode_all}') | |
history_new = tokenizer.encode(decode_all, return_tensors='pt') | |
print(f'history_new: {history_new}') | |
decode_all_split = decode_all.split('\n') | |
print(f'decode_all_split: {decode_all_split}') | |
response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)] | |
print(f'response_new: {response_new}') | |
return response_new, history_new | |
def load_config(path): | |
with open(path, 'r') as f: | |
config = json.load(f) | |
return argparse.Namespace(**config) | |
def predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens): | |
personality_clustering = PersonalityClustering() | |
personality_clustering.load('personalized-chat-bot/data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl') | |
hook = lambda dct: {int(k): v for k, v in dct.items()} | |
with open('personalized-chat-bot/prompt_paths.json', 'r') as f: | |
prompt_paths = json.load(f, object_hook=hook) | |
pm = PersonalityManager(prompt_paths, personality_clustering) | |
prompt_path, closest_persona = pm.get_prompt(person_description) | |
print(f'The closest personality is: {closest_persona}') | |
print('Wait a little longer...') | |
config = load_config('personalized-chat-bot/scripts/config_176b.json') | |
model = DistributedBloomForCausalLM.from_pretrained( | |
config.MODEL_NAME, | |
pre_seq_len=config.NUM_PREFIX_TOKENS, | |
tuning_mode=config.TUNING_MODE, | |
# max_new_tokens=number_of_new_tokens, | |
).to(config.DEVICE) | |
generation_config = load_config('personalized-chat-bot/generation_config.json') | |
generation_config.max_new_tokens=number_of_new_tokens | |
print(f'generation_config: {generation_config}') | |
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME) | |
tokenizer.padding_side = 'right' | |
tokenizer.model_max_length = config.MODEL_MAX_LENGTH | |
chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config) | |
chatbot.load_prompt('personalized-chat-bot/' + prompt_path) | |
if history != []: | |
input_text = tokenizer.decode(history[0]) + '\n' + input_text | |
print(f'INPUT: {input_text}') | |
output = chatbot.answer(input_text) | |
all_text = input_text + '\n' + output | |
print(f'all_text: {all_text}') | |
history = tokenizer.encode(all_text, return_tensors='pt') | |
print(f'history: {history}') | |
response = tokenizer.decode(history[0]).split("\n") | |
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] | |
print(f'response: {response}') | |
return response, history | |
def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens): | |
person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt') | |
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt') | |
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) | |
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1) | |
history = model.generate( | |
input_with_desc_ids, | |
max_new_tokens=number_of_new_tokens, | |
pad_token_id=tokenizer.eos_token_id | |
).tolist() | |
history[0] = history[0][len(person_description_ids[0]):] | |
response = tokenizer.decode(history[0]).split("<|endoftext|>") | |
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] | |
return response, history | |
def predict( | |
input_text, | |
history=None, | |
person_description=None, | |
number_of_new_tokens=10, | |
model_name=None, | |
del_hist=None | |
): | |
if history is None or del_hist == 'delete history': | |
history = [] | |
if model_name == 'DialoGPT-small': | |
model = model_DialoGPT_small | |
tokenizer = tokenizer_DialoGPT_small | |
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens) | |
elif model_name == 'DialoGPT-medium': | |
model = model_DialoGPT_medium | |
tokenizer = tokenizer_DialoGPT_medium | |
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens) | |
elif model_name == 'DialoGPT-large': | |
model = model_DialoGPT_large | |
tokenizer = tokenizer_DialoGPT_large | |
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens) | |
elif model_name == 'bloom-petals': | |
model = model_bloomd | |
tokenizer = tokenizer_bloomd | |
print(f'Lets go history: {history}') | |
return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens) | |
elif model_name == 'bloom-petals-cluster': | |
model = model_bloomd | |
tokenizer = tokenizer_bloomd | |
print(f'Lets go history: {history}') | |
return predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens) | |
else: | |
model_name = 'DialoGPT-medium' | |
model = model_DialoGPT_medium | |
tokenizer = tokenizer_DialoGPT_medium | |
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens) | |
gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."), | |
"state", | |
gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."), | |
gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10), | |
gr.Radio( | |
label='Model name', | |
choices=[ | |
'DialoGPT-small', | |
'DialoGPT-medium', | |
'DialoGPT-large', | |
'bloom-petals', | |
'bloom-petals-cluster', | |
] | |
), | |
gr.Radio( | |
label='Delete history', | |
value="Don't delete history", | |
choices=[ | |
'delete history', | |
"Don't delete history" | |
]), | |
], | |
outputs=[gr.Chatbot(label='History of the dialogue'), "state"], | |
).launch(), | |