|
import gradio as gr |
|
from huggingface_hub import Repository, InferenceClient |
|
import os |
|
import json |
|
import re |
|
|
|
API_TOKEN = os.environ.get("API_TOKEN") |
|
API_ENDPOINT = os.environ.get("API_ENDPOINT") |
|
|
|
KEY = os.environ.get("KEY") |
|
|
|
SPECIAL_SYMBOLS = ["⠀", "⠀"] |
|
|
|
DEFAULT_INPUT = f"User: Hi!" |
|
DEFAULT_PREOUTPUT = f"Statical: " |
|
DEFAULT_INSTRUCTION = "Statical is a helpful chatbot who is communicating with people." |
|
|
|
DEFAULT_STOPS = '["⠀", "⠀"]' |
|
|
|
API_ENDPOINTS = { |
|
"Falcon": "tiiuae/falcon-180B-chat", |
|
"Llama": "meta-llama/Llama-2-70b-chat-hf", |
|
"Mistral": "mistralai/Mistral-7B-v0.1", |
|
"Mistral-2": "mistralai/Mistral-7B-Instruct-v0.1", |
|
} |
|
|
|
CHOICES = [] |
|
CLIENTS = {} |
|
|
|
for model_name, model_endpoint in API_ENDPOINTS.items(): |
|
CHOICES.append(model_name) |
|
CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" }) |
|
|
|
def format(instruction, history, input, preoutput): |
|
sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1] |
|
formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history) |
|
formatted_input = f"{sy_l}INSTRUCTIONS: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n{sy_l}" |
|
return f"{formatted_input}{preoutput}", formatted_input |
|
|
|
def predict(access_key, instruction, history, input, preoutput, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed): |
|
|
|
if (access_key != KEY): |
|
print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key) |
|
return ("[UNAUTHORIZED ACCESS]", input, []); |
|
|
|
instruction = instruction or DEFAULT_INSTRUCTION |
|
history = history or [] |
|
input = input or "" |
|
preoutput = preoutput or "" |
|
stop_seqs = stop_seqs or DEFAULT_STOPS |
|
|
|
stops = json.loads(stop_seqs) |
|
|
|
formatted_input, formatted_input_base = format(instruction, history, input, preoutput) |
|
print(seed) |
|
print(formatted_input) |
|
response = CLIENTS[model].text_generation( |
|
formatted_input, |
|
temperature = temperature, |
|
max_new_tokens = max_tokens, |
|
top_p = top_p, |
|
top_k = top_k, |
|
repetition_penalty = rep_p, |
|
stop_sequences = stops, |
|
do_sample = True, |
|
seed = seed, |
|
stream = False, |
|
details = False, |
|
return_full_text = False |
|
) |
|
|
|
sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1] |
|
result = preoutput + response |
|
|
|
for stop in stops: |
|
result = result.split(stop, 1)[0] |
|
for symbol in stops: |
|
result = result.replace(symbol, '') |
|
|
|
history = history + [[input, result]] |
|
|
|
print(f"---\nUSER: {input}\nBOT: {result}\n---") |
|
|
|
return (result, input, history) |
|
|
|
def clear_history(): |
|
print(">>> HISTORY CLEARED!") |
|
return [] |
|
|
|
def maintain_cloud(): |
|
print(">>> SPACE MAINTAINED!") |
|
return ("SUCCESS!", "SUCCESS!") |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(variant = "panel"): |
|
gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
history = gr.Chatbot(abel = "History", elem_id = "chatbot") |
|
input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2) |
|
preoutput = gr.Textbox(label = "Pre-Output", value = DEFAULT_PREOUTPUT, lines = 1) |
|
instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4) |
|
access_key = gr.Textbox(label = "Access Key", lines = 1) |
|
run = gr.Button("▶") |
|
clear = gr.Button("🗑️") |
|
cloud = gr.Button("☁️") |
|
|
|
with gr.Column(): |
|
model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model") |
|
temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" ) |
|
top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" ) |
|
top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" ) |
|
rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" ) |
|
max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" ) |
|
stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = DEFAULT_STOPS ) |
|
seed = gr.Slider( minimum = 0, maximum = 9007199254740991, value = 42, step = 1, interactive = True, label = "Seed" ) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
output = gr.Textbox(label = "Output", value = "", lines = 50) |
|
|
|
run.click(predict, inputs = [access_key, instruction, history, input, preoutput, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history]) |
|
clear.click(clear_history, [], history) |
|
cloud.click(maintain_cloud, inputs = [], outputs = [input, output]) |
|
|
|
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True) |