chat-gradio / app.py
gosha6037's picture
Added Cluster Bloom
91ffd2c
raw history blame
No virus
8.61 kB
import sys
import json
import argparse
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, './petals/')
sys.path.insert(0, './personalized-chat-bot/')
from petals.client.remote_model import DistributedBloomForCausalLM
from personalized_chat_bot import PersonalizedChatBot, PersonalityManager
from models.personality_clustering import PersonalityClustering
MODEL_NAME = "bigscience/bloom-petals"
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True)
tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
def predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
person_description_ids = tokenizer.encode(person_description + '\n', return_tensors='pt')
print('Started predict_common_bloom')
print(f'history: {history}')
if history != []:
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
print(f'bot_input_ids: {bot_input_ids}')
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
history = model.generate(
input_with_desc_ids,
max_new_tokens=number_of_new_tokens,
pad_token_id=tokenizer.eos_token_id
).tolist()
print(f'history: {history}')
history[0] = history[0][len(person_description_ids[0]):]
decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
if all_responses[0]:
decode_all += all_responses[0] + '\n'
else:
decode_all += all_responses[1] + '\n'
print(f'decode_all: {decode_all}')
history_new = tokenizer.encode(decode_all, return_tensors='pt')
print(f'history_new: {history_new}')
decode_all_split = decode_all.split('\n')
print(f'decode_all_split: {decode_all_split}')
response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)]
print(f'response_new: {response_new}')
return response_new, history_new
def load_config(path):
with open(path, 'r') as f:
config = json.load(f)
return argparse.Namespace(**config)
def predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
personality_clustering = PersonalityClustering()
personality_clustering.load('personalized-chat-bot/data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl')
hook = lambda dct: {int(k): v for k, v in dct.items()}
with open('personalized-chat-bot/prompt_paths.json', 'r') as f:
prompt_paths = json.load(f, object_hook=hook)
pm = PersonalityManager(prompt_paths, personality_clustering)
prompt_path, closest_persona = pm.get_prompt(person_description)
print(f'The closest personality is: {closest_persona}')
print('Wait a little longer...')
config = load_config('personalized-chat-bot/scripts/config_176b.json')
model = DistributedBloomForCausalLM.from_pretrained(
config.MODEL_NAME,
pre_seq_len=config.NUM_PREFIX_TOKENS,
tuning_mode=config.TUNING_MODE,
# max_new_tokens=number_of_new_tokens,
).to(config.DEVICE)
generation_config = load_config('personalized-chat-bot/generation_config.json')
generation_config.max_new_tokens=number_of_new_tokens
print(f'generation_config: {generation_config}')
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
tokenizer.padding_side = 'right'
tokenizer.model_max_length = config.MODEL_MAX_LENGTH
chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config)
chatbot.load_prompt('personalized-chat-bot/' + prompt_path)
if history != []:
input_text = tokenizer.decode(history[0]) + '\n' + input_text
print(f'INPUT: {input_text}')
output = chatbot.answer(input_text)
all_text = input_text + '\n' + output
print(f'all_text: {all_text}')
history = tokenizer.encode(all_text, return_tensors='pt')
print(f'history: {history}')
response = tokenizer.decode(history[0]).split("\n")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
print(f'response: {response}')
return response, history
def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
history = model.generate(
input_with_desc_ids,
max_new_tokens=number_of_new_tokens,
pad_token_id=tokenizer.eos_token_id
).tolist()
history[0] = history[0][len(person_description_ids[0]):]
response = tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
return response, history
def predict(
input_text,
history=None,
person_description=None,
number_of_new_tokens=10,
model_name=None,
del_hist=None
):
if history is None or del_hist == 'delete history':
history = []
if model_name == 'DialoGPT-small':
model = model_DialoGPT_small
tokenizer = tokenizer_DialoGPT_small
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'DialoGPT-medium':
model = model_DialoGPT_medium
tokenizer = tokenizer_DialoGPT_medium
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'DialoGPT-large':
model = model_DialoGPT_large
tokenizer = tokenizer_DialoGPT_large
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'bloom-petals':
model = model_bloomd
tokenizer = tokenizer_bloomd
print(f'Lets go history: {history}')
return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'bloom-petals-cluster':
model = model_bloomd
tokenizer = tokenizer_bloomd
print(f'Lets go history: {history}')
return predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
else:
model_name = 'DialoGPT-medium'
model = model_DialoGPT_medium
tokenizer = tokenizer_DialoGPT_medium
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."),
"state",
gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."),
gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10),
gr.Radio(
label='Model name',
choices=[
'DialoGPT-small',
'DialoGPT-medium',
'DialoGPT-large',
'bloom-petals',
'bloom-petals-cluster',
]
),
gr.Radio(
label='Delete history',
value="Don't delete history",
choices=[
'delete history',
"Don't delete history"
]),
],
outputs=[gr.Chatbot(label='History of the dialogue'), "state"],
).launch(),