chat-gradio / app.py
gosha6037's picture
Added bigscience/bloom-petals
776e43c
raw history blame
No virus
4.26 kB
import sys
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, './petals/')
from petals.client.remote_model import DistributedBloomForCausalLM
MODEL_NAME = "bigscience/test-bloomd-6b3"
# INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
tokenizer_bloomd_6b3 = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model_bloomd_6b3 = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME,
# initial_peers=INITIAL_PEERS,
low_cpu_mem_usage=True, torch_dtype=torch.float32)
MODEL_NAME = "bigscience/bloom-petals"
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME,
low_cpu_mem_usage=True, torch_dtype=torch.float32)
tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
def predict(
input_text,
history=None,
person_description=None,
number_of_new_tokens=1000,
model_name=None,
del_hist=None
):
if history is None or del_hist == 'delete history':
history = []
if model_name == 'DialoGPT-small':
model = model_DialoGPT_small
tokenizer = tokenizer_DialoGPT_small
elif model_name == 'DialoGPT-medium':
model = model_DialoGPT_medium
tokenizer = tokenizer_DialoGPT_medium
elif model_name == 'DialoGPT-large':
model = model_DialoGPT_large
tokenizer = tokenizer_DialoGPT_large
elif model_name == 'test-bloomd-6b3':
model = model_bloomd_6b3
tokenizer = tokenizer_bloomd_6b3
elif model_name == 'bloom-petals':
model = model_bloomd
tokenizer = tokenizer_bloomd
else:
model = model_DialoGPT_medium
tokenizer = tokenizer_DialoGPT_medium
person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
max_token_count = number_of_new_tokens + len(input_with_desc_ids[0])
history = model.generate(input_with_desc_ids, max_length=max_token_count,
pad_token_id=tokenizer.eos_token_id).tolist()
history[0] = history[0][len(person_description_ids[0]):]
response = tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
return response, history
gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."),
"state",
gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."),
gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10),
gr.Radio(
label='Model name',
choices=[
'DialoGPT-small',
'DialoGPT-medium',
'DialoGPT-large',
'test-bloomd-6b3',
'bloom-petals',
]
),
gr.Radio(
label='Delete history',
value="Don't delete history",
choices=[
'delete history',
"Don't delete history"
]),
],
outputs=[gr.Chatbot(label='History of the dialogue'), "state"],
).launch(),