import gradio as gr from huggingface_hub import Repository, InferenceClient import os import json import re API_TOKEN = os.environ.get("API_TOKEN") API_ENDPOINT = os.environ.get("API_ENDPOINT") KEY = os.environ.get("KEY") SPECIAL_SYMBOLS = ["‹", "›"] DEFAULT_INPUT = f"You: Hi!" DEFAULT_PREOUTPUT = f"AI: " DEFAULT_INSTRUCTION = "You are an helpful chatbot." API_ENDPOINTS = { "Falcon": "tiiuae/falcon-180B-chat", "Llama": "meta-llama/Llama-2-70b-chat-hf" } CHOICES = [] CLIENTS = {} for model_name, model_endpoint in API_ENDPOINTS.items(): CHOICES.append(model_name) CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" }) def format(instruction, history, input, preoutput): sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1] formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history) formatted_input = f"{sy_l}System: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n{sy_l}{preoutput}" return formatted_input def predict(instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed): if (access_key != KEY): print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key) return ("[UNAUTHORIZED ACCESS]", input); instruction = instruction or DEFAULT_INSTRUCTION history = history or [] input = input or "" preoutput = preoutput or "" stops = json.loads(stop_seqs) formatted_input = format(instruction, history, input, preoutput) response = CLIENTS[model].text_generation( formatted_input, temperature = temperature, max_new_tokens = max_tokens, top_p = top_p, top_k = top_k, repetition_penalty = rep_p, stop_sequences = stops, do_sample = True, seed = seed, stream = False, details = False, return_full_text = False ) sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1] pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}" pattern = re.compile(f"{sy_l}(.*?){sy_r}", re.DOTALL) match = pattern.search(pre_result) get_result = match.group(1).strip() if match else "" history = history + [[input, get_result]] print(f"---\nUSER: {input}\nBOT: {get_result}\n---") return (preoutput + response, input, history) def clear_history(): print(">>> HISTORY CLEARED!") return [] def maintain_cloud(): print(">>> SPACE MAINTAINED!") return ("SUCCESS!", "SUCCESS!") with gr.Blocks() as demo: with gr.Row(variant = "panel"): gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B") with gr.Row(): with gr.Column(): history = gr.Chatbot(abel = "History", elem_id = "chatbot") input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2) preoutput = gr.Textbox(label = "Pre-Output", value = DEFAULT_PREOUTPUT, lines = 1) instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4) access_key = gr.Textbox(label = "Access Key", lines = 1) run = gr.Button("▶") clear = gr.Button("🗑️") cloud = gr.Button("☁️") with gr.Column(): model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model") temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" ) top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" ) top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" ) rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" ) max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" ) stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = '["‹", "›"]') seed = gr.Slider( minimum = 0, maximum = 8192, value = 42, step = 1, interactive = True, label = "Seed" ) with gr.Row(): with gr.Column(): output = gr.Textbox(label = "Output", value = "", lines = 50) run.click(predict, inputs = [instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history]) clear.click(clear_history, [], history) cloud.click(maintain_cloud, inputs = [], outputs = [input, output]) demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)