BLOOMZ_Compare / app.py
gururise's picture
updates
8541fdf
raw
history blame contribute delete
No virus
6.63 kB
import gradio as gr
import codecs
from datetime import datetime
import gc
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
models = {"model":None,"model_name":None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
print (DEVICE)
def to_md(text):
return text.replace("\n", "<br />")
def infer(
prompt,
model_idx = 0,
max_new_tokens=10,
temperature=0.1,
top_p=1.0,
repetition_penalty = 1.0,
stop="\n",
num_completions=1,
seed=42,
):
global output
global models
print("Loading Models\n")
model_name = MODEL_NAMES[model_idx]
if (models["model_name"] == None or models["model_name"] != model_name):
models = {"model":None,"model_name":None}
gc.collect()
if (DEVICE == "cuda"):
torch.cuda.empty_cache()
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE, request_timeout=300)
model = model.to(DEVICE)
models["model"] = tokenizer, model
models["model_name"] = model_name
output[model_name] = ""
max_new_tokens = int(max_new_tokens)
temperature = float(temperature)
top_p = float(top_p)
stop = [x.strip(' ') for x in stop.split(',')]
repetition_penalty = float(repetition_penalty)
seed = seed
assert 1 <= max_new_tokens <= 384
assert 1 <= num_completions <= 5
assert 0.0 <= temperature <= 1.0
assert 0.0 <= top_p <= 1.0
assert 0.9 <= repetition_penalty <= 3.0
if temperature == 0.0:
temperature = 0.01
if prompt == "":
prompt = " "
print(f"START -> ({datetime.now()})\n")
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
flag = False
token_cnt = 0
with models["model"][1].inference_session(max_length=512) as sess:
print(f"Encode Input Prompt")
output[model_name] = ""
inputs = models["model"][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
n_input_tokens = inputs.shape[1]
done = False
print(f"Start Inference ({sess})")
while not done:
outputs = models["model"][1].generate(
inputs,
max_new_tokens=1,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
session=sess
)
output[model_name] += models["model"][0].decode(outputs[0, n_input_tokens:])
token_cnt += 1
print("\n["+ str(model_name) + "]" + output[model_name], end="", flush=True)
yield output[model_name]
for stop_word in stop:
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
if stop_word != '' and stop_word in output[model_name]:
print(f"\nDONE (stop)")
done = True
if flag or (token_cnt >= max_new_tokens):
print(f"\nDONE (max tokens)")
done = True
inputs = None # Prefix is passed only for the 1st token of the bot's response
n_input_tokens = 0
print(f"\nEnd")
yield output[model_name]
examples = [
[
# Question Answering
'''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',"BLOOMZ" , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
[
# Chatbot 1
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
Alex: Good morning, Fritz!
Fritz:''',"BLOOM" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
[
# Chatbot 1
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
Alex: Good morning, Fritz!
Fritz:''',"BLOOMZ" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
[
# Expert Answers
'''Expert Questions & Helpful Answers
Ask Research Experts
Question:
Are humans good or bad?
Full Answer:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"],
[
# G
'''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"
]
]
iface = gr.Interface(
fn=infer,
allow_flagging="never",
inputs=[
gr.Textbox(lines=20,label="Input Prompt", max_lines=10), # prompt
gr.Radio(["BLOOM","BLOOMZ"], value="BLOOM", type="index", label="Choose 176 billion parameter Model"),
gr.Slider(1, 256, value=15), # max_tokens
gr.Slider(0.0, 1.0, value=0.2), # temperature
gr.Slider(0.0, 1.0, value=0.9), # top_p
gr.Slider(0.9, 3.0, value=1.0), # repetition penalty
gr.Textbox(lines=1, value="\\n\\n,</s>") # stop
],
outputs=gr.Textbox(lines=20, label="Generated Output:"),
examples=examples,
cache_examples=False,
title="BLOOM vs BLOOMZ",
description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the Petals network. <b>WARNING:</b> Initial inference may take a long time. Keep the input prompt to a minimum size to speed things up.<p>
<p>Please consider contributing your unused GPU cycles to the <a href='https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity'>petals swarm</a> to help speed up inference. Check the <a href='http://health.petals.ml/'>Health</a> of the Petals Swarm.</p>
<p>Big thanks to <a href='https://www.rftcapital.com/'>RFT Capital</a> for providing initial compute resources.</p>'''
)
iface.queue(concurrency_count=2)
iface.launch()