import sys import gradio as gr import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer sys.path.insert(0, './petals/') from petals.client.remote_model import DistributedBloomForCausalLM MODEL_NAME = "bigscience/test-bloomd-6b3" # INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"] tokenizer_bloomd_6b3 = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model_bloomd_6b3 = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, # initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32) MODEL_NAME = "bigscience/bloom-petals" tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32) tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") def predict( input_text, history=None, person_description=None, number_of_new_tokens=1000, model_name=None, del_hist=None ): if history is None or del_hist == 'delete history': history = [] if model_name == 'DialoGPT-small': model = model_DialoGPT_small tokenizer = tokenizer_DialoGPT_small elif model_name == 'DialoGPT-medium': model = model_DialoGPT_medium tokenizer = tokenizer_DialoGPT_medium elif model_name == 'DialoGPT-large': model = model_DialoGPT_large tokenizer = tokenizer_DialoGPT_large elif model_name == 'test-bloomd-6b3': model = model_bloomd_6b3 tokenizer = tokenizer_bloomd_6b3 elif model_name == 'bloom-petals': model = model_bloomd tokenizer = tokenizer_bloomd else: model = model_DialoGPT_medium tokenizer = tokenizer_DialoGPT_medium person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt') new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt') bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1) max_token_count = number_of_new_tokens + len(input_with_desc_ids[0]) history = model.generate(input_with_desc_ids, max_length=max_token_count, pad_token_id=tokenizer.eos_token_id).tolist() history[0] = history[0][len(person_description_ids[0]):] response = tokenizer.decode(history[0]).split("<|endoftext|>") response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] return response, history gr.Interface( fn=predict, inputs=[ gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."), "state", gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."), gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10), gr.Radio( label='Model name', choices=[ 'DialoGPT-small', 'DialoGPT-medium', 'DialoGPT-large', 'test-bloomd-6b3', 'bloom-petals', ] ), gr.Radio( label='Delete history', value="Don't delete history", choices=[ 'delete history', "Don't delete history" ]), ], outputs=[gr.Chatbot(label='History of the dialogue'), "state"], ).launch(),