chat-gradio / app.py
gosha6037's picture
Added description for bloom
a4f8b32
import sys
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, './petals/')
from petals.client.remote_model import DistributedBloomForCausalLM
MODEL_NAME = "bigscience/bloom-petals"
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True)
tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
def predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
person_description_ids = tokenizer.encode(person_description + '\n', return_tensors='pt')
print('Started predict_common_bloom')
print(f'history: {history}')
if history != []:
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
print(f'bot_input_ids: {bot_input_ids}')
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
history = model.generate(
input_with_desc_ids,
max_new_tokens=number_of_new_tokens,
pad_token_id=tokenizer.eos_token_id
).tolist()
print(f'history: {history}')
history[0] = history[0][len(person_description_ids[0]):]
decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
if all_responses[0]:
decode_all += all_responses[0] + '\n'
else:
decode_all += all_responses[1] + '\n'
print(f'decode_all: {decode_all}')
history_new = tokenizer.encode(decode_all, return_tensors='pt')
print(f'history_new: {history_new}')
decode_all_split = decode_all.split('\n')
print(f'decode_all_split: {decode_all_split}')
response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)]
print(f'response_new: {response_new}')
return response_new, history_new
def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
history = model.generate(
input_with_desc_ids,
max_new_tokens=number_of_new_tokens,
pad_token_id=tokenizer.eos_token_id
).tolist()
history[0] = history[0][len(person_description_ids[0]):]
response = tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
return response, history
def predict(
input_text,
history=None,
person_description=None,
number_of_new_tokens=10,
model_name=None,
del_hist=None
):
if history is None or del_hist == 'delete history':
history = []
if model_name == 'DialoGPT-small':
model = model_DialoGPT_small
tokenizer = tokenizer_DialoGPT_small
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'DialoGPT-medium':
model = model_DialoGPT_medium
tokenizer = tokenizer_DialoGPT_medium
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'DialoGPT-large':
model = model_DialoGPT_large
tokenizer = tokenizer_DialoGPT_large
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
elif model_name == 'bloom-petals':
model = model_bloomd
tokenizer = tokenizer_bloomd
print(f'Lets go history: {history}')
return predict_common_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
else:
model_name = 'DialoGPT-medium'
model = model_DialoGPT_medium
tokenizer = tokenizer_DialoGPT_medium
return predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens)
gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."),
"state",
gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."),
gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10),
gr.Radio(
label='Model name',
choices=[
'DialoGPT-small',
'DialoGPT-medium',
'DialoGPT-large',
'bloom-petals',
'bloom-petals-cluster',
]
),
gr.Radio(
label='Delete history',
value="Don't delete history",
choices=[
'delete history',
"Don't delete history"
]),
],
outputs=[gr.Chatbot(label='History of the dialogue'), "state"],
).launch(),