|
model_name = "berkeley-nest/Starling-LM-7B-alpha" |
|
|
|
title = """# 👋🏻Welcome to Tonic's 💫🌠Starling 7B""" |
|
description = """You can use [💫🌠Starling 7B](https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha) or duplicate it for local use or on Hugging Face! [Join me on Discord to build together](https://discord.gg/VqTxc76K3u).""" |
|
|
|
import transformers |
|
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM |
|
import torch |
|
import gradio as gr |
|
import json |
|
import os |
|
import shutil |
|
import requests |
|
import accelerate |
|
import bitsandbytes |
|
import gc |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
bos_token_id = 1, |
|
eos_token_id = 32000 |
|
pad_token_id = 32001 |
|
temperature=0.4 |
|
max_new_tokens=240 |
|
top_p=0.92 |
|
repetition_penalty=1.7 |
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") |
|
model.eval() |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50' |
|
|
|
class StarlingBot: |
|
def __init__(self, assistant_message="I am Starling-7B by Tonic-AI, I am ready to do anything to help my user."): |
|
self.assistant_message = assistant_message |
|
|
|
def predict(self, user_message, assistant_message, mode, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9): |
|
try: |
|
if mode == "Assistant": |
|
conversation = f"GPT4 Correct Assistant: {assistant_message if assistant_message else ''} GPT4 Correct User: {user_message} GPT4 Correct Assistant:" |
|
else: |
|
conversation = f"Code Assistant: {assistant_message if assistant_message else ''} Code User:: {user_message} Code Assistant:" |
|
input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True) |
|
input_ids = input_ids.to(device) |
|
response = model.generate( |
|
input_ids=input_ids, |
|
use_cache=True, |
|
early_stopping=False, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
pad_token_id=pad_token_id, |
|
temperature=temperature, |
|
do_sample=True, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty |
|
) |
|
response_text = tokenizer.decode(response[0], skip_special_tokens=True) |
|
|
|
return response_text |
|
finally: |
|
del input_ids |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
examples = [ |
|
[ |
|
"The following dialogue is a conversation between Emmanuel Macron and Elon Musk:", |
|
"[Emmanuel Macron]: Hello Mr. Musk. Thank you for receiving me today.", |
|
0.9, |
|
450, |
|
0.90, |
|
1.9, |
|
] |
|
] |
|
|
|
starling_bot = StarlingBot() |
|
|
|
def gradio_starling(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty): |
|
response = starling_bot.predict(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty) |
|
return response |
|
|
|
with gr.Blocks(theme="ParityError/Anime") as demo: |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
with gr.Row(): |
|
assistant_message = gr.Textbox(label="Optional💫🌠Starling Assistant Message", lines=2) |
|
user_message = gr.Textbox(label="Your Message", lines=3) |
|
with gr.Row(): |
|
mode = gr.Radio(choices=["Assistant", "Coder"], value="Assistant", label="Mode") |
|
do_sample = gr.Checkbox(label="Advanced", value=True) |
|
with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample): |
|
with gr.Row(): |
|
temperature = gr.Slider(label="Temperature", value=0.4, minimum=0.05, maximum=1.0, step=0.05) |
|
max_new_tokens = gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=800, step=1) |
|
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=3.6, minimum=1.0, maximum=4.0, step=0.1) |
|
repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05) |
|
|
|
submit_button = gr.Button("Submit") |
|
output_text = gr.Textbox(label="💫🌠Starling Response") |
|
|
|
submit_button.click( |
|
gradio_starling, |
|
inputs=[user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty], |
|
outputs=output_text |
|
) |
|
|
|
demo.launch() |