Starling / app.py
Tonic's picture
Update app.py
d1ca06d
model_name = "berkeley-nest/Starling-LM-7B-alpha"
title = """# 👋🏻Welcome to Tonic's 💫🌠Starling 7B"""
description = """You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."""
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import gradio as gr
import json
import os
import shutil
import requests
import accelerate
import bitsandbytes
import gc
device = "cuda" if torch.cuda.is_available() else "cpu"
bos_token_id = 1,
eos_token_id = 32000
pad_token_id = 32001
temperature=0.4
max_new_tokens=240
top_p=0.92
repetition_penalty=1.7
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.eval()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
class StarlingBot:
def __init__(self, assistant_message="I am Starling-7B by Tonic-AI, I am ready to do anything to help my user."):
self.assistant_message = assistant_message
def predict(self, user_message, assistant_message, mode, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
try:
if mode == "Assistant":
conversation = f"GPT4 Correct Assistant: {assistant_message if assistant_message else ''} GPT4 Correct User: {user_message} GPT4 Correct Assistant:"
else: # mode == "Coder"
conversation = f"Code Assistant: {assistant_message if assistant_message else ''} Code User:: {user_message} Code Assistant:"
input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True)
input_ids = input_ids.to(device)
response = model.generate(
input_ids=input_ids,
use_cache=True,
early_stopping=False,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
temperature=temperature,
do_sample=True,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
# response_text = response.split("<|assistant|>\n")[-1]
return response_text
finally:
del input_ids
gc.collect()
torch.cuda.empty_cache()
examples = [
[
"The following dialogue is a conversation between Emmanuel Macron and Elon Musk:", # user_message
"[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.", # assistant_message
0.9, # temperature
450, # max_new_tokens
0.90, # top_p
1.9, # repetition_penalty
]
]
starling_bot = StarlingBot()
def gradio_starling(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty):
response = starling_bot.predict(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty)
return response
with gr.Blocks(theme="ParityError/Anime") as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
assistant_message = gr.Textbox(label="Optional💫🌠Starling Assistant Message", lines=2)
user_message = gr.Textbox(label="Your Message", lines=3)
with gr.Row():
mode = gr.Radio(choices=["Assistant", "Coder"], value="Assistant", label="Mode")
do_sample = gr.Checkbox(label="Advanced", value=True)
with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample):
with gr.Row():
temperature = gr.Slider(label="Temperature", value=0.4, minimum=0.05, maximum=1.0, step=0.05)
max_new_tokens = gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=800, step=1)
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=3.6, minimum=1.0, maximum=4.0, step=0.1)
repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
submit_button = gr.Button("Submit")
output_text = gr.Textbox(label="💫🌠Starling Response")
submit_button.click(
gradio_starling,
inputs=[user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty],
outputs=output_text
)
demo.launch()