import sys import json import argparse import gradio as gr import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer sys.path.insert(0, './petals/') sys.path.insert(0, './personalized-chat-bot/') from petals.client.remote_model import DistributedBloomForCausalLM from personalized_chat_bot import PersonalizedChatBot, PersonalityManager from models.personality_clustering import PersonalityClustering MODEL_NAME = "bigscience/bloom-petals" tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True) tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") def predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens): new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt') person_description_ids = tokenizer.encode(person_description + '\n', return_tensors='pt') print('Started predict_common_bloom') print(f'history: {history}') if history != []: bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) else: bot_input_ids = new_user_input_ids print(f'bot_input_ids: {bot_input_ids}') input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1) history = model.generate( input_with_desc_ids, max_new_tokens=number_of_new_tokens, pad_token_id=tokenizer.eos_token_id ).tolist() print(f'history: {history}') history[0] = history[0][len(person_description_ids[0]):] decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])]) all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n') if all_responses[0]: decode_all += all_responses[0] + '\n' else: decode_all += all_responses[1] + '\n' print(f'decode_all: {decode_all}') history_new = tokenizer.encode(decode_all, return_tensors='pt') print(f'history_new: {history_new}') decode_all_split = decode_all.split('\n') print(f'decode_all_split: {decode_all_split}') response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)] print(f'response_new: {response_new}') return response_new, history_new def load_config(path): with open(path, 'r') as f: config = json.load(f) return argparse.Namespace(**config) def predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens): personality_clustering = PersonalityClustering() personality_clustering.load('personalized-chat-bot/data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl') hook = lambda dct: {int(k): v for k, v in dct.items()} with open('personalized-chat-bot/prompt_paths.json', 'r') as f: prompt_paths = json.load(f, object_hook=hook) pm = PersonalityManager(prompt_paths, personality_clustering) prompt_path, closest_persona = pm.get_prompt(person_description) print(f'The closest personality is: {closest_persona}') print('Wait a little longer...') config = load_config('personalized-chat-bot/scripts/config_176b.json') model = DistributedBloomForCausalLM.from_pretrained( config.MODEL_NAME, pre_seq_len=config.NUM_PREFIX_TOKENS, tuning_mode=config.TUNING_MODE, # max_new_tokens=number_of_new_tokens, ).to(config.DEVICE) generation_config = load_config('personalized-chat-bot/generation_config.json') generation_config.max_new_tokens=number_of_new_tokens print(f'generation_config: {generation_config}') tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME) tokenizer.padding_side = 'right' tokenizer.model_max_length = config.MODEL_MAX_LENGTH chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config) chatbot.load_prompt('personalized-chat-bot/' + prompt_path) if history != []: input_text = tokenizer.decode(history[0]) + '\n' + input_text print(f'INPUT: {input_text}') output = chatbot.answer(input_text) all_text = input_text + '\n' + output print(f'all_text: {all_text}') history = tokenizer.encode(all_text, return_tensors='pt') print(f'history: {history}') response = tokenizer.decode(history[0]).split("\n") response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] print(f'response: {response}') return response, history def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens): person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt') new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt') bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1) history = model.generate( input_with_desc_ids, max_new_tokens=number_of_new_tokens, pad_token_id=tokenizer.eos_token_id ).tolist() history[0] = history[0][len(person_description_ids[0]):] response = tokenizer.decode(history[0]).split("<|endoftext|>") response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] return response, history def predict( input_text, history=None, person_description=None, number_of_new_tokens=10, model_name=None, del_hist=None ): if history is None or del_hist == 'delete history': history = [] if model_name == 'DialoGPT-small': model = model_DialoGPT_small tokenizer = tokenizer_DialoGPT_small return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens) elif model_name == 'DialoGPT-medium': model = model_DialoGPT_medium tokenizer = tokenizer_DialoGPT_medium return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens) elif model_name == 'DialoGPT-large': model = model_DialoGPT_large tokenizer = tokenizer_DialoGPT_large return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens) elif model_name == 'bloom-petals': model = model_bloomd tokenizer = tokenizer_bloomd print(f'Lets go history: {history}') return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens) elif model_name == 'bloom-petals-cluster': model = model_bloomd tokenizer = tokenizer_bloomd print(f'Lets go history: {history}') return predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens) else: model_name = 'DialoGPT-medium' model = model_DialoGPT_medium tokenizer = tokenizer_DialoGPT_medium return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens) gr.Interface( fn=predict, inputs=[ gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."), "state", gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."), gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10), gr.Radio( label='Model name', choices=[ 'DialoGPT-small', 'DialoGPT-medium', 'DialoGPT-large', 'bloom-petals', 'bloom-petals-cluster', ] ), gr.Radio( label='Delete history', value="Don't delete history", choices=[ 'delete history', "Don't delete history" ]), ], outputs=[gr.Chatbot(label='History of the dialogue'), "state"], ).launch(),