chat-gradio / app.py
gosha6037's picture
Added Cluster Bloom
91ffd2c
import sys
import json
import argparse
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, './petals/')
sys.path.insert(0, './personalized-chat-bot/')
from petals.client.remote_model import DistributedBloomForCausalLM
from personalized_chat_bot import PersonalizedChatBot, PersonalityManager
from models.personality_clustering import PersonalityClustering
MODEL_NAME = "bigscience/bloom-petals"
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True)
tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
def predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
person_description_ids = tokenizer.encode(person_description + '\n', return_tensors='pt')
print('Started predict_common_bloom')
print(f'history: {history}')
if history != []:
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
print(f'bot_input_ids: {bot_input_ids}')
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
history = model.generate(
input_with_desc_ids,
max_new_tokens=number_of_new_tokens,
pad_token_id=tokenizer.eos_token_id
).tolist()
print(f'history: {history}')
history[0] = history[0][len(person_description_ids[0]):]
decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
if all_responses[0]:
decode_all += all_responses[0] + '\n'
else:
decode_all += all_responses[1] + '\n'
print(f'decode_all: {decode_all}')
history_new = tokenizer.encode(decode_all, return_tensors='pt')
print(f'history_new: {history_new}')
decode_all_split = decode_all.split('\n')
print(f'decode_all_split: {decode_all_split}')
response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)]
print(f'response_new: {response_new}')
return response_new, history_new
def load_config(path):
with open(path, 'r') as f:
config = json.load(f)
return argparse.Namespace(**config)
def predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
personality_clustering = PersonalityClustering()
personality_clustering.load('personalized-chat-bot/data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl')
hook = lambda dct: {int(k): v for k, v in dct.items()}
with open('personalized-chat-bot/prompt_paths.json', 'r') as f:
prompt_paths = json.load(f, object_hook=hook)
pm = PersonalityManager(prompt_paths, personality_clustering)
prompt_path, closest_persona = pm.get_prompt(person_description)
print(f'The closest personality is: {closest_persona}')
print('Wait a little longer...')
config = load_config('personalized-chat-bot/scripts/config_176b.json')
model = DistributedBloomForCausalLM.from_pretrained(
config.MODEL_NAME,
pre_seq_len=config.NUM_PREFIX_TOKENS,
tuning_mode=config.TUNING_MODE,
# max_new_tokens=number_of_new_tokens,
).to(config.DEVICE)
generation_config = load_config('personalized-chat-bot/generation_config.json')
generation_config.max_new_tokens=number_of_new_tokens
print(f'generation_config: {generation_config}')
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
tokenizer.padding_side = 'right'
tokenizer.model_max_length = config.MODEL_MAX_LENGTH
chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config)
chatbot.load_prompt('personalized-chat-bot/' + prompt_path)
if history != []:
input_text = tokenizer.decode(history[0]) + '\n' + input_text
print(f'INPUT: {input_text}')
output = chatbot.answer(input_text)
all_text = input_text + '\n' + output
print(f'all_text: {all_text}')
history = tokenizer.encode(all_text, return_tensors='pt')
print(f'history: {history}')
response = tokenizer.decode(history[0]).split("\n")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
print(f'response: {response}')
return response, history
def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
history = model.generate(
input_with_desc_ids,
max_new_tokens=number_of_new_tokens,
pad_token_id=tokenizer.eos_token_id
).tolist()
history[0] = history[0][len(person_description_ids[0]):]
response = tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
return response, history
def predict(
input_text,
history=None,
person_description=None,
number_of_new_tokens=10,
model_name=None,
del_hist=None
):
if history is None or del_hist == 'delete history':
history = []
if model_name == 'DialoGPT-small':
model = model_DialoGPT_small
tokenizer = tokenizer_DialoGPT_small
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'DialoGPT-medium':
model = model_DialoGPT_medium
tokenizer = tokenizer_DialoGPT_medium
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'DialoGPT-large':
model = model_DialoGPT_large
tokenizer = tokenizer_DialoGPT_large
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'bloom-petals':
model = model_bloomd
tokenizer = tokenizer_bloomd
print(f'Lets go history: {history}')
return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'bloom-petals-cluster':
model = model_bloomd
tokenizer = tokenizer_bloomd
print(f'Lets go history: {history}')
return predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
else:
model_name = 'DialoGPT-medium'
model = model_DialoGPT_medium
tokenizer = tokenizer_DialoGPT_medium
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."),
"state",
gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."),
gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10),
gr.Radio(
label='Model name',
choices=[
'DialoGPT-small',
'DialoGPT-medium',
'DialoGPT-large',
'bloom-petals',
'bloom-petals-cluster',
]
),
gr.Radio(
label='Delete history',
value="Don't delete history",
choices=[
'delete history',
"Don't delete history"
]),
],
outputs=[gr.Chatbot(label='History of the dialogue'), "state"],
).launch(),