STC-LLM / app.py
Staticaliza
Update app.py
120bd05
raw
history blame
No virus
3.91 kB
import gradio as gr
from huggingface_hub import Repository, InferenceClient
import os
import json
API_TOKEN = os.environ.get("API_TOKEN")
API_ENDPOINT = os.environ.get("API_ENDPOINT")
KEY = os.environ.get("KEY")
API_ENDPOINTS = {
"Falcon": "tiiuae/falcon-180B-chat",
"Llama": "meta-llama/Llama-2-70b-chat-hf"
}
CHOICES = []
CLIENTS = {}
for model_name, model_endpoint in API_ENDPOINTS.items():
CHOICES.append(model_name)
CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
def format(input, chat_history, : str) -> str:
instructions = instructions.strip(" ").strip("\n")
prompt = instructions
for turn in chat_history:
user_message, bot_message = turn
prompt = f"{prompt}\n{USER_NAME}: {user_message}\n{BOT_NAME}: {bot_message}"
prompt = f"{prompt}\n{USER_NAME}: {message}\n{BOT_NAME}:"
return prompt
def predict(instruction, history, input, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
if (access_key != KEY):
print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
return ("[UNAUTHORIZED ACCESS]", input);
stops = json.loads(stop_seqs)
response = CLIENTS[model].text_generation(
input,
temperature = temperature,
max_new_tokens = max_tokens,
top_p = top_p,
top_k = top_k,
repetition_penalty = rep_p,
stop_sequences = stops,
do_sample = True,
seed = seed,
stream = False,
details = False,
return_full_text = False
)
print(f"---\nUSER: {input}\nBOT: {response}\n---")
return (response, input)
def maintain_cloud():
print(">>> SPACE MAINTAINED!")
return ("SUCCESS!", "SUCCESS!")
with gr.Blocks() as demo:
with gr.Row(variant = "panel"):
gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B")
with gr.Row():
with gr.Column():
history = gr.Chatbot(elem_id = "chatbot")
input = gr.Textbox(label = "Input", lines = 2)
instruction = gr.Textbox(label = "Instruction", lines = 4)
access_key = gr.Textbox(label = "Access Key", lines = 1)
run = gr.Button("▶")
cloud = gr.Button("☁️")
with gr.Column():
model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = '["‹", "›"]')
seed = gr.Slider( minimum = 0, maximum = 8192, value = 42, step = 1, interactive = True, label = "Seed" )
with gr.Row():
with gr.Column():
output = gr.Textbox(label = "Output", value = "", lines = 50)
run.click(predict, inputs = [instruction, history, input, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input])
cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)